diff options
Diffstat (limited to 'g4f/Provider/HuggingChat.py')
-rw-r--r-- | g4f/Provider/HuggingChat.py | 34 |
1 files changed, 10 insertions, 24 deletions
diff --git a/g4f/Provider/HuggingChat.py b/g4f/Provider/HuggingChat.py index 509a7f16..2481aa31 100644 --- a/g4f/Provider/HuggingChat.py +++ b/g4f/Provider/HuggingChat.py @@ -4,12 +4,13 @@ import json import requests try: - from curl_cffi import requests as cf_reqs + from curl_cffi import Session has_curl_cffi = True except ImportError: has_curl_cffi = False from ..typing import CreateResult, Messages from ..errors import MissingRequirementsError +from ..requests.raise_for_status import raise_for_status from .base_provider import ProviderModelMixin, AbstractProvider from .helper import format_prompt @@ -18,7 +19,7 @@ class HuggingChat(AbstractProvider, ProviderModelMixin): working = True supports_stream = True default_model = "meta-llama/Meta-Llama-3.1-70B-Instruct" - + models = [ 'meta-llama/Meta-Llama-3.1-70B-Instruct', 'CohereForAI/c4ai-command-r-plus-08-2024', @@ -30,7 +31,7 @@ class HuggingChat(AbstractProvider, ProviderModelMixin): 'mistralai/Mistral-Nemo-Instruct-2407', 'microsoft/Phi-3.5-mini-instruct', ] - + model_aliases = { "llama-3.1-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct", "command-r-plus": "CohereForAI/c4ai-command-r-plus-08-2024", @@ -44,15 +45,6 @@ class HuggingChat(AbstractProvider, ProviderModelMixin): } @classmethod - def get_model(cls, model: str) -> str: - if model in cls.models: - return model - elif model in cls.model_aliases: - return cls.model_aliases[model] - else: - return cls.default_model - - @classmethod def create_completion( cls, model: str, @@ -65,7 +57,7 @@ class HuggingChat(AbstractProvider, ProviderModelMixin): model = cls.get_model(model) if model in cls.models: - session = cf_reqs.Session() + session = Session() session.headers = { 'accept': '*/*', 'accept-language': 'en', @@ -82,20 +74,18 @@ class HuggingChat(AbstractProvider, ProviderModelMixin): 'sec-fetch-site': 'same-origin', 'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36', } - json_data = { 'model': model, } - response = session.post('https://huggingface.co/chat/conversation', json=json_data) - if response.status_code != 200: - raise RuntimeError(f"Request failed with status code: {response.status_code}, response: {response.text}") + raise_for_status(response) conversationId = response.json().get('conversationId') # Get the data response and parse it properly response = session.get(f'https://huggingface.co/chat/conversation/{conversationId}/__data.json?x-sveltekit-invalidated=11') - + raise_for_status(response) + # Split the response content by newlines and parse each line as JSON try: json_data = None @@ -156,6 +146,7 @@ class HuggingChat(AbstractProvider, ProviderModelMixin): headers=headers, files=files, ) + raise_for_status(response) full_response = "" for line in response.iter_lines(): @@ -182,9 +173,4 @@ class HuggingChat(AbstractProvider, ProviderModelMixin): full_response = full_response.replace('<|im_end|', '').replace('\u0000', '').strip() if not stream: - yield full_response - - @classmethod - def supports_model(cls, model: str) -> bool: - """Check if the model is supported by the provider.""" - return model in cls.models or model in cls.model_aliases + yield full_response
\ No newline at end of file |