summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/needs_auth
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/needs_auth')
-rw-r--r--g4f/Provider/needs_auth/OpenaiChat.py114
1 files changed, 69 insertions, 45 deletions
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index 386e7d0a..c0e29dfb 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-import uuid, json, time, os
+import uuid, json, time, asyncio
from py_arkose_generator.arkose import get_values_for_request
from ..base_provider import AsyncGeneratorProvider
@@ -24,6 +24,7 @@ class OpenaiChat(AsyncGeneratorProvider):
proxy: str = None,
timeout: int = 120,
access_token: str = None,
+ auto_continue: bool = False,
cookies: dict = None,
**kwargs
) -> AsyncResult:
@@ -34,50 +35,73 @@ class OpenaiChat(AsyncGeneratorProvider):
"Accept": "text/event-stream",
"Authorization": f"Bearer {access_token}",
}
- async with StreamSession(
- proxies=proxies,
- headers=headers,
- impersonate="chrome107",
- timeout=timeout
- ) as session:
- messages = [
- {
- "id": str(uuid.uuid4()),
- "author": {"role": "user"},
- "content": {"content_type": "text", "parts": [format_prompt(messages)]},
- },
- ]
- data = {
- "action": "next",
- "arkose_token": await get_arkose_token(proxy),
- "messages": messages,
- "conversation_id": None,
- "parent_message_id": str(uuid.uuid4()),
- "model": "text-davinci-002-render-sha",
- "history_and_training_disabled": True,
- }
- async with session.post(f"{cls.url}/backend-api/conversation", json=data) as response:
- response.raise_for_status()
- last_message = ""
- async for line in response.iter_lines():
- if line.startswith(b"data: "):
- line = line[6:]
- if line == b"[DONE]":
- break
- try:
- line = json.loads(line)
- except:
- continue
- if "message" not in line:
- continue
- if "error" in line and line["error"]:
- raise RuntimeError(line["error"])
- if "message_type" not in line["message"]["metadata"]:
- continue
- if line["message"]["metadata"]["message_type"] == "next":
- new_message = line["message"]["content"]["parts"][0]
- yield new_message[len(last_message):]
- last_message = new_message
+ messages = [
+ {
+ "id": str(uuid.uuid4()),
+ "author": {"role": "user"},
+ "content": {"content_type": "text", "parts": [format_prompt(messages)]},
+ },
+ ]
+ message_id = str(uuid.uuid4())
+ data = {
+ "action": "next",
+ "arkose_token": await get_arkose_token(proxy),
+ "messages": messages,
+ "conversation_id": None,
+ "parent_message_id": message_id,
+ "model": "text-davinci-002-render-sha",
+ "history_and_training_disabled": not auto_continue,
+ }
+ conversation_id = None
+ while not end_turn:
+ if not auto_continue:
+ end_turn = True
+ async with StreamSession(
+ proxies=proxies,
+ headers=headers,
+ impersonate="chrome107",
+ timeout=timeout
+ ) as session:
+ async with session.post(f"{cls.url}/backend-api/conversation", json=data) as response:
+ try:
+ response.raise_for_status()
+ except:
+ raise RuntimeError(f"Response: {await response.text()}")
+ last_message = ""
+ async for line in response.iter_lines():
+ if line.startswith(b"data: "):
+ line = line[6:]
+ if line == b"[DONE]":
+ break
+ try:
+ line = json.loads(line)
+ except:
+ continue
+ if "message" not in line:
+ continue
+ if "error" in line and line["error"]:
+ raise RuntimeError(line["error"])
+ end_turn = line["message"]["end_turn"]
+ message_id = line["message"]["id"]
+ if line["conversation_id"]:
+ conversation_id = line["conversation_id"]
+ if "message_type" not in line["message"]["metadata"]:
+ continue
+ if line["message"]["metadata"]["message_type"] in ("next", "continue"):
+ new_message = line["message"]["content"]["parts"][0]
+ yield new_message[len(last_message):]
+ last_message = new_message
+ if end_turn:
+ return
+ data = {
+ "action": "continue",
+ "arkose_token": await get_arkose_token(proxy),
+ "conversation_id": conversation_id,
+ "parent_message_id": message_id,
+ "model": "text-davinci-002-render-sha",
+ "history_and_training_disabled": False,
+ }
+ await asyncio.sleep(5)
@classmethod
async def browse_access_token(cls) -> str: