diff options
author | H Lohaus <hlohaus@users.noreply.github.com> | 2024-01-23 20:08:41 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-23 20:08:41 +0100 |
commit | 2b140a32554c1e94d095c55599a2f93e86f957cf (patch) | |
tree | e2770d97f0242a0b99a3af68ea4fcf25227dfcc8 /g4f/Provider/base_provider.py | |
parent | ~ (diff) | |
parent | Add ProviderModelMixin for model selection (diff) | |
download | gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.gz gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.bz2 gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.lz gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.xz gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.zst gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.zip |
Diffstat (limited to 'g4f/Provider/base_provider.py')
-rw-r--r-- | g4f/Provider/base_provider.py | 23 |
1 files changed, 21 insertions, 2 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index bc47a1fa..e1dcd24d 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -8,7 +8,7 @@ from inspect import signature, Parameter from .helper import get_cookies, format_prompt from ..typing import CreateResult, AsyncResult, Messages, Union from ..base_provider import BaseProvider -from ..errors import NestAsyncioError +from ..errors import NestAsyncioError, ModelNotSupportedError if sys.version_info < (3, 10): NoneType = type(None) @@ -251,4 +251,23 @@ class AsyncGeneratorProvider(AsyncProvider): Returns: AsyncResult: An asynchronous generator yielding results. """ - raise NotImplementedError()
\ No newline at end of file + raise NotImplementedError() + +class ProviderModelMixin: + default_model: str + models: list[str] = [] + model_aliases: dict[str, str] = {} + + @classmethod + def get_models(cls) -> list[str]: + return cls.models + + @classmethod + def get_model(cls, model: str) -> str: + if not model: + return cls.default_model + elif model in cls.model_aliases: + return cls.model_aliases[model] + elif model not in cls.get_models(): + raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}") + return model
\ No newline at end of file |