summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/retry_provider.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/retry_provider.py')
-rw-r--r--g4f/Provider/retry_provider.py81
1 files changed, 81 insertions, 0 deletions
diff --git a/g4f/Provider/retry_provider.py b/g4f/Provider/retry_provider.py
new file mode 100644
index 00000000..e1a9cd1f
--- /dev/null
+++ b/g4f/Provider/retry_provider.py
@@ -0,0 +1,81 @@
+from __future__ import annotations
+
+import random
+
+from ..typing import CreateResult
+from .base_provider import BaseProvider, AsyncProvider
+
+
+class RetryProvider(AsyncProvider):
+ __name__ = "RetryProvider"
+ working = True
+ needs_auth = False
+ supports_stream = True
+ supports_gpt_35_turbo = False
+ supports_gpt_4 = False
+
+ def __init__(
+ self,
+ providers: list[type[BaseProvider]],
+ shuffle: bool = True
+ ) -> None:
+ self.providers = providers
+ self.shuffle = shuffle
+
+
+ def create_completion(
+ self,
+ model: str,
+ messages: list[dict[str, str]],
+ stream: bool = False,
+ **kwargs
+ ) -> CreateResult:
+ if stream:
+ providers = [provider for provider in self.providers if provider.supports_stream]
+ else:
+ providers = self.providers
+ if self.shuffle:
+ random.shuffle(providers)
+
+ self.exceptions = {}
+ started = False
+ for provider in providers:
+ try:
+ for token in provider.create_completion(model, messages, stream, **kwargs):
+ yield token
+ started = True
+ if started:
+ return
+ except Exception as e:
+ self.exceptions[provider.__name__] = e
+ if started:
+ break
+
+ self.raise_exceptions()
+
+ async def create_async(
+ self,
+ model: str,
+ messages: list[dict[str, str]],
+ **kwargs
+ ) -> str:
+ providers = [provider for provider in self.providers if issubclass(provider, AsyncProvider)]
+ if self.shuffle:
+ random.shuffle(providers)
+
+ self.exceptions = {}
+ for provider in providers:
+ try:
+ return await provider.create_async(model, messages, **kwargs)
+ except Exception as e:
+ self.exceptions[provider.__name__] = e
+
+ self.raise_exceptions()
+
+ def raise_exceptions(self):
+ if self.exceptions:
+ raise RuntimeError("\n".join(["All providers failed:"] + [
+ f"{p}: {self.exceptions[p].__class__.__name__}: {self.exceptions[p]}" for p in self.exceptions
+ ]))
+
+ raise RuntimeError("No provider found") \ No newline at end of file