summaryrefslogtreecommitdiffstats
path: root/g4f
diff options
context:
space:
mode:
Diffstat (limited to 'g4f')
-rw-r--r--g4f/Provider/GeminiPro.py19
-rw-r--r--g4f/__init__.py2
2 files changed, 11 insertions, 10 deletions
diff --git a/g4f/Provider/GeminiPro.py b/g4f/Provider/GeminiPro.py
index b296f253..e1738dc8 100644
--- a/g4f/Provider/GeminiPro.py
+++ b/g4f/Provider/GeminiPro.py
@@ -7,7 +7,7 @@ from aiohttp import ClientSession
from ..typing import AsyncResult, Messages, ImageType
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..image import to_bytes, is_accepted_format
-
+from ..errors import MissingAuthError
class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://ai.google.dev"
@@ -29,7 +29,8 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
) -> AsyncResult:
model = "gemini-pro-vision" if not model and image else model
model = cls.get_model(model)
- api_key = api_key if api_key else kwargs.get("access_token")
+ if not api_key:
+ raise MissingAuthError('Missing "api_key" for auth')
headers = {
"Content-Type": "application/json",
}
@@ -53,13 +54,13 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
})
data = {
"contents": contents,
- # "generationConfig": {
- # "stopSequences": kwargs.get("stop"),
- # "temperature": kwargs.get("temperature"),
- # "maxOutputTokens": kwargs.get("max_tokens"),
- # "topP": kwargs.get("top_p"),
- # "topK": kwargs.get("top_k"),
- # }
+ "generationConfig": {
+ "stopSequences": kwargs.get("stop"),
+ "temperature": kwargs.get("temperature"),
+ "maxOutputTokens": kwargs.get("max_tokens"),
+ "topP": kwargs.get("top_p"),
+ "topK": kwargs.get("top_k"),
+ }
}
async with session.post(url, params={"key": api_key}, json=data, proxy=proxy) as response:
if not response.ok:
diff --git a/g4f/__init__.py b/g4f/__init__.py
index 6716c727..5df942ae 100644
--- a/g4f/__init__.py
+++ b/g4f/__init__.py
@@ -42,7 +42,7 @@ def get_model_and_provider(model : Union[Model, str],
if debug.version_check:
debug.version_check = False
version.utils.check_version()
-
+
if isinstance(provider, str):
if " " in provider:
provider_list = [ProviderUtils.convert[p] for p in provider.split() if p in ProviderUtils.convert]