diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/retry_provider.py | 58 |
1 files changed, 49 insertions, 9 deletions
diff --git a/g4f/Provider/retry_provider.py b/g4f/Provider/retry_provider.py index 4d3e77ac..9cc026fc 100644 --- a/g4f/Provider/retry_provider.py +++ b/g4f/Provider/retry_provider.py @@ -7,8 +7,17 @@ from ..base_provider import BaseRetryProvider from .. import debug from ..errors import RetryProviderError, RetryNoProviderError - class RetryProvider(BaseRetryProvider): + """ + A provider class to handle retries for creating completions with different providers. + + Attributes: + providers (list): A list of provider instances. + shuffle (bool): A flag indicating whether to shuffle providers before use. + exceptions (dict): A dictionary to store exceptions encountered during retries. + last_provider (BaseProvider): The last provider that was used. + """ + def create_completion( self, model: str, @@ -16,10 +25,21 @@ class RetryProvider(BaseRetryProvider): stream: bool = False, **kwargs ) -> CreateResult: - if stream: - providers = [provider for provider in self.providers if provider.supports_stream] - else: - providers = self.providers + """ + Create a completion using available providers, with an option to stream the response. + + Args: + model (str): The model to be used for completion. + messages (Messages): The messages to be used for generating completion. + stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False. + + Yields: + CreateResult: Tokens or results from the completion. + + Raises: + Exception: Any exception encountered during the completion process. + """ + providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers if self.shuffle: random.shuffle(providers) @@ -50,10 +70,23 @@ class RetryProvider(BaseRetryProvider): messages: Messages, **kwargs ) -> str: + """ + Asynchronously create a completion using available providers. + + Args: + model (str): The model to be used for completion. + messages (Messages): The messages to be used for generating completion. + + Returns: + str: The result of the asynchronous completion. + + Raises: + Exception: Any exception encountered during the asynchronous completion process. + """ providers = self.providers if self.shuffle: random.shuffle(providers) - + self.exceptions = {} for provider in providers: self.last_provider = provider @@ -66,13 +99,20 @@ class RetryProvider(BaseRetryProvider): self.exceptions[provider.__name__] = e if debug.logging: print(f"{provider.__name__}: {e.__class__.__name__}: {e}") - + self.raise_exceptions() - + def raise_exceptions(self) -> None: + """ + Raise a combined exception if any occurred during retries. + + Raises: + RetryProviderError: If any provider encountered an exception. + RetryNoProviderError: If no provider is found. + """ if self.exceptions: raise RetryProviderError("RetryProvider failed:\n" + "\n".join([ f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items() ])) - + raise RetryNoProviderError("No provider found")
\ No newline at end of file |