From 91feb34054f529c37e10d98d2471c8c0c6780147 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Tue, 23 Jan 2024 19:44:48 +0100 Subject: Add ProviderModelMixin for model selection --- g4f/Provider/needs_auth/OpenaiChat.py | 66 +++++++++++++++-------------------- 1 file changed, 29 insertions(+), 37 deletions(-) (limited to 'g4f/Provider/needs_auth') diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index abf5b8d9..85866272 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -10,22 +10,15 @@ from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support import expected_conditions as EC -from ..base_provider import AsyncGeneratorProvider +from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..helper import format_prompt, get_cookies from ...webdriver import get_browser, get_driver_cookies from ...typing import AsyncResult, Messages from ...requests import StreamSession from ...image import to_image, to_bytes, ImageType, ImageResponse -# Aliases for model names -MODELS = { - "gpt-3.5": "text-davinci-002-render-sha", - "gpt-3.5-turbo": "text-davinci-002-render-sha", - "gpt-4": "gpt-4", - "gpt-4-gizmo": "gpt-4-gizmo" -} -class OpenaiChat(AsyncGeneratorProvider): +class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): """A class for creating and managing conversations with OpenAI chat service""" url = "https://chat.openai.com" @@ -33,6 +26,11 @@ class OpenaiChat(AsyncGeneratorProvider): needs_auth = True supports_gpt_35_turbo = True supports_gpt_4 = True + default_model = None + models = ["text-davinci-002-render-sha", "gpt-4", "gpt-4-gizmo"] + model_aliases = { + "gpt-3.5-turbo": "text-davinci-002-render-sha", + } _cookies: dict = {} _default_model: str = None @@ -91,7 +89,7 @@ class OpenaiChat(AsyncGeneratorProvider): ) @classmethod - async def _upload_image( + async def upload_image( cls, session: StreamSession, headers: dict, @@ -150,7 +148,7 @@ class OpenaiChat(AsyncGeneratorProvider): return ImageResponse(download_url, image_data["file_name"], image_data) @classmethod - async def _get_default_model(cls, session: StreamSession, headers: dict): + async def get_default_model(cls, session: StreamSession, headers: dict): """ Get the default model name from the service @@ -161,20 +159,17 @@ class OpenaiChat(AsyncGeneratorProvider): Returns: The default model name as a string """ - # Check the cache for the default model - if cls._default_model: - return cls._default_model - # Get the models data from the service - async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response: - data = await response.json() - if "categories" in data: - cls._default_model = data["categories"][-1]["default_model"] - else: - raise RuntimeError(f"Response: {data}") - return cls._default_model + if not cls.default_model: + async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response: + data = await response.json() + if "categories" in data: + cls.default_model = data["categories"][-1]["default_model"] + else: + raise RuntimeError(f"Response: {data}") + return cls.default_model @classmethod - def _create_messages(cls, prompt: str, image_response: ImageResponse = None): + def create_messages(cls, prompt: str, image_response: ImageResponse = None): """ Create a list of messages for the user input @@ -222,7 +217,7 @@ class OpenaiChat(AsyncGeneratorProvider): return messages @classmethod - async def _get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse: + async def get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse: """ Retrieves the image response based on the message content. @@ -257,7 +252,7 @@ class OpenaiChat(AsyncGeneratorProvider): raise RuntimeError(f"Error in downloading image: {e}") @classmethod - async def _delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str): + async def delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str): """ Deletes a conversation by setting its visibility to False. @@ -322,7 +317,6 @@ class OpenaiChat(AsyncGeneratorProvider): Raises: RuntimeError: If an error occurs during processing. """ - model = MODELS.get(model, model) if not parent_id: parent_id = str(uuid.uuid4()) if not cookies: @@ -333,7 +327,7 @@ class OpenaiChat(AsyncGeneratorProvider): login_url = os.environ.get("G4F_LOGIN_URL") if login_url: yield f"Please login: [ChatGPT]({login_url})\n\n" - access_token, cookies = cls._browse_access_token(proxy) + access_token, cookies = cls.browse_access_token(proxy) cls._cookies = cookies headers = {"Authorization": f"Bearer {access_token}"} @@ -344,12 +338,10 @@ class OpenaiChat(AsyncGeneratorProvider): timeout=timeout, cookies=dict([(name, value) for name, value in cookies.items() if name == "_puid"]) ) as session: - if not model: - model = await cls._get_default_model(session, headers) try: image_response = None if image: - image_response = await cls._upload_image(session, headers, image) + image_response = await cls.upload_image(session, headers, image) yield image_response except Exception as e: yield e @@ -357,15 +349,15 @@ class OpenaiChat(AsyncGeneratorProvider): while not end_turn.is_end: data = { "action": action, - "arkose_token": await cls._get_arkose_token(session), + "arkose_token": await cls.get_arkose_token(session), "conversation_id": conversation_id, "parent_message_id": parent_id, - "model": model, + "model": cls.get_model(model or await cls.get_default_model(session, headers)), "history_and_training_disabled": history_disabled and not auto_continue, } if action != "continue": prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"] - data["messages"] = cls._create_messages(prompt, image_response) + data["messages"] = cls.create_messages(prompt, image_response) async with session.post( f"{cls.url}/backend-api/conversation", json=data, @@ -391,7 +383,7 @@ class OpenaiChat(AsyncGeneratorProvider): if "message_type" not in line["message"]["metadata"]: continue try: - image_response = await cls._get_generated_image(session, headers, line) + image_response = await cls.get_generated_image(session, headers, line) if image_response: yield image_response except Exception as e: @@ -422,10 +414,10 @@ class OpenaiChat(AsyncGeneratorProvider): action = "continue" await asyncio.sleep(5) if history_disabled and auto_continue: - await cls._delete_conversation(session, headers, conversation_id) + await cls.delete_conversation(session, headers, conversation_id) @classmethod - def _browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> tuple[str, dict]: + def browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> tuple[str, dict]: """ Browse to obtain an access token. @@ -452,7 +444,7 @@ class OpenaiChat(AsyncGeneratorProvider): driver.quit() @classmethod - async def _get_arkose_token(cls, session: StreamSession) -> str: + async def get_arkose_token(cls, session: StreamSession) -> str: """ Obtain an Arkose token for the session. -- cgit v1.2.3