summaryrefslogtreecommitdiffstats
path: root/g4f/client.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/client.py54
1 files changed, 35 insertions, 19 deletions
diff --git a/g4f/client.py b/g4f/client.py
index 750c623f..c4319872 100644
--- a/g4f/client.py
+++ b/g4f/client.py
@@ -10,10 +10,12 @@ from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
from .typing import Union, Iterator, Messages, ImageType
from .providers.types import BaseProvider, ProviderType
from .image import ImageResponse as ImageProviderResponse
+from .errors import NoImageResponseError, RateLimitError, MissingAuthError
+from . import get_model_and_provider, get_last_provider
+
from .Provider.BingCreateImages import BingCreateImages
from .Provider.needs_auth import Gemini, OpenaiChat
-from .errors import NoImageResponseError
-from . import get_model_and_provider, get_last_provider
+from .Provider.You import You
ImageProvider = Union[BaseProvider, object]
Proxies = Union[dict, str]
@@ -163,6 +165,7 @@ class Chat():
class ImageModels():
gemini = Gemini
openai = OpenaiChat
+ you = You
def __init__(self, client: Client) -> None:
self.client = client
@@ -171,31 +174,44 @@ class ImageModels():
def get(self, name: str, default: ImageProvider = None) -> ImageProvider:
return getattr(self, name) if hasattr(self, name) else default or self.default
+def iter_image_response(response: Iterator) -> Union[ImagesResponse, None]:
+ for chunk in list(response):
+ if isinstance(chunk, ImageProviderResponse):
+ return ImagesResponse([Image(image) for image in chunk.get_list()])
+
+def create_image(client: Client, provider: ProviderType, prompt: str, model: str = "", **kwargs) -> Iterator:
+ prompt = f"create a image with: {prompt}"
+ return provider.create_completion(
+ model,
+ [{"role": "user", "content": prompt}],
+ True,
+ proxy=client.get_proxy(),
+ **kwargs
+ )
+
class Images():
def __init__(self, client: Client, provider: ImageProvider = None):
self.client: Client = client
self.provider: ImageProvider = provider
self.models: ImageModels = ImageModels(client)
- def generate(self, prompt, model: str = None, **kwargs):
+ def generate(self, prompt, model: str = None, **kwargs) -> ImagesResponse:
provider = self.models.get(model, self.provider)
- if isinstance(provider, BaseProvider) or isinstance(provider, type) and issubclass(provider, BaseProvider):
- prompt = f"create a image: {prompt}"
- response = provider.create_completion(
- "",
- [{"role": "user", "content": prompt}],
- True,
- proxy=self.client.get_proxy(),
- **kwargs
- )
+ if isinstance(provider, type) and issubclass(provider, BaseProvider):
+ response = create_image(self.client, provider, prompt, **kwargs)
else:
- response = provider.create(prompt)
-
- for chunk in response:
- if isinstance(chunk, ImageProviderResponse):
- images = [chunk.images] if isinstance(chunk.images, str) else chunk.images
- return ImagesResponse([Image(image) for image in images])
- raise NoImageResponseError()
+ try:
+ response = list(provider.create(prompt))
+ except (RateLimitError, MissingAuthError) as e:
+ # Fallback for default provider
+ if self.provider is None:
+ response = create_image(self.client, self.models.you, prompt, model or "dall-e", **kwargs)
+ else:
+ raise e
+ image = iter_image_response(response)
+ if image is None:
+ raise NoImageResponseError()
+ return image
def create_variation(self, image: ImageType, model: str = None, **kwargs):
provider = self.models.get(model, self.provider)