diff options
author | Heiner Lohaus <hlohaus@users.noreply.github.com> | 2024-04-22 20:02:17 +0200 |
---|---|---|
committer | Heiner Lohaus <hlohaus@users.noreply.github.com> | 2024-04-22 20:02:17 +0200 |
commit | 002a4a1d7fad2b0d980aea6bce351c2b6a579c05 (patch) | |
tree | 85994a09dedf7176f524720bed6dd596baa2293c /g4f/Provider | |
parent | Add gemini-1.5-pro-latest model (diff) | |
download | gpt4free-002a4a1d7fad2b0d980aea6bce351c2b6a579c05.tar gpt4free-002a4a1d7fad2b0d980aea6bce351c2b6a579c05.tar.gz gpt4free-002a4a1d7fad2b0d980aea6bce351c2b6a579c05.tar.bz2 gpt4free-002a4a1d7fad2b0d980aea6bce351c2b6a579c05.tar.lz gpt4free-002a4a1d7fad2b0d980aea6bce351c2b6a579c05.tar.xz gpt4free-002a4a1d7fad2b0d980aea6bce351c2b6a579c05.tar.zst gpt4free-002a4a1d7fad2b0d980aea6bce351c2b6a579c05.zip |
Diffstat (limited to 'g4f/Provider')
-rw-r--r-- | g4f/Provider/DeepInfra.py | 3 | ||||
-rw-r--r-- | g4f/Provider/HuggingChat.py | 36 | ||||
-rw-r--r-- | g4f/Provider/Llama.py | 6 | ||||
-rw-r--r-- | g4f/Provider/Replicate.py | 3 |
4 files changed, 34 insertions, 14 deletions
diff --git a/g4f/Provider/DeepInfra.py b/g4f/Provider/DeepInfra.py index 35ff84a1..a74601e8 100644 --- a/g4f/Provider/DeepInfra.py +++ b/g4f/Provider/DeepInfra.py @@ -9,13 +9,14 @@ class DeepInfra(Openai): label = "DeepInfra" url = "https://deepinfra.com" working = True + needs_auth = False has_auth = True supports_stream = True supports_message_history = True default_model = "meta-llama/Meta-Llama-3-70b-instruct" default_vision_model = "llava-hf/llava-1.5-7b-hf" model_aliases = { - 'mixtral-8x22b': 'HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1' + 'dbrx-instruct': 'databricks/dbrx-instruct', } @classmethod diff --git a/g4f/Provider/HuggingChat.py b/g4f/Provider/HuggingChat.py index 668ce4b1..527f0a56 100644 --- a/g4f/Provider/HuggingChat.py +++ b/g4f/Provider/HuggingChat.py @@ -6,12 +6,14 @@ from aiohttp import ClientSession, BaseConnector from ..typing import AsyncResult, Messages from ..requests.raise_for_status import raise_for_status +from ..providers.conversation import BaseConversation from .base_provider import AsyncGeneratorProvider, ProviderModelMixin -from .helper import format_prompt, get_connector +from .helper import format_prompt, get_connector, get_cookies class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin): url = "https://huggingface.co/chat" working = True + needs_auth = True default_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" models = [ "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1", @@ -22,9 +24,6 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin): 'mistralai/Mistral-7B-Instruct-v0.2', 'meta-llama/Meta-Llama-3-70B-Instruct' ] - model_aliases = { - "openchat/openchat_3.5": "openchat/openchat-3.5-0106", - } @classmethod def get_models(cls): @@ -45,9 +44,16 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin): connector: BaseConnector = None, web_search: bool = False, cookies: dict = None, + conversation: Conversation = None, + return_conversation: bool = False, + delete_conversation: bool = True, **kwargs ) -> AsyncResult: options = {"model": cls.get_model(model)} + if cookies is None: + cookies = get_cookies("huggingface.co", False) + if return_conversation: + delete_conversation = False system_prompt = "\n".join([message["content"] for message in messages if message["role"] == "system"]) if system_prompt: @@ -61,9 +67,14 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin): headers=headers, connector=get_connector(connector, proxy) ) as session: - async with session.post(f"{cls.url}/conversation", json=options) as response: - await raise_for_status(response) - conversation_id = (await response.json())["conversationId"] + if conversation is None: + async with session.post(f"{cls.url}/conversation", json=options) as response: + await raise_for_status(response) + conversation_id = (await response.json())["conversationId"] + if return_conversation: + yield Conversation(conversation_id) + else: + conversation_id = conversation.conversation_id async with session.get(f"{cls.url}/conversation/{conversation_id}/__data.json") as response: await raise_for_status(response) data: list = (await response.json())["nodes"][1]["data"] @@ -72,7 +83,7 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin): message_id: str = data[message_keys["id"]] options = { "id": message_id, - "inputs": format_prompt(messages), + "inputs": format_prompt(messages) if conversation is None else messages[-1]["content"], "is_continue": False, "is_retry": False, "web_search": web_search @@ -92,5 +103,10 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin): yield token elif line["type"] == "finalAnswer": break - async with session.delete(f"{cls.url}/conversation/{conversation_id}") as response: - await raise_for_status(response) + if delete_conversation: + async with session.delete(f"{cls.url}/conversation/{conversation_id}") as response: + await raise_for_status(response) + +class Conversation(BaseConversation): + def __init__(self, conversation_id: str) -> None: + self.conversation_id = conversation_id
\ No newline at end of file diff --git a/g4f/Provider/Llama.py b/g4f/Provider/Llama.py index 8f3e9ea2..f2c78b36 100644 --- a/g4f/Provider/Llama.py +++ b/g4f/Provider/Llama.py @@ -11,7 +11,7 @@ class Llama(AsyncGeneratorProvider, ProviderModelMixin): url = "https://www.llama2.ai" working = True supports_message_history = True - default_model = "meta/llama-3-70b-chat" + default_model = "meta/meta-llama-3-70b-instruct" models = [ "meta/llama-2-7b-chat", "meta/llama-2-13b-chat", @@ -20,8 +20,8 @@ class Llama(AsyncGeneratorProvider, ProviderModelMixin): "meta/meta-llama-3-70b-instruct", ] model_aliases = { - "meta-llama/Meta-Llama-3-8b-instruct": "meta/meta-llama-3-8b-instruct", - "meta-llama/Meta-Llama-3-70b-instruct": "meta/meta-llama-3-70b-instruct", + "meta-llama/Meta-Llama-3-8B-Instruct": "meta/meta-llama-3-8b-instruct", + "meta-llama/Meta-Llama-3-70B-Instruct": "meta/meta-llama-3-70b-instruct", "meta-llama/Llama-2-7b-chat-hf": "meta/llama-2-7b-chat", "meta-llama/Llama-2-13b-chat-hf": "meta/llama-2-13b-chat", "meta-llama/Llama-2-70b-chat-hf": "meta/llama-2-70b-chat", diff --git a/g4f/Provider/Replicate.py b/g4f/Provider/Replicate.py index 593fd04d..89777cf2 100644 --- a/g4f/Provider/Replicate.py +++ b/g4f/Provider/Replicate.py @@ -11,6 +11,9 @@ class Replicate(AsyncGeneratorProvider, ProviderModelMixin): url = "https://replicate.com" working = True default_model = "meta/meta-llama-3-70b-instruct" + model_aliases = { + "meta-llama/Meta-Llama-3-70B-Instruct": default_model + } @classmethod async def create_async_generator( |