diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/client/async_client.py | 90 |
1 files changed, 75 insertions, 15 deletions
diff --git a/g4f/client/async_client.py b/g4f/client/async_client.py index 07ad3357..1508e566 100644 --- a/g4f/client/async_client.py +++ b/g4f/client/async_client.py @@ -3,6 +3,9 @@ from __future__ import annotations import time import random import string +import asyncio +import base64 +from aiohttp import ClientSession, BaseConnector from .types import Client as BaseClient from .types import ProviderType, FinishReason @@ -11,9 +14,11 @@ from .types import AsyncIterResponse, ImageProvider from .image_models import ImageModels from .helper import filter_json, find_stop, filter_none, cast_iter_async from .service import get_last_provider, get_model_and_provider +from ..Provider import ProviderUtils from ..typing import Union, Messages, AsyncIterator, ImageType -from ..errors import NoImageResponseError -from ..image import ImageResponse as ImageProviderResponse +from ..errors import NoImageResponseError, ProviderNotFoundError +from ..requests.aiohttp import get_connector +from ..image import ImageResponse as ImageProviderResponse, ImageDataResponse try: anext @@ -156,12 +161,28 @@ class Chat(): def __init__(self, client: AsyncClient, provider: ProviderType = None): self.completions = Completions(client, provider) -async def iter_image_response(response: AsyncIterator) -> Union[ImagesResponse, None]: +async def iter_image_response( + response: AsyncIterator, + response_format: str = None, + connector: BaseConnector = None, + proxy: str = None +) -> Union[ImagesResponse, None]: async for chunk in response: if isinstance(chunk, ImageProviderResponse): - return ImagesResponse([Image(image) for image in chunk.get_list()]) + if response_format == "b64_json": + async with ClientSession( + connector=get_connector(connector, proxy) + ) as session: + async def fetch_image(image): + async with session.get(image) as response: + return base64.b64encode(await response.content.read()).decode() + images = await asyncio.gather(*[fetch_image(image) for image in chunk.get_list()]) + return ImagesResponse([Image(None, image, chunk.alt) for image in images], int(time.time())) + return ImagesResponse([Image(image, None, chunk.alt) for image in chunk.get_list()], int(time.time())) + elif isinstance(chunk, ImageDataResponse): + return ImagesResponse([Image(None, image, chunk.alt) for image in chunk.get_list()], int(time.time())) -def create_image(client: AsyncClient, provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator: +def create_image(provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator: prompt = f"create a image with: {prompt}" if provider.__name__ == "You": kwargs["chat_mode"] = "create" @@ -169,7 +190,6 @@ def create_image(client: AsyncClient, provider: ProviderType, prompt: str, model model, [{"role": "user", "content": prompt}], stream=True, - proxy=client.get_proxy(), **kwargs ) @@ -179,31 +199,71 @@ class Images(): self.provider: ImageProvider = provider self.models: ImageModels = ImageModels(client) - async def generate(self, prompt, model: str = "", **kwargs) -> ImagesResponse: - provider = self.models.get(model, self.provider) + def get_provider(self, model: str, provider: ProviderType = None): + if isinstance(provider, str): + if provider in ProviderUtils.convert: + provider = ProviderUtils.convert[provider] + else: + raise ProviderNotFoundError(f'Provider not found: {provider}') + else: + provider = self.models.get(model, self.provider) + return provider + + async def generate( + self, + prompt, + model: str = "", + provider: ProviderType = None, + response_format: str = None, + connector: BaseConnector = None, + proxy: str = None, + **kwargs + ) -> ImagesResponse: + provider = self.get_provider(model, provider) if hasattr(provider, "create_async_generator"): - response = create_image(self.client, provider, prompt, **kwargs) + response = create_image( + provider, + prompt, + **filter_none( + response_format=response_format, + connector=connector, + proxy=self.client.get_proxy() if proxy is None else proxy, + ), + **kwargs + ) else: response = await provider.create_async(prompt) return ImagesResponse([Image(image) for image in response.get_list()]) - image = await iter_image_response(response) + image = await iter_image_response(response, response_format, connector, proxy) if image is None: raise NoImageResponseError() return image - async def create_variation(self, image: ImageType, model: str = None, **kwargs): - provider = self.models.get(model, self.provider) + async def create_variation( + self, + image: ImageType, + model: str = None, + response_format: str = None, + connector: BaseConnector = None, + proxy: str = None, + **kwargs + ): + provider = self.get_provider(model, provider) result = None if hasattr(provider, "create_async_generator"): response = provider.create_async_generator( "", [{"role": "user", "content": "create a image like this"}], - True, + stream=True, image=image, - proxy=self.client.get_proxy(), + **filter_none( + response_format=response_format, + connector=connector, + proxy=self.client.get_proxy() if proxy is None else proxy, + ), **kwargs ) - result = iter_image_response(response) + result = iter_image_response(response, response_format, connector, proxy) if result is None: raise NoImageResponseError() return result |