diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/GeminiPro.py | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/g4f/Provider/GeminiPro.py b/g4f/Provider/GeminiPro.py index b296f253..e1738dc8 100644 --- a/g4f/Provider/GeminiPro.py +++ b/g4f/Provider/GeminiPro.py @@ -7,7 +7,7 @@ from aiohttp import ClientSession from ..typing import AsyncResult, Messages, ImageType from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..image import to_bytes, is_accepted_format - +from ..errors import MissingAuthError class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): url = "https://ai.google.dev" @@ -29,7 +29,8 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): ) -> AsyncResult: model = "gemini-pro-vision" if not model and image else model model = cls.get_model(model) - api_key = api_key if api_key else kwargs.get("access_token") + if not api_key: + raise MissingAuthError('Missing "api_key" for auth') headers = { "Content-Type": "application/json", } @@ -53,13 +54,13 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): }) data = { "contents": contents, - # "generationConfig": { - # "stopSequences": kwargs.get("stop"), - # "temperature": kwargs.get("temperature"), - # "maxOutputTokens": kwargs.get("max_tokens"), - # "topP": kwargs.get("top_p"), - # "topK": kwargs.get("top_k"), - # } + "generationConfig": { + "stopSequences": kwargs.get("stop"), + "temperature": kwargs.get("temperature"), + "maxOutputTokens": kwargs.get("max_tokens"), + "topP": kwargs.get("top_p"), + "topK": kwargs.get("top_k"), + } } async with session.post(url, params={"key": api_key}, json=data, proxy=proxy) as response: if not response.ok: |