summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/needs_auth
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/needs_auth')
-rw-r--r--g4f/Provider/needs_auth/OpenaiChat.py77
1 files changed, 21 insertions, 56 deletions
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index a202f45e..d8ea4fad 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -26,7 +26,7 @@ from ...webdriver import get_browser
from ...typing import AsyncResult, Messages, Cookies, ImageType, AsyncIterator
from ...requests import get_args_from_browser, raise_for_status
from ...requests.aiohttp import StreamSession
-from ...image import to_image, to_bytes, ImageResponse, ImageRequest
+from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format
from ...errors import MissingAuthError, ResponseError
from ...providers.conversation import BaseConversation
from ..helper import format_cookies
@@ -138,23 +138,22 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
An ImageRequest object that contains the download URL, file name, and other data
"""
# Convert the image to a PIL Image object and get the extension
- image = to_image(image)
- extension = image.format.lower()
- # Convert the image to a bytes object and get the size
data_bytes = to_bytes(image)
+ image = to_image(data_bytes)
+ extension = image.format.lower()
data = {
- "file_name": image_name if image_name else f"{image.width}x{image.height}.{extension}",
+ "file_name": "" if image_name is None else image_name,
"file_size": len(data_bytes),
"use_case": "multimodal"
}
# Post the image data to the service and get the image data
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=headers) as response:
- cls._update_request_args()
+ cls._update_request_args(session)
await raise_for_status(response)
image_data = {
**data,
**await response.json(),
- "mime_type": f"image/{extension}",
+ "mime_type": is_accepted_format(data_bytes),
"extension": extension,
"height": image.height,
"width": image.width
@@ -275,7 +274,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
first_part = line["message"]["content"]["parts"][0]
if "asset_pointer" not in first_part or "metadata" not in first_part:
return
- if first_part["metadata"] is None:
+ if first_part["metadata"] is None or first_part["metadata"]["dalle"] is None:
return
prompt = first_part["metadata"]["dalle"]["prompt"]
file_id = first_part["asset_pointer"].split("file-service://", 1)[1]
@@ -365,49 +364,17 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
) as session:
if cls._expires is not None and cls._expires < time.time():
cls._headers = cls._api_key = None
- if cls._headers is None or cookies is not None:
- cls._create_request_args(cookies)
- api_key = kwargs["access_token"] if "access_token" in kwargs else api_key
- if api_key is not None:
- cls._set_api_key(api_key)
-
- if cls.default_model is None and (not cls.needs_auth or cls._api_key is not None):
- if cls._api_key is None:
- cls._create_request_args(cookies)
- async with session.get(
- f"{cls.url}/",
- headers=DEFAULT_HEADERS
- ) as response:
- cls._update_request_args(session)
- await raise_for_status(response)
- try:
- if not model:
- cls.default_model = cls.get_model(await cls.get_default_model(session, cls._headers))
- else:
- cls.default_model = cls.get_model(model)
- except MissingAuthError:
- pass
- except Exception as e:
- api_key = cls._api_key = None
- cls._create_request_args()
- if debug.logging:
- print("OpenaiChat: Load default model failed")
- print(f"{e.__class__.__name__}: {e}")
-
arkose_token = None
proofTokens = None
- if cls.default_model is None:
- error = None
- try:
- arkose_token, api_key, cookies, headers, proofTokens = await getArkoseAndAccessToken(proxy)
- cls._create_request_args(cookies, headers)
- cls._set_api_key(api_key)
- except NoValidHarFileError as e:
- error = e
- if cls._api_key is None:
- await cls.nodriver_access_token(proxy)
+ try:
+ arkose_token, api_key, cookies, headers, proofTokens = await getArkoseAndAccessToken(proxy)
+ cls._create_request_args(cookies, headers)
+ cls._set_api_key(api_key)
+ except NoValidHarFileError as e:
if cls._api_key is None and cls.needs_auth:
- raise error
+ raise e
+
+ if cls.default_model is None:
cls.default_model = cls.get_model(await cls.get_default_model(session, cls._headers))
try:
@@ -461,7 +428,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
)
ws = None
if need_arkose:
- async with session.post("https://chatgpt.com/backend-api/register-websocket", headers=cls._headers) as response:
+ async with session.post(f"{cls.url}/backend-api/register-websocket", headers=cls._headers) as response:
wss_url = (await response.json()).get("wss_url")
if wss_url:
ws = await session.ws_connect(wss_url)
@@ -490,7 +457,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
if proofofwork is not None:
headers["Openai-Sentinel-Proof-Token"] = proofofwork
async with session.post(
- f"{cls.url}/backend-anon/conversation" if cls._api_key is None else
+ f"{cls.url}/backend-anon/conversation"
+ if cls._api_key is None else
f"{cls.url}/backend-api/conversation",
json=data,
headers=headers
@@ -580,12 +548,9 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
raise RuntimeError(line["error"])
if "message_type" not in line["message"]["metadata"]:
return
- try:
- image_response = await cls.get_generated_image(session, cls._headers, line)
- if image_response is not None:
- yield image_response
- except Exception as e:
- yield e
+ image_response = await cls.get_generated_image(session, cls._headers, line)
+ if image_response is not None:
+ yield image_response
if line["message"]["author"]["role"] != "assistant":
return
if line["message"]["content"]["content_type"] != "text":