diff options
Diffstat (limited to 'g4f/Provider')
-rw-r--r-- | g4f/Provider/GeminiPro.py | 86 | ||||
-rw-r--r-- | g4f/Provider/__init__.py | 1 | ||||
-rw-r--r-- | g4f/Provider/needs_auth/OpenaiChat.py | 89 |
3 files changed, 144 insertions, 32 deletions
diff --git a/g4f/Provider/GeminiPro.py b/g4f/Provider/GeminiPro.py new file mode 100644 index 00000000..b296f253 --- /dev/null +++ b/g4f/Provider/GeminiPro.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import base64 +import json +from aiohttp import ClientSession + +from ..typing import AsyncResult, Messages, ImageType +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin +from ..image import to_bytes, is_accepted_format + + +class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): + url = "https://ai.google.dev" + working = True + supports_message_history = True + default_model = "gemini-pro" + models = ["gemini-pro", "gemini-pro-vision"] + + @classmethod + async def create_async_generator( + cls, + model: str, + messages: Messages, + stream: bool = False, + proxy: str = None, + api_key: str = None, + image: ImageType = None, + **kwargs + ) -> AsyncResult: + model = "gemini-pro-vision" if not model and image else model + model = cls.get_model(model) + api_key = api_key if api_key else kwargs.get("access_token") + headers = { + "Content-Type": "application/json", + } + async with ClientSession(headers=headers) as session: + method = "streamGenerateContent" if stream else "generateContent" + url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:{method}" + contents = [ + { + "role": "model" if message["role"] == "assistant" else message["role"], + "parts": [{"text": message["content"]}] + } + for message in messages + ] + if image: + image = to_bytes(image) + contents[-1]["parts"].append({ + "inline_data": { + "mime_type": is_accepted_format(image), + "data": base64.b64encode(image).decode() + } + }) + data = { + "contents": contents, + # "generationConfig": { + # "stopSequences": kwargs.get("stop"), + # "temperature": kwargs.get("temperature"), + # "maxOutputTokens": kwargs.get("max_tokens"), + # "topP": kwargs.get("top_p"), + # "topK": kwargs.get("top_k"), + # } + } + async with session.post(url, params={"key": api_key}, json=data, proxy=proxy) as response: + if not response.ok: + data = await response.json() + raise RuntimeError(data[0]["error"]["message"]) + if stream: + lines = [] + async for chunk in response.content: + if chunk == b"[{\n": + lines = [b"{\n"] + elif chunk == b",\r\n" or chunk == b"]": + try: + data = b"".join(lines) + data = json.loads(data) + yield data["candidates"][0]["content"]["parts"][0]["text"] + except: + data = data.decode() if isinstance(data, bytes) else data + raise RuntimeError(f"Read text failed. data: {data}") + lines = [] + else: + lines.append(chunk) + else: + data = await response.json() + yield data["candidates"][0]["content"]["parts"][0]["text"]
\ No newline at end of file diff --git a/g4f/Provider/__init__.py b/g4f/Provider/__init__.py index bad77e9b..270b6356 100644 --- a/g4f/Provider/__init__.py +++ b/g4f/Provider/__init__.py @@ -34,6 +34,7 @@ from .FakeGpt import FakeGpt from .FreeChatgpt import FreeChatgpt from .FreeGpt import FreeGpt from .GeekGpt import GeekGpt +from .GeminiPro import GeminiPro from .GeminiProChat import GeminiProChat from .Gpt6 import Gpt6 from .GPTalk import GPTalk diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index b3577ad5..001f5a3c 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -23,10 +23,11 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..helper import format_prompt, get_cookies from ...webdriver import get_browser, get_driver_cookies from ...typing import AsyncResult, Messages, Cookies, ImageType -from ...requests import StreamSession +from ...requests import get_args_from_browser +from ...requests.aiohttp import StreamSession from ...image import to_image, to_bytes, ImageResponse, ImageRequest from ...errors import MissingRequirementsError, MissingAuthError - +from ... import debug class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): """A class for creating and managing conversations with OpenAI chat service""" @@ -39,7 +40,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): default_model = None models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"] model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo"} - _cookies: dict = {} + _args: dict = None @classmethod async def create( @@ -169,11 +170,12 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): """ if not cls.default_model: async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response: + response.raise_for_status() data = await response.json() if "categories" in data: cls.default_model = data["categories"][-1]["default_model"] - else: - raise RuntimeError(f"Response: {data}") + return cls.default_model + raise RuntimeError(f"Response: {data}") return cls.default_model @classmethod @@ -249,8 +251,10 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): first_part = line["message"]["content"]["parts"][0] if "asset_pointer" not in first_part or "metadata" not in first_part: return - file_id = first_part["asset_pointer"].split("file-service://", 1)[1] + if first_part["metadata"] is None: + return prompt = first_part["metadata"]["dalle"]["prompt"] + file_id = first_part["asset_pointer"].split("file-service://", 1)[1] try: async with session.get(f"{cls.url}/backend-api/files/{file_id}/download", headers=headers) as response: response.raise_for_status() @@ -289,7 +293,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): messages: Messages, proxy: str = None, timeout: int = 120, - access_token: str = None, + api_key: str = None, cookies: Cookies = None, auto_continue: bool = False, history_disabled: bool = True, @@ -308,7 +312,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): messages (Messages): The list of previous messages. proxy (str): Proxy to use for requests. timeout (int): Timeout for requests. - access_token (str): Access token for authentication. + api_key (str): Access token for authentication. cookies (dict): Cookies to use for authentication. auto_continue (bool): Flag to automatically continue the conversation. history_disabled (bool): Flag to disable history and training. @@ -329,35 +333,47 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): raise MissingRequirementsError('Install "py-arkose-generator" and "async_property" package') if not parent_id: parent_id = str(uuid.uuid4()) - if not cookies: - cookies = cls._cookies or get_cookies("chat.openai.com", False) - if not access_token and "access_token" in cookies: - access_token = cookies["access_token"] - if not access_token: - login_url = os.environ.get("G4F_LOGIN_URL") - if login_url: - yield f"Please login: [ChatGPT]({login_url})\n\n" - try: - access_token, cookies = cls.browse_access_token(proxy) - except MissingRequirementsError: - raise MissingAuthError(f'Missing "access_token"') - cls._cookies = cookies - - auth_headers = {"Authorization": f"Bearer {access_token}"} + if cls._args is None and cookies is None: + cookies = get_cookies("chat.openai.com", False) + api_key = kwargs["access_token"] if "access_token" in kwargs else api_key + if api_key is None: + api_key = cookies["access_token"] if "access_token" in cookies else api_key + if cls._args is None: + cls._args = { + "headers": {"Cookie": "; ".join(f"{k}={v}" for k, v in cookies.items() if k != "access_token")}, + "cookies": {} if cookies is None else cookies + } + if api_key is not None: + cls._args["headers"]["Authorization"] = f"Bearer {api_key}" async with StreamSession( proxies={"https": proxy}, - impersonate="chrome110", + impersonate="chrome", timeout=timeout, - headers={"Cookie": "; ".join(f"{k}={v}" for k, v in cookies.items())} + headers=cls._args["headers"] ) as session: + if api_key is not None: + try: + cls.default_model = await cls.get_default_model(session, cls._args["headers"]) + except Exception as e: + if debug.logging: + print(f"{e.__class__.__name__}: {e}") + if cls.default_model is None: + login_url = os.environ.get("G4F_LOGIN_URL") + if login_url: + yield f"Please login: [ChatGPT]({login_url})\n\n" + try: + cls._args = cls.browse_access_token(proxy) + except MissingRequirementsError: + raise MissingAuthError(f'Missing or invalid "access_token". Add a new "api_key" please') + cls.default_model = await cls.get_default_model(session, cls._args["headers"]) try: image_response = None if image: - image_response = await cls.upload_image(session, auth_headers, image, kwargs.get("image_name")) + image_response = await cls.upload_image(session, cls._args["headers"], image, kwargs.get("image_name")) except Exception as e: yield e end_turn = EndTurn() - model = cls.get_model(model or await cls.get_default_model(session, auth_headers)) + model = cls.get_model(model) model = "text-davinci-002-render-sha" if model == "gpt-3.5-turbo" else model while not end_turn.is_end: arkose_token = await cls.get_arkose_token(session) @@ -375,13 +391,19 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): if action != "continue": prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"] data["messages"] = cls.create_messages(prompt, image_response) + + # Update cookies before next request + for c in session.cookie_jar if hasattr(session, "cookie_jar") else session.cookies.jar: + cls._args["cookies"][c.name if hasattr(c, "name") else c.key] = c.value + cls._args["headers"]["Cookie"] = "; ".join(f"{k}={v}" for k, v in cls._args["cookies"].items()) + async with session.post( f"{cls.url}/backend-api/conversation", json=data, headers={ "Accept": "text/event-stream", "OpenAI-Sentinel-Arkose-Token": arkose_token, - **auth_headers + **cls._args["headers"] } ) as response: if not response.ok: @@ -403,8 +425,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): if "message_type" not in line["message"]["metadata"]: continue try: - image_response = await cls.get_generated_image(session, auth_headers, line) - if image_response: + image_response = await cls.get_generated_image(session, cls._args["headers"], line) + if image_response is not None: yield image_response except Exception as e: yield e @@ -432,7 +454,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): action = "continue" await asyncio.sleep(5) if history_disabled and auto_continue: - await cls.delete_conversation(session, auth_headers, conversation_id) + await cls.delete_conversation(session, cls._args["headers"], conversation_id) @classmethod def browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> tuple[str, dict]: @@ -457,7 +479,10 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): "document.cookie = 'access_token=' + accessToken + ';expires=' + expires.toUTCString() + ';path=/';" "return accessToken;" ) - return access_token, get_driver_cookies(driver) + args = get_args_from_browser(f"{cls.url}/", driver, do_bypass_cloudflare=False) + args["headers"]["Authorization"] = f"Bearer {access_token}" + args["headers"]["Cookie"] = "; ".join(f"{k}={v}" for k, v in args["cookies"].items() if k != "access_token") + return args finally: driver.close() |