summaryrefslogtreecommitdiffstats
path: root/g4f/providers
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/providers/retry_provider.py41
1 files changed, 14 insertions, 27 deletions
diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py
index cde8c848..0061bcc1 100644
--- a/g4f/providers/retry_provider.py
+++ b/g4f/providers/retry_provider.py
@@ -4,11 +4,11 @@ import asyncio
import random
from ..typing import Type, List, CreateResult, Messages, Iterator, AsyncResult
-from .types import BaseProvider, BaseRetryProvider
+from .types import BaseProvider, BaseRetryProvider, ProviderType
from .. import debug
from ..errors import RetryProviderError, RetryNoProviderError
-class NewBaseRetryProvider(BaseRetryProvider):
+class IterListProvider(BaseRetryProvider):
def __init__(
self,
providers: List[Type[BaseProvider]],
@@ -45,21 +45,17 @@ class NewBaseRetryProvider(BaseRetryProvider):
Raises:
Exception: Any exception encountered during the completion process.
"""
- providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers
- if self.shuffle:
- random.shuffle(providers)
-
exceptions = {}
started: bool = False
- for provider in providers:
+ for provider in self.get_providers(stream):
self.last_provider = provider
try:
if debug.logging:
print(f"Using {provider.__name__} provider")
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
- started = True
+ started = True
if started:
return
except Exception as e:
@@ -87,13 +83,9 @@ class NewBaseRetryProvider(BaseRetryProvider):
Raises:
Exception: Any exception encountered during the asynchronous completion process.
"""
- providers = self.providers
- if self.shuffle:
- random.shuffle(providers)
-
exceptions = {}
- for provider in providers:
+ for provider in self.get_providers(False):
self.last_provider = provider
try:
if debug.logging:
@@ -109,8 +101,8 @@ class NewBaseRetryProvider(BaseRetryProvider):
raise_exceptions(exceptions)
- def get_providers(self, stream: bool):
- providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers
+ def get_providers(self, stream: bool) -> list[ProviderType]:
+ providers = [p for p in self.providers if p.supports_stream] if stream else self.providers
if self.shuffle:
random.shuffle(providers)
return providers
@@ -138,7 +130,7 @@ class NewBaseRetryProvider(BaseRetryProvider):
else:
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
- started = True
+ started = True
if started:
return
except Exception as e:
@@ -150,7 +142,7 @@ class NewBaseRetryProvider(BaseRetryProvider):
raise_exceptions(exceptions)
-class RetryProvider(NewBaseRetryProvider):
+class RetryProvider(IterListProvider):
def __init__(
self,
providers: List[Type[BaseProvider]],
@@ -188,11 +180,10 @@ class RetryProvider(NewBaseRetryProvider):
Raises:
Exception: Any exception encountered during the completion process.
"""
- providers = self.get_providers(stream)
- if self.single_provider_retry and len(providers) == 1:
+ if self.single_provider_retry:
exceptions = {}
started: bool = False
- provider = providers[0]
+ provider = self.providers[0]
self.last_provider = provider
for attempt in range(self.max_retries):
try:
@@ -200,7 +191,7 @@ class RetryProvider(NewBaseRetryProvider):
print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
- started = True
+ started = True
if started:
return
except Exception as e:
@@ -229,14 +220,10 @@ class RetryProvider(NewBaseRetryProvider):
Raises:
Exception: Any exception encountered during the asynchronous completion process.
"""
- providers = self.providers
- if self.shuffle:
- random.shuffle(providers)
-
exceptions = {}
- if self.single_provider_retry and len(providers) == 1:
- provider = providers[0]
+ if self.single_provider_retry:
+ provider = self.providers[0]
self.last_provider = provider
for attempt in range(self.max_retries):
try: