diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/helper.py | 143 |
1 files changed, 83 insertions, 60 deletions
diff --git a/g4f/Provider/helper.py b/g4f/Provider/helper.py index 81f417dd..fce1ee6f 100644 --- a/g4f/Provider/helper.py +++ b/g4f/Provider/helper.py @@ -1,36 +1,31 @@ from __future__ import annotations import asyncio -import webbrowser +import os import random -import string import secrets -import os -from os import path +import string from asyncio import AbstractEventLoop, BaseEventLoop from platformdirs import user_config_dir from browser_cookie3 import ( - chrome, - chromium, - opera, - opera_gx, - brave, - edge, - vivaldi, - firefox, - _LinuxPasswordManager + chrome, chromium, opera, opera_gx, + brave, edge, vivaldi, firefox, + _LinuxPasswordManager, BrowserCookieError ) - from ..typing import Dict, Messages from .. import debug -# Local Cookie Storage +# Global variable to store cookies _cookies: Dict[str, Dict[str, str]] = {} -# If loop closed or not set, create new event loop. -# If event loop is already running, handle nested event loops. -# If "nest_asyncio" is installed, patch the event loop. def get_event_loop() -> AbstractEventLoop: + """ + Get the current asyncio event loop. If the loop is closed or not set, create a new event loop. + If a loop is running, handle nested event loops. Patch the loop if 'nest_asyncio' is installed. + + Returns: + AbstractEventLoop: The current or new event loop. + """ try: loop = asyncio.get_event_loop() if isinstance(loop, BaseEventLoop): @@ -39,61 +34,50 @@ def get_event_loop() -> AbstractEventLoop: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - # Is running event loop asyncio.get_running_loop() if not hasattr(loop.__class__, "_nest_patched"): import nest_asyncio nest_asyncio.apply(loop) except RuntimeError: - # No running event loop pass except ImportError: raise RuntimeError( - 'Use "create_async" instead of "create" function in a running event loop. Or install the "nest_asyncio" package.' + 'Use "create_async" instead of "create" function in a running event loop. Or install "nest_asyncio" package.' ) return loop -def init_cookies(): - urls = [ - 'https://chat-gpt.org', - 'https://www.aitianhu.com', - 'https://chatgptfree.ai', - 'https://gptchatly.com', - 'https://bard.google.com', - 'https://huggingface.co/chat', - 'https://open-assistant.io/chat' - ] - - browsers = ['google-chrome', 'chrome', 'firefox', 'safari'] - - def open_urls_in_browser(browser): - b = webbrowser.get(browser) - for url in urls: - b.open(url, new=0, autoraise=True) - - for browser in browsers: - try: - open_urls_in_browser(browser) - break - except webbrowser.Error: - continue - -# Check for broken dbus address in docker image if os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null": _LinuxPasswordManager.get_password = lambda a, b: b"secret" - -# Load cookies for a domain from all supported browsers. -# Cache the results in the "_cookies" variable. -def get_cookies(domain_name=''): + +def get_cookies(domain_name: str = '') -> Dict[str, str]: + """ + Load cookies for a given domain from all supported browsers and cache the results. + + Args: + domain_name (str): The domain for which to load cookies. + + Returns: + Dict[str, str]: A dictionary of cookie names and values. + """ if domain_name in _cookies: return _cookies[domain_name] - def g4f(domain_name): - user_data_dir = user_config_dir("g4f") - cookie_file = path.join(user_data_dir, "Default", "Cookies") - return [] if not path.exists(cookie_file) else chrome(cookie_file, domain_name) + + cookies = _load_cookies_from_browsers(domain_name) + _cookies[domain_name] = cookies + return cookies + +def _load_cookies_from_browsers(domain_name: str) -> Dict[str, str]: + """ + Helper function to load cookies from various browsers. + + Args: + domain_name (str): The domain for which to load cookies. + Returns: + Dict[str, str]: A dictionary of cookie names and values. + """ cookies = {} - for cookie_fn in [g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]: + for cookie_fn in [_g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]: try: cookie_jar = cookie_fn(domain_name=domain_name) if len(cookie_jar) and debug.logging: @@ -101,13 +85,38 @@ def get_cookies(domain_name=''): for cookie in cookie_jar: if cookie.name not in cookies: cookies[cookie.name] = cookie.value - except: + except BrowserCookieError: pass - _cookies[domain_name] = cookies - return _cookies[domain_name] + except Exception as e: + if debug.logging: + print(f"Error reading cookies from {cookie_fn.__name__} for {domain_name}: {e}") + return cookies + +def _g4f(domain_name: str) -> list: + """ + Load cookies from the 'g4f' browser (if exists). + + Args: + domain_name (str): The domain for which to load cookies. + Returns: + list: List of cookies. + """ + user_data_dir = user_config_dir("g4f") + cookie_file = os.path.join(user_data_dir, "Default", "Cookies") + return [] if not os.path.exists(cookie_file) else chrome(cookie_file, domain_name) def format_prompt(messages: Messages, add_special_tokens=False) -> str: + """ + Format a series of messages into a single string, optionally adding special tokens. + + Args: + messages (Messages): A list of message dictionaries, each containing 'role' and 'content'. + add_special_tokens (bool): Whether to add special formatting tokens. + + Returns: + str: A formatted string containing all messages. + """ if not add_special_tokens and len(messages) <= 1: return messages[0]["content"] formatted = "\n".join([ @@ -116,12 +125,26 @@ def format_prompt(messages: Messages, add_special_tokens=False) -> str: ]) return f"{formatted}\nAssistant:" - def get_random_string(length: int = 10) -> str: + """ + Generate a random string of specified length, containing lowercase letters and digits. + + Args: + length (int, optional): Length of the random string to generate. Defaults to 10. + + Returns: + str: A random string of the specified length. + """ return ''.join( random.choice(string.ascii_lowercase + string.digits) for _ in range(length) ) def get_random_hex() -> str: + """ + Generate a random hexadecimal string of a fixed length. + + Returns: + str: A random hexadecimal string of 32 characters (16 bytes). + """ return secrets.token_hex(16).zfill(32)
\ No newline at end of file |