From 47900f23718e398fc086a6dfbf6590b4c5859c28 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Fri, 9 Feb 2024 03:31:05 +0100 Subject: Resolve images in Gemini Provider --- g4f/Provider/needs_auth/Gemini.py | 41 +++++++++++++++++++++++++++++++++++++-- g4f/gui/client/html/index.html | 2 +- g4f/image.py | 5 ++--- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/g4f/Provider/needs_auth/Gemini.py b/g4f/Provider/needs_auth/Gemini.py index a6e4c15d..da7230dd 100644 --- a/g4f/Provider/needs_auth/Gemini.py +++ b/g4f/Provider/needs_auth/Gemini.py @@ -1,16 +1,25 @@ from __future__ import annotations +import os import json import random import re from aiohttp import ClientSession +try: + from selenium.webdriver.common.by import By + from selenium.webdriver.support.ui import WebDriverWait + from selenium.webdriver.support import expected_conditions as EC +except ImportError: + pass + from ...typing import Messages, Cookies, ImageType, AsyncResult from ..base_provider import AsyncGeneratorProvider from ..helper import format_prompt, get_cookies -from ...errors import MissingAuthError +from ...errors import MissingAuthError, MissingRequirementsError from ...image import to_bytes, ImageResponse +from ...webdriver import get_browser, get_driver_cookies REQUEST_HEADERS = { "authority": "gemini.google.com", @@ -55,6 +64,27 @@ class Gemini(AsyncGeneratorProvider): **kwargs ) -> AsyncResult: prompt = format_prompt(messages) + + try: + driver = get_browser(proxy=proxy) + try: + driver.get(f"{cls.url}/app") + WebDriverWait(driver, 5).until( + EC.visibility_of_element_located((By.CSS_SELECTOR, "div.ql-editor.textarea")) + ) + except: + login_url = os.environ.get("G4F_LOGIN_URL") + if login_url: + yield f"Please login: [Google Gemini]({login_url})\n\n" + WebDriverWait(driver, 240).until( + EC.visibility_of_element_located((By.CSS_SELECTOR, "div.ql-editor.textarea")) + ) + cookies = get_driver_cookies(driver) + except MissingRequirementsError: + pass + finally: + driver.close() + if not cookies: cookies = get_cookies(".google.com", False) if "__Secure-1PSID" not in cookies: @@ -108,7 +138,14 @@ class Gemini(AsyncGeneratorProvider): yield content if image_prompt: images = [image[0][3][3] for image in response_part[4][0][12][7][0]] - yield ImageResponse(images, image_prompt) + resolved_images = [] + for image in images: + async with session.get(image, allow_redirects=False) as fetch: + image = fetch.headers["location"] + async with session.get(image, allow_redirects=False) as fetch: + image = fetch.headers["location"] + resolved_images.append(image) + yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images}) def build_request( prompt: str, diff --git a/g4f/gui/client/html/index.html b/g4f/gui/client/html/index.html index 5edb55e8..55b54b48 100644 --- a/g4f/gui/client/html/index.html +++ b/g4f/gui/client/html/index.html @@ -154,7 +154,7 @@ - + diff --git a/g4f/image.py b/g4f/image.py index 3f26f75f..f0ee0395 100644 --- a/g4f/image.py +++ b/g4f/image.py @@ -46,9 +46,8 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image: return open_image(BytesIO(image)) elif not isinstance(image, Image): image = open_image(image) - copy = image.copy() - copy.format = image.format - return copy + image.load() + return image return image def is_allowed_extension(filename: str) -> bool: -- cgit v1.2.3