summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/needs_auth/Gemini.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/needs_auth/Gemini.py80
1 files changed, 26 insertions, 54 deletions
diff --git a/g4f/Provider/needs_auth/Gemini.py b/g4f/Provider/needs_auth/Gemini.py
index dad54c84..781aa410 100644
--- a/g4f/Provider/needs_auth/Gemini.py
+++ b/g4f/Provider/needs_auth/Gemini.py
@@ -6,24 +6,20 @@ import random
import re
from aiohttp import ClientSession, BaseConnector
-
-from ..helper import get_connector
-
try:
- from selenium.webdriver.common.by import By
- from selenium.webdriver.support.ui import WebDriverWait
- from selenium.webdriver.support import expected_conditions as EC
+ import nodriver
+ has_nodriver = True
except ImportError:
- pass
+ has_nodriver = False
from ... import debug
from ...typing import Messages, Cookies, ImageType, AsyncResult, AsyncIterator
from ..base_provider import AsyncGeneratorProvider, BaseConversation
from ..helper import format_prompt, get_cookies
from ...requests.raise_for_status import raise_for_status
-from ...errors import MissingAuthError, MissingRequirementsError
+from ...requests.aiohttp import get_connector
+from ...errors import MissingAuthError
from ...image import ImageResponse, to_bytes
-from ...webdriver import get_browser, get_driver_cookies
REQUEST_HEADERS = {
"authority": "gemini.google.com",
@@ -64,9 +60,9 @@ class Gemini(AsyncGeneratorProvider):
@classmethod
async def nodriver_login(cls, proxy: str = None) -> AsyncIterator[str]:
- try:
- import nodriver as uc
- except ImportError:
+ if not has_nodriver:
+ if debug.logging:
+ print("Skip nodriver login in Gemini provider")
return
try:
from platformdirs import user_config_dir
@@ -75,7 +71,7 @@ class Gemini(AsyncGeneratorProvider):
user_data_dir = None
if debug.logging:
print(f"Open nodriver with user_dir: {user_data_dir}")
- browser = await uc.start(
+ browser = await nodriver.start(
user_data_dir=user_data_dir,
browser_args=None if proxy is None else [f"--proxy-server={proxy}"],
)
@@ -92,30 +88,6 @@ class Gemini(AsyncGeneratorProvider):
cls._cookies = cookies
@classmethod
- async def webdriver_login(cls, proxy: str) -> AsyncIterator[str]:
- driver = None
- try:
- driver = get_browser(proxy=proxy)
- try:
- driver.get(f"{cls.url}/app")
- WebDriverWait(driver, 5).until(
- EC.visibility_of_element_located((By.CSS_SELECTOR, "div.ql-editor.textarea"))
- )
- except:
- login_url = os.environ.get("G4F_LOGIN_URL")
- if login_url:
- yield f"Please login: [Google Gemini]({login_url})\n\n"
- WebDriverWait(driver, 240).until(
- EC.visibility_of_element_located((By.CSS_SELECTOR, "div.ql-editor.textarea"))
- )
- cls._cookies = get_driver_cookies(driver)
- except MissingRequirementsError:
- pass
- finally:
- if driver:
- driver.close()
-
- @classmethod
async def create_async_generator(
cls,
model: str,
@@ -143,9 +115,6 @@ class Gemini(AsyncGeneratorProvider):
if not cls._snlm0e:
async for chunk in cls.nodriver_login(proxy):
yield chunk
- if cls._cookies is None:
- async for chunk in cls.webdriver_login(proxy):
- yield chunk
if not cls._snlm0e:
if cls._cookies is None or "__Secure-1PSID" not in cls._cookies:
raise MissingAuthError('Missing "__Secure-1PSID" cookie')
@@ -211,20 +180,23 @@ class Gemini(AsyncGeneratorProvider):
yield content[last_content_len:]
last_content_len = len(content)
if image_prompt:
- images = [image[0][3][3] for image in response_part[4][0][12][7][0]]
- if response_format == "b64_json":
- yield ImageResponse(images, image_prompt, {"cookies": cls._cookies})
- else:
- resolved_images = []
- preview = []
- for image in images:
- async with client.get(image, allow_redirects=False) as fetch:
- image = fetch.headers["location"]
- async with client.get(image, allow_redirects=False) as fetch:
- image = fetch.headers["location"]
- resolved_images.append(image)
- preview.append(image.replace('=s512', '=s200'))
- yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview})
+ try:
+ images = [image[0][3][3] for image in response_part[4][0][12][7][0]]
+ if response_format == "b64_json":
+ yield ImageResponse(images, image_prompt, {"cookies": cls._cookies})
+ else:
+ resolved_images = []
+ preview = []
+ for image in images:
+ async with client.get(image, allow_redirects=False) as fetch:
+ image = fetch.headers["location"]
+ async with client.get(image, allow_redirects=False) as fetch:
+ image = fetch.headers["location"]
+ resolved_images.append(image)
+ preview.append(image.replace('=s512', '=s200'))
+ yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview})
+ except TypeError:
+ pass
def build_request(
prompt: str,