An error occured, please try again, if the problem persists, please use a other model or provider.
";
} else {
html = markdown_render(text);
- html = html.substring(0, html.lastIndexOf('
')) + '';
+ let lastElement, lastIndex = null;
+ for (element of ['', '', '\n']) {
+ const index = html.lastIndexOf(element)
+ if (index > lastIndex) {
+ lastElement = element;
+ lastIndex = index;
+ }
+ }
+ if (lastIndex) {
+ html = html.substring(0, lastIndex) + '' + lastElement;
+ }
content_inner.innerHTML = html;
- document.querySelectorAll('code').forEach((el) => {
+ document.querySelectorAll('code:not(.hljs').forEach((el) => {
hljs.highlightElement(el);
});
}
window.scrollTo(0, 0);
- message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
+ if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 100) {
+ message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
+ }
}
if (!error && imageInput) imageInput.value = "";
+ if (!error && fileInput) fileInput.value = "";
} catch (e) {
console.error(e);
@@ -305,7 +328,7 @@ const load_conversation = async (conversation_id) => {
`;
}
- document.querySelectorAll(`code`).forEach((el) => {
+ document.querySelectorAll('code:not(.hljs').forEach((el) => {
hljs.highlightElement(el);
});
@@ -400,7 +423,7 @@ const load_conversations = async (limit, offset, loader) => {
`;
}
- document.querySelectorAll(`code`).forEach((el) => {
+ document.querySelectorAll('code:not(.hljs').forEach((el) => {
hljs.highlightElement(el);
});
};
@@ -602,14 +625,7 @@ observer.observe(message_input, { attributes: true });
(async () => {
response = await fetch('/backend-api/v2/models')
models = await response.json()
-
let select = document.getElementById('model');
- select.textContent = '';
-
- let auto = document.createElement('option');
- auto.value = '';
- auto.text = 'Model: Default';
- select.appendChild(auto);
for (model of models) {
let option = document.createElement('option');
@@ -619,14 +635,7 @@ observer.observe(message_input, { attributes: true });
response = await fetch('/backend-api/v2/providers')
providers = await response.json()
-
select = document.getElementById('provider');
- select.textContent = '';
-
- auto = document.createElement('option');
- auto.value = '';
- auto.text = 'Provider: Auto';
- select.appendChild(auto);
for (provider of providers) {
let option = document.createElement('option');
@@ -650,4 +659,27 @@ observer.observe(message_input, { attributes: true });
text += versions["version"];
}
document.getElementById("version_text").innerHTML = text
-})()
\ No newline at end of file
+})()
+
+fileInput.addEventListener('change', async (event) => {
+ if (fileInput.files.length) {
+ type = fileInput.files[0].type;
+ if (type && type.indexOf('/')) {
+ type = type.split('/').pop().replace('x-', '')
+ type = type.replace('plain', 'plaintext')
+ .replace('shellscript', 'sh')
+ .replace('svg+xml', 'svg')
+ .replace('vnd.trolltech.linguist', 'ts')
+ } else {
+ type = fileInput.files[0].name.split('.').pop()
+ }
+ fileInput.dataset.type = type
+ const reader = new FileReader();
+ reader.addEventListener('load', (event) => {
+ fileInput.dataset.text = event.target.result;
+ });
+ reader.readAsText(fileInput.files[0]);
+ } else {
+ delete fileInput.dataset.text;
+ }
+});
\ No newline at end of file
diff --git a/g4f/image.py b/g4f/image.py
index 01664f4e..cfa22ab1 100644
--- a/g4f/image.py
+++ b/g4f/image.py
@@ -4,9 +4,18 @@ import base64
from .typing import ImageType, Union
from PIL import Image
-ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
+ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp'}
def to_image(image: ImageType) -> Image.Image:
+ """
+ Converts the input image to a PIL Image object.
+
+ Args:
+ image (Union[str, bytes, Image.Image]): The input image.
+
+ Returns:
+ Image.Image: The converted PIL Image object.
+ """
if isinstance(image, str):
is_data_uri_an_image(image)
image = extract_data_uri(image)
@@ -20,21 +29,48 @@ def to_image(image: ImageType) -> Image.Image:
image = copy
return image
-def is_allowed_extension(filename) -> bool:
+def is_allowed_extension(filename: str) -> bool:
+ """
+ Checks if the given filename has an allowed extension.
+
+ Args:
+ filename (str): The filename to check.
+
+ Returns:
+ bool: True if the extension is allowed, False otherwise.
+ """
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def is_data_uri_an_image(data_uri: str) -> bool:
+ """
+ Checks if the given data URI represents an image.
+
+ Args:
+ data_uri (str): The data URI to check.
+
+ Raises:
+ ValueError: If the data URI is invalid or the image format is not allowed.
+ """
# Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
if not re.match(r'data:image/(\w+);base64,', data_uri):
raise ValueError("Invalid data URI image.")
- # Extract the image format from the data URI
+ # Extract the image format from the data URI
image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1)
# Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
if image_format.lower() not in ALLOWED_EXTENSIONS:
raise ValueError("Invalid image format (from mime file type).")
def is_accepted_format(binary_data: bytes) -> bool:
+ """
+ Checks if the given binary data represents an image with an accepted format.
+
+ Args:
+ binary_data (bytes): The binary data to check.
+
+ Raises:
+ ValueError: If the image format is not allowed.
+ """
if binary_data.startswith(b'\xFF\xD8\xFF'):
pass # It's a JPEG image
elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
@@ -49,13 +85,31 @@ def is_accepted_format(binary_data: bytes) -> bool:
pass # It's a WebP image
else:
raise ValueError("Invalid image format (from magic code).")
-
+
def extract_data_uri(data_uri: str) -> bytes:
+ """
+ Extracts the binary data from the given data URI.
+
+ Args:
+ data_uri (str): The data URI.
+
+ Returns:
+ bytes: The extracted binary data.
+ """
data = data_uri.split(",")[1]
data = base64.b64decode(data)
return data
def get_orientation(image: Image.Image) -> int:
+ """
+ Gets the orientation of the given image.
+
+ Args:
+ image (Image.Image): The image.
+
+ Returns:
+ int: The orientation value.
+ """
exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif()
if exif_data is not None:
orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
@@ -63,6 +117,17 @@ def get_orientation(image: Image.Image) -> int:
return orientation
def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Image:
+ """
+ Processes the given image by adjusting its orientation and resizing it.
+
+ Args:
+ img (Image.Image): The image to process.
+ new_width (int): The new width of the image.
+ new_height (int): The new height of the image.
+
+ Returns:
+ Image.Image: The processed image.
+ """
orientation = get_orientation(img)
if orientation:
if orientation > 4:
@@ -75,13 +140,34 @@ def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Im
img = img.transpose(Image.ROTATE_90)
img.thumbnail((new_width, new_height))
return img
-
+
def to_base64(image: Image.Image, compression_rate: float) -> str:
+ """
+ Converts the given image to a base64-encoded string.
+
+ Args:
+ image (Image.Image): The image to convert.
+ compression_rate (float): The compression rate (0.0 to 1.0).
+
+ Returns:
+ str: The base64-encoded image.
+ """
output_buffer = BytesIO()
image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
return base64.b64encode(output_buffer.getvalue()).decode()
def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=200") -> str:
+ """
+ Formats the given images as a markdown string.
+
+ Args:
+ images: The images to format.
+ prompt (str): The prompt for the images.
+ preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200".
+
+ Returns:
+ str: The formatted markdown string.
+ """
if isinstance(images, list):
images = [f"[![#{idx+1} {prompt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)]
images = "\n".join(images)
@@ -92,6 +178,15 @@ def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=20
return f"\n{start_flag}{images}\n{end_flag}\n"
def to_bytes(image: Image.Image) -> bytes:
+ """
+ Converts the given image to bytes.
+
+ Args:
+ image (Image.Image): The image to convert.
+
+ Returns:
+ bytes: The image as bytes.
+ """
bytes_io = BytesIO()
image.save(bytes_io, image.format)
image.seek(0)
diff --git a/g4f/models.py b/g4f/models.py
index 03deebf8..dd6e0a2c 100644
--- a/g4f/models.py
+++ b/g4f/models.py
@@ -31,12 +31,21 @@ from .Provider import (
@dataclass(unsafe_hash=True)
class Model:
+ """
+ Represents a machine learning model configuration.
+
+ Attributes:
+ name (str): Name of the model.
+ base_provider (str): Default provider for the model.
+ best_provider (ProviderType): The preferred provider for the model, typically with retry logic.
+ """
name: str
base_provider: str
best_provider: ProviderType = None
@staticmethod
def __all__() -> list[str]:
+ """Returns a list of all model names."""
return _all_models
default = Model(
@@ -298,6 +307,12 @@ pi = Model(
)
class ModelUtils:
+ """
+ Utility class for mapping string identifiers to Model instances.
+
+ Attributes:
+ convert (dict[str, Model]): Dictionary mapping model string identifiers to Model instances.
+ """
convert: dict[str, Model] = {
# gpt-3.5
'gpt-3.5-turbo' : gpt_35_turbo,
diff --git a/g4f/requests.py b/g4f/requests.py
index 1a13dec9..466d5a2a 100644
--- a/g4f/requests.py
+++ b/g4f/requests.py
@@ -1,7 +1,6 @@
from __future__ import annotations
import json
-from contextlib import asynccontextmanager
from functools import partialmethod
from typing import AsyncGenerator
from urllib.parse import urlparse
@@ -9,27 +8,41 @@ from curl_cffi.requests import AsyncSession, Session, Response
from .webdriver import WebDriver, WebDriverSession, bypass_cloudflare, get_driver_cookies
class StreamResponse:
+ """
+ A wrapper class for handling asynchronous streaming responses.
+
+ Attributes:
+ inner (Response): The original Response object.
+ """
+
def __init__(self, inner: Response) -> None:
+ """Initialize the StreamResponse with the provided Response object."""
self.inner: Response = inner
async def text(self) -> str:
+ """Asynchronously get the response text."""
return await self.inner.atext()
def raise_for_status(self) -> None:
+ """Raise an HTTPError if one occurred."""
self.inner.raise_for_status()
async def json(self, **kwargs) -> dict:
+ """Asynchronously parse the JSON response content."""
return json.loads(await self.inner.acontent(), **kwargs)
async def iter_lines(self) -> AsyncGenerator[bytes, None]:
+ """Asynchronously iterate over the lines of the response."""
async for line in self.inner.aiter_lines():
yield line
async def iter_content(self) -> AsyncGenerator[bytes, None]:
+ """Asynchronously iterate over the response content."""
async for chunk in self.inner.aiter_content():
yield chunk
-
+
async def __aenter__(self):
+ """Asynchronously enter the runtime context for the response object."""
inner: Response = await self.inner
self.inner = inner
self.request = inner.request
@@ -39,24 +52,47 @@ class StreamResponse:
self.headers = inner.headers
self.cookies = inner.cookies
return self
-
+
async def __aexit__(self, *args):
+ """Asynchronously exit the runtime context for the response object."""
await self.inner.aclose()
+
class StreamSession(AsyncSession):
+ """
+ An asynchronous session class for handling HTTP requests with streaming.
+
+ Inherits from AsyncSession.
+ """
+
def request(
self, method: str, url: str, **kwargs
) -> StreamResponse:
+ """Create and return a StreamResponse object for the given HTTP request."""
return StreamResponse(super().request(method, url, stream=True, **kwargs))
+ # Defining HTTP methods as partial methods of the request method.
head = partialmethod(request, "HEAD")
get = partialmethod(request, "GET")
post = partialmethod(request, "POST")
put = partialmethod(request, "PUT")
patch = partialmethod(request, "PATCH")
delete = partialmethod(request, "DELETE")
-
-def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120):
+
+
+def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120) -> Session:
+ """
+ Create a Session object using a WebDriver to handle cookies and headers.
+
+ Args:
+ url (str): The URL to navigate to using the WebDriver.
+ webdriver (WebDriver, optional): The WebDriver instance to use.
+ proxy (str, optional): Proxy server to use for the Session.
+ timeout (int, optional): Timeout in seconds for the WebDriver.
+
+ Returns:
+ Session: A Session object configured with cookies and headers from the WebDriver.
+ """
with WebDriverSession(webdriver, "", proxy=proxy, virtual_display=True) as driver:
bypass_cloudflare(driver, url, timeout)
cookies = get_driver_cookies(driver)
@@ -78,4 +114,4 @@ def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str =
proxies={"https": proxy, "http": proxy},
timeout=timeout,
impersonate="chrome110"
- )
+ )
\ No newline at end of file
diff --git a/g4f/version.py b/g4f/version.py
index bb4b7f17..c976c8fd 100644
--- a/g4f/version.py
+++ b/g4f/version.py
@@ -5,45 +5,94 @@ from importlib.metadata import version as get_package_version, PackageNotFoundEr
from subprocess import check_output, CalledProcessError, PIPE
from .errors import VersionNotFoundError
-def get_latest_version() -> str:
+def get_pypi_version(package_name: str) -> str:
+ """
+ Get the latest version of a package from PyPI.
+
+ :param package_name: The name of the package.
+ :return: The latest version of the package as a string.
+ """
try:
- get_package_version("g4f")
- response = requests.get("https://pypi.org/pypi/g4f/json").json()
+ response = requests.get(f"https://pypi.org/pypi/{package_name}/json").json()
return response["info"]["version"]
- except PackageNotFoundError:
- url = "https://api.github.com/repos/xtekky/gpt4free/releases/latest"
- response = requests.get(url).json()
+ except requests.RequestException as e:
+ raise VersionNotFoundError(f"Failed to get PyPI version: {e}")
+
+def get_github_version(repo: str) -> str:
+ """
+ Get the latest release version from a GitHub repository.
+
+ :param repo: The name of the GitHub repository.
+ :return: The latest release version as a string.
+ """
+ try:
+ response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest").json()
return response["tag_name"]
+ except requests.RequestException as e:
+ raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")
+
+def get_latest_version():
+ """
+ Get the latest release version from PyPI or the GitHub repository.
-class VersionUtils():
+ :return: The latest release version as a string.
+ """
+ try:
+ # Is installed via package manager?
+ get_package_version("g4f")
+ return get_pypi_version("g4f")
+ except PackageNotFoundError:
+ # Else use Github version:
+ return get_github_version("xtekky/gpt4free")
+
+class VersionUtils:
+ """
+ Utility class for managing and comparing package versions.
+ """
@cached_property
def current_version(self) -> str:
+ """
+ Get the current version of the g4f package.
+
+ :return: The current version as a string.
+ """
# Read from package manager
try:
return get_package_version("g4f")
except PackageNotFoundError:
pass
+
# Read from docker environment
version = environ.get("G4F_VERSION")
if version:
return version
+
# Read from git repository
try:
command = ["git", "describe", "--tags", "--abbrev=0"]
return check_output(command, text=True, stderr=PIPE).strip()
except CalledProcessError:
pass
+
raise VersionNotFoundError("Version not found")
-
+
@cached_property
def latest_version(self) -> str:
+ """
+ Get the latest version of the g4f package.
+
+ :return: The latest version as a string.
+ """
return get_latest_version()
-
+
def check_version(self) -> None:
+ """
+ Check if the current version is up to date with the latest version.
+ """
try:
if self.current_version != self.latest_version:
print(f'New g4f version: {self.latest_version} (current: {self.current_version}) | pip install -U g4f')
except Exception as e:
print(f'Failed to check g4f version: {e}')
-
+
utils = VersionUtils()
\ No newline at end of file
diff --git a/g4f/webdriver.py b/g4f/webdriver.py
index da283409..e5ecd8bf 100644
--- a/g4f/webdriver.py
+++ b/g4f/webdriver.py
@@ -1,5 +1,4 @@
from __future__ import annotations
-
from platformdirs import user_config_dir
from selenium.webdriver.remote.webdriver import WebDriver
from undetected_chromedriver import Chrome, ChromeOptions
@@ -21,7 +20,16 @@ def get_browser(
proxy: str = None,
options: ChromeOptions = None
) -> WebDriver:
- if user_data_dir == None:
+ """
+ Creates and returns a Chrome WebDriver with the specified options.
+
+ :param user_data_dir: Directory for user data. If None, uses default directory.
+ :param headless: Boolean indicating whether to run the browser in headless mode.
+ :param proxy: Proxy settings for the browser.
+ :param options: ChromeOptions object with specific browser options.
+ :return: An instance of WebDriver.
+ """
+ if user_data_dir is None:
user_data_dir = user_config_dir("g4f")
if user_data_dir and debug.logging:
print("Open browser with config dir:", user_data_dir)
@@ -39,36 +47,45 @@ def get_browser(
headless=headless
)
-def get_driver_cookies(driver: WebDriver):
- return dict([(cookie["name"], cookie["value"]) for cookie in driver.get_cookies()])
+def get_driver_cookies(driver: WebDriver) -> dict:
+ """
+ Retrieves cookies from the given WebDriver.
+
+ :param driver: WebDriver from which to retrieve cookies.
+ :return: A dictionary of cookies.
+ """
+ return {cookie["name"]: cookie["value"] for cookie in driver.get_cookies()}
def bypass_cloudflare(driver: WebDriver, url: str, timeout: int) -> None:
- # Open website
+ """
+ Attempts to bypass Cloudflare protection when accessing a URL using the provided WebDriver.
+
+ :param driver: The WebDriver to use.
+ :param url: URL to access.
+ :param timeout: Time in seconds to wait for the page to load.
+ """
driver.get(url)
- # Is cloudflare protection
if driver.find_element(By.TAG_NAME, "body").get_attribute("class") == "no-js":
if debug.logging:
print("Cloudflare protection detected:", url)
try:
- # Click button in iframe
- WebDriverWait(driver, 5).until(
- EC.presence_of_element_located((By.CSS_SELECTOR, "#turnstile-wrapper iframe"))
- )
driver.switch_to.frame(driver.find_element(By.CSS_SELECTOR, "#turnstile-wrapper iframe"))
WebDriverWait(driver, 5).until(
EC.presence_of_element_located((By.CSS_SELECTOR, "#challenge-stage input"))
- )
- driver.find_element(By.CSS_SELECTOR, "#challenge-stage input").click()
- except:
- pass
+ ).click()
+ except Exception as e:
+ if debug.logging:
+ print(f"Error bypassing Cloudflare: {e}")
finally:
driver.switch_to.default_content()
- # No cloudflare protection
WebDriverWait(driver, timeout).until(
EC.presence_of_element_located((By.CSS_SELECTOR, "body:not(.no-js)"))
)
-class WebDriverSession():
+class WebDriverSession:
+ """
+ Manages a Selenium WebDriver session, including handling of virtual displays and proxies.
+ """
def __init__(
self,
webdriver: WebDriver = None,
@@ -81,9 +98,7 @@ class WebDriverSession():
self.webdriver = webdriver
self.user_data_dir = user_data_dir
self.headless = headless
- self.virtual_display = None
- if has_pyvirtualdisplay and virtual_display:
- self.virtual_display = Display(size=(1920, 1080))
+ self.virtual_display = Display(size=(1920, 1080)) if has_pyvirtualdisplay and virtual_display else None
self.proxy = proxy
self.options = options
self.default_driver = None
@@ -94,8 +109,15 @@ class WebDriverSession():
headless: bool = False,
virtual_display: bool = False
) -> WebDriver:
- if user_data_dir == None:
- user_data_dir = self.user_data_dir
+ """
+ Reopens the WebDriver session with the specified parameters.
+
+ :param user_data_dir: Directory for user data.
+ :param headless: Boolean indicating whether to run the browser in headless mode.
+ :param virtual_display: Boolean indicating whether to use a virtual display.
+ :return: An instance of WebDriver.
+ """
+ user_data_dir = user_data_dir or self.user_data_dir
if self.default_driver:
self.default_driver.quit()
if not virtual_display and self.virtual_display:
@@ -105,6 +127,10 @@ class WebDriverSession():
return self.default_driver
def __enter__(self) -> WebDriver:
+ """
+ Context management method for entering a session.
+ :return: An instance of WebDriver.
+ """
if self.webdriver:
return self.webdriver
if self.virtual_display:
@@ -113,11 +139,15 @@ class WebDriverSession():
return self.default_driver
def __exit__(self, exc_type, exc_val, exc_tb):
+ """
+ Context management method for exiting a session. Closes and quits the WebDriver.
+ """
if self.default_driver:
try:
self.default_driver.close()
- except:
- pass
+ except Exception as e:
+ if debug.logging:
+ print(f"Error closing WebDriver: {e}")
self.default_driver.quit()
if self.virtual_display:
self.virtual_display.stop()
\ No newline at end of file
--
cgit v1.2.3
From 32252def150da94f12d1f3c07f977af6d8931402 Mon Sep 17 00:00:00 2001
From: Heiner Lohaus
Date: Sun, 14 Jan 2024 15:04:37 +0100
Subject: Change doctypes style to Google Fix typo in latest_version Fix Phind
Provider Add unittest worklow and main tests
---
.github/workflows/unittest.yml | 19 ++++
etc/unittest/main.py | 73 +++++++++++++
g4f/Provider/Phind.py | 8 +-
g4f/Provider/base_provider.py | 71 +++++++++++++
g4f/Provider/bing/create_images.py | 2 +-
g4f/Provider/create_images.py | 61 ++++++++++-
g4f/gui/client/js/chat.v1.js | 6 +-
g4f/gui/server/backend.py | 211 ++++++++++++++++++++++++++-----------
g4f/version.py | 56 +++++++---
g4f/webdriver.py | 75 +++++++++----
10 files changed, 478 insertions(+), 104 deletions(-)
create mode 100644 .github/workflows/unittest.yml
create mode 100644 etc/unittest/main.py
diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml
new file mode 100644
index 00000000..e895e969
--- /dev/null
+++ b/.github/workflows/unittest.yml
@@ -0,0 +1,19 @@
+name: Unittest
+
+on: [push]
+
+jobs:
+ build:
+ name: Build unittest
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.x"
+ cache: 'pip'
+ - name: Install requirements
+ - run: pip install -r requirements.txt
+ - name: Run tests
+ run: python -m etc.unittest.main
\ No newline at end of file
diff --git a/etc/unittest/main.py b/etc/unittest/main.py
new file mode 100644
index 00000000..61f4ffda
--- /dev/null
+++ b/etc/unittest/main.py
@@ -0,0 +1,73 @@
+import sys
+import pathlib
+import unittest
+from unittest.mock import MagicMock
+
+sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
+
+import g4f
+from g4f import ChatCompletion, get_last_provider
+from g4f.gui.server.backend import Backend_Api, get_error_message
+from g4f.base_provider import BaseProvider
+
+g4f.debug.logging = False
+
+class MockProvider(BaseProvider):
+ working = True
+
+ def create_completion(
+ model, messages, stream, **kwargs
+ ):
+ yield "Mock"
+
+ async def create_async(
+ model, messages, **kwargs
+ ):
+ return "Mock"
+
+class TestBackendApi(unittest.TestCase):
+
+ def setUp(self):
+ self.app = MagicMock()
+ self.api = Backend_Api(self.app)
+
+ def test_version(self):
+ response = self.api.get_version()
+ self.assertIn("version", response)
+ self.assertIn("latest_version", response)
+
+class TestChatCompletion(unittest.TestCase):
+
+ def test_create(self):
+ messages = [{'role': 'user', 'content': 'Hello'}]
+ result = ChatCompletion.create(g4f.models.default, messages)
+ self.assertTrue("Hello" in result or "Good" in result)
+
+ def test_get_last_provider(self):
+ messages = [{'role': 'user', 'content': 'Hello'}]
+ ChatCompletion.create(g4f.models.default, messages, MockProvider)
+ self.assertEqual(get_last_provider(), MockProvider)
+
+ def test_bing_provider(self):
+ messages = [{'role': 'user', 'content': 'Hello'}]
+ provider = g4f.Provider.Bing
+ result = ChatCompletion.create(g4f.models.default, messages, provider)
+ self.assertTrue("Bing" in result)
+
+class TestChatCompletionAsync(unittest.IsolatedAsyncioTestCase):
+
+ async def test_async(self):
+ messages = [{'role': 'user', 'content': 'Hello'}]
+ result = await ChatCompletion.create_async(g4f.models.default, messages, MockProvider)
+ self.assertTrue("Mock" in result)
+
+class TestUtilityFunctions(unittest.TestCase):
+
+ def test_get_error_message(self):
+ g4f.debug.last_provider = g4f.Provider.Bing
+ exception = Exception("Message")
+ result = get_error_message(exception)
+ self.assertEqual("Bing: Exception: Message", result)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/g4f/Provider/Phind.py b/g4f/Provider/Phind.py
index bb216989..9e80baa9 100644
--- a/g4f/Provider/Phind.py
+++ b/g4f/Provider/Phind.py
@@ -59,12 +59,16 @@ class Phind(AsyncGeneratorProvider):
"rewrittenQuestion": prompt,
"challenge": 0.21132115912208504
}
- async with session.post(f"{cls.url}/api/infer/followup/answer", headers=headers, json=data) as response:
+ async with session.post(f"https://https.api.phind.com/infer/", headers=headers, json=data) as response:
new_line = False
async for line in response.iter_lines():
if line.startswith(b"data: "):
chunk = line[6:]
- if chunk.startswith(b"") or chunk.startswith(b""):
+ if chunk.startswith(b''):
+ break
+ if chunk.startswith(b'') or chunk.startswith(b''):
+ pass
+ elif chunk.startswith(b"") or chunk.startswith(b""):
pass
elif chunk:
yield chunk.decode()
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index 3c083bda..fd92d17a 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -36,6 +36,17 @@ class AbstractProvider(BaseProvider):
) -> str:
"""
Asynchronously creates a result based on the given model and messages.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ str: The created result as a string.
"""
loop = loop or get_event_loop()
@@ -52,6 +63,12 @@ class AbstractProvider(BaseProvider):
def params(cls) -> str:
"""
Returns the parameters supported by the provider.
+
+ Args:
+ cls (type): The class on which this property is called.
+
+ Returns:
+ str: A string listing the supported parameters.
"""
sig = signature(
cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
@@ -90,6 +107,17 @@ class AsyncProvider(AbstractProvider):
) -> CreateResult:
"""
Creates a completion result synchronously.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to False.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ CreateResult: The result of the completion creation.
"""
loop = loop or get_event_loop()
coro = cls.create_async(model, messages, **kwargs)
@@ -104,6 +132,17 @@ class AsyncProvider(AbstractProvider):
) -> str:
"""
Abstract method for creating asynchronous results.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ NotImplementedError: If this method is not overridden in derived classes.
+
+ Returns:
+ str: The created result as a string.
"""
raise NotImplementedError()
@@ -126,6 +165,17 @@ class AsyncGeneratorProvider(AsyncProvider):
) -> CreateResult:
"""
Creates a streaming completion result synchronously.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to True.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ CreateResult: The result of the streaming completion creation.
"""
loop = loop or get_event_loop()
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
@@ -146,6 +196,15 @@ class AsyncGeneratorProvider(AsyncProvider):
) -> str:
"""
Asynchronously creates a result from a generator.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ str: The created result as a string.
"""
return "".join([
chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
@@ -162,5 +221,17 @@ class AsyncGeneratorProvider(AsyncProvider):
) -> AsyncResult:
"""
Abstract method for creating an asynchronous generator.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to True.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ NotImplementedError: If this method is not overridden in derived classes.
+
+ Returns:
+ AsyncResult: An asynchronous generator yielding results.
"""
raise NotImplementedError()
\ No newline at end of file
diff --git a/g4f/Provider/bing/create_images.py b/g4f/Provider/bing/create_images.py
index 29daccbd..060cd184 100644
--- a/g4f/Provider/bing/create_images.py
+++ b/g4f/Provider/bing/create_images.py
@@ -198,7 +198,7 @@ class CreateImagesBing:
_cookies: Dict[str, str] = {}
@classmethod
- def create_completion(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> Generator[str]:
+ def create_completion(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> Generator[str, None, None]:
"""
Generator for creating imagecompletion based on a prompt.
diff --git a/g4f/Provider/create_images.py b/g4f/Provider/create_images.py
index f8a0442d..b8bcbde3 100644
--- a/g4f/Provider/create_images.py
+++ b/g4f/Provider/create_images.py
@@ -8,13 +8,31 @@ from ..base_provider import BaseProvider, ProviderType
system_message = """
You can generate custom images with the DALL-E 3 image generator.
-To generate a image with a prompt, do this:
+To generate an image with a prompt, do this:
Don't use images with data uri. It is important to use a prompt instead.
"""
class CreateImagesProvider(BaseProvider):
+ """
+ Provider class for creating images based on text prompts.
+
+ This provider handles image creation requests embedded within message content,
+ using provided image creation functions.
+
+ Attributes:
+ provider (ProviderType): The underlying provider to handle non-image related tasks.
+ create_images (callable): A function to create images synchronously.
+ create_images_async (callable): A function to create images asynchronously.
+ system_message (str): A message that explains the image creation capability.
+ include_placeholder (bool): Flag to determine whether to include the image placeholder in the output.
+ __name__ (str): Name of the provider.
+ url (str): URL of the provider.
+ working (bool): Indicates if the provider is operational.
+ supports_stream (bool): Indicates if the provider supports streaming.
+ """
+
def __init__(
self,
provider: ProviderType,
@@ -23,6 +41,16 @@ class CreateImagesProvider(BaseProvider):
system_message: str = system_message,
include_placeholder: bool = True
) -> None:
+ """
+ Initializes the CreateImagesProvider.
+
+ Args:
+ provider (ProviderType): The underlying provider.
+ create_images (callable): Function to create images synchronously.
+ create_async (callable): Function to create images asynchronously.
+ system_message (str, optional): System message to be prefixed to messages. Defaults to a predefined message.
+ include_placeholder (bool, optional): Whether to include image placeholders in the output. Defaults to True.
+ """
self.provider = provider
self.create_images = create_images
self.create_images_async = create_async
@@ -40,6 +68,22 @@ class CreateImagesProvider(BaseProvider):
stream: bool = False,
**kwargs
) -> CreateResult:
+ """
+ Creates a completion result, processing any image creation prompts found within the messages.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process, which may contain image prompts.
+ stream (bool, optional): Indicates whether to stream the results. Defaults to False.
+ **kwargs: Additional keywordarguments for the provider.
+
+ Yields:
+ CreateResult: Yields chunks of the processed messages, including image data if applicable.
+
+ Note:
+ This method processes messages to detect image creation prompts. When such a prompt is found,
+ it calls the synchronous image creation function and includes the resulting image in the output.
+ """
messages.insert(0, {"role": "system", "content": self.system_message})
buffer = ""
for chunk in self.provider.create_completion(model, messages, stream, **kwargs):
@@ -71,6 +115,21 @@ class CreateImagesProvider(BaseProvider):
messages: Messages,
**kwargs
) -> str:
+ """
+ Asynchronously creates a response, processing any image creation prompts found within the messages.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process, which may contain image prompts.
+ **kwargs: Additional keyword arguments for the provider.
+
+ Returns:
+ str: The processed response string, including asynchronously generated image data if applicable.
+
+ Note:
+ This method processes messages to detect image creation prompts. When such a prompt is found,
+ it calls the asynchronous image creation function and includes the resulting image in the output.
+ """
messages.insert(0, {"role": "system", "content": self.system_message})
response = await self.provider.create_async(model, messages, **kwargs)
matches = re.findall(r'()', response)
diff --git a/g4f/gui/client/js/chat.v1.js b/g4f/gui/client/js/chat.v1.js
index 7ed9f183..8b9bc181 100644
--- a/g4f/gui/client/js/chat.v1.js
+++ b/g4f/gui/client/js/chat.v1.js
@@ -652,9 +652,9 @@ observer.observe(message_input, { attributes: true });
document.title = 'g4f - gui - ' + versions["version"];
text = "version ~ "
- if (versions["version"] != versions["lastet_version"]) {
- release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["lastet_version"];
- text += '' + versions["version"] + ' 🆕';
+ if (versions["version"] != versions["latest_version"]) {
+ release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["latest_version"];
+ text += '' + versions["version"] + ' 🆕';
} else {
text += versions["version"];
}
diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py
index 9d12bea5..4a5cafa8 100644
--- a/g4f/gui/server/backend.py
+++ b/g4f/gui/server/backend.py
@@ -1,6 +1,7 @@
import logging
import json
from flask import request, Flask
+from typing import Generator
from g4f import debug, version, models
from g4f import _all_models, get_last_provider, ChatCompletion
from g4f.image import is_allowed_extension, to_image
@@ -11,60 +12,123 @@ from .internet import get_search_message
debug.logging = True
class Backend_Api:
+ """
+ Handles various endpoints in a Flask application for backend operations.
+
+ This class provides methods to interact with models, providers, and to handle
+ various functionalities like conversations, error handling, and version management.
+
+ Attributes:
+ app (Flask): A Flask application instance.
+ routes (dict): A dictionary mapping API endpoints to their respective handlers.
+ """
def __init__(self, app: Flask) -> None:
+ """
+ Initialize the backend API with the given Flask application.
+
+ Args:
+ app (Flask): Flask application instance to attach routes to.
+ """
self.app: Flask = app
self.routes = {
'/backend-api/v2/models': {
- 'function': self.models,
- 'methods' : ['GET']
+ 'function': self.get_models,
+ 'methods': ['GET']
},
'/backend-api/v2/providers': {
- 'function': self.providers,
- 'methods' : ['GET']
+ 'function': self.get_providers,
+ 'methods': ['GET']
},
'/backend-api/v2/version': {
- 'function': self.version,
- 'methods' : ['GET']
+ 'function': self.get_version,
+ 'methods': ['GET']
},
'/backend-api/v2/conversation': {
- 'function': self._conversation,
+ 'function': self.handle_conversation,
'methods': ['POST']
},
'/backend-api/v2/gen.set.summarize:title': {
- 'function': self._gen_title,
+ 'function': self.generate_title,
'methods': ['POST']
},
'/backend-api/v2/error': {
- 'function': self.error,
+ 'function': self.handle_error,
'methods': ['POST']
}
}
- def error(self):
+ def handle_error(self):
+ """
+ Initialize the backend API with the given Flask application.
+
+ Args:
+ app (Flask): Flask application instance to attach routes to.
+ """
print(request.json)
-
return 'ok', 200
- def models(self):
+ def get_models(self):
+ """
+ Return a list of all models.
+
+ Fetches and returns a list of all available models in the system.
+
+ Returns:
+ List[str]: A list of model names.
+ """
return _all_models
- def providers(self):
- return [
- provider.__name__ for provider in __providers__ if provider.working
- ]
+ def get_providers(self):
+ """
+ Return a list of all working providers.
+ """
+ return [provider.__name__ for provider in __providers__ if provider.working]
- def version(self):
+ def get_version(self):
+ """
+ Returns the current and latest version of the application.
+
+ Returns:
+ dict: A dictionary containing the current and latest version.
+ """
return {
"version": version.utils.current_version,
- "lastet_version": version.get_latest_version(),
+ "latest_version": version.get_latest_version(),
}
- def _gen_title(self):
- return {
- 'title': ''
- }
+ def generate_title(self):
+ """
+ Generates and returns a title based on the request data.
+
+ Returns:
+ dict: A dictionary with the generated title.
+ """
+ return {'title': ''}
- def _conversation(self):
+ def handle_conversation(self):
+ """
+ Handles conversation requests and streams responses back.
+
+ Returns:
+ Response: A Flask response object for streaming.
+ """
+ kwargs = self._prepare_conversation_kwargs()
+
+ return self.app.response_class(
+ self._create_response_stream(kwargs),
+ mimetype='text/event-stream'
+ )
+
+ def _prepare_conversation_kwargs(self):
+ """
+ Prepares arguments for chat completion based on the request data.
+
+ Reads the request and prepares the necessary arguments for handling
+ a chat completion request.
+
+ Returns:
+ dict: Arguments prepared for chat completion.
+ """
kwargs = {}
if 'image' in request.files:
file = request.files['image']
@@ -87,47 +151,70 @@ class Backend_Api:
messages[-1]["content"] = get_search_message(messages[-1]["content"])
model = json_data.get('model')
model = model if model else models.default
- provider = json_data.get('provider', '').replace('g4f.Provider.', '')
- provider = provider if provider and provider != "Auto" else None
patch = patch_provider if json_data.get('patch_provider') else None
- def try_response():
- try:
- first = True
- for chunk in ChatCompletion.create(
- model=model,
- provider=provider,
- messages=messages,
- stream=True,
- ignore_stream_and_auth=True,
- patch_provider=patch,
- **kwargs
- ):
- if first:
- first = False
- yield json.dumps({
- 'type' : 'provider',
- 'provider': get_last_provider(True)
- }) + "\n"
- if isinstance(chunk, Exception):
- logging.exception(chunk)
- yield json.dumps({
- 'type' : 'message',
- 'message': get_error_message(chunk),
- }) + "\n"
- else:
- yield json.dumps({
- 'type' : 'content',
- 'content': str(chunk),
- }) + "\n"
- except Exception as e:
- logging.exception(e)
- yield json.dumps({
- 'type' : 'error',
- 'error': get_error_message(e)
- })
-
- return self.app.response_class(try_response(), mimetype='text/event-stream')
+ return {
+ "model": model,
+ "provider": provider,
+ "messages": messages,
+ "stream": True,
+ "ignore_stream_and_auth": True,
+ "patch_provider": patch,
+ **kwargs
+ }
+
+ def _create_response_stream(self, kwargs) -> Generator[str, None, None]:
+ """
+ Creates and returns a streaming response for the conversation.
+
+ Args:
+ kwargs (dict): Arguments for creating the chat completion.
+
+ Yields:
+ str: JSON formatted response chunks for the stream.
+
+ Raises:
+ Exception: If an error occurs during the streaming process.
+ """
+ try:
+ first = True
+ for chunk in ChatCompletion.create(**kwargs):
+ if first:
+ first = False
+ yield self._format_json('provider', get_last_provider(True))
+ if isinstance(chunk, Exception):
+ logging.exception(chunk)
+ yield self._format_json('message', get_error_message(chunk))
+ else:
+ yield self._format_json('content', str(chunk))
+ except Exception as e:
+ logging.exception(e)
+ yield self._format_json('error', get_error_message(e))
+
+ def _format_json(self, response_type: str, content) -> str:
+ """
+ Formats and returns a JSON response.
+
+ Args:
+ response_type (str): The type of the response.
+ content: The content to be included in the response.
+
+ Returns:
+ str: A JSON formatted string.
+ """
+ return json.dumps({
+ 'type': response_type,
+ response_type: content
+ }) + "\n"
def get_error_message(exception: Exception) -> str:
+ """
+ Generates a formatted error message from an exception.
+
+ Args:
+ exception (Exception): The exception to format.
+
+ Returns:
+ str: A formatted error message string.
+ """
return f"{get_last_provider().__name__}: {type(exception).__name__}: {exception}"
\ No newline at end of file
diff --git a/g4f/version.py b/g4f/version.py
index c976c8fd..9201c75c 100644
--- a/g4f/version.py
+++ b/g4f/version.py
@@ -7,10 +7,16 @@ from .errors import VersionNotFoundError
def get_pypi_version(package_name: str) -> str:
"""
- Get the latest version of a package from PyPI.
+ Retrieves the latest version of a package from PyPI.
- :param package_name: The name of the package.
- :return: The latest version of the package as a string.
+ Args:
+ package_name (str): The name of the package for which to retrieve the version.
+
+ Returns:
+ str: The latest version of the specified package from PyPI.
+
+ Raises:
+ VersionNotFoundError: If there is an error in fetching the version from PyPI.
"""
try:
response = requests.get(f"https://pypi.org/pypi/{package_name}/json").json()
@@ -20,10 +26,16 @@ def get_pypi_version(package_name: str) -> str:
def get_github_version(repo: str) -> str:
"""
- Get the latest release version from a GitHub repository.
+ Retrieves the latest release version from a GitHub repository.
+
+ Args:
+ repo (str): The name of the GitHub repository.
+
+ Returns:
+ str: The latest release version from the specified GitHub repository.
- :param repo: The name of the GitHub repository.
- :return: The latest release version as a string.
+ Raises:
+ VersionNotFoundError: If there is an error in fetching the version from GitHub.
"""
try:
response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest").json()
@@ -31,11 +43,16 @@ def get_github_version(repo: str) -> str:
except requests.RequestException as e:
raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")
-def get_latest_version():
+def get_latest_version() -> str:
"""
- Get the latest release version from PyPI or the GitHub repository.
+ Retrieves the latest release version of the 'g4f' package from PyPI or GitHub.
- :return: The latest release version as a string.
+ Returns:
+ str: The latest release version of 'g4f'.
+
+ Note:
+ The function first tries to fetch the version from PyPI. If the package is not found,
+ it retrieves the version from the GitHub repository.
"""
try:
# Is installed via package manager?
@@ -47,14 +64,19 @@ def get_latest_version():
class VersionUtils:
"""
- Utility class for managing and comparing package versions.
+ Utility class for managing and comparing package versions of 'g4f'.
"""
@cached_property
def current_version(self) -> str:
"""
- Get the current version of the g4f package.
+ Retrieves the current version of the 'g4f' package.
+
+ Returns:
+ str: The current version of 'g4f'.
- :return: The current version as a string.
+ Raises:
+ VersionNotFoundError: If the version cannot be determined from the package manager,
+ Docker environment, or git repository.
"""
# Read from package manager
try:
@@ -79,15 +101,19 @@ class VersionUtils:
@cached_property
def latest_version(self) -> str:
"""
- Get the latest version of the g4f package.
+ Retrieves the latest version of the 'g4f' package.
- :return: The latest version as a string.
+ Returns:
+ str: The latest version of 'g4f'.
"""
return get_latest_version()
def check_version(self) -> None:
"""
- Check if the current version is up to date with the latest version.
+ Checks if the current version of 'g4f' is up to date with the latest version.
+
+ Note:
+ If a newer version is available, it prints a message with the new version and update instructions.
"""
try:
if self.current_version != self.latest_version:
diff --git a/g4f/webdriver.py b/g4f/webdriver.py
index e5ecd8bf..9a83215f 100644
--- a/g4f/webdriver.py
+++ b/g4f/webdriver.py
@@ -21,13 +21,16 @@ def get_browser(
options: ChromeOptions = None
) -> WebDriver:
"""
- Creates and returns a Chrome WebDriver with the specified options.
+ Creates and returns a Chrome WebDriver with specified options.
- :param user_data_dir: Directory for user data. If None, uses default directory.
- :param headless: Boolean indicating whether to run the browser in headless mode.
- :param proxy: Proxy settings for the browser.
- :param options: ChromeOptions object with specific browser options.
- :return: An instance of WebDriver.
+ Args:
+ user_data_dir (str, optional): Directory for user data. If None, uses default directory.
+ headless (bool, optional): Whether to run the browser in headless mode. Defaults to False.
+ proxy (str, optional): Proxy settings for the browser. Defaults to None.
+ options (ChromeOptions, optional): ChromeOptions object with specific browser options. Defaults to None.
+
+ Returns:
+ WebDriver: An instance of WebDriver configured with the specified options.
"""
if user_data_dir is None:
user_data_dir = user_config_dir("g4f")
@@ -49,10 +52,13 @@ def get_browser(
def get_driver_cookies(driver: WebDriver) -> dict:
"""
- Retrieves cookies from the given WebDriver.
+ Retrieves cookies from the specified WebDriver.
+
+ Args:
+ driver (WebDriver): The WebDriver instance from which to retrieve cookies.
- :param driver: WebDriver from which to retrieve cookies.
- :return: A dictionary of cookies.
+ Returns:
+ dict: A dictionary containing cookies with their names as keys and values as cookie values.
"""
return {cookie["name"]: cookie["value"] for cookie in driver.get_cookies()}
@@ -60,9 +66,13 @@ def bypass_cloudflare(driver: WebDriver, url: str, timeout: int) -> None:
"""
Attempts to bypass Cloudflare protection when accessing a URL using the provided WebDriver.
- :param driver: The WebDriver to use.
- :param url: URL to access.
- :param timeout: Time in seconds to wait for the page to load.
+ Args:
+ driver (WebDriver): The WebDriver to use for accessing the URL.
+ url (str): The URL to access.
+ timeout (int): Time in seconds to wait for the page to load.
+
+ Raises:
+ Exception: If there is an error while bypassing Cloudflare or loading the page.
"""
driver.get(url)
if driver.find_element(By.TAG_NAME, "body").get_attribute("class") == "no-js":
@@ -86,6 +96,7 @@ class WebDriverSession:
"""
Manages a Selenium WebDriver session, including handling of virtual displays and proxies.
"""
+
def __init__(
self,
webdriver: WebDriver = None,
@@ -95,6 +106,17 @@ class WebDriverSession:
proxy: str = None,
options: ChromeOptions = None
):
+ """
+ Initializes a new instance of the WebDriverSession.
+
+ Args:
+ webdriver (WebDriver, optional): A WebDriver instance for the session. Defaults to None.
+ user_data_dir (str, optional): Directory for user data. Defaults to None.
+ headless (bool, optional): Whether to run the browser in headless mode. Defaults to False.
+ virtual_display (bool, optional): Whether to use a virtual display. Defaults to False.
+ proxy (str, optional): Proxy settings for the browser. Defaults to None.
+ options (ChromeOptions, optional): ChromeOptions for the browser. Defaults to None.
+ """
self.webdriver = webdriver
self.user_data_dir = user_data_dir
self.headless = headless
@@ -110,14 +132,17 @@ class WebDriverSession:
virtual_display: bool = False
) -> WebDriver:
"""
- Reopens the WebDriver session with the specified parameters.
+ Reopens the WebDriver session with new settings.
+
+ Args:
+ user_data_dir (str, optional): Directory for user data. Defaults to current value.
+ headless (bool, optional): Whether to run the browser in headless mode. Defaults to current value.
+ virtual_display (bool, optional): Whether to use a virtual display. Defaults to current value.
- :param user_data_dir: Directory for user data.
- :param headless: Boolean indicating whether to run the browser in headless mode.
- :param virtual_display: Boolean indicating whether to use a virtual display.
- :return: An instance of WebDriver.
+ Returns:
+ WebDriver: The reopened WebDriver instance.
"""
- user_data_dir = user_data_dir or self.user_data_dir
+ user_data_dir = user_data_data_dir or self.user_data_dir
if self.default_driver:
self.default_driver.quit()
if not virtual_display and self.virtual_display:
@@ -128,8 +153,10 @@ class WebDriverSession:
def __enter__(self) -> WebDriver:
"""
- Context management method for entering a session.
- :return: An instance of WebDriver.
+ Context management method for entering a session. Initializes and returns a WebDriver instance.
+
+ Returns:
+ WebDriver: An instance of WebDriver for this session.
"""
if self.webdriver:
return self.webdriver
@@ -141,6 +168,14 @@ class WebDriverSession:
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Context management method for exiting a session. Closes and quits the WebDriver.
+
+ Args:
+ exc_type: Exception type.
+ exc_val: Exception value.
+ exc_tb: Exception traceback.
+
+ Note:
+ Closes the WebDriver and stops the virtual display if used.
"""
if self.default_driver:
try:
--
cgit v1.2.3