summaryrefslogtreecommitdiffstats
path: root/g4f
diff options
context:
space:
mode:
Diffstat (limited to 'g4f')
-rw-r--r--g4f/Provider/AIUncensored.py86
1 files changed, 49 insertions, 37 deletions
diff --git a/g4f/Provider/AIUncensored.py b/g4f/Provider/AIUncensored.py
index ce492b38..db3aa6cd 100644
--- a/g4f/Provider/AIUncensored.py
+++ b/g4f/Provider/AIUncensored.py
@@ -1,17 +1,17 @@
from __future__ import annotations
import json
-from aiohttp import ClientSession
-from itertools import cycle
+import random
+import logging
+from aiohttp import ClientSession, ClientError
+from typing import List
from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
-from .helper import format_prompt
from ..image import ImageResponse
-
class AIUncensored(AsyncGeneratorProvider, ProviderModelMixin):
- url = "https://www.aiuncensored.info"
+ url = "https://www.aiuncensored.info/ai_uncensored"
api_endpoints_text = [
"https://twitterclone-i0wr.onrender.com/api/chat",
"https://twitterclone-4e8t.onrender.com/api/chat",
@@ -22,8 +22,6 @@ class AIUncensored(AsyncGeneratorProvider, ProviderModelMixin):
"https://twitterclone-i0wr.onrender.com/api/image",
"https://twitterclone-8wd1.onrender.com/api/image",
]
- api_endpoints_cycle_text = cycle(api_endpoints_text)
- api_endpoints_cycle_image = cycle(api_endpoints_image)
working = True
supports_stream = True
supports_system_message = True
@@ -35,10 +33,32 @@ class AIUncensored(AsyncGeneratorProvider, ProviderModelMixin):
models = [*text_models, *image_models]
model_aliases = {
- #"": "TextGenerations",
"flux": "ImageGenerations",
}
+ @staticmethod
+ def generate_cipher() -> str:
+ return ''.join([str(random.randint(0, 9)) for _ in range(16)])
+
+ @staticmethod
+ async def try_request(session: ClientSession, endpoints: List[str], data: dict, proxy: str = None):
+ available_endpoints = endpoints.copy()
+ random.shuffle(available_endpoints)
+
+ while available_endpoints:
+ endpoint = available_endpoints.pop()
+ try:
+ async with session.post(endpoint, json=data, proxy=proxy) as response:
+ response.raise_for_status()
+ return response
+ except ClientError as e:
+ logging.warning(f"Failed to connect to {endpoint}: {str(e)}")
+ if not available_endpoints:
+ raise
+ continue
+
+ raise Exception("All endpoints are unavailable")
+
@classmethod
def get_model(cls, model: str) -> str:
if model in cls.models:
@@ -81,36 +101,28 @@ class AIUncensored(AsyncGeneratorProvider, ProviderModelMixin):
prompt = messages[-1]['content']
data = {
"prompt": prompt,
+ "cipher": cls.generate_cipher()
}
- api_endpoint = next(cls.api_endpoints_cycle_image)
- async with session.post(api_endpoint, json=data, proxy=proxy) as response:
- response.raise_for_status()
- response_data = await response.json()
- image_url = response_data['image_url']
- image_response = ImageResponse(images=image_url, alt=prompt)
- yield image_response
+ response = await cls.try_request(session, cls.api_endpoints_image, data, proxy)
+ response_data = await response.json()
+ image_url = response_data['image_url']
+ image_response = ImageResponse(images=image_url, alt=prompt)
+ yield image_response
+
elif model in cls.text_models:
data = {
- "messages": [
- {
- "role": "user",
- "content": format_prompt(messages)
- }
- ]
+ "messages": messages,
+ "cipher": cls.generate_cipher()
}
- api_endpoint = next(cls.api_endpoints_cycle_text)
- async with session.post(api_endpoint, json=data, proxy=proxy) as response:
- response.raise_for_status()
- full_response = ""
- async for line in response.content:
- line = line.decode('utf-8')
- if line.startswith("data: "):
- try:
- json_str = line[6:]
- if json_str != "[DONE]":
- data = json.loads(json_str)
- if "data" in data:
- full_response += data["data"]
- yield data["data"]
- except json.JSONDecodeError:
- continue
+ response = await cls.try_request(session, cls.api_endpoints_text, data, proxy)
+ async for line in response.content:
+ line = line.decode('utf-8')
+ if line.startswith("data: "):
+ try:
+ json_str = line[6:]
+ if json_str != "[DONE]":
+ data = json.loads(json_str)
+ if "data" in data:
+ yield data["data"]
+ except json.JSONDecodeError:
+ continue