diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/helper.py | 57 |
1 files changed, 40 insertions, 17 deletions
diff --git a/g4f/Provider/helper.py b/g4f/Provider/helper.py index e14ae65e..5a9a9329 100644 --- a/g4f/Provider/helper.py +++ b/g4f/Provider/helper.py @@ -3,37 +3,46 @@ from __future__ import annotations import asyncio import sys from asyncio import AbstractEventLoop - +from os import path +from typing import Dict, List import browser_cookie3 -_cookies: dict[str, dict[str, str]] = {} - -# Use own event_loop_policy with a selector event loop on windows. +# Change event loop policy on windows if sys.platform == 'win32': - _event_loop_policy = asyncio.WindowsSelectorEventLoopPolicy() -else: - _event_loop_policy = asyncio.get_event_loop_policy() - + if isinstance( + asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy + ): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + +# Local Cookie Storage +_cookies: Dict[str, Dict[str, str]] = {} + # If event loop is already running, handle nested event loops # If "nest_asyncio" is installed, patch the event loop. def get_event_loop() -> AbstractEventLoop: try: asyncio.get_running_loop() except RuntimeError: - return _event_loop_policy.get_event_loop() + try: + return asyncio.get_event_loop() + except RuntimeError: + asyncio.set_event_loop(asyncio.new_event_loop()) + return asyncio.get_event_loop() try: - event_loop = _event_loop_policy.get_event_loop() + event_loop = asyncio.get_event_loop() if not hasattr(event_loop.__class__, "_nest_patched"): import nest_asyncio nest_asyncio.apply(event_loop) return event_loop except ImportError: raise RuntimeError( - 'Use "create_async" instead of "create" function in a running event loop. Or install the "nest_asyncio" package.') + 'Use "create_async" instead of "create" function in a running event loop. Or install the "nest_asyncio" package.' + ) -# Load cookies for a domain from all supported browser. -# Cache the results in the "_cookies" variable -def get_cookies(cookie_domain: str) -> dict: + +# Load cookies for a domain from all supported browsers. +# Cache the results in the "_cookies" variable. +def get_cookies(cookie_domain: str) -> Dict[str, str]: if cookie_domain not in _cookies: _cookies[cookie_domain] = {} try: @@ -44,11 +53,25 @@ def get_cookies(cookie_domain: str) -> dict: return _cookies[cookie_domain] -def format_prompt(messages: list[dict[str, str]], add_special_tokens=False): +def format_prompt(messages: List[Dict[str, str]], add_special_tokens=False) -> str: if add_special_tokens or len(messages) > 1: formatted = "\n".join( - ["%s: %s" % ((message["role"]).capitalize(), message["content"]) for message in messages] + [ + "%s: %s" % ((message["role"]).capitalize(), message["content"]) + for message in messages + ] ) return f"{formatted}\nAssistant:" else: - return messages[0]["content"]
\ No newline at end of file + return messages[0]["content"] + + +def get_browser(user_data_dir: str = None): + from undetected_chromedriver import Chrome + from platformdirs import user_config_dir + + if not user_data_dir: + user_data_dir = user_config_dir("g4f") + user_data_dir = path.join(user_data_dir, "Default") + + return Chrome(user_data_dir=user_data_dir)
\ No newline at end of file |