summaryrefslogtreecommitdiffstats
path: root/g4f/Provider
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider')
-rw-r--r--g4f/Provider/ReplicateHome.py26
1 files changed, 22 insertions, 4 deletions
diff --git a/g4f/Provider/ReplicateHome.py b/g4f/Provider/ReplicateHome.py
index e6c8d2d3..017e7acb 100644
--- a/g4f/Provider/ReplicateHome.py
+++ b/g4f/Provider/ReplicateHome.py
@@ -16,7 +16,7 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
working = True
default_model = 'stability-ai/stable-diffusion-3'
models = [
- # Models for image generation
+ # Models for image generation
'stability-ai/stable-diffusion-3',
'bytedance/sdxl-lightning-4step',
'playgroundai/playground-v2.5-1024px-aesthetic',
@@ -28,7 +28,7 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
]
versions = {
- # Model versions for generating images
+ # Model versions for generating images
'stability-ai/stable-diffusion-3': [
"527d2a6296facb8e47ba1eaf17f142c240c19a30894f437feee9b91cc29d8e4f"
],
@@ -39,7 +39,6 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
"a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24"
],
-
# Model versions for text generation
'meta/meta-llama-3-70b-instruct': [
"dp-cf04fe09351e25db628e8b6181276547"
@@ -55,6 +54,24 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
image_models = {"stability-ai/stable-diffusion-3", "bytedance/sdxl-lightning-4step", "playgroundai/playground-v2.5-1024px-aesthetic"}
text_models = {"meta/meta-llama-3-70b-instruct", "mistralai/mixtral-8x7b-instruct-v0.1", "google-deepmind/gemma-2b-it"}
+ model_aliases = {
+ "stable-diffusion-3": "stability-ai/stable-diffusion-3",
+ "sdxl-lightning-4step": "bytedance/sdxl-lightning-4step",
+ "playground-v2.5-aesthetic": "playgroundai/playground-v2.5-1024px-aesthetic",
+ "llama-3-70b": "meta/meta-llama-3-70b-instruct",
+ "mixtral-8x7b": "mistralai/mixtral-8x7b-instruct-v0.1",
+ "gemma-2b": "google-deepmind/gemma-2b-it",
+ }
+
+ @classmethod
+ def get_model(cls, model: str) -> str:
+ if model in cls.models:
+ return model
+ elif model in cls.model_aliases:
+ return cls.model_aliases[model]
+ else:
+ return cls.default_model
+
@classmethod
async def create_async_generator(
cls,
@@ -76,6 +93,7 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
extra_data: Dict[str, Any] = {},
**kwargs: Any
) -> Union[str, ImageResponse]:
+ model = cls.get_model(model) # Use the get_model method to resolve model name
headers = {
'Accept-Encoding': 'gzip, deflate, br',
'Accept-Language': 'en-US',
@@ -109,7 +127,7 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
"version": version
}
if api_key is None:
- data["model"] = cls.get_model(model)
+ data["model"] = model
url = "https://homepage.replicate.com/api/prediction"
else:
url = "https://api.replicate.com/v1/predictions"