diff options
Diffstat (limited to 'g4f/Provider')
-rw-r--r-- | g4f/Provider/ReplicateHome.py | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/g4f/Provider/ReplicateHome.py b/g4f/Provider/ReplicateHome.py index e6c8d2d3..017e7acb 100644 --- a/g4f/Provider/ReplicateHome.py +++ b/g4f/Provider/ReplicateHome.py @@ -16,7 +16,7 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): working = True default_model = 'stability-ai/stable-diffusion-3' models = [ - # Models for image generation + # Models for image generation 'stability-ai/stable-diffusion-3', 'bytedance/sdxl-lightning-4step', 'playgroundai/playground-v2.5-1024px-aesthetic', @@ -28,7 +28,7 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): ] versions = { - # Model versions for generating images + # Model versions for generating images 'stability-ai/stable-diffusion-3': [ "527d2a6296facb8e47ba1eaf17f142c240c19a30894f437feee9b91cc29d8e4f" ], @@ -39,7 +39,6 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): "a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24" ], - # Model versions for text generation 'meta/meta-llama-3-70b-instruct': [ "dp-cf04fe09351e25db628e8b6181276547" @@ -55,6 +54,24 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): image_models = {"stability-ai/stable-diffusion-3", "bytedance/sdxl-lightning-4step", "playgroundai/playground-v2.5-1024px-aesthetic"} text_models = {"meta/meta-llama-3-70b-instruct", "mistralai/mixtral-8x7b-instruct-v0.1", "google-deepmind/gemma-2b-it"} + model_aliases = { + "stable-diffusion-3": "stability-ai/stable-diffusion-3", + "sdxl-lightning-4step": "bytedance/sdxl-lightning-4step", + "playground-v2.5-aesthetic": "playgroundai/playground-v2.5-1024px-aesthetic", + "llama-3-70b": "meta/meta-llama-3-70b-instruct", + "mixtral-8x7b": "mistralai/mixtral-8x7b-instruct-v0.1", + "gemma-2b": "google-deepmind/gemma-2b-it", + } + + @classmethod + def get_model(cls, model: str) -> str: + if model in cls.models: + return model + elif model in cls.model_aliases: + return cls.model_aliases[model] + else: + return cls.default_model + @classmethod async def create_async_generator( cls, @@ -76,6 +93,7 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): extra_data: Dict[str, Any] = {}, **kwargs: Any ) -> Union[str, ImageResponse]: + model = cls.get_model(model) # Use the get_model method to resolve model name headers = { 'Accept-Encoding': 'gzip, deflate, br', 'Accept-Language': 'en-US', @@ -109,7 +127,7 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): "version": version } if api_key is None: - data["model"] = cls.get_model(model) + data["model"] = model url = "https://homepage.replicate.com/api/prediction" else: url = "https://api.replicate.com/v1/predictions" |