From 5756586cde6ed6da147119113fb5a5fd640d5f83 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Sun, 14 Jan 2024 07:45:41 +0100 Subject: Refactor code with AI Add doctypes to many functions Add file upload for text files Add alternative url to FreeChatgpt Add webp to allowed image types --- g4f/image.py | 105 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 100 insertions(+), 5 deletions(-) (limited to 'g4f/image.py') diff --git a/g4f/image.py b/g4f/image.py index 01664f4e..cfa22ab1 100644 --- a/g4f/image.py +++ b/g4f/image.py @@ -4,9 +4,18 @@ import base64 from .typing import ImageType, Union from PIL import Image -ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'} +ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp'} def to_image(image: ImageType) -> Image.Image: + """ + Converts the input image to a PIL Image object. + + Args: + image (Union[str, bytes, Image.Image]): The input image. + + Returns: + Image.Image: The converted PIL Image object. + """ if isinstance(image, str): is_data_uri_an_image(image) image = extract_data_uri(image) @@ -20,21 +29,48 @@ def to_image(image: ImageType) -> Image.Image: image = copy return image -def is_allowed_extension(filename) -> bool: +def is_allowed_extension(filename: str) -> bool: + """ + Checks if the given filename has an allowed extension. + + Args: + filename (str): The filename to check. + + Returns: + bool: True if the extension is allowed, False otherwise. + """ return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS def is_data_uri_an_image(data_uri: str) -> bool: + """ + Checks if the given data URI represents an image. + + Args: + data_uri (str): The data URI to check. + + Raises: + ValueError: If the data URI is invalid or the image format is not allowed. + """ # Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif) if not re.match(r'data:image/(\w+);base64,', data_uri): raise ValueError("Invalid data URI image.") - # Extract the image format from the data URI + # Extract the image format from the data URI image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1) # Check if the image format is one of the allowed formats (jpg, jpeg, png, gif) if image_format.lower() not in ALLOWED_EXTENSIONS: raise ValueError("Invalid image format (from mime file type).") def is_accepted_format(binary_data: bytes) -> bool: + """ + Checks if the given binary data represents an image with an accepted format. + + Args: + binary_data (bytes): The binary data to check. + + Raises: + ValueError: If the image format is not allowed. + """ if binary_data.startswith(b'\xFF\xD8\xFF'): pass # It's a JPEG image elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'): @@ -49,13 +85,31 @@ def is_accepted_format(binary_data: bytes) -> bool: pass # It's a WebP image else: raise ValueError("Invalid image format (from magic code).") - + def extract_data_uri(data_uri: str) -> bytes: + """ + Extracts the binary data from the given data URI. + + Args: + data_uri (str): The data URI. + + Returns: + bytes: The extracted binary data. + """ data = data_uri.split(",")[1] data = base64.b64decode(data) return data def get_orientation(image: Image.Image) -> int: + """ + Gets the orientation of the given image. + + Args: + image (Image.Image): The image. + + Returns: + int: The orientation value. + """ exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif() if exif_data is not None: orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF @@ -63,6 +117,17 @@ def get_orientation(image: Image.Image) -> int: return orientation def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Image: + """ + Processes the given image by adjusting its orientation and resizing it. + + Args: + img (Image.Image): The image to process. + new_width (int): The new width of the image. + new_height (int): The new height of the image. + + Returns: + Image.Image: The processed image. + """ orientation = get_orientation(img) if orientation: if orientation > 4: @@ -75,13 +140,34 @@ def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Im img = img.transpose(Image.ROTATE_90) img.thumbnail((new_width, new_height)) return img - + def to_base64(image: Image.Image, compression_rate: float) -> str: + """ + Converts the given image to a base64-encoded string. + + Args: + image (Image.Image): The image to convert. + compression_rate (float): The compression rate (0.0 to 1.0). + + Returns: + str: The base64-encoded image. + """ output_buffer = BytesIO() image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100)) return base64.b64encode(output_buffer.getvalue()).decode() def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=200") -> str: + """ + Formats the given images as a markdown string. + + Args: + images: The images to format. + prompt (str): The prompt for the images. + preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200". + + Returns: + str: The formatted markdown string. + """ if isinstance(images, list): images = [f"[![#{idx+1} {prompt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)] images = "\n".join(images) @@ -92,6 +178,15 @@ def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=20 return f"\n{start_flag}{images}\n{end_flag}\n" def to_bytes(image: Image.Image) -> bytes: + """ + Converts the given image to bytes. + + Args: + image (Image.Image): The image to convert. + + Returns: + bytes: The image as bytes. + """ bytes_io = BytesIO() image.save(bytes_io, image.format) image.seek(0) -- cgit v1.2.3