summaryrefslogtreecommitdiffstats
path: root/g4f/providers/retry_provider.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/providers/retry_provider.py')
-rw-r--r--g4f/providers/retry_provider.py130
1 files changed, 108 insertions, 22 deletions
diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py
index a7ab2881..52f473e9 100644
--- a/g4f/providers/retry_provider.py
+++ b/g4f/providers/retry_provider.py
@@ -3,22 +3,37 @@ from __future__ import annotations
import asyncio
import random
-from ..typing import CreateResult, Messages
-from .types import BaseRetryProvider
+from ..typing import Type, List, CreateResult, Messages, Iterator
+from .types import BaseProvider, BaseRetryProvider
from .. import debug
from ..errors import RetryProviderError, RetryNoProviderError
class RetryProvider(BaseRetryProvider):
+ def __init__(
+ self,
+ providers: List[Type[BaseProvider]],
+ shuffle: bool = True
+ ) -> None:
+ """
+ Initialize the BaseRetryProvider.
+
+ Args:
+ providers (List[Type[BaseProvider]]): List of providers to use.
+ shuffle (bool): Whether to shuffle the providers list.
+ """
+ self.providers = providers
+ self.shuffle = shuffle
+ self.working = True
+ self.last_provider: Type[BaseProvider] = None
+
"""
A provider class to handle retries for creating completions with different providers.
Attributes:
providers (list): A list of provider instances.
shuffle (bool): A flag indicating whether to shuffle providers before use.
- exceptions (dict): A dictionary to store exceptions encountered during retries.
last_provider (BaseProvider): The last provider that was used.
"""
-
def create_completion(
self,
model: str,
@@ -44,7 +59,7 @@ class RetryProvider(BaseRetryProvider):
if self.shuffle:
random.shuffle(providers)
- self.exceptions = {}
+ exceptions = {}
started: bool = False
for provider in providers:
self.last_provider = provider
@@ -57,13 +72,13 @@ class RetryProvider(BaseRetryProvider):
if started:
return
except Exception as e:
- self.exceptions[provider.__name__] = e
+ exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
if started:
raise e
- self.raise_exceptions()
+ raise_exceptions(exceptions)
async def create_async(
self,
@@ -88,7 +103,7 @@ class RetryProvider(BaseRetryProvider):
if self.shuffle:
random.shuffle(providers)
- self.exceptions = {}
+ exceptions = {}
for provider in providers:
self.last_provider = provider
try:
@@ -97,23 +112,94 @@ class RetryProvider(BaseRetryProvider):
timeout=kwargs.get("timeout", 60)
)
except Exception as e:
- self.exceptions[provider.__name__] = e
+ exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
- self.raise_exceptions()
+ raise_exceptions(exceptions)
- def raise_exceptions(self) -> None:
- """
- Raise a combined exception if any occurred during retries.
+class IterProvider(BaseRetryProvider):
+ __name__ = "IterProvider"
- Raises:
- RetryProviderError: If any provider encountered an exception.
- RetryNoProviderError: If no provider is found.
- """
- if self.exceptions:
- raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
- f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items()
- ]))
+ def __init__(
+ self,
+ providers: List[BaseProvider],
+ ) -> None:
+ providers.reverse()
+ self.providers: List[BaseProvider] = providers
+ self.working: bool = True
+ self.last_provider: BaseProvider = None
+
+ def create_completion(
+ self,
+ model: str,
+ messages: Messages,
+ stream: bool = False,
+ **kwargs
+ ) -> CreateResult:
+ exceptions: dict = {}
+ started: bool = False
+ for provider in self.iter_providers():
+ if stream and not provider.supports_stream:
+ continue
+ try:
+ for token in provider.create_completion(model, messages, stream, **kwargs):
+ yield token
+ started = True
+ if started:
+ return
+ except Exception as e:
+ exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+ if started:
+ raise e
+ raise_exceptions(exceptions)
+
+ async def create_async(
+ self,
+ model: str,
+ messages: Messages,
+ **kwargs
+ ) -> str:
+ exceptions: dict = {}
+ for provider in self.iter_providers():
+ try:
+ return await asyncio.wait_for(
+ provider.create_async(model, messages, **kwargs),
+ timeout=kwargs.get("timeout", 60)
+ )
+ except Exception as e:
+ exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+ raise_exceptions(exceptions)
+
+ def iter_providers(self) -> Iterator[BaseProvider]:
+ used_provider = []
+ try:
+ while self.providers:
+ provider = self.providers.pop()
+ used_provider.append(provider)
+ self.last_provider = provider
+ if debug.logging:
+ print(f"Using {provider.__name__} provider")
+ yield provider
+ finally:
+ used_provider.reverse()
+ self.providers = [*used_provider, *self.providers]
+
+def raise_exceptions(exceptions: dict) -> None:
+ """
+ Raise a combined exception if any occurred during retries.
+
+ Raises:
+ RetryProviderError: If any provider encountered an exception.
+ RetryNoProviderError: If no provider is found.
+ """
+ if exceptions:
+ raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
+ f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in exceptions.items()
+ ]))
- raise RetryNoProviderError("No provider found") \ No newline at end of file
+ raise RetryNoProviderError("No provider found") \ No newline at end of file