summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/needs_auth/OpenaiChat.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/needs_auth/OpenaiChat.py65
1 files changed, 13 insertions, 52 deletions
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index 15a87f38..97515ec4 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -65,6 +65,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
default_vision_model = "gpt-4o"
fallback_models = ["auto", "gpt-4", "gpt-4o", "gpt-4o-mini", "gpt-4o-canmore", "o1-preview", "o1-mini"]
vision_models = fallback_models
+ image_models = fallback_models
_api_key: str = None
_headers: dict = None
@@ -330,7 +331,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
api_key: str = None,
cookies: Cookies = None,
auto_continue: bool = False,
- history_disabled: bool = True,
+ history_disabled: bool = False,
action: str = "next",
conversation_id: str = None,
conversation: Conversation = None,
@@ -425,12 +426,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
f"Arkose: {'False' if not need_arkose else RequestConfig.arkose_token[:12]+'...'}",
f"Proofofwork: {'False' if proofofwork is None else proofofwork[:12]+'...'}",
)]
- ws = None
- if need_arkose:
- async with session.post(f"{cls.url}/backend-api/register-websocket", headers=cls._headers) as response:
- wss_url = (await response.json()).get("wss_url")
- if wss_url:
- ws = await session.ws_connect(wss_url)
data = {
"action": action,
"messages": None,
@@ -474,7 +469,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
await asyncio.sleep(5)
continue
await raise_for_status(response)
- async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, conversation, ws):
+ async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, conversation):
if return_conversation:
history_disabled = False
return_conversation = False
@@ -489,44 +484,16 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
if history_disabled and auto_continue:
await cls.delete_conversation(session, cls._headers, conversation.conversation_id)
- @staticmethod
- async def iter_messages_ws(ws: ClientWebSocketResponse, conversation_id: str, is_curl: bool) -> AsyncIterator:
- while True:
- if is_curl:
- message = json.loads(ws.recv()[0])
- else:
- message = await ws.receive_json()
- if message["conversation_id"] == conversation_id:
- yield base64.b64decode(message["body"])
-
@classmethod
async def iter_messages_chunk(
cls,
messages: AsyncIterator,
session: StreamSession,
fields: Conversation,
- ws = None
) -> AsyncIterator:
async for message in messages:
- if message.startswith(b'{"wss_url":'):
- message = json.loads(message)
- ws = await session.ws_connect(message["wss_url"]) if ws is None else ws
- try:
- async for chunk in cls.iter_messages_chunk(
- cls.iter_messages_ws(ws, message["conversation_id"], hasattr(ws, "recv")),
- session, fields
- ):
- yield chunk
- finally:
- await ws.aclose() if hasattr(ws, "aclose") else await ws.close()
- break
async for chunk in cls.iter_messages_line(session, message, fields):
- if fields.finish_reason is not None:
- break
- else:
- yield chunk
- if fields.finish_reason is not None:
- break
+ yield chunk
@classmethod
async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: Conversation) -> AsyncIterator:
@@ -542,9 +509,9 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
return
if isinstance(line, dict) and "v" in line:
v = line.get("v")
- if isinstance(v, str):
+ if isinstance(v, str) and fields.is_recipient:
yield v
- elif isinstance(v, list):
+ elif isinstance(v, list) and fields.is_recipient:
for m in v:
if m.get("p") == "/message/content/parts/0":
yield m.get("v")
@@ -556,25 +523,20 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
fields.conversation_id = v.get("conversation_id")
debug.log(f"OpenaiChat: New conversation: {fields.conversation_id}")
m = v.get("message", {})
- if m.get("author", {}).get("role") == "assistant":
- fields.message_id = v.get("message", {}).get("id")
+ fields.is_recipient = m.get("recipient") == "all"
+ if fields.is_recipient:
c = m.get("content", {})
if c.get("content_type") == "multimodal_text":
generated_images = []
for element in c.get("parts"):
- if isinstance(element, str):
- debug.log(f"No image or text: {line}")
- elif element.get("content_type") == "image_asset_pointer":
+ if isinstance(element, dict) and element.get("content_type") == "image_asset_pointer":
generated_images.append(
cls.get_generated_image(session, cls._headers, element)
)
- elif element.get("content_type") == "text":
- for part in element.get("parts", []):
- yield part
for image_response in await asyncio.gather(*generated_images):
yield image_response
- else:
- debug.log(f"OpenaiChat: {line}")
+ if m.get("author", {}).get("role") == "assistant":
+ fields.message_id = v.get("message", {}).get("id")
return
if "error" in line and line.get("error"):
raise RuntimeError(line.get("error"))
@@ -652,7 +614,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
cls._headers = cls.get_default_headers() if headers is None else headers
if user_agent is not None:
cls._headers["user-agent"] = user_agent
- cls._cookies = {} if cookies is None else {k: v for k, v in cookies.items() if k != "access_token"}
+ cls._cookies = {} if cookies is None else cookies
cls._update_cookie_header()
@classmethod
@@ -671,8 +633,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
@classmethod
def _update_cookie_header(cls):
cls._headers["cookie"] = format_cookies(cls._cookies)
- if "oai-did" in cls._cookies:
- cls._headers["oai-device-id"] = cls._cookies["oai-did"]
class Conversation(BaseConversation):
"""
@@ -682,6 +642,7 @@ class Conversation(BaseConversation):
self.conversation_id = conversation_id
self.message_id = message_id
self.finish_reason = finish_reason
+ self.is_recipient = False
class Response():
"""