summaryrefslogblamecommitdiffstats
path: root/g4f/Provider/create_images.py
blob: b8bcbde385b1c33551b937d09881c52c2745007b (plain) (tree)
1
2
3
4
5
6
7
8
9
10



                                  
                    




                                                                 
                                            





                                                                        

















                                                                                                             







                                             









                                                                                                                        



                                               
                                                      
                                         
                               

                                                       







                             















                                                                                                            













                                                                                        

                                                                         














                                                             














                                                                                                             

                                                                              
                                                                      
                    


                                               

                                                                 

                                                                

                                                
                                          


                                                            
                      
from __future__ import annotations

import re
import asyncio
from .. import debug
from ..typing import CreateResult, Messages
from ..base_provider import BaseProvider, ProviderType

system_message = """
You can generate custom images with the DALL-E 3 image generator.
To generate an image with a prompt, do this:
<img data-prompt=\"keywords for the image\">
Don't use images with data uri. It is important to use a prompt instead.
<img data-prompt=\"image caption\">
"""

class CreateImagesProvider(BaseProvider):
    """
    Provider class for creating images based on text prompts.

    This provider handles image creation requests embedded within message content, 
    using provided image creation functions.

    Attributes:
        provider (ProviderType): The underlying provider to handle non-image related tasks.
        create_images (callable): A function to create images synchronously.
        create_images_async (callable): A function to create images asynchronously.
        system_message (str): A message that explains the image creation capability.
        include_placeholder (bool): Flag to determine whether to include the image placeholder in the output.
        __name__ (str): Name of the provider.
        url (str): URL of the provider.
        working (bool): Indicates if the provider is operational.
        supports_stream (bool): Indicates if the provider supports streaming.
    """

    def __init__(
        self,
        provider: ProviderType,
        create_images: callable,
        create_async: callable,
        system_message: str = system_message,
        include_placeholder: bool = True
    ) -> None:
        """
        Initializes the CreateImagesProvider.

        Args:
            provider (ProviderType): The underlying provider.
            create_images (callable): Function to create images synchronously.
            create_async (callable): Function to create images asynchronously.
            system_message (str, optional): System message to be prefixed to messages. Defaults to a predefined message.
            include_placeholder (bool, optional): Whether to include image placeholders in the output. Defaults to True.
        """
        self.provider = provider
        self.create_images = create_images
        self.create_images_async = create_async
        self.system_message = system_message
        self.include_placeholder = include_placeholder
        self.__name__ = provider.__name__
        self.url = provider.url
        self.working = provider.working
        self.supports_stream = provider.supports_stream

    def create_completion(
        self,
        model: str,
        messages: Messages,
        stream: bool = False,
        **kwargs
    ) -> CreateResult:
        """
        Creates a completion result, processing any image creation prompts found within the messages.

        Args:
            model (str): The model to use for creation.
            messages (Messages): The messages to process, which may contain image prompts.
            stream (bool, optional): Indicates whether to stream the results. Defaults to False.
            **kwargs: Additional keywordarguments for the provider.

        Yields:
            CreateResult: Yields chunks of the processed messages, including image data if applicable.

        Note:
            This method processes messages to detect image creation prompts. When such a prompt is found, 
            it calls the synchronous image creation function and includes the resulting image in the output.
        """
        messages.insert(0, {"role": "system", "content": self.system_message})
        buffer = ""
        for chunk in self.provider.create_completion(model, messages, stream, **kwargs):
            if buffer or "<" in chunk:
                buffer += chunk
                if ">" in buffer:
                    match = re.search(r'<img data-prompt="(.*?)">', buffer)
                    if match:
                        placeholder, prompt = match.group(0), match.group(1)
                        start, append = buffer.split(placeholder, 1)
                        if start:
                            yield start
                        if self.include_placeholder:
                            yield placeholder
                        if debug.logging:
                            print(f"Create images with prompt: {prompt}")
                        yield from self.create_images(prompt)
                        if append:
                            yield append
                    else:
                        yield buffer
                    buffer = ""
            else:
                yield chunk

    async def create_async(
        self,
        model: str,
        messages: Messages,
        **kwargs
    ) -> str:
        """
        Asynchronously creates a response, processing any image creation prompts found within the messages.

        Args:
            model (str): The model to use for creation.
            messages (Messages): The messages to process, which may contain image prompts.
            **kwargs: Additional keyword arguments for the provider.

        Returns:
            str: The processed response string, including asynchronously generated image data if applicable.

        Note:
            This method processes messages to detect image creation prompts. When such a prompt is found, 
            it calls the asynchronous image creation function and includes the resulting image in the output.
        """
        messages.insert(0, {"role": "system", "content": self.system_message})
        response = await self.provider.create_async(model, messages, **kwargs)
        matches = re.findall(r'(<img data-prompt="(.*?)">)', response)
        results = []
        placeholders = []
        for placeholder, prompt in matches:
            if placeholder not in placeholders:
                if debug.logging:
                    print(f"Create images with prompt: {prompt}")
                results.append(self.create_images_async(prompt))
                placeholders.append(placeholder)
        results = await asyncio.gather(*results)
        for idx, result in enumerate(results):
            placeholder = placeholder[idx]
            if self.include_placeholder:
                result = placeholder + result
            response = response.replace(placeholder, result)
        return response