summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/base_provider.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/base_provider.py')
-rw-r--r--g4f/Provider/base_provider.py198
1 files changed, 133 insertions, 65 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index e7e88841..fd92d17a 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -1,28 +1,29 @@
from __future__ import annotations
-
import sys
import asyncio
-from asyncio import AbstractEventLoop
+from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
-from abc import abstractmethod
-from inspect import signature, Parameter
-from .helper import get_event_loop, get_cookies, format_prompt
-from ..typing import CreateResult, AsyncResult, Messages
-from ..base_provider import BaseProvider
+from abc import abstractmethod
+from inspect import signature, Parameter
+from .helper import get_event_loop, get_cookies, format_prompt
+from ..typing import CreateResult, AsyncResult, Messages
+from ..base_provider import BaseProvider
if sys.version_info < (3, 10):
NoneType = type(None)
else:
from types import NoneType
-# Change event loop policy on windows for curl_cffi
+# Set Windows event loop policy for better compatibility with asyncio and curl_cffi
if sys.platform == 'win32':
- if isinstance(
- asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy
- ):
+ if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
class AbstractProvider(BaseProvider):
+ """
+ Abstract class for providing asynchronous functionality to derived classes.
+ """
+
@classmethod
async def create_async(
cls,
@@ -33,62 +34,67 @@ class AbstractProvider(BaseProvider):
executor: ThreadPoolExecutor = None,
**kwargs
) -> str:
- if not loop:
- loop = get_event_loop()
+ """
+ Asynchronously creates a result based on the given model and messages.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ str: The created result as a string.
+ """
+ loop = loop or get_event_loop()
def create_func() -> str:
- return "".join(cls.create_completion(
- model,
- messages,
- False,
- **kwargs
- ))
+ return "".join(cls.create_completion(model, messages, False, **kwargs))
return await asyncio.wait_for(
- loop.run_in_executor(
- executor,
- create_func
- ),
+ loop.run_in_executor(executor, create_func),
timeout=kwargs.get("timeout", 0)
)
@classmethod
@property
def params(cls) -> str:
- if issubclass(cls, AsyncGeneratorProvider):
- sig = signature(cls.create_async_generator)
- elif issubclass(cls, AsyncProvider):
- sig = signature(cls.create_async)
- else:
- sig = signature(cls.create_completion)
+ """
+ Returns the parameters supported by the provider.
+
+ Args:
+ cls (type): The class on which this property is called.
+
+ Returns:
+ str: A string listing the supported parameters.
+ """
+ sig = signature(
+ cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
+ cls.create_async if issubclass(cls, AsyncProvider) else
+ cls.create_completion
+ )
def get_type_name(annotation: type) -> str:
- if hasattr(annotation, "__name__"):
- annotation = annotation.__name__
- elif isinstance(annotation, NoneType):
- annotation = "None"
- return str(annotation)
-
+ return annotation.__name__ if hasattr(annotation, "__name__") else str(annotation)
+
args = ""
for name, param in sig.parameters.items():
- if name in ("self", "kwargs"):
- continue
- if name == "stream" and not cls.supports_stream:
+ if name in ("self", "kwargs") or (name == "stream" and not cls.supports_stream):
continue
- if args:
- args += ", "
- args += "\n " + name
- if name != "model" and param.annotation is not Parameter.empty:
- args += f": {get_type_name(param.annotation)}"
- if param.default == "":
- args += ' = ""'
- elif param.default is not Parameter.empty:
- args += f" = {param.default}"
+ args += f"\n {name}"
+ args += f": {get_type_name(param.annotation)}" if param.annotation is not Parameter.empty else ""
+ args += f' = "{param.default}"' if param.default == "" else f" = {param.default}" if param.default is not Parameter.empty else ""
return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
class AsyncProvider(AbstractProvider):
+ """
+ Provides asynchronous functionality for creating completions.
+ """
+
@classmethod
def create_completion(
cls,
@@ -99,8 +105,21 @@ class AsyncProvider(AbstractProvider):
loop: AbstractEventLoop = None,
**kwargs
) -> CreateResult:
- if not loop:
- loop = get_event_loop()
+ """
+ Creates a completion result synchronously.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to False.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ CreateResult: The result of the completion creation.
+ """
+ loop = loop or get_event_loop()
coro = cls.create_async(model, messages, **kwargs)
yield loop.run_until_complete(coro)
@@ -111,10 +130,27 @@ class AsyncProvider(AbstractProvider):
messages: Messages,
**kwargs
) -> str:
+ """
+ Abstract method for creating asynchronous results.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ NotImplementedError: If this method is not overridden in derived classes.
+
+ Returns:
+ str: The created result as a string.
+ """
raise NotImplementedError()
class AsyncGeneratorProvider(AsyncProvider):
+ """
+ Provides asynchronous generator functionality for streaming results.
+ """
supports_stream = True
@classmethod
@@ -127,15 +163,24 @@ class AsyncGeneratorProvider(AsyncProvider):
loop: AbstractEventLoop = None,
**kwargs
) -> CreateResult:
- if not loop:
- loop = get_event_loop()
- generator = cls.create_async_generator(
- model,
- messages,
- stream=stream,
- **kwargs
- )
+ """
+ Creates a streaming completion result synchronously.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to True.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ CreateResult: The result of the streaming completion creation.
+ """
+ loop = loop or get_event_loop()
+ generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
gen = generator.__aiter__()
+
while True:
try:
yield loop.run_until_complete(gen.__anext__())
@@ -149,21 +194,44 @@ class AsyncGeneratorProvider(AsyncProvider):
messages: Messages,
**kwargs
) -> str:
+ """
+ Asynchronously creates a result from a generator.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ str: The created result as a string.
+ """
return "".join([
- chunk async for chunk in cls.create_async_generator(
- model,
- messages,
- stream=False,
- **kwargs
- ) if not isinstance(chunk, Exception)
+ chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
+ if not isinstance(chunk, Exception)
])
@staticmethod
@abstractmethod
- def create_async_generator(
+ async def create_async_generator(
model: str,
messages: Messages,
stream: bool = True,
**kwargs
) -> AsyncResult:
- raise NotImplementedError()
+ """
+ Abstract method for creating an asynchronous generator.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to True.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ NotImplementedError: If this method is not overridden in derived classes.
+
+ Returns:
+ AsyncResult: An asynchronous generator yielding results.
+ """
+ raise NotImplementedError() \ No newline at end of file