summaryrefslogtreecommitdiffstats
path: root/g4f
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/AIUncensored.py80
1 files changed, 42 insertions, 38 deletions
diff --git a/g4f/Provider/AIUncensored.py b/g4f/Provider/AIUncensored.py
index db3aa6cd..c2f0f4b3 100644
--- a/g4f/Provider/AIUncensored.py
+++ b/g4f/Provider/AIUncensored.py
@@ -2,9 +2,9 @@ from __future__ import annotations
import json
import random
-import logging
from aiohttp import ClientSession, ClientError
-from typing import List
+import asyncio
+from itertools import cycle
from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
@@ -38,27 +38,9 @@ class AIUncensored(AsyncGeneratorProvider, ProviderModelMixin):
@staticmethod
def generate_cipher() -> str:
+ """Generate a cipher in format like '3221229284179118'"""
return ''.join([str(random.randint(0, 9)) for _ in range(16)])
- @staticmethod
- async def try_request(session: ClientSession, endpoints: List[str], data: dict, proxy: str = None):
- available_endpoints = endpoints.copy()
- random.shuffle(available_endpoints)
-
- while available_endpoints:
- endpoint = available_endpoints.pop()
- try:
- async with session.post(endpoint, json=data, proxy=proxy) as response:
- response.raise_for_status()
- return response
- except ClientError as e:
- logging.warning(f"Failed to connect to {endpoint}: {str(e)}")
- if not available_endpoints:
- raise
- continue
-
- raise Exception("All endpoints are unavailable")
-
@classmethod
def get_model(cls, model: str) -> str:
if model in cls.models:
@@ -103,26 +85,48 @@ class AIUncensored(AsyncGeneratorProvider, ProviderModelMixin):
"prompt": prompt,
"cipher": cls.generate_cipher()
}
- response = await cls.try_request(session, cls.api_endpoints_image, data, proxy)
- response_data = await response.json()
- image_url = response_data['image_url']
- image_response = ImageResponse(images=image_url, alt=prompt)
- yield image_response
+ endpoints = cycle(cls.api_endpoints_image)
+
+ while True:
+ endpoint = next(endpoints)
+ try:
+ async with session.post(endpoint, json=data, proxy=proxy, timeout=10) as response:
+ response.raise_for_status()
+ response_data = await response.json()
+ image_url = response_data['image_url']
+ image_response = ImageResponse(images=image_url, alt=prompt)
+ yield image_response
+ return
+ except (ClientError, asyncio.TimeoutError):
+ continue
+
elif model in cls.text_models:
data = {
"messages": messages,
"cipher": cls.generate_cipher()
}
- response = await cls.try_request(session, cls.api_endpoints_text, data, proxy)
- async for line in response.content:
- line = line.decode('utf-8')
- if line.startswith("data: "):
- try:
- json_str = line[6:]
- if json_str != "[DONE]":
- data = json.loads(json_str)
- if "data" in data:
- yield data["data"]
- except json.JSONDecodeError:
- continue
+
+ endpoints = cycle(cls.api_endpoints_text)
+
+ while True:
+ endpoint = next(endpoints)
+ try:
+ async with session.post(endpoint, json=data, proxy=proxy, timeout=10) as response:
+ response.raise_for_status()
+ full_response = ""
+ async for line in response.content:
+ line = line.decode('utf-8')
+ if line.startswith("data: "):
+ try:
+ json_str = line[6:]
+ if json_str != "[DONE]":
+ data = json.loads(json_str)
+ if "data" in data:
+ full_response += data["data"]
+ yield data["data"]
+ except json.JSONDecodeError:
+ continue
+ return
+ except (ClientError, asyncio.TimeoutError):
+ continue