diff options
Diffstat (limited to 'g4f/Provider/Llama2.py')
-rw-r--r-- | g4f/Provider/Llama2.py | 17 |
1 files changed, 8 insertions, 9 deletions
diff --git a/g4f/Provider/Llama2.py b/g4f/Provider/Llama2.py index b59fde12..1b332f86 100644 --- a/g4f/Provider/Llama2.py +++ b/g4f/Provider/Llama2.py @@ -6,15 +6,14 @@ from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider models = { - "7B": {"name": "Llama 2 7B", "version": "d24902e3fa9b698cc208b5e63136c4e26e828659a9f09827ca6ec5bb83014381", "shortened":"7B"}, - "13B": {"name": "Llama 2 13B", "version": "9dff94b1bed5af738655d4a7cbcdcde2bd503aa85c94334fe1f42af7f3dd5ee3", "shortened":"13B"}, - "70B": {"name": "Llama 2 70B", "version": "2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf", "shortened":"70B"}, + "meta-llama/Llama-2-7b-chat-hf": {"name": "Llama 2 7B", "version": "d24902e3fa9b698cc208b5e63136c4e26e828659a9f09827ca6ec5bb83014381", "shortened":"7B"}, + "meta-llama/Llama-2-13b-chat-hf": {"name": "Llama 2 13B", "version": "9dff94b1bed5af738655d4a7cbcdcde2bd503aa85c94334fe1f42af7f3dd5ee3", "shortened":"13B"}, + "meta-llama/Llama-2-70b-chat-hf": {"name": "Llama 2 70B", "version": "2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf", "shortened":"70B"}, "Llava": {"name": "Llava 13B", "version": "6bc1c7bb0d2a34e413301fee8f7cc728d2d4e75bfab186aa995f63292bda92fc", "shortened":"Llava"} } class Llama2(AsyncGeneratorProvider): url = "https://www.llama2.ai" - supports_gpt_35_turbo = True working = True @classmethod @@ -26,8 +25,8 @@ class Llama2(AsyncGeneratorProvider): **kwargs ) -> AsyncResult: if not model: - model = "70B" - if model not in models: + model = "meta-llama/Llama-2-70b-chat-hf" + elif model not in models: raise ValueError(f"Model are not supported: {model}") version = models[model]["version"] headers = { @@ -54,7 +53,7 @@ class Llama2(AsyncGeneratorProvider): "systemPrompt": kwargs.get("system_message", "You are a helpful assistant."), "temperature": kwargs.get("temperature", 0.75), "topP": kwargs.get("top_p", 0.9), - "maxTokens": kwargs.get("max_tokens", 1024), + "maxTokens": kwargs.get("max_tokens", 8000), "image": None } started = False @@ -68,9 +67,9 @@ class Llama2(AsyncGeneratorProvider): def format_prompt(messages: Messages): messages = [ - f"[INST]{message['content']}[/INST]" + f"[INST] {message['content']} [/INST]" if message["role"] == "user" else message["content"] for message in messages ] - return "\n".join(messages)
\ No newline at end of file + return "\n".join(messages) + "\n"
\ No newline at end of file |