From ac86e576d20b42a983c9c81dddd067d6e4d51cf4 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Wed, 28 Feb 2024 07:36:43 +0100 Subject: Add websocket support in OpenaiChat --- g4f/Provider/needs_auth/OpenaiChat.py | 228 ++++++++++++++++++++-------------- 1 file changed, 133 insertions(+), 95 deletions(-) (limited to 'g4f/Provider') diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 8154bc44..0fa433a4 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -4,6 +4,8 @@ import asyncio import uuid import json import os +import base64 +from aiohttp import ClientWebSocketResponse try: from py_arkose_generator.arkose import get_values_for_request @@ -22,7 +24,7 @@ except ImportError: from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..helper import get_cookies from ...webdriver import get_browser -from ...typing import AsyncResult, Messages, Cookies, ImageType, Union +from ...typing import AsyncResult, Messages, Cookies, ImageType, Union, AsyncIterator from ...requests import get_args_from_browser from ...requests.aiohttp import StreamSession from ...image import to_image, to_bytes, ImageResponse, ImageRequest @@ -38,10 +40,14 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): supports_gpt_35_turbo = True supports_gpt_4 = True supports_message_history = True + supports_system_message = True default_model = None models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"] - model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo"} - _args: dict = None + model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo", "": "gpt-3.5-turbo"} + _api_key: str = None + _headers: dict = None + _cookies: Cookies = None + _last_message: int = 0 @classmethod async def create( @@ -299,6 +305,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): conversation_id: str = None, parent_id: str = None, image: ImageType = None, + image_name: str = None, response_fields: bool = False, **kwargs ) -> AsyncResult: @@ -332,67 +339,64 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): if not parent_id: parent_id = str(uuid.uuid4()) - # Read api_key from args + # Read api_key from arguments api_key = kwargs["access_token"] if "access_token" in kwargs else api_key - # If no cached args - if cls._args is None: - if api_key is None: - # Read api_key from cookies - cookies = get_cookies("chat.openai.com", False) if cookies is None else cookies - api_key = cookies["access_token"] if "access_token" in cookies else api_key - cls._args = cls._create_request_args(cookies) - else: - # Read api_key from cache - api_key = cls._args["headers"]["Authorization"] if "Authorization" in cls._args["headers"] else None async with StreamSession( proxies={"https": proxy}, impersonate="chrome", timeout=timeout ) as session: - # Read api_key from session cookies + # Read api_key and cookies from cache / browser config + if cls._headers is None: + if api_key is None: + # Read api_key from cookies + cookies = get_cookies("chat.openai.com", False) if cookies is None else cookies + api_key = cookies["access_token"] if "access_token" in cookies else api_key + cls._create_request_args(cookies) + else: + api_key = cls._api_key if api_key is None else api_key + # Read api_key with session cookies if api_key is None and cookies: - api_key = await cls.fetch_access_token(session, cls._args["headers"]) + api_key = await cls.fetch_access_token(session, cls._headers) # Load default model - if cls.default_model is None: + if cls.default_model is None and api_key is not None: try: - if cookies and not model and api_key is not None: - cls._args["headers"]["Authorization"] = api_key - cls.default_model = cls.get_model(await cls.get_default_model(session, cls._args["headers"])) - elif api_key: - cls.default_model = cls.get_model(model or "gpt-3.5-turbo") + if not model: + cls._set_api_key(api_key) + cls.default_model = cls.get_model(await cls.get_default_model(session, cls._headers)) + else: + cls.default_model = cls.get_model(model) except Exception as e: if debug.logging: print("OpenaiChat: Load default_model failed") print(f"{e.__class__.__name__}: {e}") - # Browse api_key and update default model + # Browse api_key and default model if api_key is None or cls.default_model is None: login_url = os.environ.get("G4F_LOGIN_URL") if login_url: yield f"Please login: [ChatGPT]({login_url})\n\n" try: - cls._args = cls.browse_access_token(proxy) + cls.browse_access_token(proxy) except MissingRequirementsError: raise MissingAuthError(f'Missing "access_token". Add a "api_key" please') - cls.default_model = cls.get_model(await cls.get_default_model(session, cls._args["headers"])) + cls.default_model = cls.get_model(await cls.get_default_model(session, cls._headers)) else: - cls._args["headers"]["Authorization"] = api_key + cls._set_api_key(api_key) try: - image_response = await cls.upload_image( - session, - cls._args["headers"], - image, - kwargs.get("image_name") - ) if image else None + image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None except Exception as e: - yield e + if debug.logging: + print("OpenaiChat: Upload image failed") + print(f"{e.__class__.__name__}: {e}") - end_turn = EndTurn() - model = cls.get_model(model) - model = "text-davinci-002-render-sha" if model == "gpt-3.5-turbo" else model - while not end_turn.is_end: + model = cls.get_model(model).replace("gpt-3.5-turbo", "text-davinci-002-render-sha") + fields = ResponseFields() + while fields.finish_reason is None: arkose_token = await cls.get_arkose_token(session) + conversation_id = conversation_id if fields.conversation_id is None else fields.conversation_id + parent_id = parent_id if fields.message_id is None else fields.message_id data = { "action": action, "arkose_token": arkose_token, @@ -405,8 +409,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): "history_and_training_disabled": history_disabled and not auto_continue, } if action != "continue": - messages = messages if not conversation_id else [messages[-1]] - data["messages"] = cls.create_messages(messages, image_response) + messages = messages if conversation_id is None else [messages[-1]] + data["messages"] = cls.create_messages(messages, image_request) async with session.post( f"{cls.url}/backend-api/conversation", @@ -414,63 +418,88 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): headers={ "Accept": "text/event-stream", "OpenAI-Sentinel-Arkose-Token": arkose_token, - **cls._args["headers"] + **cls._headers } ) as response: cls._update_request_args(session) if not response.ok: - message = f"{await response.text()} headers:\n{json.dumps(cls._args['headers'], indent=4)}" - raise RuntimeError(f"Response {response.status}: {message}") - last_message: int = 0 - async for line in response.iter_lines(): - if not line.startswith(b"data: "): - continue - elif line.startswith(b"data: [DONE]"): - break - try: - line = json.loads(line[6:]) - except: - continue - if "message" not in line: - continue - if "error" in line and line["error"]: - raise RuntimeError(line["error"]) - if "message_type" not in line["message"]["metadata"]: - continue - try: - image_response = await cls.get_generated_image(session, cls._args["headers"], line) - if image_response is not None: - yield image_response - except Exception as e: - yield e - if line["message"]["author"]["role"] != "assistant": - continue - if line["message"]["content"]["content_type"] != "text": - continue - if line["message"]["metadata"]["message_type"] not in ("next", "continue", "variant"): - continue - conversation_id = line["conversation_id"] - parent_id = line["message"]["id"] + raise RuntimeError(f"Response {response.status}: {await response.text()}") + async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, fields): if response_fields: response_fields = False - yield ResponseFields(conversation_id, parent_id, end_turn) - if "parts" in line["message"]["content"]: - new_message = line["message"]["content"]["parts"][0] - if len(new_message) > last_message: - yield new_message[last_message:] - last_message = len(new_message) - if "finish_details" in line["message"]["metadata"]: - if line["message"]["metadata"]["finish_details"]["type"] == "stop": - end_turn.end() + yield fields + yield chunk if not auto_continue: break action = "continue" await asyncio.sleep(5) if history_disabled and auto_continue: - await cls.delete_conversation(session, cls._args["headers"], conversation_id) + await cls.delete_conversation(session, cls._headers, conversation_id) + + @staticmethod + async def iter_messages_ws(ws: ClientWebSocketResponse) -> AsyncIterator: + while True: + yield base64.b64decode((await ws.receive_json())["body"]) + + @classmethod + async def iter_messages_chunk(cls, messages: AsyncIterator, session: StreamSession, fields: ResponseFields) -> AsyncIterator: + last_message: int = 0 + async for message in messages: + if message.startswith(b'{"wss_url":'): + async with session.ws_connect(json.loads(message)["wss_url"]) as ws: + async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws), session, fields): + yield chunk + break + async for chunk in cls.iter_messages_line(session, message, fields): + if fields.finish_reason is not None: + break + elif isinstance(chunk, str): + if len(chunk) > last_message: + yield chunk[last_message:] + last_message = len(chunk) + else: + yield chunk + if fields.finish_reason is not None: + break + + @classmethod + async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: ResponseFields) -> AsyncIterator: + if not line.startswith(b"data: "): + return + elif line.startswith(b"data: [DONE]"): + return + try: + line = json.loads(line[6:]) + except: + return + if "message" not in line: + return + if "error" in line and line["error"]: + raise RuntimeError(line["error"]) + if "message_type" not in line["message"]["metadata"]: + return + try: + image_response = await cls.get_generated_image(session, cls._headers, line) + if image_response is not None: + yield image_response + except Exception as e: + yield e + if line["message"]["author"]["role"] != "assistant": + return + if line["message"]["content"]["content_type"] != "text": + return + if line["message"]["metadata"]["message_type"] not in ("next", "continue", "variant"): + return + if fields.conversation_id is None: + fields.conversation_id = line["conversation_id"] + fields.message_id = line["message"]["id"] + if "parts" in line["message"]["content"]: + yield line["message"]["content"]["parts"][0] + if "finish_details" in line["message"]["metadata"]: + fields.finish_reason = line["message"]["metadata"]["finish_details"]["type"] @classmethod - def browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> tuple[str, dict]: + def browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> None: """ Browse to obtain an access token. @@ -493,9 +522,10 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): "return accessToken;" ) args = get_args_from_browser(f"{cls.url}/", driver, do_bypass_cloudflare=False) - args["headers"]["Authorization"] = f"Bearer {access_token}" - args["headers"]["Cookie"] = cls._format_cookies(args["cookies"]) - return args + cls._headers = args["headers"] + cls._cookies = args["cookies"] + cls._update_cookie_header() + cls._set_api_key(access_token) finally: driver.close() @@ -546,16 +576,24 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): @classmethod def _create_request_args(cls, cookies: Union[Cookies, None]): - return { - "headers": {} if cookies is None else {"Cookie": cls._format_cookies(cookies)}, - "cookies": {} if cookies is None else cookies - } + cls._headers = {} + cls._cookies = {} if cookies is None else cookies + cls._update_cookie_header() @classmethod def _update_request_args(cls, session: StreamSession): for c in session.cookie_jar if hasattr(session, "cookie_jar") else session.cookies.jar: - cls._args["cookies"][c.name if hasattr(c, "name") else c.key] = c.value - cls._args["headers"]["Cookie"] = cls._format_cookies(cls._args["cookies"]) + cls._cookies[c.name if hasattr(c, "name") else c.key] = c.value + cls._update_cookie_header() + + @classmethod + def _set_api_key(cls, api_key: str): + cls._api_key = api_key + cls._headers["Authorization"] = f"Bearer {api_key}" + + @classmethod + def _update_cookie_header(cls): + cls._headers["Cookie"] = cls._format_cookies(cls._cookies) class EndTurn: """ @@ -571,10 +609,10 @@ class ResponseFields: """ Class to encapsulate response fields. """ - def __init__(self, conversation_id: str, message_id: str, end_turn: EndTurn): + def __init__(self, conversation_id: str = None, message_id: str = None, finish_reason: str = None): self.conversation_id = conversation_id self.message_id = message_id - self._end_turn = end_turn + self.finish_reason = finish_reason class Response(): """ @@ -608,7 +646,7 @@ class Response(): self._message = "".join(chunks) if not self._fields: raise RuntimeError("Missing response fields") - self.is_end = self._fields._end_turn.is_end + self.is_end = self._fields.end_turn def __aiter__(self): return self.generator() -- cgit v1.2.3