summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/GeminiPro.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/GeminiPro.py')
-rw-r--r--g4f/Provider/GeminiPro.py19
1 files changed, 10 insertions, 9 deletions
diff --git a/g4f/Provider/GeminiPro.py b/g4f/Provider/GeminiPro.py
index b296f253..e1738dc8 100644
--- a/g4f/Provider/GeminiPro.py
+++ b/g4f/Provider/GeminiPro.py
@@ -7,7 +7,7 @@ from aiohttp import ClientSession
from ..typing import AsyncResult, Messages, ImageType
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..image import to_bytes, is_accepted_format
-
+from ..errors import MissingAuthError
class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://ai.google.dev"
@@ -29,7 +29,8 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
) -> AsyncResult:
model = "gemini-pro-vision" if not model and image else model
model = cls.get_model(model)
- api_key = api_key if api_key else kwargs.get("access_token")
+ if not api_key:
+ raise MissingAuthError('Missing "api_key" for auth')
headers = {
"Content-Type": "application/json",
}
@@ -53,13 +54,13 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
})
data = {
"contents": contents,
- # "generationConfig": {
- # "stopSequences": kwargs.get("stop"),
- # "temperature": kwargs.get("temperature"),
- # "maxOutputTokens": kwargs.get("max_tokens"),
- # "topP": kwargs.get("top_p"),
- # "topK": kwargs.get("top_k"),
- # }
+ "generationConfig": {
+ "stopSequences": kwargs.get("stop"),
+ "temperature": kwargs.get("temperature"),
+ "maxOutputTokens": kwargs.get("max_tokens"),
+ "topP": kwargs.get("top_p"),
+ "topK": kwargs.get("top_k"),
+ }
}
async with session.post(url, params={"key": api_key}, json=data, proxy=proxy) as response:
if not response.ok: