summaryrefslogtreecommitdiffstats
path: root/g4f/image.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/image.py')
-rw-r--r--g4f/image.py60
1 files changed, 38 insertions, 22 deletions
diff --git a/g4f/image.py b/g4f/image.py
index 94b8c24c..68767155 100644
--- a/g4f/image.py
+++ b/g4f/image.py
@@ -1,39 +1,52 @@
+from __future__ import annotations
+
import re
from io import BytesIO
import base64
from .typing import ImageType, Union
-from PIL import Image
+
+try:
+ from PIL.Image import open as open_image, new as new_image, Image
+ from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90
+ has_requirements = True
+except ImportError:
+ Image = type
+ has_requirements = False
+
+from .errors import MissingRequirementsError
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'}
-def to_image(image: ImageType, is_svg: bool = False) -> Image.Image:
+def to_image(image: ImageType, is_svg: bool = False) -> Image:
"""
Converts the input image to a PIL Image object.
Args:
- image (Union[str, bytes, Image.Image]): The input image.
+ image (Union[str, bytes, Image]): The input image.
Returns:
- Image.Image: The converted PIL Image object.
+ Image: The converted PIL Image object.
"""
+ if not has_requirements:
+ raise MissingRequirementsError('Install "pillow" package for images')
if is_svg:
try:
import cairosvg
except ImportError:
- raise RuntimeError('Install "cairosvg" package for svg images')
+ raise MissingRequirementsError('Install "cairosvg" package for svg images')
if not isinstance(image, bytes):
image = image.read()
buffer = BytesIO()
cairosvg.svg2png(image, write_to=buffer)
- return Image.open(buffer)
+ return open_image(buffer)
if isinstance(image, str):
is_data_uri_an_image(image)
image = extract_data_uri(image)
if isinstance(image, bytes):
is_accepted_format(image)
- return Image.open(BytesIO(image))
- elif not isinstance(image, Image.Image):
- image = Image.open(image)
+ return open_image(BytesIO(image))
+ elif not isinstance(image, Image):
+ image = open_image(image)
copy = image.copy()
copy.format = image.format
return copy
@@ -110,12 +123,12 @@ def extract_data_uri(data_uri: str) -> bytes:
data = base64.b64decode(data)
return data
-def get_orientation(image: Image.Image) -> int:
+def get_orientation(image: Image) -> int:
"""
Gets the orientation of the given image.
Args:
- image (Image.Image): The image.
+ image (Image): The image.
Returns:
int: The orientation value.
@@ -126,40 +139,40 @@ def get_orientation(image: Image.Image) -> int:
if orientation is not None:
return orientation
-def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Image:
+def process_image(img: Image, new_width: int, new_height: int) -> Image:
"""
Processes the given image by adjusting its orientation and resizing it.
Args:
- img (Image.Image): The image to process.
+ img (Image): The image to process.
new_width (int): The new width of the image.
new_height (int): The new height of the image.
Returns:
- Image.Image: The processed image.
+ Image: The processed image.
"""
# Fix orientation
orientation = get_orientation(img)
if orientation:
if orientation > 4:
- img = img.transpose(Image.FLIP_LEFT_RIGHT)
+ img = img.transpose(FLIP_LEFT_RIGHT)
if orientation in [3, 4]:
- img = img.transpose(Image.ROTATE_180)
+ img = img.transpose(ROTATE_180)
if orientation in [5, 6]:
- img = img.transpose(Image.ROTATE_270)
+ img = img.transpose(ROTATE_270)
if orientation in [7, 8]:
- img = img.transpose(Image.ROTATE_90)
+ img = img.transpose(ROTATE_90)
# Resize image
img.thumbnail((new_width, new_height))
# Remove transparency
if img.mode != "RGB":
img.load()
- white = Image.new('RGB', img.size, (255, 255, 255))
+ white = new_image('RGB', img.size, (255, 255, 255))
white.paste(img, mask=img.split()[3])
return white
return img
-def to_base64(image: Image.Image, compression_rate: float) -> str:
+def to_base64_jpg(image: Image, compression_rate: float) -> str:
"""
Converts the given image to a base64-encoded string.
@@ -195,7 +208,7 @@ def format_images_markdown(images, alt: str, preview: str="{image}?w=200&h=200")
end_flag = "<!-- generated images end -->\n"
return f"\n{start_flag}{images}\n{end_flag}\n"
-def to_bytes(image: Image.Image) -> bytes:
+def to_bytes(image: Image) -> bytes:
"""
Converts the given image to bytes.
@@ -225,4 +238,7 @@ class ImageResponse():
return format_images_markdown(self.images, self.alt)
def get(self, key: str):
- return self.options.get(key) \ No newline at end of file
+ return self.options.get(key)
+
+class ImageRequest(ImageResponse):
+ pass \ No newline at end of file