summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/needs_auth/OpenaiChat.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/needs_auth/OpenaiChat.py')
-rw-r--r--g4f/Provider/needs_auth/OpenaiChat.py339
1 files changed, 229 insertions, 110 deletions
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index a790f0de..7d352a46 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -1,6 +1,9 @@
from __future__ import annotations
+import asyncio
+import uuid
+import json
+import os
-import uuid, json, asyncio, os
from py_arkose_generator.arkose import get_values_for_request
from async_property import async_cached_property
from selenium.webdriver.common.by import By
@@ -14,7 +17,8 @@ from ...typing import AsyncResult, Messages
from ...requests import StreamSession
from ...image import to_image, to_bytes, ImageType, ImageResponse
-models = {
+# Aliases for model names
+MODELS = {
"gpt-3.5": "text-davinci-002-render-sha",
"gpt-3.5-turbo": "text-davinci-002-render-sha",
"gpt-4": "gpt-4",
@@ -22,13 +26,15 @@ models = {
}
class OpenaiChat(AsyncGeneratorProvider):
- url = "https://chat.openai.com"
- working = True
- needs_auth = True
+ """A class for creating and managing conversations with OpenAI chat service"""
+
+ url = "https://chat.openai.com"
+ working = True
+ needs_auth = True
supports_gpt_35_turbo = True
- supports_gpt_4 = True
- _cookies: dict = {}
- _default_model: str = None
+ supports_gpt_4 = True
+ _cookies: dict = {}
+ _default_model: str = None
@classmethod
async def create(
@@ -43,6 +49,23 @@ class OpenaiChat(AsyncGeneratorProvider):
image: ImageType = None,
**kwargs
) -> Response:
+ """Create a new conversation or continue an existing one
+
+ Args:
+ prompt: The user input to start or continue the conversation
+ model: The name of the model to use for generating responses
+ messages: The list of previous messages in the conversation
+ history_disabled: A flag indicating if the history and training should be disabled
+ action: The type of action to perform, either "next", "continue", or "variant"
+ conversation_id: The ID of the existing conversation, if any
+ parent_id: The ID of the parent message, if any
+ image: The image to include in the user input, if any
+ **kwargs: Additional keyword arguments to pass to the generator
+
+ Returns:
+ A Response object that contains the generator, action, messages, and options
+ """
+ # Add the user input to the messages list
if prompt:
messages.append({
"role": "user",
@@ -67,20 +90,33 @@ class OpenaiChat(AsyncGeneratorProvider):
)
@classmethod
- async def upload_image(
+ async def _upload_image(
cls,
session: StreamSession,
headers: dict,
image: ImageType
) -> ImageResponse:
+ """Upload an image to the service and get the download URL
+
+ Args:
+ session: The StreamSession object to use for requests
+ headers: The headers to include in the requests
+ image: The image to upload, either a PIL Image object or a bytes object
+
+ Returns:
+ An ImageResponse object that contains the download URL, file name, and other data
+ """
+ # Convert the image to a PIL Image object and get the extension
image = to_image(image)
extension = image.format.lower()
+ # Convert the image to a bytes object and get the size
data_bytes = to_bytes(image)
data = {
"file_name": f"{image.width}x{image.height}.{extension}",
"file_size": len(data_bytes),
"use_case": "multimodal"
}
+ # Post the image data to the service and get the image data
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=headers) as response:
response.raise_for_status()
image_data = {
@@ -91,6 +127,7 @@ class OpenaiChat(AsyncGeneratorProvider):
"height": image.height,
"width": image.width
}
+ # Put the image bytes to the upload URL and check the status
async with session.put(
image_data["upload_url"],
data=data_bytes,
@@ -100,6 +137,7 @@ class OpenaiChat(AsyncGeneratorProvider):
}
) as response:
response.raise_for_status()
+ # Post the file ID to the service and get the download URL
async with session.post(
f"{cls.url}/backend-api/files/{image_data['file_id']}/uploaded",
json={},
@@ -110,24 +148,45 @@ class OpenaiChat(AsyncGeneratorProvider):
return ImageResponse(download_url, image_data["file_name"], image_data)
@classmethod
- async def get_default_model(cls, session: StreamSession, headers: dict):
+ async def _get_default_model(cls, session: StreamSession, headers: dict):
+ """Get the default model name from the service
+
+ Args:
+ session: The StreamSession object to use for requests
+ headers: The headers to include in the requests
+
+ Returns:
+ The default model name as a string
+ """
+ # Check the cache for the default model
if cls._default_model:
- model = cls._default_model
- else:
- async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response:
- data = await response.json()
- if "categories" in data:
- model = data["categories"][-1]["default_model"]
- else:
- RuntimeError(f"Response: {data}")
- cls._default_model = model
- return model
+ return cls._default_model
+ # Get the models data from the service
+ async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response:
+ data = await response.json()
+ if "categories" in data:
+ cls._default_model = data["categories"][-1]["default_model"]
+ else:
+ raise RuntimeError(f"Response: {data}")
+ return cls._default_model
@classmethod
- def create_messages(cls, prompt: str, image_response: ImageResponse = None):
+ def _create_messages(cls, prompt: str, image_response: ImageResponse = None):
+ """Create a list of messages for the user input
+
+ Args:
+ prompt: The user input as a string
+ image_response: The image response object, if any
+
+ Returns:
+ A list of messages with the user input and the image, if any
+ """
+ # Check if there is an image response
if not image_response:
+ # Create a content object with the text type and the prompt
content = {"content_type": "text", "parts": [prompt]}
else:
+ # Create a content object with the multimodal text type and the image and the prompt
content = {
"content_type": "multimodal_text",
"parts": [{
@@ -137,12 +196,15 @@ class OpenaiChat(AsyncGeneratorProvider):
"width": image_response.get("width"),
}, prompt]
}
+ # Create a message object with the user role and the content
messages = [{
"id": str(uuid.uuid4()),
"author": {"role": "user"},
"content": content,
}]
+ # Check if there is an image response
if image_response:
+ # Add the metadata object with the attachments
messages[0]["metadata"] = {
"attachments": [{
"height": image_response.get("height"),
@@ -156,19 +218,38 @@ class OpenaiChat(AsyncGeneratorProvider):
return messages
@classmethod
- async def get_image_response(cls, session: StreamSession, headers: dict, line: dict):
- if "parts" in line["message"]["content"]:
- part = line["message"]["content"]["parts"][0]
- if "asset_pointer" in part and part["metadata"]:
- file_id = part["asset_pointer"].split("file-service://", 1)[1]
- prompt = part["metadata"]["dalle"]["prompt"]
- async with session.get(
- f"{cls.url}/backend-api/files/{file_id}/download",
- headers=headers
- ) as response:
- response.raise_for_status()
- download_url = (await response.json())["download_url"]
- return ImageResponse(download_url, prompt)
+ async def _get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse:
+ """
+ Retrieves the image response based on the message content.
+
+ :param session: The StreamSession object.
+ :param headers: HTTP headers for the request.
+ :param line: The line of response containing image information.
+ :return: An ImageResponse object with the image details.
+ """
+ if "parts" not in line["message"]["content"]:
+ return
+ first_part = line["message"]["content"]["parts"][0]
+ if "asset_pointer" not in first_part or "metadata" not in first_part:
+ return
+ file_id = first_part["asset_pointer"].split("file-service://", 1)[1]
+ prompt = first_part["metadata"]["dalle"]["prompt"]
+ try:
+ async with session.get(f"{cls.url}/backend-api/files/{file_id}/download", headers=headers) as response:
+ response.raise_for_status()
+ download_url = (await response.json())["download_url"]
+ return ImageResponse(download_url, prompt)
+ except Exception as e:
+ raise RuntimeError(f"Error in downloading image: {e}")
+
+ @classmethod
+ async def _delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str):
+ async with session.patch(
+ f"{cls.url}/backend-api/conversation/{conversation_id}",
+ json={"is_visible": False},
+ headers=headers
+ ) as response:
+ response.raise_for_status()
@classmethod
async def create_async_generator(
@@ -188,26 +269,47 @@ class OpenaiChat(AsyncGeneratorProvider):
response_fields: bool = False,
**kwargs
) -> AsyncResult:
- if model in models:
- model = models[model]
+ """
+ Create an asynchronous generator for the conversation.
+
+ Args:
+ model (str): The model name.
+ messages (Messages): The list of previous messages.
+ proxy (str): Proxy to use for requests.
+ timeout (int): Timeout for requests.
+ access_token (str): Access token for authentication.
+ cookies (dict): Cookies to use for authentication.
+ auto_continue (bool): Flag to automatically continue the conversation.
+ history_disabled (bool): Flag to disable history and training.
+ action (str): Type of action ('next', 'continue', 'variant').
+ conversation_id (str): ID of the conversation.
+ parent_id (str): ID of the parent message.
+ image (ImageType): Image to include in the conversation.
+ response_fields (bool): Flag to include response fields in the output.
+ **kwargs: Additional keyword arguments.
+
+ Yields:
+ AsyncResult: Asynchronous results from the generator.
+
+ Raises:
+ RuntimeError: If an error occurs during processing.
+ """
+ model = MODELS.get(model, model)
if not parent_id:
parent_id = str(uuid.uuid4())
if not cookies:
- cookies = cls._cookies
- if not access_token:
- if not cookies:
- cls._cookies = cookies = get_cookies("chat.openai.com")
- if "access_token" in cookies:
- access_token = cookies["access_token"]
+ cookies = cls._cookies or get_cookies("chat.openai.com")
+ if not access_token and "access_token" in cookies:
+ access_token = cookies["access_token"]
if not access_token:
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
yield f"Please login: [ChatGPT]({login_url})\n\n"
- access_token, cookies = cls.browse_access_token(proxy)
+ access_token, cookies = cls._browse_access_token(proxy)
cls._cookies = cookies
- headers = {
- "Authorization": f"Bearer {access_token}",
- }
+
+ headers = {"Authorization": f"Bearer {access_token}"}
+
async with StreamSession(
proxies={"https": proxy},
impersonate="chrome110",
@@ -215,11 +317,11 @@ class OpenaiChat(AsyncGeneratorProvider):
cookies=dict([(name, value) for name, value in cookies.items() if name == "_puid"])
) as session:
if not model:
- model = await cls.get_default_model(session, headers)
+ model = await cls._get_default_model(session, headers)
try:
image_response = None
if image:
- image_response = await cls.upload_image(session, headers, image)
+ image_response = await cls._upload_image(session, headers, image)
yield image_response
except Exception as e:
yield e
@@ -227,7 +329,7 @@ class OpenaiChat(AsyncGeneratorProvider):
while not end_turn.is_end:
data = {
"action": action,
- "arkose_token": await cls.get_arkose_token(session),
+ "arkose_token": await cls._get_arkose_token(session),
"conversation_id": conversation_id,
"parent_message_id": parent_id,
"model": model,
@@ -235,7 +337,7 @@ class OpenaiChat(AsyncGeneratorProvider):
}
if action != "continue":
prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"]
- data["messages"] = cls.create_messages(prompt, image_response)
+ data["messages"] = cls._create_messages(prompt, image_response)
async with session.post(
f"{cls.url}/backend-api/conversation",
json=data,
@@ -261,62 +363,80 @@ class OpenaiChat(AsyncGeneratorProvider):
if "message_type" not in line["message"]["metadata"]:
continue
try:
- image_response = await cls.get_image_response(session, headers, line)
+ image_response = await cls._get_generated_image(session, headers, line)
if image_response:
yield image_response
except Exception as e:
yield e
if line["message"]["author"]["role"] != "assistant":
continue
- if line["message"]["metadata"]["message_type"] in ("next", "continue", "variant"):
- conversation_id = line["conversation_id"]
- parent_id = line["message"]["id"]
- if response_fields:
- response_fields = False
- yield ResponseFields(conversation_id, parent_id, end_turn)
- if "parts" in line["message"]["content"]:
- new_message = line["message"]["content"]["parts"][0]
- if len(new_message) > last_message:
- yield new_message[last_message:]
- last_message = len(new_message)
+ if line["message"]["content"]["content_type"] != "text":
+ continue
+ if line["message"]["metadata"]["message_type"] not in ("next", "continue", "variant"):
+ continue
+ conversation_id = line["conversation_id"]
+ parent_id = line["message"]["id"]
+ if response_fields:
+ response_fields = False
+ yield ResponseFields(conversation_id, parent_id, end_turn)
+ if "parts" in line["message"]["content"]:
+ new_message = line["message"]["content"]["parts"][0]
+ if len(new_message) > last_message:
+ yield new_message[last_message:]
+ last_message = len(new_message)
if "finish_details" in line["message"]["metadata"]:
if line["message"]["metadata"]["finish_details"]["type"] == "stop":
end_turn.end()
- break
except Exception as e:
- yield e
+ raise e
if not auto_continue:
break
action = "continue"
await asyncio.sleep(5)
- if history_disabled:
- async with session.patch(
- f"{cls.url}/backend-api/conversation/{conversation_id}",
- json={"is_visible": False},
- headers=headers
- ) as response:
- response.raise_for_status()
+ if history_disabled and auto_continue:
+ await cls._delete_conversation(session, headers, conversation_id)
@classmethod
- def browse_access_token(cls, proxy: str = None) -> tuple[str, dict]:
+ def _browse_access_token(cls, proxy: str = None) -> tuple[str, dict]:
+ """
+ Browse to obtain an access token.
+
+ Args:
+ proxy (str): Proxy to use for browsing.
+
+ Returns:
+ tuple[str, dict]: A tuple containing the access token and cookies.
+ """
driver = get_browser(proxy=proxy)
try:
driver.get(f"{cls.url}/")
- WebDriverWait(driver, 1200).until(
- EC.presence_of_element_located((By.ID, "prompt-textarea"))
+ WebDriverWait(driver, 1200).until(EC.presence_of_element_located((By.ID, "prompt-textarea")))
+ access_token = driver.execute_script(
+ "let session = await fetch('/api/auth/session');"
+ "let data = await session.json();"
+ "let accessToken = data['accessToken'];"
+ "let expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 24 * 7);"
+ "document.cookie = 'access_token=' + accessToken + ';expires=' + expires.toUTCString() + ';path=/';"
+ "return accessToken;"
)
- javascript = """
-access_token = (await (await fetch('/api/auth/session')).json())['accessToken'];
-expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 24 * 7); // One week
-document.cookie = 'access_token=' + access_token + ';expires=' + expires.toUTCString() + ';path=/';
-return access_token;
-"""
- return driver.execute_script(javascript), get_driver_cookies(driver)
+ return access_token, get_driver_cookies(driver)
finally:
driver.quit()
- @classmethod
- async def get_arkose_token(cls, session: StreamSession) -> str:
+ @classmethod
+ async def _get_arkose_token(cls, session: StreamSession) -> str:
+ """
+ Obtain an Arkose token for the session.
+
+ Args:
+ session (StreamSession): The session object.
+
+ Returns:
+ str: The Arkose token.
+
+ Raises:
+ RuntimeError: If unable to retrieve the token.
+ """
config = {
"pkey": "3D86FBBA-9D22-402A-B512-3420086BA6CC",
"surl": "https://tcr9i.chat.openai.com",
@@ -332,26 +452,30 @@ return access_token;
if "token" in decoded_json:
return decoded_json["token"]
raise RuntimeError(f"Response: {decoded_json}")
-
-class EndTurn():
+
+class EndTurn:
+ """
+ Class to represent the end of a conversation turn.
+ """
def __init__(self):
self.is_end = False
def end(self):
self.is_end = True
-class ResponseFields():
- def __init__(
- self,
- conversation_id: str,
- message_id: str,
- end_turn: EndTurn
- ):
+class ResponseFields:
+ """
+ Class to encapsulate response fields.
+ """
+ def __init__(self, conversation_id: str, message_id: str, end_turn: EndTurn):
self.conversation_id = conversation_id
self.message_id = message_id
self._end_turn = end_turn
class Response():
+ """
+ Class to encapsulate a response from the chat service.
+ """
def __init__(
self,
generator: AsyncResult,
@@ -360,13 +484,13 @@ class Response():
options: dict
):
self._generator = generator
- self.action: str = action
- self.is_end: bool = False
+ self.action = action
+ self.is_end = False
self._message = None
self._messages = messages
self._options = options
self._fields = None
-
+
async def generator(self):
if self._generator:
self._generator = None
@@ -384,19 +508,16 @@ class Response():
def __aiter__(self):
return self.generator()
-
+
@async_cached_property
async def message(self) -> str:
- [_ async for _ in self.generator()]
+ await self.generator()
return self._message
-
+
async def get_fields(self):
- [_ async for _ in self.generator()]
- return {
- "conversation_id": self._fields.conversation_id,
- "parent_id": self._fields.message_id,
- }
-
+ await self.generator()
+ return {"conversation_id": self._fields.conversation_id, "parent_id": self._fields.message_id}
+
async def next(self, prompt: str, **kwargs) -> Response:
return await OpenaiChat.create(
**self._options,
@@ -406,7 +527,7 @@ class Response():
**await self.get_fields(),
**kwargs
)
-
+
async def do_continue(self, **kwargs) -> Response:
fields = await self.get_fields()
if self.is_end:
@@ -418,7 +539,7 @@ class Response():
**fields,
**kwargs
)
-
+
async def variant(self, **kwargs) -> Response:
if self.action != "next":
raise RuntimeError("Can't create variant from continue or variant request.")
@@ -429,11 +550,9 @@ class Response():
**await self.get_fields(),
**kwargs
)
-
+
@async_cached_property
async def messages(self):
messages = self._messages
- messages.append({
- "role": "assistant", "content": await self.message
- })
+ messages.append({"role": "assistant", "content": await self.message})
return messages \ No newline at end of file