From 9cbe9c1ccb2381e37402a36297f11a0f96b1b557 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Sun, 21 Jan 2024 02:20:23 +0100 Subject: Improve tests --- .github/workflows/copilot.yml | 4 ++- .github/workflows/unittest.yml | 2 +- etc/unittest/__main__.py | 6 ++++ etc/unittest/asyncio.py | 57 ++++++++++++++++++++++++++++++ etc/unittest/backend.py | 38 ++++++++++++++++++++ etc/unittest/include.py | 11 ++++++ etc/unittest/main.py | 78 +++++++++++------------------------------- etc/unittest/mocks.py | 25 ++++++++++++++ g4f/Provider/Bing.py | 13 +++---- g4f/Provider/base_provider.py | 41 ++++++++++++---------- g4f/gui/server/backend.py | 3 +- g4f/image.py | 14 ++++---- g4f/typing.py | 7 ++++ 13 files changed, 204 insertions(+), 95 deletions(-) create mode 100644 etc/unittest/__main__.py create mode 100644 etc/unittest/asyncio.py create mode 100644 etc/unittest/backend.py create mode 100644 etc/unittest/include.py create mode 100644 etc/unittest/mocks.py diff --git a/.github/workflows/copilot.yml b/.github/workflows/copilot.yml index c34dcb11..38c24378 100644 --- a/.github/workflows/copilot.yml +++ b/.github/workflows/copilot.yml @@ -9,7 +9,9 @@ on: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -permissions: write-all +permissions: + contents: read + pull-requests: write jobs: review: diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index 20d9e55f..0646eab5 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -16,4 +16,4 @@ jobs: - name: Install requirements run: pip install -r requirements.txt - name: Run tests - run: python -m etc.unittest.main \ No newline at end of file + run: python -m etc.unittest \ No newline at end of file diff --git a/etc/unittest/__main__.py b/etc/unittest/__main__.py new file mode 100644 index 00000000..243c56b2 --- /dev/null +++ b/etc/unittest/__main__.py @@ -0,0 +1,6 @@ +import unittest +from .asyncio import * +from .backend import * +from .main import * + +unittest.main() \ No newline at end of file diff --git a/etc/unittest/asyncio.py b/etc/unittest/asyncio.py new file mode 100644 index 00000000..74e29986 --- /dev/null +++ b/etc/unittest/asyncio.py @@ -0,0 +1,57 @@ +from .include import DEFAULT_MESSAGES +import asyncio +import nest_asyncio +import unittest +import g4f +from g4f import ChatCompletion +from .mocks import ProviderMock, AsyncProviderMock, AsyncGeneratorProviderMock + +class TestChatCompletion(unittest.TestCase): + + async def run_exception(self): + return ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock) + + def test_exception(self): + self.assertRaises(g4f.errors.NestAsyncioError, asyncio.run, self.run_exception()) + + def test_create(self): + result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock) + self.assertEqual("Mock",result) + + def test_create_generator(self): + result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncGeneratorProviderMock) + self.assertEqual("Mock",result) + +class TestChatCompletionAsync(unittest.IsolatedAsyncioTestCase): + + async def test_base(self): + result = await ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, ProviderMock) + self.assertEqual("Mock",result) + + async def test_async(self): + result = await ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock) + self.assertEqual("Mock",result) + + async def test_create_generator(self): + result = await ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, AsyncGeneratorProviderMock) + self.assertEqual("Mock",result) + +class TestChatCompletionNestAsync(unittest.IsolatedAsyncioTestCase): + + def setUp(self) -> None: + nest_asyncio.apply() + + async def test_create(self): + result = await ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, ProviderMock) + self.assertEqual("Mock",result) + + async def test_nested(self): + result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock) + self.assertEqual("Mock",result) + + async def test_nested_generator(self): + result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncGeneratorProviderMock) + self.assertEqual("Mock",result) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/etc/unittest/backend.py b/etc/unittest/backend.py new file mode 100644 index 00000000..f5961e2d --- /dev/null +++ b/etc/unittest/backend.py @@ -0,0 +1,38 @@ +from . import include +import unittest +from unittest.mock import MagicMock +from .mocks import ProviderMock +import g4f +from g4f.gui.server.backend import Backend_Api, get_error_message + +class TestBackendApi(unittest.TestCase): + + def setUp(self): + self.app = MagicMock() + self.api = Backend_Api(self.app) + + def test_version(self): + response = self.api.get_version() + self.assertIn("version", response) + self.assertIn("latest_version", response) + + def test_get_models(self): + response = self.api.get_models() + self.assertIsInstance(response, list) + self.assertTrue(len(response) > 0) + + def test_get_providers(self): + response = self.api.get_providers() + self.assertIsInstance(response, list) + self.assertTrue(len(response) > 0) + +class TestUtilityFunctions(unittest.TestCase): + + def test_get_error_message(self): + g4f.debug.last_provider = ProviderMock + exception = Exception("Message") + result = get_error_message(exception) + self.assertEqual("ProviderMock: Exception: Message", result) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/etc/unittest/include.py b/etc/unittest/include.py new file mode 100644 index 00000000..e67fd5a7 --- /dev/null +++ b/etc/unittest/include.py @@ -0,0 +1,11 @@ +import sys +import pathlib + +sys.path.append(str(pathlib.Path(__file__).parent.parent.parent)) + +import g4f + +g4f.debug.logging = False +g4f.debug.version_check = False + +DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}] \ No newline at end of file diff --git a/etc/unittest/main.py b/etc/unittest/main.py index ad1fe02d..5a220323 100644 --- a/etc/unittest/main.py +++ b/etc/unittest/main.py @@ -1,75 +1,37 @@ -import sys -import pathlib +from .include import DEFAULT_MESSAGES import unittest -from unittest.mock import MagicMock - -sys.path.append(str(pathlib.Path(__file__).parent.parent.parent)) - +import asyncio import g4f from g4f import ChatCompletion, get_last_provider -from g4f.gui.server.backend import Backend_Api, get_error_message -from g4f.base_provider import BaseProvider - -g4f.debug.logging = False -g4f.debug.version_check = False - -class MockProvider(BaseProvider): - working = True - - def create_completion( - model, messages, stream, **kwargs - ): - yield "Mock" - - async def create_async( - model, messages, **kwargs - ): - return "Mock" - -class TestBackendApi(unittest.TestCase): - - def setUp(self): - self.app = MagicMock() - self.api = Backend_Api(self.app) - - def test_version(self): - response = self.api.get_version() - self.assertIn("version", response) - self.assertIn("latest_version", response) +from g4f.Provider import RetryProvider +from .mocks import ProviderMock class TestChatCompletion(unittest.TestCase): def test_create_default(self): - messages = [{'role': 'user', 'content': 'Hello'}] - result = ChatCompletion.create(g4f.models.default, messages) + result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES) if "Good" not in result and "Hi" not in result: self.assertIn("Hello", result) - - def test_get_last_provider(self): - messages = [{'role': 'user', 'content': 'Hello'}] - ChatCompletion.create(g4f.models.default, messages, MockProvider) - self.assertEqual(get_last_provider(), MockProvider) - + def test_bing_provider(self): - messages = [{'role': 'user', 'content': 'Hello'}] provider = g4f.Provider.Bing - result = ChatCompletion.create(g4f.models.default, messages, provider) + result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, provider) self.assertIn("Bing", result) -class TestChatCompletionAsync(unittest.IsolatedAsyncioTestCase): - - async def test_async(self): - messages = [{'role': 'user', 'content': 'Hello'}] - result = await ChatCompletion.create_async(g4f.models.default, messages, MockProvider) - self.assertEqual("Mock", result) +class TestGetLastProvider(unittest.TestCase): -class TestUtilityFunctions(unittest.TestCase): - - def test_get_error_message(self): - g4f.debug.last_provider = g4f.Provider.Bing - exception = Exception("Message") - result = get_error_message(exception) - self.assertEqual("Bing: Exception: Message", result) + def test_get_last_provider(self): + ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, ProviderMock) + self.assertEqual(get_last_provider(), ProviderMock) + + def test_get_last_provider_retry(self): + ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, RetryProvider([ProviderMock])) + self.assertEqual(get_last_provider(), ProviderMock) + + def test_get_last_provider_async(self): + coroutine = ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, ProviderMock) + asyncio.run(coroutine) + self.assertEqual(get_last_provider(), ProviderMock) if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/etc/unittest/mocks.py b/etc/unittest/mocks.py new file mode 100644 index 00000000..a9505997 --- /dev/null +++ b/etc/unittest/mocks.py @@ -0,0 +1,25 @@ +from g4f.Provider.base_provider import AbstractProvider, AsyncProvider, AsyncGeneratorProvider + +class ProviderMock(AbstractProvider): + working = True + + def create_completion( + model, messages, stream, **kwargs + ): + yield "Mock" + +class AsyncProviderMock(AsyncProvider): + working = True + + async def create_async( + model, messages, **kwargs + ): + return "Mock" + +class AsyncGeneratorProviderMock(AsyncGeneratorProvider): + working = True + + async def create_async_generator( + model, messages, stream, **kwargs + ): + yield "Mock" \ No newline at end of file diff --git a/g4f/Provider/Bing.py b/g4f/Provider/Bing.py index 34687866..b869a6ef 100644 --- a/g4f/Provider/Bing.py +++ b/g4f/Provider/Bing.py @@ -64,12 +64,7 @@ class Bing(AsyncGeneratorProvider): prompt = messages[-1]["content"] context = create_context(messages[:-1]) - if not cookies: - cookies = Defaults.cookies - else: - for key, value in Defaults.cookies.items(): - if key not in cookies: - cookies[key] = value + cookies = {**Defaults.cookies, **cookies} if cookies else Defaults.cookies gpt4_turbo = True if model.startswith("gpt-4-turbo") else False @@ -207,10 +202,12 @@ def create_message( request_id = str(uuid.uuid4()) struct = { 'arguments': [{ - 'source': 'cib', 'optionsSets': options_sets, + 'source': 'cib', + 'optionsSets': options_sets, 'allowedMessageTypes': Defaults.allowedMessageTypes, 'sliceIds': Defaults.sliceIds, - 'traceId': os.urandom(16).hex(), 'isStartOfSession': True, + 'traceId': os.urandom(16).hex(), + 'isStartOfSession': True, 'requestId': request_id, 'message': { **Defaults.location, diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index 95f1b0b2..bc47a1fa 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -5,8 +5,8 @@ from asyncio import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor from abc import abstractmethod from inspect import signature, Parameter -from .helper import get_event_loop, get_cookies, format_prompt -from ..typing import CreateResult, AsyncResult, Messages +from .helper import get_cookies, format_prompt +from ..typing import CreateResult, AsyncResult, Messages, Union from ..base_provider import BaseProvider from ..errors import NestAsyncioError @@ -20,6 +20,17 @@ if sys.platform == 'win32': if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) +def get_running_loop() -> Union[AbstractEventLoop, None]: + try: + loop = asyncio.get_running_loop() + if not hasattr(loop.__class__, "_nest_patched"): + raise NestAsyncioError( + 'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.' + ) + return loop + except RuntimeError: + pass + class AbstractProvider(BaseProvider): """ Abstract class for providing asynchronous functionality to derived classes. @@ -56,7 +67,7 @@ class AbstractProvider(BaseProvider): return await asyncio.wait_for( loop.run_in_executor(executor, create_func), - timeout=kwargs.get("timeout", 0) + timeout=kwargs.get("timeout") ) @classmethod @@ -118,14 +129,7 @@ class AsyncProvider(AbstractProvider): Returns: CreateResult: The result of the completion creation. """ - try: - loop = asyncio.get_running_loop() - if not hasattr(loop.__class__, "_nest_patched"): - raise NestAsyncioError( - 'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.' - ) - except RuntimeError: - pass + get_running_loop() yield asyncio.run(cls.create_async(model, messages, **kwargs)) @staticmethod @@ -180,15 +184,12 @@ class AsyncGeneratorProvider(AsyncProvider): Returns: CreateResult: The result of the streaming completion creation. """ - try: - loop = asyncio.get_running_loop() - if not hasattr(loop.__class__, "_nest_patched"): - raise NestAsyncioError( - 'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.' - ) - except RuntimeError: + loop = get_running_loop() + new_loop = False + if not loop: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) + new_loop = True generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) gen = generator.__aiter__() @@ -199,6 +200,10 @@ class AsyncGeneratorProvider(AsyncProvider): except StopAsyncIteration: break + if new_loop: + loop.close() + asyncio.set_event_loop(None) + @classmethod async def create_async( cls, diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py index b4c8f56c..d5c59ed1 100644 --- a/g4f/gui/server/backend.py +++ b/g4f/gui/server/backend.py @@ -2,7 +2,7 @@ import logging import json from flask import request, Flask from typing import Generator -from g4f import debug, version, models +from g4f import version, models from g4f import _all_models, get_last_provider, ChatCompletion from g4f.image import is_allowed_extension, to_image from g4f.errors import VersionNotFoundError @@ -10,7 +10,6 @@ from g4f.Provider import __providers__ from g4f.Provider.bing.create_images import patch_provider from .internet import get_search_message -debug.logging = True class Backend_Api: """ diff --git a/g4f/image.py b/g4f/image.py index cfa22ab1..24ded915 100644 --- a/g4f/image.py +++ b/g4f/image.py @@ -112,7 +112,7 @@ def get_orientation(image: Image.Image) -> int: """ exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif() if exif_data is not None: - orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF + orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF if orientation is not None: return orientation @@ -156,23 +156,23 @@ def to_base64(image: Image.Image, compression_rate: float) -> str: image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100)) return base64.b64encode(output_buffer.getvalue()).decode() -def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=200") -> str: +def format_images_markdown(images, alt: str, preview: str="{image}?w=200&h=200") -> str: """ Formats the given images as a markdown string. Args: images: The images to format. - prompt (str): The prompt for the images. + alt (str): The alt for the images. preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200". Returns: str: The formatted markdown string. """ - if isinstance(images, list): - images = [f"[![#{idx+1} {prompt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)] - images = "\n".join(images) + if isinstance(images, str): + images = f"[![{alt}]({preview.replace('{image}', images)})]({images})" else: - images = f"[![{prompt}]({images})]({images})" + images = [f"[![#{idx+1} {alt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)] + images = "\n".join(images) start_flag = "\n" end_flag = "\n" return f"\n{start_flag}{images}\n{end_flag}\n" diff --git a/g4f/typing.py b/g4f/typing.py index c972f505..a6a62e3f 100644 --- a/g4f/typing.py +++ b/g4f/typing.py @@ -18,7 +18,14 @@ __all__ = [ 'AsyncGenerator', 'Generator', 'Tuple', + 'Union', + 'List', + 'Dict', + 'Type', 'TypedDict', 'SHA256', 'CreateResult', + 'AsyncResult', + 'Messages', + 'ImageType' ] -- cgit v1.2.3