summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorhs_junxiang <jimmy871117@gmail.com>2023-10-13 07:45:29 +0200
committerhs_junxiang <jimmy871117@gmail.com>2023-10-13 07:46:28 +0200
commitc84ff591457647bdd0f5a5620505580d3b615540 (patch)
treeb93f02a6c6b77a00ef4dc882d0aeddcbe5738d54
parent~ | g4f `v-0.1.6.2` (diff)
downloadgpt4free-c84ff591457647bdd0f5a5620505580d3b615540.tar
gpt4free-c84ff591457647bdd0f5a5620505580d3b615540.tar.gz
gpt4free-c84ff591457647bdd0f5a5620505580d3b615540.tar.bz2
gpt4free-c84ff591457647bdd0f5a5620505580d3b615540.tar.lz
gpt4free-c84ff591457647bdd0f5a5620505580d3b615540.tar.xz
gpt4free-c84ff591457647bdd0f5a5620505580d3b615540.tar.zst
gpt4free-c84ff591457647bdd0f5a5620505580d3b615540.zip
-rw-r--r--g4f/__init__.py27
1 files changed, 19 insertions, 8 deletions
diff --git a/g4f/__init__.py b/g4f/__init__.py
index 1a696c6c..6f777e4c 100644
--- a/g4f/__init__.py
+++ b/g4f/__init__.py
@@ -1,13 +1,14 @@
from __future__ import annotations
from requests import get
from g4f.models import Model, ModelUtils
-from .Provider import BaseProvider
-from .typing import Messages, CreateResult, Union
+from .Provider import BaseProvider, RetryProvider
+from .typing import Messages, CreateResult, Union, List
from .debug import logging
version = '0.1.6.2'
version_check = True
+
def check_pypi_version() -> None:
try:
response = get("https://pypi.org/pypi/g4f/json").json()
@@ -19,9 +20,11 @@ def check_pypi_version() -> None:
except Exception as e:
print(f'Failed to check g4f pypi version: {e}')
+
def get_model_and_provider(model : Union[Model, str],
provider : Union[type[BaseProvider], None],
- stream : bool) -> tuple[Model, type[BaseProvider]]:
+ stream : bool,
+ ignored : List[str] = None) -> tuple[Model, type[BaseProvider]]:
if isinstance(model, str):
if model in ModelUtils.convert:
@@ -32,6 +35,9 @@ def get_model_and_provider(model : Union[Model, str],
if not provider:
provider = model.best_provider
+ if isinstance(provider, RetryProvider) and ignored:
+ provider.providers = [p for p in provider.providers if p.__name__ not in ignored]
+
if not provider:
raise RuntimeError(f'No provider found for model: {model}')
@@ -46,15 +52,17 @@ def get_model_and_provider(model : Union[Model, str],
return model, provider
+
class ChatCompletion:
@staticmethod
def create(model: Union[Model, str],
messages : Messages,
provider : Union[type[BaseProvider], None] = None,
stream : bool = False,
- auth : Union[str, None] = None, **kwargs) -> Union[CreateResult, str]:
+ auth : Union[str, None] = None,
+ ignored : List[str] = None, **kwargs) -> Union[CreateResult, str]:
- model, provider = get_model_and_provider(model, provider, stream)
+ model, provider = get_model_and_provider(model, provider, stream, ignored)
if provider.needs_auth and not auth:
raise ValueError(
@@ -71,15 +79,17 @@ class ChatCompletion:
model : Union[Model, str],
messages: Messages,
provider: Union[type[BaseProvider], None] = None,
- stream : bool = False, **kwargs) -> str:
+ stream : bool = False,
+ ignored : List[str] = None, **kwargs) -> str:
if stream:
raise ValueError(f'"create_async" does not support "stream" argument')
- model, provider = get_model_and_provider(model, provider, False)
+ model, provider = get_model_and_provider(model, provider, False, ignored)
return await provider.create_async(model.name, messages, **kwargs)
+
class Completion:
@staticmethod
def create(
@@ -87,6 +97,7 @@ class Completion:
prompt: str,
provider: Union[type[BaseProvider], None] = None,
stream: bool = False,
+ ignored : List[str] = None,
**kwargs
) -> Union[CreateResult, str]:
@@ -102,7 +113,7 @@ class Completion:
if model not in allowed_models:
raise Exception(f'ValueError: Can\'t use {model} with Completion.create()')
- model, provider = get_model_and_provider(model, provider, stream)
+ model, provider = get_model_and_provider(model, provider, stream, ignored)
result = provider.create_completion(model.name, [{"role": "user", "content": prompt}], stream, **kwargs)