From be9b8f796cb01483cf6dd807dfc6a71c51433c16 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Tue, 3 Oct 2023 22:12:56 +0200 Subject: Add streaming in openai chat Fetch access token with chromedriver --- g4f/Provider/OpenaiChat.py | 73 +++++++++++++++++++++++++++++++--------------- g4f/Provider/helper.py | 26 ++++++++++++++++- 2 files changed, 75 insertions(+), 24 deletions(-) (limited to 'g4f/Provider') diff --git a/g4f/Provider/OpenaiChat.py b/g4f/Provider/OpenaiChat.py index f7dc8298..fbd26d7c 100644 --- a/g4f/Provider/OpenaiChat.py +++ b/g4f/Provider/OpenaiChat.py @@ -1,14 +1,14 @@ from __future__ import annotations -from curl_cffi.requests import AsyncSession import uuid import json -from .base_provider import AsyncProvider, get_cookies, format_prompt +from .base_provider import AsyncGeneratorProvider +from .helper import get_browser, get_cookies, format_prompt from ..typing import AsyncGenerator +from ..requests import StreamSession - -class OpenaiChat(AsyncProvider): +class OpenaiChat(AsyncGeneratorProvider): url = "https://chat.openai.com" needs_auth = True working = True @@ -16,7 +16,7 @@ class OpenaiChat(AsyncProvider): _access_token = None @classmethod - async def create_async( + async def create_async_generator( cls, model: str, messages: list[dict[str, str]], @@ -32,7 +32,7 @@ class OpenaiChat(AsyncProvider): "Accept": "text/event-stream", "Authorization": f"Bearer {access_token}", } - async with AsyncSession(proxies=proxies, headers=headers, impersonate="chrome107") as session: + async with StreamSession(proxies=proxies, headers=headers, impersonate="chrome107") as session: messages = [ { "id": str(uuid.uuid4()), @@ -48,31 +48,58 @@ class OpenaiChat(AsyncProvider): "model": "text-davinci-002-render-sha", "history_and_training_disabled": True, } - response = await session.post("https://chat.openai.com/backend-api/conversation", json=data) - response.raise_for_status() - last_message = None - for line in response.content.decode().splitlines(): - if line.startswith("data: "): - line = line[6:] - if line == "[DONE]": - break - line = json.loads(line) - if "message" in line: - last_message = line["message"]["content"]["parts"][0] - return last_message + async with session.post(f"{cls.url}/backend-api/conversation", json=data) as response: + response.raise_for_status() + last_message = "" + async for line in response.iter_lines(): + if line.startswith(b"data: "): + line = line[6:] + if line == b"[DONE]": + break + line = json.loads(line) + if "message" in line and not line["message"]["end_turn"]: + new_message = line["message"]["content"]["parts"][0] + yield new_message[len(last_message):] + last_message = new_message + + @classmethod + def fetch_access_token(cls) -> str: + try: + from selenium.webdriver.common.by import By + from selenium.webdriver.support.ui import WebDriverWait + from selenium.webdriver.support import expected_conditions as EC + except ImportError: + return + driver = get_browser() + if not driver: + return + + driver.get(f"{cls.url}/") + try: + WebDriverWait(driver, 1200).until( + EC.presence_of_element_located((By.ID, "prompt-textarea")) + ) + javascript = "return (await (await fetch('/api/auth/session')).json())['accessToken']" + return driver.execute_script(javascript) + finally: + driver.quit() @classmethod async def get_access_token(cls, cookies: dict = None, proxies: dict = None) -> str: if not cls._access_token: cookies = cookies if cookies else get_cookies("chat.openai.com") - async with AsyncSession(proxies=proxies, cookies=cookies, impersonate="chrome107") as session: - response = await session.get("https://chat.openai.com/api/auth/session") - response.raise_for_status() - cls._access_token = response.json()["accessToken"] + async with StreamSession(proxies=proxies, cookies=cookies, impersonate="chrome107") as session: + async with session.get(f"{cls.url}/api/auth/session") as response: + response.raise_for_status() + auth = await response.json() + if "accessToken" in auth: + cls._access_token = auth["accessToken"] + cls._access_token = cls.fetch_access_token() + if not cls._access_token: + raise RuntimeError("Missing access token") return cls._access_token - @classmethod @property def params(cls): diff --git a/g4f/Provider/helper.py b/g4f/Provider/helper.py index 234cdaa1..8f09239a 100644 --- a/g4f/Provider/helper.py +++ b/g4f/Provider/helper.py @@ -56,4 +56,28 @@ def format_prompt(messages: list[dict[str, str]], add_special_tokens=False): ) return f"{formatted}\nAssistant:" else: - return messages[0]["content"] \ No newline at end of file + return messages[0]["content"] + + +def get_browser(user_data_dir: str = None): + try: + from undetected_chromedriver import Chrome + except ImportError: + return None + + def get_user_data_dir(): + dirs = [ + '~/.config/google-chrome/Default', + '~/.var/app/com.google.Chrome/config/google-chrome/Default', + '%LOCALAPPDATA%\\Google\\Chrome\\User Data\\Default', + '~/Library/Application Support/Google/Chrome/Default', + ] + from os import path + for dir in dirs: + dir = path.expandvars(dir) + if path.exists(dir): + return dir + if not user_data_dir: + user_data_dir = get_user_data_dir() + + return Chrome(user_data_dir=user_data_dir) \ No newline at end of file -- cgit v1.2.3