diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/client.py | 54 |
1 files changed, 35 insertions, 19 deletions
diff --git a/g4f/client.py b/g4f/client.py index 750c623f..c4319872 100644 --- a/g4f/client.py +++ b/g4f/client.py @@ -10,10 +10,12 @@ from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse from .typing import Union, Iterator, Messages, ImageType from .providers.types import BaseProvider, ProviderType from .image import ImageResponse as ImageProviderResponse +from .errors import NoImageResponseError, RateLimitError, MissingAuthError +from . import get_model_and_provider, get_last_provider + from .Provider.BingCreateImages import BingCreateImages from .Provider.needs_auth import Gemini, OpenaiChat -from .errors import NoImageResponseError -from . import get_model_and_provider, get_last_provider +from .Provider.You import You ImageProvider = Union[BaseProvider, object] Proxies = Union[dict, str] @@ -163,6 +165,7 @@ class Chat(): class ImageModels(): gemini = Gemini openai = OpenaiChat + you = You def __init__(self, client: Client) -> None: self.client = client @@ -171,31 +174,44 @@ class ImageModels(): def get(self, name: str, default: ImageProvider = None) -> ImageProvider: return getattr(self, name) if hasattr(self, name) else default or self.default +def iter_image_response(response: Iterator) -> Union[ImagesResponse, None]: + for chunk in list(response): + if isinstance(chunk, ImageProviderResponse): + return ImagesResponse([Image(image) for image in chunk.get_list()]) + +def create_image(client: Client, provider: ProviderType, prompt: str, model: str = "", **kwargs) -> Iterator: + prompt = f"create a image with: {prompt}" + return provider.create_completion( + model, + [{"role": "user", "content": prompt}], + True, + proxy=client.get_proxy(), + **kwargs + ) + class Images(): def __init__(self, client: Client, provider: ImageProvider = None): self.client: Client = client self.provider: ImageProvider = provider self.models: ImageModels = ImageModels(client) - def generate(self, prompt, model: str = None, **kwargs): + def generate(self, prompt, model: str = None, **kwargs) -> ImagesResponse: provider = self.models.get(model, self.provider) - if isinstance(provider, BaseProvider) or isinstance(provider, type) and issubclass(provider, BaseProvider): - prompt = f"create a image: {prompt}" - response = provider.create_completion( - "", - [{"role": "user", "content": prompt}], - True, - proxy=self.client.get_proxy(), - **kwargs - ) + if isinstance(provider, type) and issubclass(provider, BaseProvider): + response = create_image(self.client, provider, prompt, **kwargs) else: - response = provider.create(prompt) - - for chunk in response: - if isinstance(chunk, ImageProviderResponse): - images = [chunk.images] if isinstance(chunk.images, str) else chunk.images - return ImagesResponse([Image(image) for image in images]) - raise NoImageResponseError() + try: + response = list(provider.create(prompt)) + except (RateLimitError, MissingAuthError) as e: + # Fallback for default provider + if self.provider is None: + response = create_image(self.client, self.models.you, prompt, model or "dall-e", **kwargs) + else: + raise e + image = iter_image_response(response) + if image is None: + raise NoImageResponseError() + return image def create_variation(self, image: ImageType, model: str = None, **kwargs): provider = self.models.get(model, self.provider) |