diff options
author | kqlio67 <kqlio67@users.noreply.github.com> | 2024-09-25 20:31:27 +0200 |
---|---|---|
committer | kqlio67 <kqlio67@users.noreply.github.com> | 2024-09-25 20:31:27 +0200 |
commit | 85c95be22d571f68393be7130248dd910d06a8da (patch) | |
tree | c9aa6494db1a590908b7e5d0e39a102ad6b0b402 | |
parent | feat(g4f/client/async_client.py): enhance image and chat response handling (diff) | |
download | gpt4free-85c95be22d571f68393be7130248dd910d06a8da.tar gpt4free-85c95be22d571f68393be7130248dd910d06a8da.tar.gz gpt4free-85c95be22d571f68393be7130248dd910d06a8da.tar.bz2 gpt4free-85c95be22d571f68393be7130248dd910d06a8da.tar.lz gpt4free-85c95be22d571f68393be7130248dd910d06a8da.tar.xz gpt4free-85c95be22d571f68393be7130248dd910d06a8da.tar.zst gpt4free-85c95be22d571f68393be7130248dd910d06a8da.zip |
-rw-r--r-- | g4f/client/async_client.py | 322 |
1 files changed, 162 insertions, 160 deletions
diff --git a/g4f/client/async_client.py b/g4f/client/async_client.py index 3ac77b41..2fe4640b 100644 --- a/g4f/client/async_client.py +++ b/g4f/client/async_client.py @@ -1,27 +1,32 @@ from __future__ import annotations -import os import time import random import string -import logging import asyncio -from typing import Union, AsyncIterator -from ..providers.base_provider import AsyncGeneratorProvider -from ..image import ImageResponse, to_image, to_data_uri -from ..typing import Messages, ImageType -from ..providers.types import BaseProvider, ProviderType, FinishReason -from ..providers.conversation import BaseConversation -from ..image import ImageResponse as ImageProviderResponse -from ..errors import NoImageResponseError -from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse -from .image_models import ImageModels -from .types import IterResponse, ImageProvider +import base64 +from aiohttp import ClientSession, BaseConnector + from .types import Client as BaseClient -from .service import get_model_and_provider, get_last_provider -from .helper import find_stop, filter_json, filter_none -from ..models import ModelUtils -from ..Provider import IterListProvider +from .types import ProviderType, FinishReason +from .stubs import ChatCompletion, ChatCompletionChunk, ImagesResponse, Image +from .types import AsyncIterResponse, ImageProvider +from .image_models import ImageModels +from .helper import filter_json, find_stop, filter_none, cast_iter_async +from .service import get_last_provider, get_model_and_provider +from ..Provider import ProviderUtils +from ..typing import Union, Messages, AsyncIterator, ImageType +from ..errors import NoImageResponseError, ProviderNotFoundError +from ..requests.aiohttp import get_connector +from ..providers.conversation import BaseConversation +from ..image import ImageResponse as ImageProviderResponse, ImageDataResponse + +try: + anext +except NameError: + async def anext(iter): + async for chunk in iter: + return chunk async def iter_response( response: AsyncIterator[str], @@ -29,37 +34,30 @@ async def iter_response( response_format: dict = None, max_tokens: int = None, stop: list = None -) -> AsyncIterator[ChatCompletion | ChatCompletionChunk]: +) -> AsyncIterResponse: content = "" finish_reason = None completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28)) - - async for idx, chunk in enumerate(response): + count: int = 0 + async for chunk in response: if isinstance(chunk, FinishReason): finish_reason = chunk.reason break elif isinstance(chunk, BaseConversation): yield chunk continue - content += str(chunk) - - if max_tokens is not None and idx + 1 >= max_tokens: + count += 1 + if max_tokens is not None and count >= max_tokens: finish_reason = "length" - - first, content, chunk = find_stop(stop, content, chunk if stream else None) - + first, content, chunk = find_stop(stop, content, chunk) if first != -1: finish_reason = "stop" - if stream: yield ChatCompletionChunk(chunk, None, completion_id, int(time.time())) - if finish_reason is not None: break - finish_reason = "stop" if finish_reason is None else finish_reason - if stream: yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time())) else: @@ -68,12 +66,12 @@ async def iter_response( content = filter_json(content) yield ChatCompletion(content, finish_reason, completion_id, int(time.time())) -async def iter_append_model_and_provider(response: AsyncIterator) -> AsyncIterator: +async def iter_append_model_and_provider(response: AsyncIterResponse) -> AsyncIterResponse: last_provider = None async for chunk in response: last_provider = get_last_provider(True) if last_provider is None else last_provider chunk.model = last_provider.get("model") - chunk.provider = last_provider.get("name") + chunk.provider = last_provider.get("name") yield chunk class AsyncClient(BaseClient): @@ -82,32 +80,59 @@ class AsyncClient(BaseClient): provider: ProviderType = None, image_provider: ImageProvider = None, **kwargs - ) -> None: + ): super().__init__(**kwargs) self.chat: Chat = Chat(self, provider) - self._images: Images = Images(self, image_provider) + self.images: Images = Images(self, image_provider) - @property - def images(self) -> Images: - return self._images +def create_response( + messages: Messages, + model: str, + provider: ProviderType = None, + stream: bool = False, + proxy: str = None, + max_tokens: int = None, + stop: list[str] = None, + api_key: str = None, + **kwargs +): + has_asnyc = hasattr(provider, "create_async_generator") + if has_asnyc: + create = provider.create_async_generator + else: + create = provider.create_completion + response = create( + model, messages, + stream=stream, + **filter_none( + proxy=proxy, + max_tokens=max_tokens, + stop=stop, + api_key=api_key + ), + **kwargs + ) + if not has_asnyc: + response = cast_iter_async(response) + return response -class Completions: +class Completions(): def __init__(self, client: AsyncClient, provider: ProviderType = None): self.client: AsyncClient = client self.provider: ProviderType = provider - async def create( + def create( self, messages: Messages, model: str, provider: ProviderType = None, stream: bool = False, proxy: str = None, - response_format: dict = None, max_tokens: int = None, stop: Union[list[str], str] = None, api_key: str = None, - ignored: list[str] = None, + response_format: dict = None, + ignored : list[str] = None, ignore_working: bool = False, ignore_stream: bool = False, **kwargs @@ -118,156 +143,133 @@ class Completions: stream, ignored, ignore_working, - ignore_stream, + ignore_stream ) - stop = [stop] if isinstance(stop, str) else stop - - response = await provider.create_completion( - model, - messages, - stream=stream, - **filter_none( - proxy=self.client.get_proxy() if proxy is None else proxy, - max_tokens=max_tokens, - stop=stop, - api_key=self.client.api_key if api_key is None else api_key - ), + response = create_response( + messages, model, + provider, stream, + proxy=self.client.get_proxy() if proxy is None else proxy, + max_tokens=max_tokens, + stop=stop, + api_key=self.client.api_key if api_key is None else api_key, **kwargs ) - response = iter_response(response, stream, response_format, max_tokens, stop) response = iter_append_model_and_provider(response) - - return response if stream else await anext(response) + return response if stream else anext(response) -class Chat: +class Chat(): completions: Completions def __init__(self, client: AsyncClient, provider: ProviderType = None): self.completions = Completions(client, provider) -async def iter_image_response(response: AsyncIterator) -> Union[ImagesResponse, None]: - logging.info("Starting iter_image_response") +async def iter_image_response( + response: AsyncIterator, + response_format: str = None, + connector: BaseConnector = None, + proxy: str = None +) -> Union[ImagesResponse, None]: async for chunk in response: - logging.info(f"Processing chunk: {chunk}") if isinstance(chunk, ImageProviderResponse): - logging.info("Found ImageProviderResponse") - return ImagesResponse([Image(image) for image in chunk.get_list()]) - - logging.warning("No ImageProviderResponse found in the response") - return None - -async def create_image(client: AsyncClient, provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator: - logging.info(f"Creating image with provider: {provider}, model: {model}, prompt: {prompt}") + if response_format == "b64_json": + async with ClientSession( + connector=get_connector(connector, proxy), + cookies=chunk.options.get("cookies") + ) as session: + async def fetch_image(image): + async with session.get(image) as response: + return base64.b64encode(await response.content.read()).decode() + images = await asyncio.gather(*[fetch_image(image) for image in chunk.get_list()]) + return ImagesResponse([Image(None, image, chunk.alt) for image in images], int(time.time())) + return ImagesResponse([Image(image, None, chunk.alt) for image in chunk.get_list()], int(time.time())) + elif isinstance(chunk, ImageDataResponse): + return ImagesResponse([Image(None, image, chunk.alt) for image in chunk.get_list()], int(time.time())) +def create_image(provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator: if isinstance(provider, type) and provider.__name__ == "You": kwargs["chat_mode"] = "create" else: - prompt = f"create an image with: {prompt}" - - response = await provider.create_completion( + prompt = f"create a image with: {prompt}" + return provider.create_async_generator( model, [{"role": "user", "content": prompt}], stream=True, - proxy=client.get_proxy(), **kwargs ) - - logging.info(f"Response from create_completion: {response}") - return response -class Images: +class Images(): def __init__(self, client: AsyncClient, provider: ImageProvider = None): self.client: AsyncClient = client self.provider: ImageProvider = provider self.models: ImageModels = ImageModels(client) - async def generate(self, prompt: str, model: str = None, **kwargs) -> ImagesResponse: - logging.info(f"Starting asynchronous image generation for model: {model}, prompt: {prompt}") - provider = self.models.get(model, self.provider) - if provider is None: - raise ValueError(f"Unknown model: {model}") - - logging.info(f"Provider: {provider}") - - if isinstance(provider, IterListProvider): - if provider.providers: - provider = provider.providers[0] - logging.info(f"Using first provider from IterListProvider: {provider}") + def get_provider(self, model: str, provider: ProviderType = None): + if isinstance(provider, str): + if provider in ProviderUtils.convert: + provider = ProviderUtils.convert[provider] else: - raise ValueError(f"IterListProvider for model {model} has no providers") - - if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider): - logging.info("Using AsyncGeneratorProvider") - messages = [{"role": "user", "content": prompt}] - async for response in provider.create_async_generator(model, messages, **kwargs): - if isinstance(response, ImageResponse): - return self._process_image_response(response) - elif isinstance(response, str): - image_response = ImageResponse([response], prompt) - return self._process_image_response(image_response) - elif hasattr(provider, 'create'): - logging.info("Using provider's create method") - if asyncio.iscoroutinefunction(provider.create): - response = await provider.create(prompt) - else: - response = provider.create(prompt) - - if isinstance(response, ImageResponse): - return self._process_image_response(response) - elif isinstance(response, str): - image_response = ImageResponse([response], prompt) - return self._process_image_response(image_response) + raise ProviderNotFoundError(f'Provider not found: {provider}') else: - raise ValueError(f"Provider {provider} does not support image generation") - - logging.error(f"Unexpected response type: {type(response)}") - raise NoImageResponseError(f"Unexpected response type: {type(response)}") - - def _process_image_response(self, response: ImageResponse) -> ImagesResponse: - processed_images = [] - for image_data in response.get_list(): - if image_data.startswith('http://') or image_data.startswith('https://'): - processed_images.append(Image(url=image_data)) - else: - image = to_image(image_data) - file_name = self._save_image(image) - processed_images.append(Image(url=file_name)) - return ImagesResponse(processed_images) - - def _save_image(self, image: 'PILImage') -> str: - os.makedirs('generated_images', exist_ok=True) - file_name = f"generated_images/image_{int(time.time())}.png" - image.save(file_name) - return file_name - - async def create_variation(self, image: Union[str, bytes], model: str = None, **kwargs) -> ImagesResponse: - provider = self.models.get(model, self.provider) - if provider is None: - raise ValueError(f"Unknown model: {model}") + provider = self.models.get(model, self.provider) + return provider - if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider): - messages = [{"role": "user", "content": "create a variation of this image"}] - image_data = to_data_uri(image) - async for response in provider.create_async_generator(model, messages, image=image_data, **kwargs): - if isinstance(response, ImageResponse): - return self._process_image_response(response) - elif isinstance(response, str): - image_response = ImageResponse([response], "Image variation") - return self._process_image_response(image_response) - elif hasattr(provider, 'create_variation'): - if asyncio.iscoroutinefunction(provider.create_variation): - response = await provider.create_variation(image, **kwargs) - else: - response = provider.create_variation(image, **kwargs) - - if isinstance(response, ImageResponse): - return self._process_image_response(response) - elif isinstance(response, str): - image_response = ImageResponse([response], "Image variation") - return self._process_image_response(image_response) + async def generate( + self, + prompt, + model: str = "", + provider: ProviderType = None, + response_format: str = None, + connector: BaseConnector = None, + proxy: str = None, + **kwargs + ) -> ImagesResponse: + provider = self.get_provider(model, provider) + if hasattr(provider, "create_async_generator"): + response = create_image( + provider, + prompt, + **filter_none( + response_format=response_format, + connector=connector, + proxy=self.client.get_proxy() if proxy is None else proxy, + ), + **kwargs + ) else: - raise ValueError(f"Provider {provider} does not support image variation") + response = await provider.create_async(prompt) + return ImagesResponse([Image(image) for image in response.get_list()]) + image = await iter_image_response(response, response_format, connector, proxy) + if image is None: + raise NoImageResponseError() + return image - raise NoImageResponseError("Failed to create image variation") + async def create_variation( + self, + image: ImageType, + model: str = None, + response_format: str = None, + connector: BaseConnector = None, + proxy: str = None, + **kwargs + ): + provider = self.get_provider(model, provider) + result = None + if hasattr(provider, "create_async_generator"): + response = provider.create_async_generator( + "", + [{"role": "user", "content": "create a image like this"}], + stream=True, + image=image, + **filter_none( + response_format=response_format, + connector=connector, + proxy=self.client.get_proxy() if proxy is None else proxy, + ), + **kwargs + ) + result = iter_image_response(response, response_format, connector, proxy) + if result is None: + raise NoImageResponseError() + return result |