summaryrefslogtreecommitdiffstats
path: root/g4f/__init__.py
diff options
context:
space:
mode:
authorH Lohaus <hlohaus@users.noreply.github.com>2024-01-10 10:34:56 +0100
committerGitHub <noreply@github.com>2024-01-10 10:34:56 +0100
commitbee75be8e38d25c4568c641412a49b576d425b24 (patch)
tree63ea1505dbe2b84c3011164a7b2699e642d94c19 /g4f/__init__.py
parentMerge pull request #1441 from w453y/patch-1 (diff)
downloadgpt4free-bee75be8e38d25c4568c641412a49b576d425b24.tar
gpt4free-bee75be8e38d25c4568c641412a49b576d425b24.tar.gz
gpt4free-bee75be8e38d25c4568c641412a49b576d425b24.tar.bz2
gpt4free-bee75be8e38d25c4568c641412a49b576d425b24.tar.lz
gpt4free-bee75be8e38d25c4568c641412a49b576d425b24.tar.xz
gpt4free-bee75be8e38d25c4568c641412a49b576d425b24.tar.zst
gpt4free-bee75be8e38d25c4568c641412a49b576d425b24.zip
Diffstat (limited to 'g4f/__init__.py')
-rw-r--r--g4f/__init__.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/g4f/__init__.py b/g4f/__init__.py
index 699dc238..dc7808f9 100644
--- a/g4f/__init__.py
+++ b/g4f/__init__.py
@@ -68,6 +68,7 @@ class ChatCompletion:
ignored : list[str] = None,
ignore_working: bool = False,
ignore_stream_and_auth: bool = False,
+ patch_provider: callable = None,
**kwargs) -> Union[CreateResult, str]:
model, provider = get_model_and_provider(model, provider, stream, ignored, ignore_working, ignore_stream_and_auth)
@@ -83,6 +84,9 @@ class ChatCompletion:
if proxy:
kwargs['proxy'] = proxy
+ if patch_provider:
+ provider = patch_provider(provider)
+
result = provider.create_completion(model, messages, stream, **kwargs)
return result if stream else ''.join(result)
@@ -92,6 +96,7 @@ class ChatCompletion:
provider : Union[ProviderType, str, None] = None,
stream : bool = False,
ignored : list[str] = None,
+ patch_provider: callable = None,
**kwargs) -> Union[AsyncResult, str]:
model, provider = get_model_and_provider(model, provider, False, ignored)
@@ -101,6 +106,9 @@ class ChatCompletion:
return provider.create_async_generator(model, messages, **kwargs)
raise StreamNotSupportedError(f'{provider.__name__} does not support "stream" argument in "create_async"')
+ if patch_provider:
+ provider = patch_provider(provider)
+
return provider.create_async(model, messages, **kwargs)
class Completion: