diff options
Diffstat (limited to 'g4f/Provider/GeminiPro.py')
-rw-r--r-- | g4f/Provider/GeminiPro.py | 28 |
1 files changed, 18 insertions, 10 deletions
diff --git a/g4f/Provider/GeminiPro.py b/g4f/Provider/GeminiPro.py index e1738dc8..87ded3ac 100644 --- a/g4f/Provider/GeminiPro.py +++ b/g4f/Provider/GeminiPro.py @@ -13,6 +13,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): url = "https://ai.google.dev" working = True supports_message_history = True + needs_auth = True default_model = "gemini-pro" models = ["gemini-pro", "gemini-pro-vision"] @@ -24,19 +25,27 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): stream: bool = False, proxy: str = None, api_key: str = None, + api_base: str = None, image: ImageType = None, **kwargs ) -> AsyncResult: model = "gemini-pro-vision" if not model and image else model model = cls.get_model(model) + if not api_key: - raise MissingAuthError('Missing "api_key" for auth') - headers = { - "Content-Type": "application/json", - } + raise MissingAuthError('Missing "api_key"') + if not api_base: + api_base = f"https://generativelanguage.googleapis.com/v1beta" + + method = "streamGenerateContent" if stream else "generateContent" + url = f"{api_base.rstrip('/')}/models/{model}:{method}" + headers = None + if api_base: + headers = {f"Authorization": "Bearer {api_key}"} + else: + url += f"?key={api_key}" + async with ClientSession(headers=headers) as session: - method = "streamGenerateContent" if stream else "generateContent" - url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:{method}" contents = [ { "role": "model" if message["role"] == "assistant" else message["role"], @@ -62,7 +71,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): "topK": kwargs.get("top_k"), } } - async with session.post(url, params={"key": api_key}, json=data, proxy=proxy) as response: + async with session.post(url, json=data, proxy=proxy) as response: if not response.ok: data = await response.json() raise RuntimeError(data[0]["error"]["message"]) @@ -73,12 +82,11 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): lines = [b"{\n"] elif chunk == b",\r\n" or chunk == b"]": try: - data = b"".join(lines) - data = json.loads(data) + data = json.loads(b"".join(lines)) yield data["candidates"][0]["content"]["parts"][0]["text"] except: data = data.decode() if isinstance(data, bytes) else data - raise RuntimeError(f"Read text failed. data: {data}") + raise RuntimeError(f"Read chunk failed: {data}") lines = [] else: lines.append(chunk) |