summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/Cloudflare.py
blob: 7d477d57327fa1debbdb09bbc96170bc7c12eaec (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from __future__ import annotations

import asyncio
import json
import uuid

from ..typing import AsyncResult, Messages, Cookies
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
from ..errors import ResponseStatusError

class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
    label = "Cloudflare AI"
    url = "https://playground.ai.cloudflare.com"
    api_endpoint = "https://playground.ai.cloudflare.com/api/inference"
    models_url = "https://playground.ai.cloudflare.com/api/models"
    working = True
    supports_stream = True
    supports_system_message = True
    supports_message_history = True
    default_model = "@cf/meta/llama-3.1-8b-instruct"
    model_aliases = {       
        "llama-2-7b": "@cf/meta/llama-2-7b-chat-fp16",
        "llama-2-7b": "@cf/meta/llama-2-7b-chat-int8",
        "llama-3-8b": "@cf/meta/llama-3-8b-instruct",
        "llama-3-8b": "@cf/meta/llama-3-8b-instruct-awq",
        "llama-3-8b": "@hf/meta-llama/meta-llama-3-8b-instruct",
        "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-awq",
        "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-fp8",
        "llama-3.2-1b": "@cf/meta/llama-3.2-1b-instruct",
        "qwen-1.5-7b": "@cf/qwen/qwen1.5-7b-chat-awq",
    }
    _args: dict = None

    @classmethod
    def get_models(cls) -> str:
        if not cls.models:
            if cls._args is None:
                get_running_loop(check_nested=True)
                args = get_args_from_nodriver(cls.url, cookies={
                    '__cf_bm': uuid.uuid4().hex,
                })
                cls._args = asyncio.run(args)
            with Session(**cls._args) as session:
                response = session.get(cls.models_url)
                cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
                try:
                    raise_for_status(response)
                except ResponseStatusError as e:
                    cls._args = None
                    raise e
                json_data = response.json()
                cls.models = [model.get("name") for model in json_data.get("models")]
        return cls.models

    @classmethod
    async def create_async_generator(
        cls,
        model: str,
        messages: Messages,
        proxy: str = None,
        max_tokens: int = 2048,
        cookies: Cookies = None,
        timeout: int = 300,
        **kwargs
    ) -> AsyncResult:
        model = cls.get_model(model)
        if cls._args is None:
            cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
        data = {
            "messages": messages,
            "lora": None,
            "model": model,
            "max_tokens": max_tokens,
            "stream": True
        }
        async with StreamSession(**cls._args) as session:
            async with session.post(
                cls.api_endpoint,
                json=data,
            ) as response:
                cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
                try:
                    await raise_for_status(response)
                except ResponseStatusError as e:
                    cls._args = None
                    raise e
                async for line in response.iter_lines():
                    if line.startswith(b'data: '):
                        if line == b'data: [DONE]':
                            break
                        try:
                            content = json.loads(line[6:].decode())
                            if content.get("response") and content.get("response") != '</s>':
                                yield content['response']
                        except Exception:
                            continue