summaryrefslogtreecommitdiffstats
path: root/g4f/__init__.py
diff options
context:
space:
mode:
authorTekky <98614666+xtekky@users.noreply.github.com>2023-11-20 19:27:38 +0100
committerGitHub <noreply@github.com>2023-11-20 19:27:38 +0100
commite8d88c955f75f539dd71bd4b713e90094751161c (patch)
tree6d1ad2636abfd7ad0b4f5a59aa4630cf2a29723e /g4f/__init__.py
parentMerge pull request #1275 from egcash/patch-1 (diff)
parentMerge branch 'main' into webdriver (diff)
downloadgpt4free-e8d88c955f75f539dd71bd4b713e90094751161c.tar
gpt4free-e8d88c955f75f539dd71bd4b713e90094751161c.tar.gz
gpt4free-e8d88c955f75f539dd71bd4b713e90094751161c.tar.bz2
gpt4free-e8d88c955f75f539dd71bd4b713e90094751161c.tar.lz
gpt4free-e8d88c955f75f539dd71bd4b713e90094751161c.tar.xz
gpt4free-e8d88c955f75f539dd71bd4b713e90094751161c.tar.zst
gpt4free-e8d88c955f75f539dd71bd4b713e90094751161c.zip
Diffstat (limited to 'g4f/__init__.py')
-rw-r--r--g4f/__init__.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/g4f/__init__.py b/g4f/__init__.py
index faef7923..2c9ef7d7 100644
--- a/g4f/__init__.py
+++ b/g4f/__init__.py
@@ -1,8 +1,8 @@
from __future__ import annotations
from requests import get
from .models import Model, ModelUtils, _all_models
-from .Provider import BaseProvider, RetryProvider
-from .typing import Messages, CreateResult, Union, List
+from .Provider import BaseProvider, AsyncGeneratorProvider, RetryProvider
+from .typing import Messages, CreateResult, AsyncResult, Union, List
from . import debug
version = '0.1.8.7'
@@ -80,13 +80,15 @@ class ChatCompletion:
messages : Messages,
provider : Union[type[BaseProvider], None] = None,
stream : bool = False,
- ignored : List[str] = None, **kwargs) -> str:
-
- if stream:
- raise ValueError('"create_async" does not support "stream" argument')
-
+ ignored : List[str] = None,
+ **kwargs) -> Union[AsyncResult, str]:
model, provider = get_model_and_provider(model, provider, False, ignored)
+ if stream:
+ if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
+ return await provider.create_async_generator(model.name, messages, **kwargs)
+ raise ValueError(f'{provider.__name__} does not support "stream" argument')
+
return await provider.create_async(model.name, messages, **kwargs)
class Completion: