summaryrefslogtreecommitdiffstats
path: root/g4f/client/async_client.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/client/async_client.py90
1 files changed, 75 insertions, 15 deletions
diff --git a/g4f/client/async_client.py b/g4f/client/async_client.py
index 07ad3357..1508e566 100644
--- a/g4f/client/async_client.py
+++ b/g4f/client/async_client.py
@@ -3,6 +3,9 @@ from __future__ import annotations
import time
import random
import string
+import asyncio
+import base64
+from aiohttp import ClientSession, BaseConnector
from .types import Client as BaseClient
from .types import ProviderType, FinishReason
@@ -11,9 +14,11 @@ from .types import AsyncIterResponse, ImageProvider
from .image_models import ImageModels
from .helper import filter_json, find_stop, filter_none, cast_iter_async
from .service import get_last_provider, get_model_and_provider
+from ..Provider import ProviderUtils
from ..typing import Union, Messages, AsyncIterator, ImageType
-from ..errors import NoImageResponseError
-from ..image import ImageResponse as ImageProviderResponse
+from ..errors import NoImageResponseError, ProviderNotFoundError
+from ..requests.aiohttp import get_connector
+from ..image import ImageResponse as ImageProviderResponse, ImageDataResponse
try:
anext
@@ -156,12 +161,28 @@ class Chat():
def __init__(self, client: AsyncClient, provider: ProviderType = None):
self.completions = Completions(client, provider)
-async def iter_image_response(response: AsyncIterator) -> Union[ImagesResponse, None]:
+async def iter_image_response(
+ response: AsyncIterator,
+ response_format: str = None,
+ connector: BaseConnector = None,
+ proxy: str = None
+) -> Union[ImagesResponse, None]:
async for chunk in response:
if isinstance(chunk, ImageProviderResponse):
- return ImagesResponse([Image(image) for image in chunk.get_list()])
+ if response_format == "b64_json":
+ async with ClientSession(
+ connector=get_connector(connector, proxy)
+ ) as session:
+ async def fetch_image(image):
+ async with session.get(image) as response:
+ return base64.b64encode(await response.content.read()).decode()
+ images = await asyncio.gather(*[fetch_image(image) for image in chunk.get_list()])
+ return ImagesResponse([Image(None, image, chunk.alt) for image in images], int(time.time()))
+ return ImagesResponse([Image(image, None, chunk.alt) for image in chunk.get_list()], int(time.time()))
+ elif isinstance(chunk, ImageDataResponse):
+ return ImagesResponse([Image(None, image, chunk.alt) for image in chunk.get_list()], int(time.time()))
-def create_image(client: AsyncClient, provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator:
+def create_image(provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator:
prompt = f"create a image with: {prompt}"
if provider.__name__ == "You":
kwargs["chat_mode"] = "create"
@@ -169,7 +190,6 @@ def create_image(client: AsyncClient, provider: ProviderType, prompt: str, model
model,
[{"role": "user", "content": prompt}],
stream=True,
- proxy=client.get_proxy(),
**kwargs
)
@@ -179,31 +199,71 @@ class Images():
self.provider: ImageProvider = provider
self.models: ImageModels = ImageModels(client)
- async def generate(self, prompt, model: str = "", **kwargs) -> ImagesResponse:
- provider = self.models.get(model, self.provider)
+ def get_provider(self, model: str, provider: ProviderType = None):
+ if isinstance(provider, str):
+ if provider in ProviderUtils.convert:
+ provider = ProviderUtils.convert[provider]
+ else:
+ raise ProviderNotFoundError(f'Provider not found: {provider}')
+ else:
+ provider = self.models.get(model, self.provider)
+ return provider
+
+ async def generate(
+ self,
+ prompt,
+ model: str = "",
+ provider: ProviderType = None,
+ response_format: str = None,
+ connector: BaseConnector = None,
+ proxy: str = None,
+ **kwargs
+ ) -> ImagesResponse:
+ provider = self.get_provider(model, provider)
if hasattr(provider, "create_async_generator"):
- response = create_image(self.client, provider, prompt, **kwargs)
+ response = create_image(
+ provider,
+ prompt,
+ **filter_none(
+ response_format=response_format,
+ connector=connector,
+ proxy=self.client.get_proxy() if proxy is None else proxy,
+ ),
+ **kwargs
+ )
else:
response = await provider.create_async(prompt)
return ImagesResponse([Image(image) for image in response.get_list()])
- image = await iter_image_response(response)
+ image = await iter_image_response(response, response_format, connector, proxy)
if image is None:
raise NoImageResponseError()
return image
- async def create_variation(self, image: ImageType, model: str = None, **kwargs):
- provider = self.models.get(model, self.provider)
+ async def create_variation(
+ self,
+ image: ImageType,
+ model: str = None,
+ response_format: str = None,
+ connector: BaseConnector = None,
+ proxy: str = None,
+ **kwargs
+ ):
+ provider = self.get_provider(model, provider)
result = None
if hasattr(provider, "create_async_generator"):
response = provider.create_async_generator(
"",
[{"role": "user", "content": "create a image like this"}],
- True,
+ stream=True,
image=image,
- proxy=self.client.get_proxy(),
+ **filter_none(
+ response_format=response_format,
+ connector=connector,
+ proxy=self.client.get_proxy() if proxy is None else proxy,
+ ),
**kwargs
)
- result = iter_image_response(response)
+ result = iter_image_response(response, response_format, connector, proxy)
if result is None:
raise NoImageResponseError()
return result