summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/helper.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/helper.py')
-rw-r--r--g4f/Provider/helper.py143
1 files changed, 83 insertions, 60 deletions
diff --git a/g4f/Provider/helper.py b/g4f/Provider/helper.py
index 81f417dd..fce1ee6f 100644
--- a/g4f/Provider/helper.py
+++ b/g4f/Provider/helper.py
@@ -1,36 +1,31 @@
from __future__ import annotations
import asyncio
-import webbrowser
+import os
import random
-import string
import secrets
-import os
-from os import path
+import string
from asyncio import AbstractEventLoop, BaseEventLoop
from platformdirs import user_config_dir
from browser_cookie3 import (
- chrome,
- chromium,
- opera,
- opera_gx,
- brave,
- edge,
- vivaldi,
- firefox,
- _LinuxPasswordManager
+ chrome, chromium, opera, opera_gx,
+ brave, edge, vivaldi, firefox,
+ _LinuxPasswordManager, BrowserCookieError
)
-
from ..typing import Dict, Messages
from .. import debug
-# Local Cookie Storage
+# Global variable to store cookies
_cookies: Dict[str, Dict[str, str]] = {}
-# If loop closed or not set, create new event loop.
-# If event loop is already running, handle nested event loops.
-# If "nest_asyncio" is installed, patch the event loop.
def get_event_loop() -> AbstractEventLoop:
+ """
+ Get the current asyncio event loop. If the loop is closed or not set, create a new event loop.
+ If a loop is running, handle nested event loops. Patch the loop if 'nest_asyncio' is installed.
+
+ Returns:
+ AbstractEventLoop: The current or new event loop.
+ """
try:
loop = asyncio.get_event_loop()
if isinstance(loop, BaseEventLoop):
@@ -39,61 +34,50 @@ def get_event_loop() -> AbstractEventLoop:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
- # Is running event loop
asyncio.get_running_loop()
if not hasattr(loop.__class__, "_nest_patched"):
import nest_asyncio
nest_asyncio.apply(loop)
except RuntimeError:
- # No running event loop
pass
except ImportError:
raise RuntimeError(
- 'Use "create_async" instead of "create" function in a running event loop. Or install the "nest_asyncio" package.'
+ 'Use "create_async" instead of "create" function in a running event loop. Or install "nest_asyncio" package.'
)
return loop
-def init_cookies():
- urls = [
- 'https://chat-gpt.org',
- 'https://www.aitianhu.com',
- 'https://chatgptfree.ai',
- 'https://gptchatly.com',
- 'https://bard.google.com',
- 'https://huggingface.co/chat',
- 'https://open-assistant.io/chat'
- ]
-
- browsers = ['google-chrome', 'chrome', 'firefox', 'safari']
-
- def open_urls_in_browser(browser):
- b = webbrowser.get(browser)
- for url in urls:
- b.open(url, new=0, autoraise=True)
-
- for browser in browsers:
- try:
- open_urls_in_browser(browser)
- break
- except webbrowser.Error:
- continue
-
-# Check for broken dbus address in docker image
if os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null":
_LinuxPasswordManager.get_password = lambda a, b: b"secret"
-
-# Load cookies for a domain from all supported browsers.
-# Cache the results in the "_cookies" variable.
-def get_cookies(domain_name=''):
+
+def get_cookies(domain_name: str = '') -> Dict[str, str]:
+ """
+ Load cookies for a given domain from all supported browsers and cache the results.
+
+ Args:
+ domain_name (str): The domain for which to load cookies.
+
+ Returns:
+ Dict[str, str]: A dictionary of cookie names and values.
+ """
if domain_name in _cookies:
return _cookies[domain_name]
- def g4f(domain_name):
- user_data_dir = user_config_dir("g4f")
- cookie_file = path.join(user_data_dir, "Default", "Cookies")
- return [] if not path.exists(cookie_file) else chrome(cookie_file, domain_name)
+
+ cookies = _load_cookies_from_browsers(domain_name)
+ _cookies[domain_name] = cookies
+ return cookies
+
+def _load_cookies_from_browsers(domain_name: str) -> Dict[str, str]:
+ """
+ Helper function to load cookies from various browsers.
+
+ Args:
+ domain_name (str): The domain for which to load cookies.
+ Returns:
+ Dict[str, str]: A dictionary of cookie names and values.
+ """
cookies = {}
- for cookie_fn in [g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]:
+ for cookie_fn in [_g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]:
try:
cookie_jar = cookie_fn(domain_name=domain_name)
if len(cookie_jar) and debug.logging:
@@ -101,13 +85,38 @@ def get_cookies(domain_name=''):
for cookie in cookie_jar:
if cookie.name not in cookies:
cookies[cookie.name] = cookie.value
- except:
+ except BrowserCookieError:
pass
- _cookies[domain_name] = cookies
- return _cookies[domain_name]
+ except Exception as e:
+ if debug.logging:
+ print(f"Error reading cookies from {cookie_fn.__name__} for {domain_name}: {e}")
+ return cookies
+
+def _g4f(domain_name: str) -> list:
+ """
+ Load cookies from the 'g4f' browser (if exists).
+
+ Args:
+ domain_name (str): The domain for which to load cookies.
+ Returns:
+ list: List of cookies.
+ """
+ user_data_dir = user_config_dir("g4f")
+ cookie_file = os.path.join(user_data_dir, "Default", "Cookies")
+ return [] if not os.path.exists(cookie_file) else chrome(cookie_file, domain_name)
def format_prompt(messages: Messages, add_special_tokens=False) -> str:
+ """
+ Format a series of messages into a single string, optionally adding special tokens.
+
+ Args:
+ messages (Messages): A list of message dictionaries, each containing 'role' and 'content'.
+ add_special_tokens (bool): Whether to add special formatting tokens.
+
+ Returns:
+ str: A formatted string containing all messages.
+ """
if not add_special_tokens and len(messages) <= 1:
return messages[0]["content"]
formatted = "\n".join([
@@ -116,12 +125,26 @@ def format_prompt(messages: Messages, add_special_tokens=False) -> str:
])
return f"{formatted}\nAssistant:"
-
def get_random_string(length: int = 10) -> str:
+ """
+ Generate a random string of specified length, containing lowercase letters and digits.
+
+ Args:
+ length (int, optional): Length of the random string to generate. Defaults to 10.
+
+ Returns:
+ str: A random string of the specified length.
+ """
return ''.join(
random.choice(string.ascii_lowercase + string.digits)
for _ in range(length)
)
def get_random_hex() -> str:
+ """
+ Generate a random hexadecimal string of a fixed length.
+
+ Returns:
+ str: A random hexadecimal string of 32 characters (16 bytes).
+ """
return secrets.token_hex(16).zfill(32) \ No newline at end of file