diff options
author | H Lohaus <hlohaus@users.noreply.github.com> | 2024-11-17 18:32:51 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-17 18:32:51 +0100 |
commit | 275574d71ece22975de7df0e226d466a2056605b (patch) | |
tree | 3f113ea8beb7c43920871019512aeb8de9d1b4f7 /g4f/Provider/needs_auth/HuggingFace.py | |
parent | Fix api streaming, fix AsyncClient (#2357) (diff) | |
parent | Add nodriver to Gemini provider, (diff) | |
download | gpt4free-275574d71ece22975de7df0e226d466a2056605b.tar gpt4free-275574d71ece22975de7df0e226d466a2056605b.tar.gz gpt4free-275574d71ece22975de7df0e226d466a2056605b.tar.bz2 gpt4free-275574d71ece22975de7df0e226d466a2056605b.tar.lz gpt4free-275574d71ece22975de7df0e226d466a2056605b.tar.xz gpt4free-275574d71ece22975de7df0e226d466a2056605b.tar.zst gpt4free-275574d71ece22975de7df0e226d466a2056605b.zip |
Diffstat (limited to 'g4f/Provider/needs_auth/HuggingFace.py')
-rw-r--r-- | g4f/Provider/needs_auth/HuggingFace.py | 29 |
1 files changed, 8 insertions, 21 deletions
diff --git a/g4f/Provider/needs_auth/HuggingFace.py b/g4f/Provider/needs_auth/HuggingFace.py index ecc75d1c..35270e60 100644 --- a/g4f/Provider/needs_auth/HuggingFace.py +++ b/g4f/Provider/needs_auth/HuggingFace.py @@ -1,13 +1,11 @@ from __future__ import annotations import json -from aiohttp import ClientSession, BaseConnector from ...typing import AsyncResult, Messages from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin -from ..helper import get_connector -from ...errors import RateLimitError, ModelNotFoundError -from ...requests.raise_for_status import raise_for_status +from ...errors import ModelNotFoundError +from ...requests import StreamSession, raise_for_status from ..HuggingChat import HuggingChat @@ -21,22 +19,12 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): model_aliases = HuggingChat.model_aliases @classmethod - def get_model(cls, model: str) -> str: - if model in cls.models: - return model - elif model in cls.model_aliases: - return cls.model_aliases[model] - else: - return cls.default_model - - @classmethod async def create_async_generator( cls, model: str, messages: Messages, stream: bool = True, proxy: str = None, - connector: BaseConnector = None, api_base: str = "https://api-inference.huggingface.co", api_key: str = None, max_new_tokens: int = 1024, @@ -62,7 +50,6 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): } if api_key is not None: headers["Authorization"] = f"Bearer {api_key}" - params = { "return_full_text": False, "max_new_tokens": max_new_tokens, @@ -70,10 +57,9 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): **kwargs } payload = {"inputs": format_prompt(messages), "parameters": params, "stream": stream} - - async with ClientSession( + async with StreamSession( headers=headers, - connector=get_connector(connector, proxy) + proxy=proxy ) as session: async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response: if response.status == 404: @@ -81,7 +67,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): await raise_for_status(response) if stream: first = True - async for line in response.content: + async for line in response.iter_lines(): if line.startswith(b"data:"): data = json.loads(line[5:]) if not data["token"]["special"]: @@ -89,7 +75,8 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): if first: first = False chunk = chunk.lstrip() - yield chunk + if chunk: + yield chunk else: yield (await response.json())[0]["generated_text"].strip() @@ -101,4 +88,4 @@ def format_prompt(messages: Messages) -> str: for idx, message in enumerate(messages) if message["role"] == "assistant" ]) - return f"{history}<s>[INST] {question} [/INST]" + return f"{history}<s>[INST] {question} [/INST]"
\ No newline at end of file |