summaryrefslogtreecommitdiffstats
path: root/g4f/providers
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/providers/base_provider.py280
-rw-r--r--g4f/providers/create_images.py (renamed from g4f/Provider/create_images.py)3
-rw-r--r--g4f/providers/helper.py61
-rw-r--r--g4f/providers/retry_provider.py (renamed from g4f/Provider/retry_provider.py)3
-rw-r--r--g4f/providers/types.py (renamed from g4f/base_provider.py)6
5 files changed, 348 insertions, 5 deletions
diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py
new file mode 100644
index 00000000..b8649ba5
--- /dev/null
+++ b/g4f/providers/base_provider.py
@@ -0,0 +1,280 @@
+from __future__ import annotations
+
+import sys
+import asyncio
+from asyncio import AbstractEventLoop
+from concurrent.futures import ThreadPoolExecutor
+from abc import abstractmethod
+from inspect import signature, Parameter
+from ..typing import CreateResult, AsyncResult, Messages, Union
+from .types import BaseProvider
+from ..errors import NestAsyncioError, ModelNotSupportedError
+from .. import debug
+
+if sys.version_info < (3, 10):
+ NoneType = type(None)
+else:
+ from types import NoneType
+
+# Set Windows event loop policy for better compatibility with asyncio and curl_cffi
+if sys.platform == 'win32':
+ if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
+
+def get_running_loop() -> Union[AbstractEventLoop, None]:
+ try:
+ loop = asyncio.get_running_loop()
+ if not hasattr(loop.__class__, "_nest_patched"):
+ raise NestAsyncioError(
+ 'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.'
+ )
+ return loop
+ except RuntimeError:
+ pass
+
+class AbstractProvider(BaseProvider):
+ """
+ Abstract class for providing asynchronous functionality to derived classes.
+ """
+
+ @classmethod
+ async def create_async(
+ cls,
+ model: str,
+ messages: Messages,
+ *,
+ loop: AbstractEventLoop = None,
+ executor: ThreadPoolExecutor = None,
+ **kwargs
+ ) -> str:
+ """
+ Asynchronously creates a result based on the given model and messages.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ str: The created result as a string.
+ """
+ loop = loop or asyncio.get_running_loop()
+
+ def create_func() -> str:
+ return "".join(cls.create_completion(model, messages, False, **kwargs))
+
+ return await asyncio.wait_for(
+ loop.run_in_executor(executor, create_func),
+ timeout=kwargs.get("timeout")
+ )
+
+ @classmethod
+ @property
+ def params(cls) -> str:
+ """
+ Returns the parameters supported by the provider.
+
+ Args:
+ cls (type): The class on which this property is called.
+
+ Returns:
+ str: A string listing the supported parameters.
+ """
+ sig = signature(
+ cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
+ cls.create_async if issubclass(cls, AsyncProvider) else
+ cls.create_completion
+ )
+
+ def get_type_name(annotation: type) -> str:
+ return annotation.__name__ if hasattr(annotation, "__name__") else str(annotation)
+
+ args = ""
+ for name, param in sig.parameters.items():
+ if name in ("self", "kwargs") or (name == "stream" and not cls.supports_stream):
+ continue
+ args += f"\n {name}"
+ args += f": {get_type_name(param.annotation)}" if param.annotation is not Parameter.empty else ""
+ args += f' = "{param.default}"' if param.default == "" else f" = {param.default}" if param.default is not Parameter.empty else ""
+
+ return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
+
+
+class AsyncProvider(AbstractProvider):
+ """
+ Provides asynchronous functionality for creating completions.
+ """
+
+ @classmethod
+ def create_completion(
+ cls,
+ model: str,
+ messages: Messages,
+ stream: bool = False,
+ **kwargs
+ ) -> CreateResult:
+ """
+ Creates a completion result synchronously.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to False.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ CreateResult: The result of the completion creation.
+ """
+ get_running_loop()
+ yield asyncio.run(cls.create_async(model, messages, **kwargs))
+
+ @staticmethod
+ @abstractmethod
+ async def create_async(
+ model: str,
+ messages: Messages,
+ **kwargs
+ ) -> str:
+ """
+ Abstract method for creating asynchronous results.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ NotImplementedError: If this method is not overridden in derived classes.
+
+ Returns:
+ str: The created result as a string.
+ """
+ raise NotImplementedError()
+
+
+class AsyncGeneratorProvider(AsyncProvider):
+ """
+ Provides asynchronous generator functionality for streaming results.
+ """
+ supports_stream = True
+
+ @classmethod
+ def create_completion(
+ cls,
+ model: str,
+ messages: Messages,
+ stream: bool = True,
+ **kwargs
+ ) -> CreateResult:
+ """
+ Creates a streaming completion result synchronously.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to True.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ CreateResult: The result of the streaming completion creation.
+ """
+ loop = get_running_loop()
+ new_loop = False
+ if not loop:
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ new_loop = True
+
+ generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
+ gen = generator.__aiter__()
+
+ # Fix for RuntimeError: async generator ignored GeneratorExit
+ async def await_callback(callback):
+ return await callback()
+
+ try:
+ while True:
+ yield loop.run_until_complete(await_callback(gen.__anext__))
+ except StopAsyncIteration:
+ ...
+ # Fix for: ResourceWarning: unclosed event loop
+ finally:
+ if new_loop:
+ loop.close()
+ asyncio.set_event_loop(None)
+
+ @classmethod
+ async def create_async(
+ cls,
+ model: str,
+ messages: Messages,
+ **kwargs
+ ) -> str:
+ """
+ Asynchronously creates a result from a generator.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ str: The created result as a string.
+ """
+ return "".join([
+ chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
+ if not isinstance(chunk, Exception)
+ ])
+
+ @staticmethod
+ @abstractmethod
+ async def create_async_generator(
+ model: str,
+ messages: Messages,
+ stream: bool = True,
+ **kwargs
+ ) -> AsyncResult:
+ """
+ Abstract method for creating an asynchronous generator.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to True.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ NotImplementedError: If this method is not overridden in derived classes.
+
+ Returns:
+ AsyncResult: An asynchronous generator yielding results.
+ """
+ raise NotImplementedError()
+
+class ProviderModelMixin:
+ default_model: str
+ models: list[str] = []
+ model_aliases: dict[str, str] = {}
+
+ @classmethod
+ def get_models(cls) -> list[str]:
+ return cls.models
+
+ @classmethod
+ def get_model(cls, model: str) -> str:
+ if not model:
+ model = cls.default_model
+ elif model in cls.model_aliases:
+ model = cls.model_aliases[model]
+ elif model not in cls.get_models():
+ raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
+ debug.last_model = model
+ return model \ No newline at end of file
diff --git a/g4f/Provider/create_images.py b/g4f/providers/create_images.py
index 2ca92432..29a2a041 100644
--- a/g4f/Provider/create_images.py
+++ b/g4f/providers/create_images.py
@@ -2,9 +2,10 @@ from __future__ import annotations
import re
import asyncio
+
from .. import debug
from ..typing import CreateResult, Messages
-from ..base_provider import BaseProvider, ProviderType
+from .types import BaseProvider, ProviderType
system_message = """
You can generate images, pictures, photos or img with the DALL-E 3 image generator.
diff --git a/g4f/providers/helper.py b/g4f/providers/helper.py
new file mode 100644
index 00000000..49d033d1
--- /dev/null
+++ b/g4f/providers/helper.py
@@ -0,0 +1,61 @@
+from __future__ import annotations
+
+import random
+import secrets
+import string
+from aiohttp import BaseConnector
+
+from ..typing import Messages, Optional
+from ..errors import MissingRequirementsError
+
+def format_prompt(messages: Messages, add_special_tokens=False) -> str:
+ """
+ Format a series of messages into a single string, optionally adding special tokens.
+
+ Args:
+ messages (Messages): A list of message dictionaries, each containing 'role' and 'content'.
+ add_special_tokens (bool): Whether to add special formatting tokens.
+
+ Returns:
+ str: A formatted string containing all messages.
+ """
+ if not add_special_tokens and len(messages) <= 1:
+ return messages[0]["content"]
+ formatted = "\n".join([
+ f'{message["role"].capitalize()}: {message["content"]}'
+ for message in messages
+ ])
+ return f"{formatted}\nAssistant:"
+
+def get_random_string(length: int = 10) -> str:
+ """
+ Generate a random string of specified length, containing lowercase letters and digits.
+
+ Args:
+ length (int, optional): Length of the random string to generate. Defaults to 10.
+
+ Returns:
+ str: A random string of the specified length.
+ """
+ return ''.join(
+ random.choice(string.ascii_lowercase + string.digits)
+ for _ in range(length)
+ )
+
+def get_random_hex() -> str:
+ """
+ Generate a random hexadecimal string of a fixed length.
+
+ Returns:
+ str: A random hexadecimal string of 32 characters (16 bytes).
+ """
+ return secrets.token_hex(16).zfill(32)
+
+def get_connector(connector: BaseConnector = None, proxy: str = None) -> Optional[BaseConnector]:
+ if proxy and not connector:
+ try:
+ from aiohttp_socks import ProxyConnector
+ connector = ProxyConnector.from_url(proxy)
+ except ImportError:
+ raise MissingRequirementsError('Install "aiohttp_socks" package for proxy support')
+ return connector \ No newline at end of file
diff --git a/g4f/Provider/retry_provider.py b/g4f/providers/retry_provider.py
index 9cc026fc..a7ab2881 100644
--- a/g4f/Provider/retry_provider.py
+++ b/g4f/providers/retry_provider.py
@@ -2,8 +2,9 @@ from __future__ import annotations
import asyncio
import random
+
from ..typing import CreateResult, Messages
-from ..base_provider import BaseRetryProvider
+from .types import BaseRetryProvider
from .. import debug
from ..errors import RetryProviderError, RetryNoProviderError
diff --git a/g4f/base_provider.py b/g4f/providers/types.py
index cc3451a2..7b11ec43 100644
--- a/g4f/base_provider.py
+++ b/g4f/providers/types.py
@@ -2,7 +2,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Union, List, Dict, Type
-from .typing import Messages, CreateResult
+from ..typing import Messages, CreateResult
class BaseProvider(ABC):
"""
@@ -81,7 +81,7 @@ class BaseProvider(ABC):
Dict[str, str]: A dictionary with provider's details.
"""
return {'name': cls.__name__, 'url': cls.url}
-
+
class BaseRetryProvider(BaseProvider):
"""
Base class for a provider that implements retry logic.
@@ -113,5 +113,5 @@ class BaseRetryProvider(BaseProvider):
self.working = True
self.exceptions: Dict[str, Exception] = {}
self.last_provider: Type[BaseProvider] = None
-
+
ProviderType = Union[Type[BaseProvider], BaseRetryProvider] \ No newline at end of file