From d733930a2b1876340039d90f19ece81fab0d078d Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Fri, 23 Feb 2024 02:51:10 +0100 Subject: Fix unittests, use Union typing --- etc/unittest/client.py | 8 +++++--- g4f/api/__init__.py | 10 +++++----- g4f/client.py | 4 ++-- g4f/stubs.py | 10 +++++++--- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/etc/unittest/client.py b/etc/unittest/client.py index 2bc00c2e..ec8aa4b7 100644 --- a/etc/unittest/client.py +++ b/etc/unittest/client.py @@ -35,13 +35,15 @@ class TestPassModel(unittest.TestCase): response = client.chat.completions.create(messages, "Hello", stream=True) for chunk in response: self.assertIsInstance(chunk, ChatCompletionChunk) - self.assertIsInstance(chunk.choices[0].delta.content, str) + if chunk.choices[0].delta.content is not None: + self.assertIsInstance(chunk.choices[0].delta.content, str) messages = [{'role': 'user', 'content': chunk} for chunk in ["You ", "You ", "Other", "?"]] response = client.chat.completions.create(messages, "Hello", stream=True, max_tokens=2) response = list(response) - self.assertEqual(len(response), 2) + self.assertEqual(len(response), 3) for chunk in response: - self.assertEqual(chunk.choices[0].delta.content, "You ") + if chunk.choices[0].delta.content is not None: + self.assertEqual(chunk.choices[0].delta.content, "You ") def test_stop(self): client = Client(provider=YieldProviderMock) diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index 9033aafe..d1e8539f 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -6,7 +6,7 @@ import nest_asyncio from fastapi import FastAPI, Response, Request from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse from pydantic import BaseModel -from typing import List +from typing import List, Union import g4f import g4f.debug @@ -16,12 +16,12 @@ from g4f.typing import Messages class ChatCompletionsConfig(BaseModel): messages: Messages model: str - provider: str | None + provider: Union[str, None] stream: bool = False - temperature: float | None + temperature: Union[float, None] max_tokens: int = None - stop: list[str] | str | None - access_token: str | None + stop: Union[list[str], str, None] + access_token: Union[str, None] class Api: def __init__(self, engine: g4f, debug: bool = True, sentry: bool = False, diff --git a/g4f/client.py b/g4f/client.py index b44a5230..023d53f6 100644 --- a/g4f/client.py +++ b/g4f/client.py @@ -17,7 +17,7 @@ from . import get_model_and_provider, get_last_provider ImageProvider = Union[BaseProvider, object] Proxies = Union[dict, str] -IterResponse = Generator[ChatCompletion | ChatCompletionChunk, None, None] +IterResponse = Generator[Union[ChatCompletion, ChatCompletionChunk], None, None] def read_json(text: str) -> dict: """ @@ -124,7 +124,7 @@ class Completions(): stream: bool = False, response_format: dict = None, max_tokens: int = None, - stop: list[str] | str = None, + stop: Union[list[str], str] = None, **kwargs ) -> Union[ChatCompletion, Generator[ChatCompletionChunk]]: if max_tokens is not None: diff --git a/g4f/stubs.py b/g4f/stubs.py index b9934b8c..49cf8a88 100644 --- a/g4f/stubs.py +++ b/g4f/stubs.py @@ -1,6 +1,8 @@ from __future__ import annotations +from typing import Union + class Model(): ... @@ -52,7 +54,7 @@ class ChatCompletionChunk(Model): } class ChatCompletionMessage(Model): - def __init__(self, content: str | None): + def __init__(self, content: Union[str, None]): self.role = "assistant" self.content = content @@ -72,7 +74,9 @@ class ChatCompletionChoice(Model): } class ChatCompletionDelta(Model): - def __init__(self, content: str | None): + content: Union[str, None] = None + + def __init__(self, content: Union[str, None]): if content is not None: self.content = content @@ -80,7 +84,7 @@ class ChatCompletionDelta(Model): return self.__dict__ class ChatCompletionDeltaChoice(Model): - def __init__(self, delta: ChatCompletionDelta, finish_reason: str | None): + def __init__(self, delta: ChatCompletionDelta, finish_reason: Union[str, None]): self.delta = delta self.finish_reason = finish_reason -- cgit v1.2.3