diff options
Diffstat (limited to 'g4f/providers/retry_provider.py')
-rw-r--r-- | g4f/providers/retry_provider.py | 119 |
1 files changed, 119 insertions, 0 deletions
diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py new file mode 100644 index 00000000..a7ab2881 --- /dev/null +++ b/g4f/providers/retry_provider.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import asyncio +import random + +from ..typing import CreateResult, Messages +from .types import BaseRetryProvider +from .. import debug +from ..errors import RetryProviderError, RetryNoProviderError + +class RetryProvider(BaseRetryProvider): + """ + A provider class to handle retries for creating completions with different providers. + + Attributes: + providers (list): A list of provider instances. + shuffle (bool): A flag indicating whether to shuffle providers before use. + exceptions (dict): A dictionary to store exceptions encountered during retries. + last_provider (BaseProvider): The last provider that was used. + """ + + def create_completion( + self, + model: str, + messages: Messages, + stream: bool = False, + **kwargs + ) -> CreateResult: + """ + Create a completion using available providers, with an option to stream the response. + + Args: + model (str): The model to be used for completion. + messages (Messages): The messages to be used for generating completion. + stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False. + + Yields: + CreateResult: Tokens or results from the completion. + + Raises: + Exception: Any exception encountered during the completion process. + """ + providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers + if self.shuffle: + random.shuffle(providers) + + self.exceptions = {} + started: bool = False + for provider in providers: + self.last_provider = provider + try: + if debug.logging: + print(f"Using {provider.__name__} provider") + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token + started = True + if started: + return + except Exception as e: + self.exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + if started: + raise e + + self.raise_exceptions() + + async def create_async( + self, + model: str, + messages: Messages, + **kwargs + ) -> str: + """ + Asynchronously create a completion using available providers. + + Args: + model (str): The model to be used for completion. + messages (Messages): The messages to be used for generating completion. + + Returns: + str: The result of the asynchronous completion. + + Raises: + Exception: Any exception encountered during the asynchronous completion process. + """ + providers = self.providers + if self.shuffle: + random.shuffle(providers) + + self.exceptions = {} + for provider in providers: + self.last_provider = provider + try: + return await asyncio.wait_for( + provider.create_async(model, messages, **kwargs), + timeout=kwargs.get("timeout", 60) + ) + except Exception as e: + self.exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + + self.raise_exceptions() + + def raise_exceptions(self) -> None: + """ + Raise a combined exception if any occurred during retries. + + Raises: + RetryProviderError: If any provider encountered an exception. + RetryNoProviderError: If no provider is found. + """ + if self.exceptions: + raise RetryProviderError("RetryProvider failed:\n" + "\n".join([ + f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items() + ])) + + raise RetryNoProviderError("No provider found")
\ No newline at end of file |