diff options
Diffstat (limited to 'g4f/image.py')
-rw-r--r-- | g4f/image.py | 39 |
1 files changed, 24 insertions, 15 deletions
diff --git a/g4f/image.py b/g4f/image.py index 68767155..1a4692b3 100644 --- a/g4f/image.py +++ b/g4f/image.py @@ -3,14 +3,13 @@ from __future__ import annotations import re from io import BytesIO import base64 -from .typing import ImageType, Union +from .typing import ImageType, Union, Image try: - from PIL.Image import open as open_image, new as new_image, Image + from PIL.Image import open as open_image, new as new_image from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90 has_requirements = True except ImportError: - Image = type has_requirements = False from .errors import MissingRequirementsError @@ -29,6 +28,9 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image: """ if not has_requirements: raise MissingRequirementsError('Install "pillow" package for images') + if isinstance(image, str): + is_data_uri_an_image(image) + image = extract_data_uri(image) if is_svg: try: import cairosvg @@ -39,9 +41,6 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image: buffer = BytesIO() cairosvg.svg2png(image, write_to=buffer) return open_image(buffer) - if isinstance(image, str): - is_data_uri_an_image(image) - image = extract_data_uri(image) if isinstance(image, bytes): is_accepted_format(image) return open_image(BytesIO(image)) @@ -79,9 +78,9 @@ def is_data_uri_an_image(data_uri: str) -> bool: if not re.match(r'data:image/(\w+);base64,', data_uri): raise ValueError("Invalid data URI image.") # Extract the image format from the data URI - image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1) + image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower() # Check if the image format is one of the allowed formats (jpg, jpeg, png, gif) - if image_format.lower() not in ALLOWED_EXTENSIONS: + if image_format not in ALLOWED_EXTENSIONS and image_format != "svg+xml": raise ValueError("Invalid image format (from mime file type).") def is_accepted_format(binary_data: bytes) -> bool: @@ -187,7 +186,7 @@ def to_base64_jpg(image: Image, compression_rate: float) -> str: image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100)) return base64.b64encode(output_buffer.getvalue()).decode() -def format_images_markdown(images, alt: str, preview: str="{image}?w=200&h=200") -> str: +def format_images_markdown(images, alt: str, preview: str = None) -> str: """ Formats the given images as a markdown string. @@ -200,9 +199,12 @@ def format_images_markdown(images, alt: str, preview: str="{image}?w=200&h=200") str: The formatted markdown string. """ if isinstance(images, str): - images = f"[![{alt}]({preview.replace('{image}', images)})]({images})" + images = f"[![{alt}]({preview.replace('{image}', images) if preview else images})]({images})" else: - images = [f"[![#{idx+1} {alt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)] + images = [ + f"[![#{idx+1} {alt}]({preview.replace('{image}', image) if preview else image})]({image})" + for idx, image in enumerate(images) + ] images = "\n".join(images) start_flag = "<!-- generated images start -->\n" end_flag = "<!-- generated images end -->\n" @@ -223,7 +225,7 @@ def to_bytes(image: Image) -> bytes: image.seek(0) return bytes_io.getvalue() -class ImageResponse(): +class ImageResponse: def __init__( self, images: Union[str, list], @@ -235,10 +237,17 @@ class ImageResponse(): self.options = options def __str__(self) -> str: - return format_images_markdown(self.images, self.alt) + return format_images_markdown(self.images, self.alt, self.get("preview")) def get(self, key: str): return self.options.get(key) -class ImageRequest(ImageResponse): - pass
\ No newline at end of file +class ImageRequest: + def __init__( + self, + options: dict = {} + ): + self.options = options + + def get(self, key: str): + return self.options.get(key)
\ No newline at end of file |