summaryrefslogtreecommitdiffstats
path: root/g4f
diff options
context:
space:
mode:
Diffstat (limited to 'g4f')
-rw-r--r--g4f/client/async_client.py151
1 files changed, 95 insertions, 56 deletions
diff --git a/g4f/client/async_client.py b/g4f/client/async_client.py
index 9caa74b2..b4d52a60 100644
--- a/g4f/client/async_client.py
+++ b/g4f/client/async_client.py
@@ -33,6 +33,12 @@ except NameError:
except StopAsyncIteration:
raise StopIteration
+async def safe_aclose(generator):
+ try:
+ await generator.aclose()
+ except Exception as e:
+ logging.warning(f"Error while closing generator: {e}")
+
async def iter_response(
response: AsyncIterator[str],
stream: bool,
@@ -45,48 +51,56 @@ async def iter_response(
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
idx = 0
- async for chunk in response:
- if isinstance(chunk, FinishReason):
- finish_reason = chunk.reason
- break
- elif isinstance(chunk, BaseConversation):
- yield chunk
- continue
+ try:
+ async for chunk in response:
+ if isinstance(chunk, FinishReason):
+ finish_reason = chunk.reason
+ break
+ elif isinstance(chunk, BaseConversation):
+ yield chunk
+ continue
- content += str(chunk)
- idx += 1
+ content += str(chunk)
+ idx += 1
- if max_tokens is not None and idx >= max_tokens:
- finish_reason = "length"
+ if max_tokens is not None and idx >= max_tokens:
+ finish_reason = "length"
- first, content, chunk = find_stop(stop, content, chunk if stream else None)
+ first, content, chunk = find_stop(stop, content, chunk if stream else None)
- if first != -1:
- finish_reason = "stop"
+ if first != -1:
+ finish_reason = "stop"
- if stream:
- yield ChatCompletionChunk(chunk, None, completion_id, int(time.time()))
+ if stream:
+ yield ChatCompletionChunk(chunk, None, completion_id, int(time.time()))
- if finish_reason is not None:
- break
+ if finish_reason is not None:
+ break
- finish_reason = "stop" if finish_reason is None else finish_reason
+ finish_reason = "stop" if finish_reason is None else finish_reason
- if stream:
- yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time()))
- else:
- if response_format is not None and "type" in response_format:
- if response_format["type"] == "json_object":
- content = filter_json(content)
- yield ChatCompletion(content, finish_reason, completion_id, int(time.time()))
+ if stream:
+ yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time()))
+ else:
+ if response_format is not None and "type" in response_format:
+ if response_format["type"] == "json_object":
+ content = filter_json(content)
+ yield ChatCompletion(content, finish_reason, completion_id, int(time.time()))
+ finally:
+ if hasattr(response, 'aclose'):
+ await safe_aclose(response)
async def iter_append_model_and_provider(response: AsyncIterator) -> AsyncIterator:
last_provider = None
- async for chunk in response:
- last_provider = get_last_provider(True) if last_provider is None else last_provider
- chunk.model = last_provider.get("model")
- chunk.provider = last_provider.get("name")
- yield chunk
+ try:
+ async for chunk in response:
+ last_provider = get_last_provider(True) if last_provider is None else last_provider
+ chunk.model = last_provider.get("model")
+ chunk.provider = last_provider.get("name")
+ yield chunk
+ finally:
+ if hasattr(response, 'aclose'):
+ await safe_aclose(response)
class AsyncClient(BaseClient):
def __init__(
@@ -158,8 +172,6 @@ class Completions:
response = iter_append_model_and_provider(response)
return response if stream else await anext(response)
-
-
class Chat:
completions: Completions
@@ -168,14 +180,18 @@ class Chat:
async def iter_image_response(response: AsyncIterator) -> Union[ImagesResponse, None]:
logging.info("Starting iter_image_response")
- async for chunk in response:
- logging.info(f"Processing chunk: {chunk}")
- if isinstance(chunk, ImageProviderResponse):
- logging.info("Found ImageProviderResponse")
- return ImagesResponse([Image(image) for image in chunk.get_list()])
-
- logging.warning("No ImageProviderResponse found in the response")
- return None
+ try:
+ async for chunk in response:
+ logging.info(f"Processing chunk: {chunk}")
+ if isinstance(chunk, ImageProviderResponse):
+ logging.info("Found ImageProviderResponse")
+ return ImagesResponse([Image(image) for image in chunk.get_list()])
+
+ logging.warning("No ImageProviderResponse found in the response")
+ return None
+ finally:
+ if hasattr(response, 'aclose'):
+ await safe_aclose(response)
async def create_image(client: AsyncClient, provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator:
logging.info(f"Creating image with provider: {provider}, model: {model}, prompt: {prompt}")
@@ -220,12 +236,25 @@ class Images:
if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
logging.info("Using AsyncGeneratorProvider")
messages = [{"role": "user", "content": prompt}]
- async for response in provider.create_async_generator(model, messages, **kwargs):
- if isinstance(response, ImageResponse):
- return self._process_image_response(response)
- elif isinstance(response, str):
- image_response = ImageResponse([response], prompt)
- return self._process_image_response(image_response)
+ generator = None
+ try:
+ generator = provider.create_async_generator(model, messages, **kwargs)
+ async for response in generator:
+ logging.debug(f"Received response: {type(response)}")
+ if isinstance(response, ImageResponse):
+ return self._process_image_response(response)
+ elif isinstance(response, str):
+ image_response = ImageResponse([response], prompt)
+ return self._process_image_response(image_response)
+ except RuntimeError as e:
+ if "async generator ignored GeneratorExit" in str(e):
+ logging.warning("Generator ignored GeneratorExit, handling gracefully")
+ else:
+ raise
+ finally:
+ if generator and hasattr(generator, 'aclose'):
+ await safe_aclose(generator)
+ logging.info("AsyncGeneratorProvider processing completed")
elif hasattr(provider, 'create'):
logging.info("Using provider's create method")
async_create = asyncio.iscoroutinefunction(provider.create)
@@ -241,7 +270,7 @@ class Images:
return self._process_image_response(image_response)
elif hasattr(provider, 'create_completion'):
logging.info("Using provider's create_completion method")
- response = await create_image(provider, prompt, model, **kwargs)
+ response = await create_image(self.client, provider, prompt, model, **kwargs)
async for chunk in response:
if isinstance(chunk, ImageProviderResponse):
logging.info("Found ImageProviderResponse")
@@ -277,12 +306,24 @@ class Images:
if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
messages = [{"role": "user", "content": "create a variation of this image"}]
image_data = to_data_uri(image)
- async for response in provider.create_async_generator(model, messages, image=image_data, **kwargs):
- if isinstance(response, ImageResponse):
- return self._process_image_response(response)
- elif isinstance(response, str):
- image_response = ImageResponse([response], "Image variation")
- return self._process_image_response(image_response)
+ generator = None
+ try:
+ generator = provider.create_async_generator(model, messages, image=image_data, **kwargs)
+ async for response in generator:
+ if isinstance(response, ImageResponse):
+ return self._process_image_response(response)
+ elif isinstance(response, str):
+ image_response = ImageResponse([response], "Image variation")
+ return self._process_image_response(image_response)
+ except RuntimeError as e:
+ if "async generator ignored GeneratorExit" in str(e):
+ logging.warning("Generator ignored GeneratorExit in create_variation, handling gracefully")
+ else:
+ raise
+ finally:
+ if generator and hasattr(generator, 'aclose'):
+ await safe_aclose(generator)
+ logging.info("AsyncGeneratorProvider processing completed in create_variation")
elif hasattr(provider, 'create_variation'):
if asyncio.iscoroutinefunction(provider.create_variation):
response = await provider.create_variation(image, **kwargs)
@@ -296,5 +337,3 @@ class Images:
return self._process_image_response(image_response)
else:
raise ValueError(f"Provider {provider} does not support image variation")
-
- raise NoImageResponseError("Failed to create image variation")