diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/Bing.py | 2 | ||||
-rw-r--r-- | g4f/Provider/GigaChat.py | 23 | ||||
-rw-r--r-- | g4f/Provider/HuggingFace.py | 10 |
3 files changed, 14 insertions, 21 deletions
diff --git a/g4f/Provider/Bing.py b/g4f/Provider/Bing.py index 69c32775..7ff4d74b 100644 --- a/g4f/Provider/Bing.py +++ b/g4f/Provider/Bing.py @@ -311,7 +311,7 @@ def create_message( "allowedMessageTypes": Defaults.allowedMessageTypes, "sliceIds": Defaults.sliceIds[tone], "verbosity": "verbose", - "scenario": "CopilotMicrosoftCom", # "SERP", + "scenario": "CopilotMicrosoftCom" if tone == "copilot" else "SERP", "plugins": [{"id": "c310c353-b9f0-4d76-ab0d-1dd5e979cf68", "category": 1}] if web_search else [], "traceId": get_random_hex(40), "conversationHistoryOptionsSets": ["autosave","savemem","uprofupd","uprofgen"], diff --git a/g4f/Provider/GigaChat.py b/g4f/Provider/GigaChat.py index 699353b1..c1ec7f5e 100644 --- a/g4f/Provider/GigaChat.py +++ b/g4f/Provider/GigaChat.py @@ -1,35 +1,28 @@ from __future__ import annotations -import base64 import os import ssl import time import uuid import json -from aiohttp import ClientSession, BaseConnector, TCPConnector +from aiohttp import ClientSession, TCPConnector, BaseConnector from g4f.requests import raise_for_status -from ..typing import AsyncResult, Messages, ImageType +from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider, ProviderModelMixin -from ..image import to_bytes, is_accepted_format from ..errors import MissingAuthError from .helper import get_connector -access_token = '' +access_token = "" token_expires_at = 0 -ssl_ctx = ssl.create_default_context( - cafile=os.path.dirname(__file__) + '/gigachat_crt/russian_trusted_root_ca_pem.crt') - - class GigaChat(AsyncGeneratorProvider, ProviderModelMixin): url = "https://developers.sber.ru/gigachat" working = True supports_message_history = True supports_system_message = True supports_stream = True - needs_auth = True default_model = "GigaChat:latest" models = ["GigaChat:latest", "GigaChat-Plus", "GigaChat-Pro"] @@ -42,18 +35,20 @@ class GigaChat(AsyncGeneratorProvider, ProviderModelMixin): stream: bool = True, proxy: str = None, api_key: str = None, + coonector: BaseConnector = None, scope: str = "GIGACHAT_API_PERS", update_interval: float = 0, **kwargs ) -> AsyncResult: global access_token, token_expires_at model = cls.get_model(model) - if not api_key: raise MissingAuthError('Missing "api_key"') - - connector = TCPConnector(ssl_context=ssl_ctx) - + + cafile = os.path.join(os.path.dirname(__file__), "gigachat_crt/russian_trusted_root_ca_pem.crt") + ssl_context = ssl.create_default_context(cafile=cafile) if os.path.exists(cafile) else None + if connector is None and ssl_context is not None: + connector = TCPConnector(ssl_context=ssl_context) async with ClientSession(connector=get_connector(connector, proxy)) as session: if token_expires_at - int(time.time() * 1000) < 60000: async with session.post(url="https://ngw.devices.sberbank.ru:9443/api/v2/oauth", diff --git a/g4f/Provider/HuggingFace.py b/g4f/Provider/HuggingFace.py index a73411ce..647780fd 100644 --- a/g4f/Provider/HuggingFace.py +++ b/g4f/Provider/HuggingFace.py @@ -7,6 +7,7 @@ from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .helper import get_connector from ..errors import RateLimitError, ModelNotFoundError +from ..requests.raise_for_status import raise_for_status class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): url = "https://huggingface.co/chat" @@ -44,12 +45,9 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): connector=get_connector(connector, proxy) ) as session: async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response: - if response.status == 429: - raise RateLimitError("Rate limit reached. Set a api_key") - elif response.status == 404: + if response.status == 404: raise ModelNotFoundError(f"Model is not supported: {model}") - elif response.status != 200: - raise RuntimeError(f"Response {response.status}: {await response.text()}") + await raise_for_status(response) if stream: first = True async for line in response.content: @@ -68,7 +66,7 @@ def format_prompt(messages: Messages) -> str: system_messages = [message["content"] for message in messages if message["role"] == "system"] question = " ".join([messages[-1]["content"], *system_messages]) history = "".join([ - f"<s>[INST]{messages[idx-1]['content']} [/INST] {message}</s>" + f"<s>[INST]{messages[idx-1]['content']} [/INST] {message['content']}</s>" for idx, message in enumerate(messages) if message["role"] == "assistant" ]) |