From d41f599adb3e4a8816ec62c4fd014cb3005aabf0 Mon Sep 17 00:00:00 2001 From: kqlio67 Date: Thu, 7 Nov 2024 21:27:26 +0200 Subject: refactor(g4f/Provider/AIUncensored.py): Enhance robustness and add features --- g4f/Provider/AIUncensored.py | 86 +++++++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 37 deletions(-) (limited to 'g4f') diff --git a/g4f/Provider/AIUncensored.py b/g4f/Provider/AIUncensored.py index ce492b38..db3aa6cd 100644 --- a/g4f/Provider/AIUncensored.py +++ b/g4f/Provider/AIUncensored.py @@ -1,17 +1,17 @@ from __future__ import annotations import json -from aiohttp import ClientSession -from itertools import cycle +import random +import logging +from aiohttp import ClientSession, ClientError +from typing import List from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider, ProviderModelMixin -from .helper import format_prompt from ..image import ImageResponse - class AIUncensored(AsyncGeneratorProvider, ProviderModelMixin): - url = "https://www.aiuncensored.info" + url = "https://www.aiuncensored.info/ai_uncensored" api_endpoints_text = [ "https://twitterclone-i0wr.onrender.com/api/chat", "https://twitterclone-4e8t.onrender.com/api/chat", @@ -22,8 +22,6 @@ class AIUncensored(AsyncGeneratorProvider, ProviderModelMixin): "https://twitterclone-i0wr.onrender.com/api/image", "https://twitterclone-8wd1.onrender.com/api/image", ] - api_endpoints_cycle_text = cycle(api_endpoints_text) - api_endpoints_cycle_image = cycle(api_endpoints_image) working = True supports_stream = True supports_system_message = True @@ -35,10 +33,32 @@ class AIUncensored(AsyncGeneratorProvider, ProviderModelMixin): models = [*text_models, *image_models] model_aliases = { - #"": "TextGenerations", "flux": "ImageGenerations", } + @staticmethod + def generate_cipher() -> str: + return ''.join([str(random.randint(0, 9)) for _ in range(16)]) + + @staticmethod + async def try_request(session: ClientSession, endpoints: List[str], data: dict, proxy: str = None): + available_endpoints = endpoints.copy() + random.shuffle(available_endpoints) + + while available_endpoints: + endpoint = available_endpoints.pop() + try: + async with session.post(endpoint, json=data, proxy=proxy) as response: + response.raise_for_status() + return response + except ClientError as e: + logging.warning(f"Failed to connect to {endpoint}: {str(e)}") + if not available_endpoints: + raise + continue + + raise Exception("All endpoints are unavailable") + @classmethod def get_model(cls, model: str) -> str: if model in cls.models: @@ -81,36 +101,28 @@ class AIUncensored(AsyncGeneratorProvider, ProviderModelMixin): prompt = messages[-1]['content'] data = { "prompt": prompt, + "cipher": cls.generate_cipher() } - api_endpoint = next(cls.api_endpoints_cycle_image) - async with session.post(api_endpoint, json=data, proxy=proxy) as response: - response.raise_for_status() - response_data = await response.json() - image_url = response_data['image_url'] - image_response = ImageResponse(images=image_url, alt=prompt) - yield image_response + response = await cls.try_request(session, cls.api_endpoints_image, data, proxy) + response_data = await response.json() + image_url = response_data['image_url'] + image_response = ImageResponse(images=image_url, alt=prompt) + yield image_response + elif model in cls.text_models: data = { - "messages": [ - { - "role": "user", - "content": format_prompt(messages) - } - ] + "messages": messages, + "cipher": cls.generate_cipher() } - api_endpoint = next(cls.api_endpoints_cycle_text) - async with session.post(api_endpoint, json=data, proxy=proxy) as response: - response.raise_for_status() - full_response = "" - async for line in response.content: - line = line.decode('utf-8') - if line.startswith("data: "): - try: - json_str = line[6:] - if json_str != "[DONE]": - data = json.loads(json_str) - if "data" in data: - full_response += data["data"] - yield data["data"] - except json.JSONDecodeError: - continue + response = await cls.try_request(session, cls.api_endpoints_text, data, proxy) + async for line in response.content: + line = line.decode('utf-8') + if line.startswith("data: "): + try: + json_str = line[6:] + if json_str != "[DONE]": + data = json.loads(json_str) + if "data" in data: + yield data["data"] + except json.JSONDecodeError: + continue -- cgit v1.2.3