diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/Pi.py | 21 |
1 files changed, 9 insertions, 12 deletions
diff --git a/g4f/Provider/Pi.py b/g4f/Provider/Pi.py index 2f7dc436..5a1e9f0e 100644 --- a/g4f/Provider/Pi.py +++ b/g4f/Provider/Pi.py @@ -4,12 +4,13 @@ import json from ..typing import CreateResult, Messages from .base_provider import AbstractProvider, format_prompt -from ..requests import Session, get_session_from_browser +from ..requests import Session, get_session_from_browser, raise_for_status class Pi(AbstractProvider): url = "https://pi.ai/talk" working = True supports_stream = True + _session = None @classmethod def create_completion( @@ -17,20 +18,19 @@ class Pi(AbstractProvider): model: str, messages: Messages, stream: bool, - session: Session = None, proxy: str = None, timeout: int = 180, conversation_id: str = None, **kwargs ) -> CreateResult: - if not session: - session = get_session_from_browser(url=cls.url, proxy=proxy, timeout=timeout) + if cls._session is None: + cls._session = get_session_from_browser(url=cls.url, proxy=proxy, timeout=timeout) if not conversation_id: - conversation_id = cls.start_conversation(session) + conversation_id = cls.start_conversation(cls._session) prompt = format_prompt(messages) else: prompt = messages[-1]["content"] - answer = cls.ask(session, prompt, conversation_id) + answer = cls.ask(cls._session, prompt, conversation_id) for line in answer: if "text" in line: yield line["text"] @@ -41,8 +41,7 @@ class Pi(AbstractProvider): 'accept': 'application/json', 'x-api-version': '3' }) - if 'Just a moment' in response.text: - raise RuntimeError('Error: Cloudflare detected') + raise_for_status(response) return response.json()['conversations'][0]['sid'] def get_chat_history(session: Session, conversation_id: str): @@ -50,8 +49,7 @@ class Pi(AbstractProvider): 'conversation': conversation_id, } response = session.get('https://pi.ai/api/chat/history', params=params) - if 'Just a moment' in response.text: - raise RuntimeError('Error: Cloudflare detected') + raise_for_status(response) return response.json() def ask(session: Session, prompt: str, conversation_id: str): @@ -61,9 +59,8 @@ class Pi(AbstractProvider): 'mode': 'BASE', } response = session.post('https://pi.ai/api/chat', json=json_data, stream=True) + raise_for_status(response) for line in response.iter_lines(): - if b'Just a moment' in line: - raise RuntimeError('Error: Cloudflare detected') if line.startswith(b'data: {"text":'): yield json.loads(line.split(b'data: ')[1]) elif line.startswith(b'data: {"title":'): |