summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/needs_auth/Gemini.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/needs_auth/Gemini.py36
1 files changed, 19 insertions, 17 deletions
diff --git a/g4f/Provider/needs_auth/Gemini.py b/g4f/Provider/needs_auth/Gemini.py
index e468f64a..f9b1c4a5 100644
--- a/g4f/Provider/needs_auth/Gemini.py
+++ b/g4f/Provider/needs_auth/Gemini.py
@@ -4,6 +4,7 @@ import os
import json
import random
import re
+import base64
from aiohttp import ClientSession, BaseConnector
@@ -22,7 +23,7 @@ from ..base_provider import AsyncGeneratorProvider
from ..helper import format_prompt, get_cookies
from ...requests.raise_for_status import raise_for_status
from ...errors import MissingAuthError, MissingRequirementsError
-from ...image import to_bytes, to_data_uri, ImageResponse
+from ...image import to_bytes, ImageResponse, ImageDataResponse
from ...webdriver import get_browser, get_driver_cookies
REQUEST_HEADERS = {
@@ -122,6 +123,7 @@ class Gemini(AsyncGeneratorProvider):
connector: BaseConnector = None,
image: ImageType = None,
image_name: str = None,
+ response_format: str = None,
**kwargs
) -> AsyncResult:
prompt = format_prompt(messages)
@@ -192,22 +194,22 @@ class Gemini(AsyncGeneratorProvider):
if image_prompt:
images = [image[0][3][3] for image in response_part[4][0][12][7][0]]
resolved_images = []
- preview = []
- for image in images:
- async with client.get(image, allow_redirects=False) as fetch:
- image = fetch.headers["location"]
- async with client.get(image, allow_redirects=False) as fetch:
- image = fetch.headers["location"]
- resolved_images.append(image)
- preview.append(image.replace('=s512', '=s200'))
- # preview_url = image.replace('=s512', '=s200')
- # async with client.get(preview_url) as fetch:
- # preview_data = to_data_uri(await fetch.content.read())
- # async with client.get(image) as fetch:
- # data = to_data_uri(await fetch.content.read())
- # preview.append(preview_data)
- # resolved_images.append(data)
- yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview})
+ if response_format == "b64_json":
+ for image in images:
+ async with client.get(image) as response:
+ data = base64.b64encode(await response.content.read()).decode()
+ resolved_images.append(data)
+ yield ImageDataResponse(resolved_images, image_prompt)
+ else:
+ preview = []
+ for image in images:
+ async with client.get(image, allow_redirects=False) as fetch:
+ image = fetch.headers["location"]
+ async with client.get(image, allow_redirects=False) as fetch:
+ image = fetch.headers["location"]
+ resolved_images.append(image)
+ preview.append(image.replace('=s512', '=s200'))
+ yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview})
def build_request(
prompt: str,