diff options
author | H Lohaus <hlohaus@users.noreply.github.com> | 2024-01-23 20:08:41 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-23 20:08:41 +0100 |
commit | 2b140a32554c1e94d095c55599a2f93e86f957cf (patch) | |
tree | e2770d97f0242a0b99a3af68ea4fcf25227dfcc8 /g4f/Provider/DeepInfra.py | |
parent | ~ (diff) | |
parent | Add ProviderModelMixin for model selection (diff) | |
download | gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.gz gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.bz2 gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.lz gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.xz gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.zst gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.zip |
Diffstat (limited to 'g4f/Provider/DeepInfra.py')
-rw-r--r-- | g4f/Provider/DeepInfra.py | 26 |
1 files changed, 17 insertions, 9 deletions
diff --git a/g4f/Provider/DeepInfra.py b/g4f/Provider/DeepInfra.py index acde1200..2f34b679 100644 --- a/g4f/Provider/DeepInfra.py +++ b/g4f/Provider/DeepInfra.py @@ -1,18 +1,27 @@ from __future__ import annotations import json -from ..typing import AsyncResult, Messages -from .base_provider import AsyncGeneratorProvider -from ..requests import StreamSession +import requests +from ..typing import AsyncResult, Messages +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin +from ..requests import StreamSession -class DeepInfra(AsyncGeneratorProvider): +class DeepInfra(AsyncGeneratorProvider, ProviderModelMixin): url = "https://deepinfra.com" working = True supports_stream = True supports_message_history = True - + default_model = 'meta-llama/Llama-2-70b-chat-hf' + @staticmethod + def get_models(): + url = 'https://api.deepinfra.com/models/featured' + models = requests.get(url).json() + return [model['model_name'] for model in models] + + @classmethod async def create_async_generator( + cls, model: str, messages: Messages, stream: bool, @@ -21,8 +30,6 @@ class DeepInfra(AsyncGeneratorProvider): auth: str = None, **kwargs ) -> AsyncResult: - if not model: - model = 'meta-llama/Llama-2-70b-chat-hf' headers = { 'Accept-Encoding': 'gzip, deflate, br', 'Accept-Language': 'en-US', @@ -49,7 +56,7 @@ class DeepInfra(AsyncGeneratorProvider): impersonate="chrome110" ) as session: json_data = { - 'model' : model, + 'model' : cls.get_model(model), 'messages': messages, 'stream' : True } @@ -70,7 +77,8 @@ class DeepInfra(AsyncGeneratorProvider): if token: if first: token = token.lstrip() + if token: first = False - yield token + yield token except Exception: raise RuntimeError(f"Response: {line}") |