summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/GeminiPro.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/GeminiPro.py')
-rw-r--r--g4f/Provider/GeminiPro.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/g4f/Provider/GeminiPro.py b/g4f/Provider/GeminiPro.py
index 792cd5d1..87ded3ac 100644
--- a/g4f/Provider/GeminiPro.py
+++ b/g4f/Provider/GeminiPro.py
@@ -32,17 +32,20 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
model = "gemini-pro-vision" if not model and image else model
model = cls.get_model(model)
- if not api_key and not api_base:
- raise MissingAuthError('Missing "api_key" or "api_base"')
+ if not api_key:
+ raise MissingAuthError('Missing "api_key"')
if not api_base:
api_base = f"https://generativelanguage.googleapis.com/v1beta"
method = "streamGenerateContent" if stream else "generateContent"
url = f"{api_base.rstrip('/')}/models/{model}:{method}"
- if api_key:
+ headers = None
+ if api_base:
+ headers = {f"Authorization": "Bearer {api_key}"}
+ else:
url += f"?key={api_key}"
- async with ClientSession() as session:
+ async with ClientSession(headers=headers) as session:
contents = [
{
"role": "model" if message["role"] == "assistant" else message["role"],
@@ -79,12 +82,11 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
lines = [b"{\n"]
elif chunk == b",\r\n" or chunk == b"]":
try:
- data = b"".join(lines)
- data = json.loads(data)
+ data = json.loads(b"".join(lines))
yield data["candidates"][0]["content"]["parts"][0]["text"]
except:
data = data.decode() if isinstance(data, bytes) else data
- raise RuntimeError(f"Read chunk failed. data: {data}")
+ raise RuntimeError(f"Read chunk failed: {data}")
lines = []
else:
lines.append(chunk)