summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/base_provider.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/base_provider.py')
-rw-r--r--g4f/Provider/base_provider.py43
1 files changed, 35 insertions, 8 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index 47ea6ff8..564dd77e 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -3,6 +3,8 @@ from __future__ import annotations
from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
from abc import ABC, abstractmethod
+from inspect import signature, Parameter
+from types import NoneType
from .helper import get_event_loop, get_cookies, format_prompt
from ..typing import CreateResult, AsyncResult, Messages
@@ -52,17 +54,42 @@ class BaseProvider(ABC):
executor,
create_func
)
-
+
@classmethod
@property
def params(cls) -> str:
- params = [
- ("model", "str"),
- ("messages", "list[dict[str, str]]"),
- ("stream", "bool"),
- ]
- param = ", ".join([": ".join(p) for p in params])
- return f"g4f.provider.{cls.__name__} supports: ({param})"
+ if issubclass(cls, AsyncGeneratorProvider):
+ sig = signature(cls.create_async_generator)
+ elif issubclass(cls, AsyncProvider):
+ sig = signature(cls.create_async)
+ else:
+ sig = signature(cls.create_completion)
+
+ def get_type_name(annotation: type) -> str:
+ if hasattr(annotation, "__name__"):
+ annotation = annotation.__name__
+ elif isinstance(annotation, NoneType):
+ annotation = "None"
+ return str(annotation)
+
+ args = "";
+ for name, param in sig.parameters.items():
+ if name in ("self", "kwargs"):
+ continue
+ if name == "stream" and not cls.supports_stream:
+ continue
+ if args:
+ args += ", "
+ args += "\n"
+ args += " " + name
+ if name != "model" and param.annotation is not Parameter.empty:
+ args += f": {get_type_name(param.annotation)}"
+ if param.default == "":
+ args += ' = ""'
+ elif param.default is not Parameter.empty:
+ args += f" = {param.default}"
+
+ return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
class AsyncProvider(BaseProvider):