diff options
author | hs_junxiang <jimmy871117@gmail.com> | 2023-10-13 07:45:29 +0200 |
---|---|---|
committer | hs_junxiang <jimmy871117@gmail.com> | 2023-10-13 07:46:28 +0200 |
commit | c84ff591457647bdd0f5a5620505580d3b615540 (patch) | |
tree | b93f02a6c6b77a00ef4dc882d0aeddcbe5738d54 | |
parent | ~ | g4f `v-0.1.6.2` (diff) | |
download | gpt4free-c84ff591457647bdd0f5a5620505580d3b615540.tar gpt4free-c84ff591457647bdd0f5a5620505580d3b615540.tar.gz gpt4free-c84ff591457647bdd0f5a5620505580d3b615540.tar.bz2 gpt4free-c84ff591457647bdd0f5a5620505580d3b615540.tar.lz gpt4free-c84ff591457647bdd0f5a5620505580d3b615540.tar.xz gpt4free-c84ff591457647bdd0f5a5620505580d3b615540.tar.zst gpt4free-c84ff591457647bdd0f5a5620505580d3b615540.zip |
-rw-r--r-- | g4f/__init__.py | 27 |
1 files changed, 19 insertions, 8 deletions
diff --git a/g4f/__init__.py b/g4f/__init__.py index 1a696c6c..6f777e4c 100644 --- a/g4f/__init__.py +++ b/g4f/__init__.py @@ -1,13 +1,14 @@ from __future__ import annotations from requests import get from g4f.models import Model, ModelUtils -from .Provider import BaseProvider -from .typing import Messages, CreateResult, Union +from .Provider import BaseProvider, RetryProvider +from .typing import Messages, CreateResult, Union, List from .debug import logging version = '0.1.6.2' version_check = True + def check_pypi_version() -> None: try: response = get("https://pypi.org/pypi/g4f/json").json() @@ -19,9 +20,11 @@ def check_pypi_version() -> None: except Exception as e: print(f'Failed to check g4f pypi version: {e}') + def get_model_and_provider(model : Union[Model, str], provider : Union[type[BaseProvider], None], - stream : bool) -> tuple[Model, type[BaseProvider]]: + stream : bool, + ignored : List[str] = None) -> tuple[Model, type[BaseProvider]]: if isinstance(model, str): if model in ModelUtils.convert: @@ -32,6 +35,9 @@ def get_model_and_provider(model : Union[Model, str], if not provider: provider = model.best_provider + if isinstance(provider, RetryProvider) and ignored: + provider.providers = [p for p in provider.providers if p.__name__ not in ignored] + if not provider: raise RuntimeError(f'No provider found for model: {model}') @@ -46,15 +52,17 @@ def get_model_and_provider(model : Union[Model, str], return model, provider + class ChatCompletion: @staticmethod def create(model: Union[Model, str], messages : Messages, provider : Union[type[BaseProvider], None] = None, stream : bool = False, - auth : Union[str, None] = None, **kwargs) -> Union[CreateResult, str]: + auth : Union[str, None] = None, + ignored : List[str] = None, **kwargs) -> Union[CreateResult, str]: - model, provider = get_model_and_provider(model, provider, stream) + model, provider = get_model_and_provider(model, provider, stream, ignored) if provider.needs_auth and not auth: raise ValueError( @@ -71,15 +79,17 @@ class ChatCompletion: model : Union[Model, str], messages: Messages, provider: Union[type[BaseProvider], None] = None, - stream : bool = False, **kwargs) -> str: + stream : bool = False, + ignored : List[str] = None, **kwargs) -> str: if stream: raise ValueError(f'"create_async" does not support "stream" argument') - model, provider = get_model_and_provider(model, provider, False) + model, provider = get_model_and_provider(model, provider, False, ignored) return await provider.create_async(model.name, messages, **kwargs) + class Completion: @staticmethod def create( @@ -87,6 +97,7 @@ class Completion: prompt: str, provider: Union[type[BaseProvider], None] = None, stream: bool = False, + ignored : List[str] = None, **kwargs ) -> Union[CreateResult, str]: @@ -102,7 +113,7 @@ class Completion: if model not in allowed_models: raise Exception(f'ValueError: Can\'t use {model} with Completion.create()') - model, provider = get_model_and_provider(model, provider, stream) + model, provider = get_model_and_provider(model, provider, stream, ignored) result = provider.create_completion(model.name, [{"role": "user", "content": prompt}], stream, **kwargs) |