diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/providers/retry_provider.py | 130 |
1 files changed, 108 insertions, 22 deletions
diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py index a7ab2881..52f473e9 100644 --- a/g4f/providers/retry_provider.py +++ b/g4f/providers/retry_provider.py @@ -3,22 +3,37 @@ from __future__ import annotations import asyncio import random -from ..typing import CreateResult, Messages -from .types import BaseRetryProvider +from ..typing import Type, List, CreateResult, Messages, Iterator +from .types import BaseProvider, BaseRetryProvider from .. import debug from ..errors import RetryProviderError, RetryNoProviderError class RetryProvider(BaseRetryProvider): + def __init__( + self, + providers: List[Type[BaseProvider]], + shuffle: bool = True + ) -> None: + """ + Initialize the BaseRetryProvider. + + Args: + providers (List[Type[BaseProvider]]): List of providers to use. + shuffle (bool): Whether to shuffle the providers list. + """ + self.providers = providers + self.shuffle = shuffle + self.working = True + self.last_provider: Type[BaseProvider] = None + """ A provider class to handle retries for creating completions with different providers. Attributes: providers (list): A list of provider instances. shuffle (bool): A flag indicating whether to shuffle providers before use. - exceptions (dict): A dictionary to store exceptions encountered during retries. last_provider (BaseProvider): The last provider that was used. """ - def create_completion( self, model: str, @@ -44,7 +59,7 @@ class RetryProvider(BaseRetryProvider): if self.shuffle: random.shuffle(providers) - self.exceptions = {} + exceptions = {} started: bool = False for provider in providers: self.last_provider = provider @@ -57,13 +72,13 @@ class RetryProvider(BaseRetryProvider): if started: return except Exception as e: - self.exceptions[provider.__name__] = e + exceptions[provider.__name__] = e if debug.logging: print(f"{provider.__name__}: {e.__class__.__name__}: {e}") if started: raise e - self.raise_exceptions() + raise_exceptions(exceptions) async def create_async( self, @@ -88,7 +103,7 @@ class RetryProvider(BaseRetryProvider): if self.shuffle: random.shuffle(providers) - self.exceptions = {} + exceptions = {} for provider in providers: self.last_provider = provider try: @@ -97,23 +112,94 @@ class RetryProvider(BaseRetryProvider): timeout=kwargs.get("timeout", 60) ) except Exception as e: - self.exceptions[provider.__name__] = e + exceptions[provider.__name__] = e if debug.logging: print(f"{provider.__name__}: {e.__class__.__name__}: {e}") - self.raise_exceptions() + raise_exceptions(exceptions) - def raise_exceptions(self) -> None: - """ - Raise a combined exception if any occurred during retries. +class IterProvider(BaseRetryProvider): + __name__ = "IterProvider" - Raises: - RetryProviderError: If any provider encountered an exception. - RetryNoProviderError: If no provider is found. - """ - if self.exceptions: - raise RetryProviderError("RetryProvider failed:\n" + "\n".join([ - f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items() - ])) + def __init__( + self, + providers: List[BaseProvider], + ) -> None: + providers.reverse() + self.providers: List[BaseProvider] = providers + self.working: bool = True + self.last_provider: BaseProvider = None + + def create_completion( + self, + model: str, + messages: Messages, + stream: bool = False, + **kwargs + ) -> CreateResult: + exceptions: dict = {} + started: bool = False + for provider in self.iter_providers(): + if stream and not provider.supports_stream: + continue + try: + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token + started = True + if started: + return + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + if started: + raise e + raise_exceptions(exceptions) + + async def create_async( + self, + model: str, + messages: Messages, + **kwargs + ) -> str: + exceptions: dict = {} + for provider in self.iter_providers(): + try: + return await asyncio.wait_for( + provider.create_async(model, messages, **kwargs), + timeout=kwargs.get("timeout", 60) + ) + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + raise_exceptions(exceptions) + + def iter_providers(self) -> Iterator[BaseProvider]: + used_provider = [] + try: + while self.providers: + provider = self.providers.pop() + used_provider.append(provider) + self.last_provider = provider + if debug.logging: + print(f"Using {provider.__name__} provider") + yield provider + finally: + used_provider.reverse() + self.providers = [*used_provider, *self.providers] + +def raise_exceptions(exceptions: dict) -> None: + """ + Raise a combined exception if any occurred during retries. + + Raises: + RetryProviderError: If any provider encountered an exception. + RetryNoProviderError: If no provider is found. + """ + if exceptions: + raise RetryProviderError("RetryProvider failed:\n" + "\n".join([ + f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in exceptions.items() + ])) - raise RetryNoProviderError("No provider found")
\ No newline at end of file + raise RetryNoProviderError("No provider found")
\ No newline at end of file |