diff options
Diffstat (limited to 'g4f/providers')
-rw-r--r-- | g4f/providers/retry_provider.py | 41 |
1 files changed, 14 insertions, 27 deletions
diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py index cde8c848..0061bcc1 100644 --- a/g4f/providers/retry_provider.py +++ b/g4f/providers/retry_provider.py @@ -4,11 +4,11 @@ import asyncio import random from ..typing import Type, List, CreateResult, Messages, Iterator, AsyncResult -from .types import BaseProvider, BaseRetryProvider +from .types import BaseProvider, BaseRetryProvider, ProviderType from .. import debug from ..errors import RetryProviderError, RetryNoProviderError -class NewBaseRetryProvider(BaseRetryProvider): +class IterListProvider(BaseRetryProvider): def __init__( self, providers: List[Type[BaseProvider]], @@ -45,21 +45,17 @@ class NewBaseRetryProvider(BaseRetryProvider): Raises: Exception: Any exception encountered during the completion process. """ - providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers - if self.shuffle: - random.shuffle(providers) - exceptions = {} started: bool = False - for provider in providers: + for provider in self.get_providers(stream): self.last_provider = provider try: if debug.logging: print(f"Using {provider.__name__} provider") for token in provider.create_completion(model, messages, stream, **kwargs): yield token - started = True + started = True if started: return except Exception as e: @@ -87,13 +83,9 @@ class NewBaseRetryProvider(BaseRetryProvider): Raises: Exception: Any exception encountered during the asynchronous completion process. """ - providers = self.providers - if self.shuffle: - random.shuffle(providers) - exceptions = {} - for provider in providers: + for provider in self.get_providers(False): self.last_provider = provider try: if debug.logging: @@ -109,8 +101,8 @@ class NewBaseRetryProvider(BaseRetryProvider): raise_exceptions(exceptions) - def get_providers(self, stream: bool): - providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers + def get_providers(self, stream: bool) -> list[ProviderType]: + providers = [p for p in self.providers if p.supports_stream] if stream else self.providers if self.shuffle: random.shuffle(providers) return providers @@ -138,7 +130,7 @@ class NewBaseRetryProvider(BaseRetryProvider): else: for token in provider.create_completion(model, messages, stream, **kwargs): yield token - started = True + started = True if started: return except Exception as e: @@ -150,7 +142,7 @@ class NewBaseRetryProvider(BaseRetryProvider): raise_exceptions(exceptions) -class RetryProvider(NewBaseRetryProvider): +class RetryProvider(IterListProvider): def __init__( self, providers: List[Type[BaseProvider]], @@ -188,11 +180,10 @@ class RetryProvider(NewBaseRetryProvider): Raises: Exception: Any exception encountered during the completion process. """ - providers = self.get_providers(stream) - if self.single_provider_retry and len(providers) == 1: + if self.single_provider_retry: exceptions = {} started: bool = False - provider = providers[0] + provider = self.providers[0] self.last_provider = provider for attempt in range(self.max_retries): try: @@ -200,7 +191,7 @@ class RetryProvider(NewBaseRetryProvider): print(f"Using {provider.__name__} provider (attempt {attempt + 1})") for token in provider.create_completion(model, messages, stream, **kwargs): yield token - started = True + started = True if started: return except Exception as e: @@ -229,14 +220,10 @@ class RetryProvider(NewBaseRetryProvider): Raises: Exception: Any exception encountered during the asynchronous completion process. """ - providers = self.providers - if self.shuffle: - random.shuffle(providers) - exceptions = {} - if self.single_provider_retry and len(providers) == 1: - provider = providers[0] + if self.single_provider_retry: + provider = self.providers[0] self.last_provider = provider for attempt in range(self.max_retries): try: |