diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/client.py | 37 |
1 files changed, 24 insertions, 13 deletions
diff --git a/g4f/client.py b/g4f/client.py index 4e5394b7..b44a5230 100644 --- a/g4f/client.py +++ b/g4f/client.py @@ -2,6 +2,9 @@ from __future__ import annotations import re import os +import time +import random +import string from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse from .typing import Union, Generator, Messages, ImageType @@ -10,10 +13,11 @@ from .image import ImageResponse as ImageProviderResponse from .Provider.BingCreateImages import BingCreateImages from .Provider.needs_auth import Gemini, OpenaiChat from .errors import NoImageResponseError -from . import get_model_and_provider +from . import get_model_and_provider, get_last_provider ImageProvider = Union[BaseProvider, object] Proxies = Union[dict, str] +IterResponse = Generator[ChatCompletion | ChatCompletionChunk, None, None] def read_json(text: str) -> dict: """ @@ -31,18 +35,16 @@ def read_json(text: str) -> dict: return text def iter_response( - response: iter, + response: iter[str], stream: bool, response_format: dict = None, max_tokens: int = None, stop: list = None -) -> Generator: +) -> IterResponse: content = "" finish_reason = None - last_chunk = None + completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28)) for idx, chunk in enumerate(response): - if last_chunk is not None: - yield ChatCompletionChunk(last_chunk, finish_reason) content += str(chunk) if max_tokens is not None and idx + 1 >= max_tokens: finish_reason = "length" @@ -63,16 +65,25 @@ def iter_response( if first != -1: finish_reason = "stop" if stream: - last_chunk = chunk + yield ChatCompletionChunk(chunk, None, completion_id, int(time.time())) if finish_reason is not None: break - if last_chunk is not None: - yield ChatCompletionChunk(last_chunk, finish_reason) - if not stream: + finish_reason = "stop" if finish_reason is None else finish_reason + if stream: + yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time())) + else: if response_format is not None and "type" in response_format: if response_format["type"] == "json_object": content = read_json(content) - yield ChatCompletion(content, finish_reason) + yield ChatCompletion(content, finish_reason, completion_id, int(time.time())) + +def iter_append_model_and_provider(response: IterResponse) -> IterResponse: + last_provider = None + for chunk in response: + last_provider = get_last_provider(True) if last_provider is None else last_provider + chunk.model = last_provider.get("model") + chunk.provider = last_provider.get("name") + yield chunk class Client(): proxies: Proxies = None @@ -113,7 +124,7 @@ class Completions(): stream: bool = False, response_format: dict = None, max_tokens: int = None, - stop: Union[list. str] = None, + stop: list[str] | str = None, **kwargs ) -> Union[ChatCompletion, Generator[ChatCompletionChunk]]: if max_tokens is not None: @@ -128,7 +139,7 @@ class Completions(): ) response = provider.create_completion(model, messages, stream=stream, proxy=self.client.get_proxy(), **kwargs) stop = [stop] if isinstance(stop, str) else stop - response = iter_response(response, stream, response_format, max_tokens, stop) + response = iter_append_model_and_provider(iter_response(response, stream, response_format, max_tokens, stop)) return response if stream else next(response) class Chat(): |