summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/needs_auth
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/needs_auth/Bard.py12
-rw-r--r--g4f/Provider/needs_auth/OpenaiChat.py52
2 files changed, 40 insertions, 24 deletions
diff --git a/g4f/Provider/needs_auth/Bard.py b/g4f/Provider/needs_auth/Bard.py
index aea67874..09ed1c3c 100644
--- a/g4f/Provider/needs_auth/Bard.py
+++ b/g4f/Provider/needs_auth/Bard.py
@@ -2,10 +2,14 @@ from __future__ import annotations
import time
import os
-from selenium.webdriver.common.by import By
-from selenium.webdriver.support.ui import WebDriverWait
-from selenium.webdriver.support import expected_conditions as EC
-from selenium.webdriver.common.keys import Keys
+
+try:
+ from selenium.webdriver.common.by import By
+ from selenium.webdriver.support.ui import WebDriverWait
+ from selenium.webdriver.support import expected_conditions as EC
+ from selenium.webdriver.common.keys import Keys
+except ImportError:
+ pass
from ...typing import CreateResult, Messages
from ..base_provider import AbstractProvider
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index 85866272..b07bd49b 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -1,21 +1,32 @@
from __future__ import annotations
+
import asyncio
import uuid
import json
import os
-from py_arkose_generator.arkose import get_values_for_request
-from async_property import async_cached_property
-from selenium.webdriver.common.by import By
-from selenium.webdriver.support.ui import WebDriverWait
-from selenium.webdriver.support import expected_conditions as EC
+try:
+ from py_arkose_generator.arkose import get_values_for_request
+ from async_property import async_cached_property
+ has_requirements = True
+except ImportError:
+ async_cached_property = property
+ has_requirements = False
+try:
+ from selenium.webdriver.common.by import By
+ from selenium.webdriver.support.ui import WebDriverWait
+ from selenium.webdriver.support import expected_conditions as EC
+ has_webdriver = True
+except ImportError:
+ has_webdriver = False
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt, get_cookies
from ...webdriver import get_browser, get_driver_cookies
-from ...typing import AsyncResult, Messages
+from ...typing import AsyncResult, Messages, Cookies, ImageType
from ...requests import StreamSession
-from ...image import to_image, to_bytes, ImageType, ImageResponse
+from ...image import to_image, to_bytes, ImageResponse, ImageRequest
+from ...errors import MissingRequirementsError, MissingAccessToken
class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
@@ -27,12 +38,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
supports_gpt_35_turbo = True
supports_gpt_4 = True
default_model = None
- models = ["text-davinci-002-render-sha", "gpt-4", "gpt-4-gizmo"]
- model_aliases = {
- "gpt-3.5-turbo": "text-davinci-002-render-sha",
- }
+ models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"]
_cookies: dict = {}
- _default_model: str = None
@classmethod
async def create(
@@ -94,7 +101,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
session: StreamSession,
headers: dict,
image: ImageType
- ) -> ImageResponse:
+ ) -> ImageRequest:
"""
Upload an image to the service and get the download URL
@@ -104,7 +111,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
image: The image to upload, either a PIL Image object or a bytes object
Returns:
- An ImageResponse object that contains the download URL, file name, and other data
+ An ImageRequest object that contains the download URL, file name, and other data
"""
# Convert the image to a PIL Image object and get the extension
image = to_image(image)
@@ -145,7 +152,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
) as response:
response.raise_for_status()
download_url = (await response.json())["download_url"]
- return ImageResponse(download_url, image_data["file_name"], image_data)
+ return ImageRequest(download_url, image_data["file_name"], image_data)
@classmethod
async def get_default_model(cls, session: StreamSession, headers: dict):
@@ -169,7 +176,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
return cls.default_model
@classmethod
- def create_messages(cls, prompt: str, image_response: ImageResponse = None):
+ def create_messages(cls, prompt: str, image_response: ImageRequest = None):
"""
Create a list of messages for the user input
@@ -282,7 +289,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
proxy: str = None,
timeout: int = 120,
access_token: str = None,
- cookies: dict = None,
+ cookies: Cookies = None,
auto_continue: bool = False,
history_disabled: bool = True,
action: str = "next",
@@ -317,12 +324,16 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
Raises:
RuntimeError: If an error occurs during processing.
"""
+ if not has_requirements:
+ raise MissingRequirementsError('Install "py-arkose-generator" and "async_property" package')
if not parent_id:
parent_id = str(uuid.uuid4())
if not cookies:
- cookies = cls._cookies or get_cookies("chat.openai.com")
+ cookies = cls._cookies or get_cookies("chat.openai.com", False)
if not access_token and "access_token" in cookies:
access_token = cookies["access_token"]
+ if not access_token and not has_webdriver:
+ raise MissingAccessToken(f'Missing "access_token"')
if not access_token:
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
@@ -331,7 +342,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
cls._cookies = cookies
headers = {"Authorization": f"Bearer {access_token}"}
-
async with StreamSession(
proxies={"https": proxy},
impersonate="chrome110",
@@ -346,13 +356,15 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
except Exception as e:
yield e
end_turn = EndTurn()
+ model = cls.get_model(model or await cls.get_default_model(session, headers))
+ model = "text-davinci-002-render-sha" if model == "gpt-3.5-turbo" else model
while not end_turn.is_end:
data = {
"action": action,
"arkose_token": await cls.get_arkose_token(session),
"conversation_id": conversation_id,
"parent_message_id": parent_id,
- "model": cls.get_model(model or await cls.get_default_model(session, headers)),
+ "model": model,
"history_and_training_disabled": history_disabled and not auto_continue,
}
if action != "continue":