summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/DeepInfra.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/DeepInfra.py26
1 files changed, 17 insertions, 9 deletions
diff --git a/g4f/Provider/DeepInfra.py b/g4f/Provider/DeepInfra.py
index acde1200..2f34b679 100644
--- a/g4f/Provider/DeepInfra.py
+++ b/g4f/Provider/DeepInfra.py
@@ -1,18 +1,27 @@
from __future__ import annotations
import json
-from ..typing import AsyncResult, Messages
-from .base_provider import AsyncGeneratorProvider
-from ..requests import StreamSession
+import requests
+from ..typing import AsyncResult, Messages
+from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
+from ..requests import StreamSession
-class DeepInfra(AsyncGeneratorProvider):
+class DeepInfra(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://deepinfra.com"
working = True
supports_stream = True
supports_message_history = True
-
+ default_model = 'meta-llama/Llama-2-70b-chat-hf'
+
@staticmethod
+ def get_models():
+ url = 'https://api.deepinfra.com/models/featured'
+ models = requests.get(url).json()
+ return [model['model_name'] for model in models]
+
+ @classmethod
async def create_async_generator(
+ cls,
model: str,
messages: Messages,
stream: bool,
@@ -21,8 +30,6 @@ class DeepInfra(AsyncGeneratorProvider):
auth: str = None,
**kwargs
) -> AsyncResult:
- if not model:
- model = 'meta-llama/Llama-2-70b-chat-hf'
headers = {
'Accept-Encoding': 'gzip, deflate, br',
'Accept-Language': 'en-US',
@@ -49,7 +56,7 @@ class DeepInfra(AsyncGeneratorProvider):
impersonate="chrome110"
) as session:
json_data = {
- 'model' : model,
+ 'model' : cls.get_model(model),
'messages': messages,
'stream' : True
}
@@ -70,7 +77,8 @@ class DeepInfra(AsyncGeneratorProvider):
if token:
if first:
token = token.lstrip()
+ if token:
first = False
- yield token
+ yield token
except Exception:
raise RuntimeError(f"Response: {line}")