summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/Pi.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/Pi.py')
-rw-r--r--g4f/Provider/Pi.py80
1 files changed, 42 insertions, 38 deletions
diff --git a/g4f/Provider/Pi.py b/g4f/Provider/Pi.py
index 68a7357f..6aabe7b1 100644
--- a/g4f/Provider/Pi.py
+++ b/g4f/Provider/Pi.py
@@ -2,20 +2,21 @@ from __future__ import annotations
import json
-from ..typing import CreateResult, Messages
-from .base_provider import AbstractProvider, format_prompt
-from ..requests import Session, get_session_from_browser, raise_for_status
+from ..typing import AsyncResult, Messages, Cookies
+from .base_provider import AsyncGeneratorProvider, format_prompt
+from ..requests import StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
-class Pi(AbstractProvider):
+class Pi(AsyncGeneratorProvider):
url = "https://pi.ai/talk"
working = True
supports_stream = True
- _session = None
default_model = "pi"
models = [default_model]
+ _headers: dict = None
+ _cookies: Cookies = {}
@classmethod
- def create_completion(
+ async def create_async_generator(
cls,
model: str,
messages: Messages,
@@ -23,49 +24,52 @@ class Pi(AbstractProvider):
proxy: str = None,
timeout: int = 180,
conversation_id: str = None,
- webdriver: WebDriver = None,
**kwargs
- ) -> CreateResult:
- if cls._session is None:
- cls._session = get_session_from_browser(url=cls.url, proxy=proxy, timeout=timeout)
- if not conversation_id:
- conversation_id = cls.start_conversation(cls._session)
- prompt = format_prompt(messages)
- else:
- prompt = messages[-1]["content"]
- answer = cls.ask(cls._session, prompt, conversation_id)
- for line in answer:
- if "text" in line:
- yield line["text"]
-
+ ) -> AsyncResult:
+ if cls._headers is None:
+ args = await get_args_from_nodriver(cls.url, proxy=proxy, timeout=timeout)
+ cls._cookies = args.get("cookies", {})
+ cls._headers = args.get("headers")
+ async with StreamSession(headers=cls._headers, cookies=cls._cookies, proxy=proxy) as session:
+ if not conversation_id:
+ conversation_id = await cls.start_conversation(session)
+ prompt = format_prompt(messages)
+ else:
+ prompt = messages[-1]["content"]
+ answer = cls.ask(session, prompt, conversation_id)
+ async for line in answer:
+ if "text" in line:
+ yield line["text"]
+
@classmethod
- def start_conversation(cls, session: Session) -> str:
- response = session.post('https://pi.ai/api/chat/start', data="{}", headers={
+ async def start_conversation(cls, session: StreamSession) -> str:
+ async with session.post('https://pi.ai/api/chat/start', data="{}", headers={
'accept': 'application/json',
'x-api-version': '3'
- })
- raise_for_status(response)
- return response.json()['conversations'][0]['sid']
+ }) as response:
+ await raise_for_status(response)
+ return (await response.json())['conversations'][0]['sid']
- def get_chat_history(session: Session, conversation_id: str):
+ async def get_chat_history(session: StreamSession, conversation_id: str):
params = {
'conversation': conversation_id,
}
- response = session.get('https://pi.ai/api/chat/history', params=params)
- raise_for_status(response)
- return response.json()
+ async with session.get('https://pi.ai/api/chat/history', params=params) as response:
+ await raise_for_status(response)
+ return await response.json()
- def ask(session: Session, prompt: str, conversation_id: str):
+ @classmethod
+ async def ask(cls, session: StreamSession, prompt: str, conversation_id: str):
json_data = {
'text': prompt,
'conversation': conversation_id,
'mode': 'BASE',
}
- response = session.post('https://pi.ai/api/chat', json=json_data, stream=True)
- raise_for_status(response)
- for line in response.iter_lines():
- if line.startswith(b'data: {"text":'):
- yield json.loads(line.split(b'data: ')[1])
- elif line.startswith(b'data: {"title":'):
- yield json.loads(line.split(b'data: ')[1])
-
+ async with session.post('https://pi.ai/api/chat', json=json_data) as response:
+ await raise_for_status(response)
+ cls._cookies = merge_cookies(cls._cookies, response)
+ async for line in response.iter_lines():
+ if line.startswith(b'data: {"text":'):
+ yield json.loads(line.split(b'data: ')[1])
+ elif line.startswith(b'data: {"title":'):
+ yield json.loads(line.split(b'data: ')[1])