summaryrefslogtreecommitdiffstats
path: root/g4f
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/needs_auth/Gemini.py41
-rw-r--r--g4f/gui/client/html/index.html2
-rw-r--r--g4f/image.py5
3 files changed, 42 insertions, 6 deletions
diff --git a/g4f/Provider/needs_auth/Gemini.py b/g4f/Provider/needs_auth/Gemini.py
index a6e4c15d..da7230dd 100644
--- a/g4f/Provider/needs_auth/Gemini.py
+++ b/g4f/Provider/needs_auth/Gemini.py
@@ -1,16 +1,25 @@
from __future__ import annotations
+import os
import json
import random
import re
from aiohttp import ClientSession
+try:
+ from selenium.webdriver.common.by import By
+ from selenium.webdriver.support.ui import WebDriverWait
+ from selenium.webdriver.support import expected_conditions as EC
+except ImportError:
+ pass
+
from ...typing import Messages, Cookies, ImageType, AsyncResult
from ..base_provider import AsyncGeneratorProvider
from ..helper import format_prompt, get_cookies
-from ...errors import MissingAuthError
+from ...errors import MissingAuthError, MissingRequirementsError
from ...image import to_bytes, ImageResponse
+from ...webdriver import get_browser, get_driver_cookies
REQUEST_HEADERS = {
"authority": "gemini.google.com",
@@ -55,6 +64,27 @@ class Gemini(AsyncGeneratorProvider):
**kwargs
) -> AsyncResult:
prompt = format_prompt(messages)
+
+ try:
+ driver = get_browser(proxy=proxy)
+ try:
+ driver.get(f"{cls.url}/app")
+ WebDriverWait(driver, 5).until(
+ EC.visibility_of_element_located((By.CSS_SELECTOR, "div.ql-editor.textarea"))
+ )
+ except:
+ login_url = os.environ.get("G4F_LOGIN_URL")
+ if login_url:
+ yield f"Please login: [Google Gemini]({login_url})\n\n"
+ WebDriverWait(driver, 240).until(
+ EC.visibility_of_element_located((By.CSS_SELECTOR, "div.ql-editor.textarea"))
+ )
+ cookies = get_driver_cookies(driver)
+ except MissingRequirementsError:
+ pass
+ finally:
+ driver.close()
+
if not cookies:
cookies = get_cookies(".google.com", False)
if "__Secure-1PSID" not in cookies:
@@ -108,7 +138,14 @@ class Gemini(AsyncGeneratorProvider):
yield content
if image_prompt:
images = [image[0][3][3] for image in response_part[4][0][12][7][0]]
- yield ImageResponse(images, image_prompt)
+ resolved_images = []
+ for image in images:
+ async with session.get(image, allow_redirects=False) as fetch:
+ image = fetch.headers["location"]
+ async with session.get(image, allow_redirects=False) as fetch:
+ image = fetch.headers["location"]
+ resolved_images.append(image)
+ yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images})
def build_request(
prompt: str,
diff --git a/g4f/gui/client/html/index.html b/g4f/gui/client/html/index.html
index 5edb55e8..55b54b48 100644
--- a/g4f/gui/client/html/index.html
+++ b/g4f/gui/client/html/index.html
@@ -154,7 +154,7 @@
<option value="Bing">Bing</option>
<option value="OpenaiChat">OpenaiChat</option>
<option value="HuggingChat">HuggingChat</option>
- <option value="Bard">Bard</option>
+ <option value="Gemini">Gemini</option>
<option value="Liaobots">Liaobots</option>
<option value="Phind">Phind</option>
<option value="">----</option>
diff --git a/g4f/image.py b/g4f/image.py
index 3f26f75f..f0ee0395 100644
--- a/g4f/image.py
+++ b/g4f/image.py
@@ -46,9 +46,8 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image:
return open_image(BytesIO(image))
elif not isinstance(image, Image):
image = open_image(image)
- copy = image.copy()
- copy.format = image.format
- return copy
+ image.load()
+ return image
return image
def is_allowed_extension(filename: str) -> bool: