From a28bab938704a15c825c1b45a8983c72e8c90ace Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Mon, 29 Jan 2024 18:14:46 +0100 Subject: Add aiohttp_socks to requirements Fix preview for uploaded and generated images in gui Improve typing, readme --- README.md | 55 +++++++++++++++++++++++++++++++++-- docker/Dockerfile | 3 +- g4f/Provider/Bing.py | 4 +-- g4f/Provider/bing/create_images.py | 9 +++--- g4f/Provider/bing/upload_image.py | 2 +- g4f/Provider/needs_auth/OpenaiChat.py | 31 ++++++++++---------- g4f/defaults.py | 13 +++++++++ g4f/gui/client/js/chat.v1.js | 16 ++++++++++ g4f/image.py | 39 +++++++++++++++---------- g4f/requests.py | 24 +++++---------- g4f/requests_aiohttp.py | 13 ++------- g4f/typing.py | 7 +++-- g4f/webdriver.py | 9 +++--- requirements.txt | 3 +- 14 files changed, 148 insertions(+), 80 deletions(-) create mode 100644 g4f/defaults.py diff --git a/README.md b/README.md index f82d6adb..9264d3fc 100644 --- a/README.md +++ b/README.md @@ -99,8 +99,29 @@ or set the api base in your client to: [http://localhost:1337/v1](http://localho ##### Install using pypi: +Install all supported tools / all used packages: ``` -pip install -U "g4f[all]" +pip install -U g4f[all] +``` +Install packages for uploading / generating images: +``` +pip install -U g4f[image] +``` +Install the packages required for providers with webdriver: +``` +pip install -U g4f[webdriver] +``` +Install the packages required for the OpenaiChat provider: +``` +pip install -U g4f[openai] +``` +Install the packages required for the interference api: +``` +pip install -U g4f[api] +``` +Install the packages required for the web gui: +``` +pip install -U g4f[gui] ``` ##### or: @@ -202,8 +223,9 @@ docker-compose down ### The Web UI -To use it in the web interface, type the following codes in the command line. -```python3 +To start the web interface, type the following codes in the command line. + +```python from g4f.gui import run_gui run_gui() ``` @@ -283,6 +305,33 @@ for message in response: print(message) ``` +##### Cookies / Access Token + +For generating images with Bing and for the OpenAi Chat you need cookies or a token from your browser session. From Bing you need the "_U" cookie and from OpenAI you need the "access_token". You can pass the cookies / the access token in the create function or you use the `set_cookies` setter: + +```python +from g4f import set_cookies + +set_cookies(".bing", { + "_U": "cookie value" +}) +set_cookies("chat.openai.com", { + "access_token": "token value" +}) + +from g4f.gui import run_gui +run_gui() +``` + +Alternatively, g4f reads the cookies with “browser_cookie3” from your browser +or it starts a browser instance with selenium "webdriver" for logging in. +If you use the pip package, you have to install “browser_cookie3” or "webdriver" by yourself. + +```bash +pip install browser_cookie3 +pip install g4f[webdriver] +``` + ##### Using Browser Some providers using a browser to bypass the bot protection. They using the selenium webdriver to control the browser. The browser settings and the login data are saved in a custom directory. If the headless mode is enabled, the browser windows are loaded invisibly. For performance reasons, it is recommended to reuse the browser instances and close them yourself at the end: diff --git a/docker/Dockerfile b/docker/Dockerfile index 294e1372..88e21b18 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -86,6 +86,5 @@ RUN pip install --upgrade pip && pip install -r requirements.txt # Copy the entire package into the container. ADD --chown=$G4F_USER:$G4F_USER g4f $G4F_DIR/g4f - # Expose ports -EXPOSE 8080 1337 +EXPOSE 8080 1337 \ No newline at end of file diff --git a/g4f/Provider/Bing.py b/g4f/Provider/Bing.py index 32879fa6..40a42bf5 100644 --- a/g4f/Provider/Bing.py +++ b/g4f/Provider/Bing.py @@ -288,8 +288,6 @@ async def stream_generate( ) as session: conversation = await create_conversation(session) image_request = await upload_image(session, image, tone) if image else None - if image_request: - yield image_request try: async with session.ws_connect( @@ -327,7 +325,7 @@ async def stream_generate( elif message.get('contentType') == "IMAGE": prompt = message.get('text') try: - image_response = ImageResponse(await create_images(session, prompt), prompt) + image_response = ImageResponse(await create_images(session, prompt), prompt, {"preview": "{image}?w=200&h=200"}) except: response_txt += f"\nhttps://www.bing.com/images/create?q={parse.quote(prompt)}" final = True diff --git a/g4f/Provider/bing/create_images.py b/g4f/Provider/bing/create_images.py index a3fcd91b..e1031e61 100644 --- a/g4f/Provider/bing/create_images.py +++ b/g4f/Provider/bing/create_images.py @@ -187,11 +187,11 @@ def get_cookies_from_browser(proxy: str = None) -> dict[str, str]: class CreateImagesBing: """A class for creating images using Bing.""" - + def __init__(self, cookies: dict[str, str] = {}, proxy: str = None) -> None: self.cookies = cookies self.proxy = proxy - + def create_completion(self, prompt: str) -> Generator[ImageResponse, None, None]: """ Generator for creating imagecompletion based on a prompt. @@ -229,9 +229,7 @@ class CreateImagesBing: proxy = os.environ.get("G4F_PROXY") async with create_session(cookies, proxy) as session: images = await create_images(session, prompt, self.proxy) - return ImageResponse(images, prompt) - -service = CreateImagesBing() + return ImageResponse(images, prompt, {"preview": "{image}?w=200&h=200"}) def patch_provider(provider: ProviderType) -> CreateImagesProvider: """ @@ -243,6 +241,7 @@ def patch_provider(provider: ProviderType) -> CreateImagesProvider: Returns: CreateImagesProvider: The patched provider with image creation capabilities. """ + service = CreateImagesBing() return CreateImagesProvider( provider, service.create_completion, diff --git a/g4f/Provider/bing/upload_image.py b/g4f/Provider/bing/upload_image.py index f9e11561..6d51aba0 100644 --- a/g4f/Provider/bing/upload_image.py +++ b/g4f/Provider/bing/upload_image.py @@ -149,4 +149,4 @@ def parse_image_response(response: dict) -> ImageRequest: if IMAGE_CONFIG["enableFaceBlurDebug"] else f"https://www.bing.com/images/blob?bcid={result['bcid']}" ) - return ImageRequest(result["imageUrl"], "", result) \ No newline at end of file + return ImageRequest(result) \ No newline at end of file diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 60a101d7..253d4f77 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -150,8 +150,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): headers=headers ) as response: response.raise_for_status() - download_url = (await response.json())["download_url"] - return ImageRequest(download_url, image_data["file_name"], image_data) + image_data["download_url"] = (await response.json())["download_url"] + return ImageRequest(image_data) @classmethod async def get_default_model(cls, session: StreamSession, headers: dict): @@ -175,7 +175,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): return cls.default_model @classmethod - def create_messages(cls, prompt: str, image_response: ImageRequest = None): + def create_messages(cls, prompt: str, image_request: ImageRequest = None): """ Create a list of messages for the user input @@ -187,7 +187,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): A list of messages with the user input and the image, if any """ # Check if there is an image response - if not image_response: + if not image_request: # Create a content object with the text type and the prompt content = {"content_type": "text", "parts": [prompt]} else: @@ -195,10 +195,10 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): content = { "content_type": "multimodal_text", "parts": [{ - "asset_pointer": f"file-service://{image_response.get('file_id')}", - "height": image_response.get("height"), - "size_bytes": image_response.get("file_size"), - "width": image_response.get("width"), + "asset_pointer": f"file-service://{image_request.get('file_id')}", + "height": image_request.get("height"), + "size_bytes": image_request.get("file_size"), + "width": image_request.get("width"), }, prompt] } # Create a message object with the user role and the content @@ -208,16 +208,16 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): "content": content, }] # Check if there is an image response - if image_response: + if image_request: # Add the metadata object with the attachments messages[0]["metadata"] = { "attachments": [{ - "height": image_response.get("height"), - "id": image_response.get("file_id"), - "mimeType": image_response.get("mime_type"), - "name": image_response.get("file_name"), - "size": image_response.get("file_size"), - "width": image_response.get("width"), + "height": image_request.get("height"), + "id": image_request.get("file_id"), + "mimeType": image_request.get("mime_type"), + "name": image_request.get("file_name"), + "size": image_request.get("file_size"), + "width": image_request.get("width"), }] } return messages @@ -352,7 +352,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): image_response = None if image: image_response = await cls.upload_image(session, headers, image) - yield image_response except Exception as e: yield e end_turn = EndTurn() diff --git a/g4f/defaults.py b/g4f/defaults.py new file mode 100644 index 00000000..6ae6d7eb --- /dev/null +++ b/g4f/defaults.py @@ -0,0 +1,13 @@ +DEFAULT_HEADERS = { + 'Accept': '*/*', + 'Accept-Encoding': 'gzip, deflate, br', + 'Accept-Language': 'en-US', + 'Connection': 'keep-alive', + 'Sec-Ch-Ua': '"Not A(Brand";v="99", "Google Chrome";v="121", "Chromium";v="121"', + 'Sec-Ch-Ua-Mobile': '?0', + 'Sec-Ch-Ua-Platform': '"Windows"', + 'Sec-Fetch-Dest': 'empty', + 'Sec-Fetch-Mode': 'cors', + 'Sec-Fetch-Site': 'same-site', + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.0.0 Safari/537.36' +} \ No newline at end of file diff --git a/g4f/gui/client/js/chat.v1.js b/g4f/gui/client/js/chat.v1.js index 99a75569..86eef8c9 100644 --- a/g4f/gui/client/js/chat.v1.js +++ b/g4f/gui/client/js/chat.v1.js @@ -59,6 +59,10 @@ const handle_ask = async () => {
${markdown_render(message)} + ${imageInput.dataset.src + ? 'Image upload' + : '' + }
`; @@ -666,6 +670,18 @@ observer.observe(message_input, { attributes: true }); })() imageInput.addEventListener('click', async (event) => { imageInput.value = ''; + delete imageInput.dataset.src; +}); +imageInput.addEventListener('change', async (event) => { + if (imageInput.files.length) { + const reader = new FileReader(); + reader.addEventListener('load', (event) => { + imageInput.dataset.src = event.target.result; + }); + reader.readAsDataURL(imageInput.files[0]); + } else { + delete imageInput.dataset.src; + } }); fileInput.addEventListener('click', async (event) => { fileInput.value = ''; diff --git a/g4f/image.py b/g4f/image.py index 68767155..1a4692b3 100644 --- a/g4f/image.py +++ b/g4f/image.py @@ -3,14 +3,13 @@ from __future__ import annotations import re from io import BytesIO import base64 -from .typing import ImageType, Union +from .typing import ImageType, Union, Image try: - from PIL.Image import open as open_image, new as new_image, Image + from PIL.Image import open as open_image, new as new_image from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90 has_requirements = True except ImportError: - Image = type has_requirements = False from .errors import MissingRequirementsError @@ -29,6 +28,9 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image: """ if not has_requirements: raise MissingRequirementsError('Install "pillow" package for images') + if isinstance(image, str): + is_data_uri_an_image(image) + image = extract_data_uri(image) if is_svg: try: import cairosvg @@ -39,9 +41,6 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image: buffer = BytesIO() cairosvg.svg2png(image, write_to=buffer) return open_image(buffer) - if isinstance(image, str): - is_data_uri_an_image(image) - image = extract_data_uri(image) if isinstance(image, bytes): is_accepted_format(image) return open_image(BytesIO(image)) @@ -79,9 +78,9 @@ def is_data_uri_an_image(data_uri: str) -> bool: if not re.match(r'data:image/(\w+);base64,', data_uri): raise ValueError("Invalid data URI image.") # Extract the image format from the data URI - image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1) + image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower() # Check if the image format is one of the allowed formats (jpg, jpeg, png, gif) - if image_format.lower() not in ALLOWED_EXTENSIONS: + if image_format not in ALLOWED_EXTENSIONS and image_format != "svg+xml": raise ValueError("Invalid image format (from mime file type).") def is_accepted_format(binary_data: bytes) -> bool: @@ -187,7 +186,7 @@ def to_base64_jpg(image: Image, compression_rate: float) -> str: image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100)) return base64.b64encode(output_buffer.getvalue()).decode() -def format_images_markdown(images, alt: str, preview: str="{image}?w=200&h=200") -> str: +def format_images_markdown(images, alt: str, preview: str = None) -> str: """ Formats the given images as a markdown string. @@ -200,9 +199,12 @@ def format_images_markdown(images, alt: str, preview: str="{image}?w=200&h=200") str: The formatted markdown string. """ if isinstance(images, str): - images = f"[![{alt}]({preview.replace('{image}', images)})]({images})" + images = f"[![{alt}]({preview.replace('{image}', images) if preview else images})]({images})" else: - images = [f"[![#{idx+1} {alt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)] + images = [ + f"[![#{idx+1} {alt}]({preview.replace('{image}', image) if preview else image})]({image})" + for idx, image in enumerate(images) + ] images = "\n".join(images) start_flag = "\n" end_flag = "\n" @@ -223,7 +225,7 @@ def to_bytes(image: Image) -> bytes: image.seek(0) return bytes_io.getvalue() -class ImageResponse(): +class ImageResponse: def __init__( self, images: Union[str, list], @@ -235,10 +237,17 @@ class ImageResponse(): self.options = options def __str__(self) -> str: - return format_images_markdown(self.images, self.alt) + return format_images_markdown(self.images, self.alt, self.get("preview")) def get(self, key: str): return self.options.get(key) -class ImageRequest(ImageResponse): - pass \ No newline at end of file +class ImageRequest: + def __init__( + self, + options: dict = {} + ): + self.options = options + + def get(self, key: str): + return self.options.get(key) \ No newline at end of file diff --git a/g4f/requests.py b/g4f/requests.py index 275e108b..d7b5996b 100644 --- a/g4f/requests.py +++ b/g4f/requests.py @@ -7,13 +7,13 @@ try: from .requests_curl_cffi import StreamResponse, StreamSession has_curl_cffi = True except ImportError: - Session = type + from typing import Type as Session from .requests_aiohttp import StreamResponse, StreamSession has_curl_cffi = False from .webdriver import WebDriver, WebDriverSession, bypass_cloudflare, get_driver_cookies from .errors import MissingRequirementsError - +from .defaults import DEFAULT_HEADERS def get_args_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120) -> dict: """ @@ -36,22 +36,14 @@ def get_args_from_browser(url: str, webdriver: WebDriver = None, proxy: str = No return { 'cookies': cookies, 'headers': { - 'accept': '*/*', - "accept-language": "en-US", - "accept-encoding": "gzip, deflate, br", - 'authority': parse.netloc, - 'origin': f'{parse.scheme}://{parse.netloc}', - 'referer': url, - "sec-ch-ua": "\"Google Chrome\";v=\"121\", \"Not;A=Brand\";v=\"8\", \"Chromium\";v=\"121\"", - "sec-ch-ua-mobile": "?0", - "sec-ch-ua-platform": "Windows", - 'sec-fetch-dest': 'empty', - 'sec-fetch-mode': 'cors', - 'sec-fetch-site': 'same-origin', - 'user-agent': user_agent, + **DEFAULT_HEADERS, + 'Authority': parse.netloc, + 'Origin': f'{parse.scheme}://{parse.netloc}', + 'Referer': url, + 'User-Agent': user_agent, }, } - + def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120) -> Session: if not has_curl_cffi: raise MissingRequirementsError('Install "curl_cffi" package') diff --git a/g4f/requests_aiohttp.py b/g4f/requests_aiohttp.py index aa097312..0da8973b 100644 --- a/g4f/requests_aiohttp.py +++ b/g4f/requests_aiohttp.py @@ -4,6 +4,7 @@ from aiohttp import ClientSession, ClientResponse, ClientTimeout from typing import AsyncGenerator, Any from .Provider.helper import get_connector +from .defaults import DEFAULT_HEADERS class StreamResponse(ClientResponse): async def iter_lines(self) -> AsyncGenerator[bytes, None]: @@ -17,17 +18,7 @@ class StreamSession(ClientSession): def __init__(self, headers: dict = {}, timeout: int = None, proxies: dict = {}, impersonate = None, **kwargs): if impersonate: headers = { - 'Accept-Encoding': 'gzip, deflate, br', - 'Accept-Language': 'en-US', - 'Connection': 'keep-alive', - 'Sec-Fetch-Dest': 'empty', - 'Sec-Fetch-Mode': 'cors', - 'Sec-Fetch-Site': 'same-site', - "User-Agent": 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36', - 'Accept': '*/*', - 'sec-ch-ua': '"Google Chrome";v="107", "Chromium";v="107", "Not?A_Brand";v="24"', - 'sec-ch-ua-mobile': '?0', - 'sec-ch-ua-platform': '"Windows"', + **DEFAULT_HEADERS, **headers } super().__init__( diff --git a/g4f/typing.py b/g4f/typing.py index c5a981bd..386b3dfc 100644 --- a/g4f/typing.py +++ b/g4f/typing.py @@ -1,9 +1,10 @@ import sys from typing import Any, AsyncGenerator, Generator, NewType, Tuple, Union, List, Dict, Type, IO, Optional + try: from PIL.Image import Image except ImportError: - Image = type + from typing import Type as Image if sys.version_info >= (3, 8): from typing import TypedDict @@ -14,7 +15,7 @@ SHA256 = NewType('sha_256_hash', str) CreateResult = Generator[str, None, None] AsyncResult = AsyncGenerator[str, None] Messages = List[Dict[str, str]] -Cookies = List[Dict[str, str]] +Cookies = Dict[str, str] ImageType = Union[str, bytes, IO, Image, None] __all__ = [ @@ -33,5 +34,7 @@ __all__ = [ 'CreateResult', 'AsyncResult', 'Messages', + 'Cookies', + 'Image', 'ImageType' ] diff --git a/g4f/webdriver.py b/g4f/webdriver.py index 44765402..d28cd97b 100644 --- a/g4f/webdriver.py +++ b/g4f/webdriver.py @@ -18,6 +18,7 @@ import time from shutil import which from os import path from os import access, R_OK +from .typing import Cookies from .errors import MissingRequirementsError from . import debug @@ -56,9 +57,7 @@ def get_browser( if proxy: options.add_argument(f'--proxy-server={proxy}') # Check for system driver in docker - driver = which('chromedriver') - if not driver: - driver = '/usr/bin/chromedriver' + driver = which('chromedriver') or '/usr/bin/chromedriver' if not path.isfile(driver) or not access(driver, R_OK): driver = None return Chrome( @@ -68,7 +67,7 @@ def get_browser( headless=headless ) -def get_driver_cookies(driver: WebDriver) -> dict: +def get_driver_cookies(driver: WebDriver) -> Cookies: """ Retrieves cookies from the specified WebDriver. @@ -115,8 +114,8 @@ def bypass_cloudflare(driver: WebDriver, url: str, timeout: int) -> None: driver.switch_to.window(window_handle) break + # Click on the challenge button in the iframe try: - # Click on the challenge button in the iframe driver.switch_to.frame(driver.find_element(By.CSS_SELECTOR, "#turnstile-wrapper iframe")) WebDriverWait(driver, 5).until( EC.presence_of_element_located((By.CSS_SELECTOR, "#challenge-stage input")) diff --git a/requirements.txt b/requirements.txt index 79bda967..ecb69a11 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,4 +19,5 @@ async-property undetected-chromedriver brotli beautifulsoup4 -setuptools \ No newline at end of file +setuptools +aiohttp_socks \ No newline at end of file -- cgit v1.2.3 From 770bdc54fc52a71ff56ec0e7f1a38adba01f0ae0 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Mon, 29 Jan 2024 20:13:54 +0100 Subject: Improve readme / unittests --- README.md | 24 ++++++++++++++---------- etc/unittest/asyncio.py | 30 +++++++++++++++++------------- etc/unittest/main.py | 13 ++++++++++--- 3 files changed, 41 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 9264d3fc..288f886f 100644 --- a/README.md +++ b/README.md @@ -103,25 +103,29 @@ Install all supported tools / all used packages: ``` pip install -U g4f[all] ``` -Install packages for uploading / generating images: +Install required packages for the OpenaiChat provider: ``` -pip install -U g4f[image] +pip install -U g4f[openai] ``` -Install the packages required for providers with webdriver: +Install required packages for the interference api: ``` -pip install -U g4f[webdriver] +pip install -U g4f[api] ``` -Install the packages required for the OpenaiChat provider: +Install required packages for the web interface: ``` -pip install -U g4f[openai] +pip install -U g4f[gui] ``` -Install the packages required for the interference api: +Install required packages for uploading / generating images: ``` -pip install -U g4f[api] +pip install -U g4f[image] ``` -Install the packages required for the web gui: +Install required packages for providers with webdriver: ``` -pip install -U g4f[gui] +pip install -U g4f[webdriver] +``` +Install required packages for proxy support: +``` +pip install -U aiohttp_socks ``` ##### or: diff --git a/etc/unittest/asyncio.py b/etc/unittest/asyncio.py index a31ce211..e886c43a 100644 --- a/etc/unittest/asyncio.py +++ b/etc/unittest/asyncio.py @@ -1,4 +1,3 @@ -from .include import DEFAULT_MESSAGES import asyncio try: import nest_asyncio @@ -6,55 +5,60 @@ try: except: has_nest_asyncio = False import unittest + import g4f from g4f import ChatCompletion from .mocks import ProviderMock, AsyncProviderMock, AsyncGeneratorProviderMock - + +DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}] + class TestChatCompletion(unittest.TestCase): - + async def run_exception(self): return ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock) - + def test_exception(self): + if hasattr(asyncio, '_nest_patched'): + self.skipTest('asyncio is already patched') self.assertRaises(g4f.errors.NestAsyncioError, asyncio.run, self.run_exception()) def test_create(self): result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock) self.assertEqual("Mock",result) - + def test_create_generator(self): result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncGeneratorProviderMock) self.assertEqual("Mock",result) class TestChatCompletionAsync(unittest.IsolatedAsyncioTestCase): - + async def test_base(self): result = await ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, ProviderMock) self.assertEqual("Mock",result) - + async def test_async(self): result = await ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock) self.assertEqual("Mock",result) - + async def test_create_generator(self): result = await ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, AsyncGeneratorProviderMock) self.assertEqual("Mock",result) - + class TestChatCompletionNestAsync(unittest.IsolatedAsyncioTestCase): - + def setUp(self) -> None: if not has_nest_asyncio: self.skipTest('"nest_asyncio" not installed') nest_asyncio.apply() - + async def test_create(self): result = await ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, ProviderMock) self.assertEqual("Mock",result) - + async def test_nested(self): result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock) self.assertEqual("Mock",result) - + async def test_nested_generator(self): result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncGeneratorProviderMock) self.assertEqual("Mock",result) diff --git a/etc/unittest/main.py b/etc/unittest/main.py index f5eb5138..cc3c6a18 100644 --- a/etc/unittest/main.py +++ b/etc/unittest/main.py @@ -24,12 +24,19 @@ class TestGetLastProvider(unittest.TestCase): def test_get_last_provider(self): ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, ProviderMock) self.assertEqual(get_last_provider(), ProviderMock) - + def test_get_last_provider_retry(self): ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, RetryProvider([ProviderMock])) self.assertEqual(get_last_provider(), ProviderMock) - + def test_get_last_provider_async(self): coroutine = ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, ProviderMock) asyncio.run(coroutine) - self.assertEqual(get_last_provider(), ProviderMock) \ No newline at end of file + self.assertEqual(get_last_provider(), ProviderMock) + + def test_get_last_provider_as_dict(self): + ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, ProviderMock) + last_provider_dict = get_last_provider(True) + self.assertIsInstance(last_provider_dict, dict) + self.assertIn('name', last_provider_dict) + self.assertEqual(ProviderMock.__name__, last_provider_dict['name']) \ No newline at end of file -- cgit v1.2.3