diff options
Diffstat (limited to 'g4f/providers/create_images.py')
-rw-r--r-- | g4f/providers/create_images.py | 155 |
1 files changed, 155 insertions, 0 deletions
diff --git a/g4f/providers/create_images.py b/g4f/providers/create_images.py new file mode 100644 index 00000000..29a2a041 --- /dev/null +++ b/g4f/providers/create_images.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import re +import asyncio + +from .. import debug +from ..typing import CreateResult, Messages +from .types import BaseProvider, ProviderType + +system_message = """ +You can generate images, pictures, photos or img with the DALL-E 3 image generator. +To generate an image with a prompt, do this: + +<img data-prompt=\"keywords for the image\"> + +Never use own image links. Don't wrap it in backticks. +It is important to use a only a img tag with a prompt. + +<img data-prompt=\"image caption\"> +""" + +class CreateImagesProvider(BaseProvider): + """ + Provider class for creating images based on text prompts. + + This provider handles image creation requests embedded within message content, + using provided image creation functions. + + Attributes: + provider (ProviderType): The underlying provider to handle non-image related tasks. + create_images (callable): A function to create images synchronously. + create_images_async (callable): A function to create images asynchronously. + system_message (str): A message that explains the image creation capability. + include_placeholder (bool): Flag to determine whether to include the image placeholder in the output. + __name__ (str): Name of the provider. + url (str): URL of the provider. + working (bool): Indicates if the provider is operational. + supports_stream (bool): Indicates if the provider supports streaming. + """ + + def __init__( + self, + provider: ProviderType, + create_images: callable, + create_async: callable, + system_message: str = system_message, + include_placeholder: bool = True + ) -> None: + """ + Initializes the CreateImagesProvider. + + Args: + provider (ProviderType): The underlying provider. + create_images (callable): Function to create images synchronously. + create_async (callable): Function to create images asynchronously. + system_message (str, optional): System message to be prefixed to messages. Defaults to a predefined message. + include_placeholder (bool, optional): Whether to include image placeholders in the output. Defaults to True. + """ + self.provider = provider + self.create_images = create_images + self.create_images_async = create_async + self.system_message = system_message + self.include_placeholder = include_placeholder + self.__name__ = provider.__name__ + self.url = provider.url + self.working = provider.working + self.supports_stream = provider.supports_stream + + def create_completion( + self, + model: str, + messages: Messages, + stream: bool = False, + **kwargs + ) -> CreateResult: + """ + Creates a completion result, processing any image creation prompts found within the messages. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process, which may contain image prompts. + stream (bool, optional): Indicates whether to stream the results. Defaults to False. + **kwargs: Additional keywordarguments for the provider. + + Yields: + CreateResult: Yields chunks of the processed messages, including image data if applicable. + + Note: + This method processes messages to detect image creation prompts. When such a prompt is found, + it calls the synchronous image creation function and includes the resulting image in the output. + """ + messages.insert(0, {"role": "system", "content": self.system_message}) + buffer = "" + for chunk in self.provider.create_completion(model, messages, stream, **kwargs): + if isinstance(chunk, str) and buffer or "<" in chunk: + buffer += chunk + if ">" in buffer: + match = re.search(r'<img data-prompt="(.*?)">', buffer) + if match: + placeholder, prompt = match.group(0), match.group(1) + start, append = buffer.split(placeholder, 1) + if start: + yield start + if self.include_placeholder: + yield placeholder + if debug.logging: + print(f"Create images with prompt: {prompt}") + yield from self.create_images(prompt) + if append: + yield append + else: + yield buffer + buffer = "" + else: + yield chunk + + async def create_async( + self, + model: str, + messages: Messages, + **kwargs + ) -> str: + """ + Asynchronously creates a response, processing any image creation prompts found within the messages. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process, which may contain image prompts. + **kwargs: Additional keyword arguments for the provider. + + Returns: + str: The processed response string, including asynchronously generated image data if applicable. + + Note: + This method processes messages to detect image creation prompts. When such a prompt is found, + it calls the asynchronous image creation function and includes the resulting image in the output. + """ + messages.insert(0, {"role": "system", "content": self.system_message}) + response = await self.provider.create_async(model, messages, **kwargs) + matches = re.findall(r'(<img data-prompt="(.*?)">)', response) + results = [] + placeholders = [] + for placeholder, prompt in matches: + if placeholder not in placeholders: + if debug.logging: + print(f"Create images with prompt: {prompt}") + results.append(self.create_images_async(prompt)) + placeholders.append(placeholder) + results = await asyncio.gather(*results) + for idx, result in enumerate(results): + placeholder = placeholder[idx] + if self.include_placeholder: + result = placeholder + result + response = response.replace(placeholder, result) + return response
\ No newline at end of file |