From 8d5d522c4e5770386e7e222b371ab17cbb1030b1 Mon Sep 17 00:00:00 2001 From: kqlio67 <166700875+kqlio67@users.noreply.github.com> Date: Thu, 28 Nov 2024 16:50:24 +0000 Subject: feat(g4f): Major provider updates and new model support (#2437) * refactor(g4f/Provider/Airforce.py): Enhance Airforce provider with dynamic model fetching * refactor(g4f/Provider/Blackbox.py): Enhance Blackbox AI provider configuration and streamline code * feat(g4f/Provider/RobocodersAPI.py): Add RobocodersAPI new async chat provider * refactor(g4f/client/__init__.py): Improve provider handling in async_generate method * refactor(g4f/models.py): Update provider configurations for multiple models * refactor(g4f/Provider/Blackbox.py): Streamline model configuration and improve response handling * feat(g4f/Provider/DDG.py): Enhance model support and improve conversation handling * refactor(g4f/Provider/Copilot.py): Enhance Copilot provider with model support * refactor(g4f/Provider/AmigoChat.py): update models and improve code structure * chore(g4f/Provider/not_working/AIUncensored.): move AIUncensored to not_working directory * chore(g4f/Provider/not_working/Allyfy.py): remove Allyfy provider * Update (g4f/Provider/not_working/AIUncensored.py g4f/Provider/not_working/__init__.py) * refactor(g4f/Provider/ChatGptEs.py): Implement format_prompt for message handling * refactor(g4f/Provider/Blackbox.py): Update message formatting and improve code structure * refactor(g4f/Provider/LLMPlayground.py): Enhance text generation and error handling * refactor(g4f/Provider/needs_auth/PollinationsAI.py): move PollinationsAI to needs_auth directory * refactor(g4f/Provider/Liaobots.py): Update Liaobots provider models and aliases * feat(g4f/Provider/DeepInfraChat.py): Add new DeepInfra models and aliases * Update (g4f/Provider/__init__.py) * Update (g4f/models.py) * g4f/models.py * Update g4f/models.py * Update g4f/Provider/LLMPlayground.py * Update (g4f/models.py g4f/Provider/Airforce.py g4f/Provider/__init__.py g4f/Provider/LLMPlayground.py) * Update g4f/Provider/__init__.py * Update (g4f/Provider/Airforce.py) --------- Co-authored-by: kqlio67 --- g4f/Provider/Airforce.py | 89 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 60 insertions(+), 29 deletions(-) (limited to 'g4f/Provider/Airforce.py') diff --git a/g4f/Provider/Airforce.py b/g4f/Provider/Airforce.py index 5200d6f7..f65cd953 100644 --- a/g4f/Provider/Airforce.py +++ b/g4f/Provider/Airforce.py @@ -14,6 +14,19 @@ from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..image import ImageResponse from ..requests import StreamSession, raise_for_status +def split_message(message: str, max_length: int = 1000) -> list[str]: + """Splits the message into parts up to (max_length).""" + chunks = [] + while len(message) > max_length: + split_point = message.rfind(' ', 0, max_length) + if split_point == -1: + split_point = max_length + chunks.append(message[:split_point]) + message = message[split_point:].strip() + if message: + chunks.append(message) + return chunks + class Airforce(AsyncGeneratorProvider, ProviderModelMixin): url = "https://llmplayground.net" api_endpoint_completions = "https://api.airforce/chat/completions" @@ -84,6 +97,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin): # HuggingFaceH4 "zephyr-7b": "zephyr-7b-beta", + ### imagine ### "sdxl": "stable-diffusion-xl-base", "sdxl": "stable-diffusion-xl-lightning", @@ -125,7 +139,6 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin): "accept": "*/*", "accept-language": "en-US,en;q=0.9", "cache-control": "no-cache", - "origin": "https://llmplayground.net", "user-agent": "Mozilla/5.0" } if seed is None: @@ -167,35 +180,47 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin): "content-type": "application/json", "user-agent": "Mozilla/5.0" } + + full_message = "\n".join( + [f"{msg['role'].capitalize()}: {msg['content']}" for msg in messages] + ) + + message_chunks = split_message(full_message, max_length=1000) + async with StreamSession(headers=headers, proxy=proxy) as session: - data = { - "messages": messages, - "model": model, - "max_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - "stream": stream - } - async with session.post(cls.api_endpoint_completions, json=data) as response: - await raise_for_status(response) - content_type = response.headers.get('Content-Type', '').lower() - if 'application/json' in content_type: - json_data = await response.json() - if json_data.get("model") == "error": - raise RuntimeError(json_data['choices'][0]['message'].get('content', '')) - if stream: - async for line in response.iter_lines(): - if line: - line = line.decode('utf-8').strip() - if line.startswith("data: ") and line != "data: [DONE]": - json_data = json.loads(line[6:]) - content = json_data['choices'][0]['delta'].get('content', '') - if content: - yield cls._filter_content(content) - else: - json_data = await response.json() - content = json_data['choices'][0]['message']['content'] - yield cls._filter_content(content) + full_response = "" + for chunk in message_chunks: + data = { + "messages": [{"role": "user", "content": chunk}], + "model": model, + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "stream": stream + } + + async with session.post(cls.api_endpoint_completions, json=data) as response: + await raise_for_status(response) + content_type = response.headers.get('Content-Type', '').lower() + + if 'application/json' in content_type: + json_data = await response.json() + if json_data.get("model") == "error": + raise RuntimeError(json_data['choices'][0]['message'].get('content', '')) + if stream: + async for line in response.iter_lines(): + if line: + line = line.decode('utf-8').strip() + if line.startswith("data: ") and line != "data: [DONE]": + json_data = json.loads(line[6:]) + content = json_data['choices'][0]['delta'].get('content', '') + if content: + yield cls._filter_content(content) + else: + content = json_data['choices'][0]['message']['content'] + full_response += cls._filter_content(content) + + yield full_response @classmethod def _filter_content(cls, part_response: str) -> str: @@ -210,4 +235,10 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin): '', part_response ) + + part_response = re.sub( + r"\[ERROR\] '\w{8}-\w{4}-\w{4}-\w{4}-\w{12}'", # any-uncensored + '', + part_response + ) return part_response -- cgit v1.2.3