diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/OpenaiChat.py | 41 |
1 files changed, 24 insertions, 17 deletions
diff --git a/g4f/Provider/OpenaiChat.py b/g4f/Provider/OpenaiChat.py index c023c898..9ca0cd58 100644 --- a/g4f/Provider/OpenaiChat.py +++ b/g4f/Provider/OpenaiChat.py @@ -4,8 +4,11 @@ try: except ImportError: has_module = False -from .base_provider import AsyncGeneratorProvider, get_cookies -from ..typing import AsyncGenerator +from .base_provider import AsyncGeneratorProvider, get_cookies, format_prompt +from ..typing import AsyncGenerator +from httpx import AsyncClient +import json + class OpenaiChat(AsyncGeneratorProvider): url = "https://chat.openai.com" @@ -14,6 +17,7 @@ class OpenaiChat(AsyncGeneratorProvider): supports_gpt_35_turbo = True supports_gpt_4 = True supports_stream = True + _access_token = None @classmethod async def create_async_generator( @@ -21,9 +25,9 @@ class OpenaiChat(AsyncGeneratorProvider): model: str, messages: list[dict[str, str]], proxy: str = None, - access_token: str = None, + access_token: str = _access_token, cookies: dict = None, - **kwargs + **kwargs: dict ) -> AsyncGenerator: config = {"access_token": access_token, "model": model} @@ -37,21 +41,12 @@ class OpenaiChat(AsyncGeneratorProvider): ) if not access_token: - cookies = cookies if cookies else get_cookies("chat.openai.com") - response = await bot.session.get("https://chat.openai.com/api/auth/session", cookies=cookies) - access_token = response.json()["accessToken"] - bot.set_access_token(access_token) - - if len(messages) > 1: - formatted = "\n".join( - ["%s: %s" % ((message["role"]).capitalize(), message["content"]) for message in messages] - ) - prompt = f"{formatted}\nAssistant:" - else: - prompt = messages.pop()["content"] + cookies = cookies if cookies else get_cookies("chat.openai.com") + cls._access_token = await get_access_token(bot.session, cookies) + bot.set_access_token(cls._access_token) returned = None - async for message in bot.ask(prompt): + async for message in bot.ask(format_prompt(messages)): message = message["message"] if returned: if message.startswith(returned): @@ -61,6 +56,9 @@ class OpenaiChat(AsyncGeneratorProvider): else: yield message returned = message + + await bot.delete_conversation(bot.conversation_id) + @classmethod @property @@ -73,3 +71,12 @@ class OpenaiChat(AsyncGeneratorProvider): ] param = ", ".join([": ".join(p) for p in params]) return f"g4f.provider.{cls.__name__} supports: ({param})" + + +async def get_access_token(session: AsyncClient, cookies: dict): + response = await session.get("https://chat.openai.com/api/auth/session", cookies=cookies) + response.raise_for_status() + try: + return response.json()["accessToken"] + except json.decoder.JSONDecodeError: + raise RuntimeError(f"Response: {response.text}")
\ No newline at end of file |