summaryrefslogtreecommitdiffstats
path: root/testing/test_providers.py
diff options
context:
space:
mode:
Diffstat (limited to 'testing/test_providers.py')
-rw-r--r--testing/test_providers.py33
1 files changed, 23 insertions, 10 deletions
diff --git a/testing/test_providers.py b/testing/test_providers.py
index fee79e20..6d3b62d8 100644
--- a/testing/test_providers.py
+++ b/testing/test_providers.py
@@ -1,5 +1,6 @@
import sys
from pathlib import Path
+from colorama import Fore
sys.path.append(str(Path(__file__).parent.parent))
@@ -20,9 +21,14 @@ def main():
if _provider.working and not result:
failed_providers.append([_provider, result])
- print("Failed providers:")
- for _provider, result in failed_providers:
- print(f"{_provider.__name__}: {result}")
+ print()
+
+ if failed_providers:
+ print(f"{Fore.RED}Failed providers:\n")
+ for _provider, result in failed_providers:
+ print(f"{Fore.RED}{_provider.__name__}")
+ else:
+ print(f"{Fore.GREEN}All providers are working")
def get_providers() -> list[type[BaseProvider]]:
@@ -36,18 +42,21 @@ def get_providers() -> list[type[BaseProvider]]:
for provider_name in provider_names
if not provider_name.startswith("__") and provider_name not in ignore_names
]
- return [getattr(Provider, provider_name) for provider_name in provider_names]
+ return [getattr(Provider, provider_name) for provider_name in sorted(provider_names)]
def create_response(_provider: type[BaseProvider]) -> str:
- model = (
- models.gpt_35_turbo.name
- if _provider.supports_gpt_35_turbo
- else _provider.model
- )
+ if _provider.supports_gpt_35_turbo:
+ model = models.gpt_35_turbo.name
+ elif _provider.supports_gpt_4:
+ model = models.gpt_4
+ elif hasattr(_provider, "model"):
+ model = _provider.model
+ else:
+ model = None
response = _provider.create_completion(
model=model,
- messages=[{"role": "user", "content": "Hello world!"}],
+ messages=[{"role": "user", "content": "Hello"}],
stream=False,
)
return "".join(response)
@@ -57,9 +66,13 @@ def judge(_provider: type[BaseProvider]) -> bool:
if _provider.needs_auth:
return _provider.working
+ return test(_provider)
+
+def test(_provider: type[BaseProvider]) -> bool:
try:
response = create_response(_provider)
assert type(response) is str
+ assert len(response) > 0
return response
except Exception as e:
if logging: