diff options
author | H Lohaus <hlohaus@users.noreply.github.com> | 2024-04-10 08:14:50 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-10 08:14:50 +0200 |
commit | 00951eb79114adf74ad1a3f1ce596e9e0fa932bf (patch) | |
tree | fea75e7745d69b09d91b0003e5dbf12b77380223 /g4f/Provider/DeepInfraImage.py | |
parent | Update Dockerfile (diff) | |
download | gpt4free-00951eb79114adf74ad1a3f1ce596e9e0fa932bf.tar gpt4free-00951eb79114adf74ad1a3f1ce596e9e0fa932bf.tar.gz gpt4free-00951eb79114adf74ad1a3f1ce596e9e0fa932bf.tar.bz2 gpt4free-00951eb79114adf74ad1a3f1ce596e9e0fa932bf.tar.lz gpt4free-00951eb79114adf74ad1a3f1ce596e9e0fa932bf.tar.xz gpt4free-00951eb79114adf74ad1a3f1ce596e9e0fa932bf.tar.zst gpt4free-00951eb79114adf74ad1a3f1ce596e9e0fa932bf.zip |
Diffstat (limited to 'g4f/Provider/DeepInfraImage.py')
-rw-r--r-- | g4f/Provider/DeepInfraImage.py | 74 |
1 files changed, 74 insertions, 0 deletions
diff --git a/g4f/Provider/DeepInfraImage.py b/g4f/Provider/DeepInfraImage.py new file mode 100644 index 00000000..6099b793 --- /dev/null +++ b/g4f/Provider/DeepInfraImage.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import requests + +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin +from ..typing import AsyncResult, Messages +from ..requests import StreamSession, raise_for_status +from ..image import ImageResponse + +class DeepInfraImage(AsyncGeneratorProvider, ProviderModelMixin): + url = "https://deepinfra.com" + working = True + default_model = 'stability-ai/sdxl' + + @classmethod + def get_models(cls): + if not cls.models: + url = 'https://api.deepinfra.com/models/featured' + models = requests.get(url).json() + cls.models = [model['model_name'] for model in models if model["reported_type"] == "text-to-image"] + return cls.models + + @classmethod + async def create_async_generator( + cls, + model: str, + messages: Messages, + **kwargs + ) -> AsyncResult: + yield await cls.create_async(messages[-1]["content"], model, **kwargs) + + @classmethod + async def create_async( + cls, + prompt: str, + model: str, + api_key: str = None, + api_base: str = "https://api.deepinfra.com/v1/inference", + proxy: str = None, + timeout: int = 180, + extra_data: dict = {}, + **kwargs + ) -> ImageResponse: + headers = { + 'Accept-Encoding': 'gzip, deflate, br', + 'Accept-Language': 'en-US', + 'Connection': 'keep-alive', + 'Origin': 'https://deepinfra.com', + 'Referer': 'https://deepinfra.com/', + 'Sec-Fetch-Dest': 'empty', + 'Sec-Fetch-Mode': 'cors', + 'Sec-Fetch-Site': 'same-site', + 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36', + 'X-Deepinfra-Source': 'web-embed', + 'sec-ch-ua': '"Google Chrome";v="119", "Chromium";v="119", "Not?A_Brand";v="24"', + 'sec-ch-ua-mobile': '?0', + 'sec-ch-ua-platform': '"macOS"', + } + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" + async with StreamSession( + proxies={"all": proxy}, + headers=headers, + timeout=timeout + ) as session: + model = cls.get_model(model) + data = {"prompt": prompt, **extra_data} + data = {"input": data} if model == cls.default_model else data + async with session.post(f"{api_base.rstrip('/')}/{model}", json=data) as response: + await raise_for_status(response) + data = await response.json() + images = data["output"] if "output" in data else data["images"] + images = images[0] if len(images) == 1 else images + return ImageResponse(images, prompt)
\ No newline at end of file |