diff options
Diffstat (limited to 'g4f/Provider')
-rw-r--r-- | g4f/Provider/__init__.py | 6 | ||||
-rw-r--r-- | g4f/Provider/retry_provider.py | 81 |
2 files changed, 86 insertions, 1 deletions
diff --git a/g4f/Provider/__init__.py b/g4f/Provider/__init__.py index 0ca22533..b9ee2544 100644 --- a/g4f/Provider/__init__.py +++ b/g4f/Provider/__init__.py @@ -38,10 +38,14 @@ from .FastGpt import FastGpt from .V50 import V50 from .Wuguokai import Wuguokai -from .base_provider import BaseProvider, AsyncProvider, AsyncGeneratorProvider +from .base_provider import BaseProvider, AsyncProvider, AsyncGeneratorProvider +from .retry_provider import RetryProvider __all__ = [ 'BaseProvider', + 'AsyncProvider', + 'AsyncGeneratorProvider', + 'RetryProvider', 'Acytoo', 'Aichat', 'Ails', diff --git a/g4f/Provider/retry_provider.py b/g4f/Provider/retry_provider.py new file mode 100644 index 00000000..e1a9cd1f --- /dev/null +++ b/g4f/Provider/retry_provider.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import random + +from ..typing import CreateResult +from .base_provider import BaseProvider, AsyncProvider + + +class RetryProvider(AsyncProvider): + __name__ = "RetryProvider" + working = True + needs_auth = False + supports_stream = True + supports_gpt_35_turbo = False + supports_gpt_4 = False + + def __init__( + self, + providers: list[type[BaseProvider]], + shuffle: bool = True + ) -> None: + self.providers = providers + self.shuffle = shuffle + + + def create_completion( + self, + model: str, + messages: list[dict[str, str]], + stream: bool = False, + **kwargs + ) -> CreateResult: + if stream: + providers = [provider for provider in self.providers if provider.supports_stream] + else: + providers = self.providers + if self.shuffle: + random.shuffle(providers) + + self.exceptions = {} + started = False + for provider in providers: + try: + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token + started = True + if started: + return + except Exception as e: + self.exceptions[provider.__name__] = e + if started: + break + + self.raise_exceptions() + + async def create_async( + self, + model: str, + messages: list[dict[str, str]], + **kwargs + ) -> str: + providers = [provider for provider in self.providers if issubclass(provider, AsyncProvider)] + if self.shuffle: + random.shuffle(providers) + + self.exceptions = {} + for provider in providers: + try: + return await provider.create_async(model, messages, **kwargs) + except Exception as e: + self.exceptions[provider.__name__] = e + + self.raise_exceptions() + + def raise_exceptions(self): + if self.exceptions: + raise RuntimeError("\n".join(["All providers failed:"] + [ + f"{p}: {self.exceptions[p].__class__.__name__}: {self.exceptions[p]}" for p in self.exceptions + ])) + + raise RuntimeError("No provider found")
\ No newline at end of file |