summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/base_provider.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/base_provider.py')
-rw-r--r--g4f/Provider/base_provider.py23
1 files changed, 21 insertions, 2 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index bc47a1fa..e1dcd24d 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -8,7 +8,7 @@ from inspect import signature, Parameter
from .helper import get_cookies, format_prompt
from ..typing import CreateResult, AsyncResult, Messages, Union
from ..base_provider import BaseProvider
-from ..errors import NestAsyncioError
+from ..errors import NestAsyncioError, ModelNotSupportedError
if sys.version_info < (3, 10):
NoneType = type(None)
@@ -251,4 +251,23 @@ class AsyncGeneratorProvider(AsyncProvider):
Returns:
AsyncResult: An asynchronous generator yielding results.
"""
- raise NotImplementedError() \ No newline at end of file
+ raise NotImplementedError()
+
+class ProviderModelMixin:
+ default_model: str
+ models: list[str] = []
+ model_aliases: dict[str, str] = {}
+
+ @classmethod
+ def get_models(cls) -> list[str]:
+ return cls.models
+
+ @classmethod
+ def get_model(cls, model: str) -> str:
+ if not model:
+ return cls.default_model
+ elif model in cls.model_aliases:
+ return cls.model_aliases[model]
+ elif model not in cls.get_models():
+ raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
+ return model \ No newline at end of file