diff options
-rw-r--r-- | g4f/Provider/AIUncensored.py | 80 |
1 files changed, 42 insertions, 38 deletions
diff --git a/g4f/Provider/AIUncensored.py b/g4f/Provider/AIUncensored.py index db3aa6cd..c2f0f4b3 100644 --- a/g4f/Provider/AIUncensored.py +++ b/g4f/Provider/AIUncensored.py @@ -2,9 +2,9 @@ from __future__ import annotations import json import random -import logging from aiohttp import ClientSession, ClientError -from typing import List +import asyncio +from itertools import cycle from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider, ProviderModelMixin @@ -38,27 +38,9 @@ class AIUncensored(AsyncGeneratorProvider, ProviderModelMixin): @staticmethod def generate_cipher() -> str: + """Generate a cipher in format like '3221229284179118'""" return ''.join([str(random.randint(0, 9)) for _ in range(16)]) - @staticmethod - async def try_request(session: ClientSession, endpoints: List[str], data: dict, proxy: str = None): - available_endpoints = endpoints.copy() - random.shuffle(available_endpoints) - - while available_endpoints: - endpoint = available_endpoints.pop() - try: - async with session.post(endpoint, json=data, proxy=proxy) as response: - response.raise_for_status() - return response - except ClientError as e: - logging.warning(f"Failed to connect to {endpoint}: {str(e)}") - if not available_endpoints: - raise - continue - - raise Exception("All endpoints are unavailable") - @classmethod def get_model(cls, model: str) -> str: if model in cls.models: @@ -103,26 +85,48 @@ class AIUncensored(AsyncGeneratorProvider, ProviderModelMixin): "prompt": prompt, "cipher": cls.generate_cipher() } - response = await cls.try_request(session, cls.api_endpoints_image, data, proxy) - response_data = await response.json() - image_url = response_data['image_url'] - image_response = ImageResponse(images=image_url, alt=prompt) - yield image_response + endpoints = cycle(cls.api_endpoints_image) + + while True: + endpoint = next(endpoints) + try: + async with session.post(endpoint, json=data, proxy=proxy, timeout=10) as response: + response.raise_for_status() + response_data = await response.json() + image_url = response_data['image_url'] + image_response = ImageResponse(images=image_url, alt=prompt) + yield image_response + return + except (ClientError, asyncio.TimeoutError): + continue + elif model in cls.text_models: data = { "messages": messages, "cipher": cls.generate_cipher() } - response = await cls.try_request(session, cls.api_endpoints_text, data, proxy) - async for line in response.content: - line = line.decode('utf-8') - if line.startswith("data: "): - try: - json_str = line[6:] - if json_str != "[DONE]": - data = json.loads(json_str) - if "data" in data: - yield data["data"] - except json.JSONDecodeError: - continue + + endpoints = cycle(cls.api_endpoints_text) + + while True: + endpoint = next(endpoints) + try: + async with session.post(endpoint, json=data, proxy=proxy, timeout=10) as response: + response.raise_for_status() + full_response = "" + async for line in response.content: + line = line.decode('utf-8') + if line.startswith("data: "): + try: + json_str = line[6:] + if json_str != "[DONE]": + data = json.loads(json_str) + if "data" in data: + full_response += data["data"] + yield data["data"] + except json.JSONDecodeError: + continue + return + except (ClientError, asyncio.TimeoutError): + continue |