From 5ca47b44b2b42abb4f48163c17500b5ee67ab28f Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Tue, 5 Sep 2023 17:27:24 +0200 Subject: Add to many provider async and stream support, Fix Ails, AItianhu, ChatgptAi, ChatgptLogin Provider, Add fallback cookies to Bing, Improve OpenaiChat Provider --- testing/test_async.py | 37 +++++++++++++++++++++++++++++++++++++ testing/test_providers.py | 23 ++++++++++++++--------- 2 files changed, 51 insertions(+), 9 deletions(-) create mode 100644 testing/test_async.py (limited to 'testing') diff --git a/testing/test_async.py b/testing/test_async.py new file mode 100644 index 00000000..692946ea --- /dev/null +++ b/testing/test_async.py @@ -0,0 +1,37 @@ +import sys +from pathlib import Path +import asyncio + +sys.path.append(str(Path(__file__).parent.parent)) + +import g4f +from g4f.Provider import AsyncProvider +from testing.test_providers import get_providers +from testing.log_time import log_time_async + +async def create_async(provider: AsyncProvider): + model = g4f.models.gpt_35_turbo.name if provider.supports_gpt_35_turbo else g4f.models.default.name + try: + response = await log_time_async( + provider.create_async, + model=model, + messages=[{"role": "user", "content": "Hello Assistant!"}] + ) + assert type(response) is str + assert len(response) > 0 + return response + except Exception as e: + return e + +async def run_async(): + _providers: list[AsyncProvider] = [ + _provider + for _provider in get_providers() + if _provider.working and hasattr(_provider, "create_async") + ] + responses = [create_async(_provider) for _provider in _providers] + responses = await asyncio.gather(*responses) + for idx, provider in enumerate(_providers): + print(f"{provider.__name__}:", responses[idx]) + +print("Total:", asyncio.run(log_time_async(run_async))) \ No newline at end of file diff --git a/testing/test_providers.py b/testing/test_providers.py index c4fcbc0c..676f1a59 100644 --- a/testing/test_providers.py +++ b/testing/test_providers.py @@ -8,6 +8,11 @@ from g4f import BaseProvider, models, Provider logging = False +class Styles: + ENDC = "\033[0m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + def main(): providers = get_providers() failed_providers = [] @@ -24,39 +29,39 @@ def main(): print() if failed_providers: - print(f"{Fore.RED}Failed providers:\n") + print(f"{Fore.RED + Styles.BOLD}Failed providers:{Styles.ENDC}") for _provider in failed_providers: print(f"{Fore.RED}{_provider.__name__}") else: - print(f"{Fore.GREEN}All providers are working") + print(f"{Fore.GREEN + Styles.BOLD}All providers are working") def get_providers() -> list[type[BaseProvider]]: provider_names = dir(Provider) ignore_names = [ "base_provider", - "BaseProvider" + "BaseProvider", + "AsyncProvider", + "AsyncGeneratorProvider" ] provider_names = [ provider_name for provider_name in provider_names if not provider_name.startswith("__") and provider_name not in ignore_names ] - return [getattr(Provider, provider_name) for provider_name in sorted(provider_names)] + return [getattr(Provider, provider_name) for provider_name in provider_names] def create_response(_provider: type[BaseProvider]) -> str: if _provider.supports_gpt_35_turbo: model = models.gpt_35_turbo.name elif _provider.supports_gpt_4: - model = models.gpt_4 - elif hasattr(_provider, "model"): - model = _provider.model + model = models.gpt_4.name else: - model = None + model = models.default.name response = _provider.create_completion( model=model, - messages=[{"role": "user", "content": "Hello"}], + messages=[{"role": "user", "content": "Hello, who are you? Answer in detail much as possible."}], stream=False, ) return "".join(response) -- cgit v1.2.3 From 7a9b7195736153481fd8b50393004e231a3ee7a0 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Tue, 5 Sep 2023 17:35:51 +0200 Subject: Fix imports in Bing --- testing/test_providers.py | 1 + 1 file changed, 1 insertion(+) (limited to 'testing') diff --git a/testing/test_providers.py b/testing/test_providers.py index 676f1a59..be04e7a3 100644 --- a/testing/test_providers.py +++ b/testing/test_providers.py @@ -39,6 +39,7 @@ def main(): def get_providers() -> list[type[BaseProvider]]: provider_names = dir(Provider) ignore_names = [ + "annotations", "base_provider", "BaseProvider", "AsyncProvider", -- cgit v1.2.3