summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/HuggingFace.py
blob: a73411ced3deb71674fb6b9192275016e152d929 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from __future__ import annotations

import json
from aiohttp import ClientSession, BaseConnector

from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import get_connector
from ..errors import RateLimitError, ModelNotFoundError

class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
    url = "https://huggingface.co/chat"
    working = True
    supports_message_history = True
    default_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"

    @classmethod
    async def create_async_generator(
        cls,
        model: str,
        messages: Messages,
        stream: bool = True,
        proxy: str = None,
        connector: BaseConnector = None,
        api_base: str = "https://api-inference.huggingface.co",
        api_key: str = None,
        max_new_tokens: int = 1024,
        temperature: float = 0.7,
        **kwargs
    ) -> AsyncResult:
        model = cls.get_model(model)
        headers = {}
        if api_key is not None:
            headers["Authorization"] = f"Bearer {api_key}"
        params = {
            "return_full_text": False,
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
            **kwargs
        }
        payload = {"inputs": format_prompt(messages), "parameters": params, "stream": stream}
        async with ClientSession(
            headers=headers,
            connector=get_connector(connector, proxy)
        ) as session:
            async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response:
                if response.status == 429:
                    raise RateLimitError("Rate limit reached. Set a api_key")
                elif response.status == 404:
                    raise ModelNotFoundError(f"Model is not supported: {model}")
                elif response.status != 200:
                    raise RuntimeError(f"Response {response.status}: {await response.text()}")
                if stream:
                    first = True
                    async for line in response.content:
                        if line.startswith(b"data:"):
                            data = json.loads(line[5:])
                            if not data["token"]["special"]:
                                chunk = data["token"]["text"]
                                if first:
                                    first = False
                                    chunk = chunk.lstrip()
                                yield chunk
                else:
                    yield (await response.json())[0]["generated_text"].strip()

def format_prompt(messages: Messages) -> str:
    system_messages = [message["content"] for message in messages if message["role"] == "system"]
    question = " ".join([messages[-1]["content"], *system_messages])
    history = "".join([
        f"<s>[INST]{messages[idx-1]['content']} [/INST] {message}</s>"
        for idx, message in enumerate(messages)
        if message["role"] == "assistant"
    ])
    return f"{history}<s>[INST] {question} [/INST]"