From fd5fa8a4ebaf80084894141a1164b2da8f36d73d Mon Sep 17 00:00:00 2001 From: hlohaus <983577+hlohaus@users.noreply.github.com> Date: Fri, 24 Jan 2025 09:45:40 +0100 Subject: Update provider parameters, check for valid provider Fix reading model list in GeminiPro Fix check content-type in OpenaiAPI --- g4f/Provider/needs_auth/GeminiPro.py | 6 ++++-- g4f/Provider/needs_auth/OpenaiAPI.py | 7 ++++--- 2 files changed, 8 insertions(+), 5 deletions(-) (limited to 'g4f/Provider/needs_auth') diff --git a/g4f/Provider/needs_auth/GeminiPro.py b/g4f/Provider/needs_auth/GeminiPro.py index 89dbf52e..9d6de82c 100644 --- a/g4f/Provider/needs_auth/GeminiPro.py +++ b/g4f/Provider/needs_auth/GeminiPro.py @@ -23,6 +23,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): working = True supports_message_history = True + supports_system_message = True needs_auth = True default_model = "gemini-1.5-pro" @@ -39,7 +40,8 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]: if not cls.models: try: - response = requests.get(f"{api_base}/models?key={api_key}") + url = f"{cls.api_base if not api_base else api_base}/models" + response = requests.get(url, params={"key": api_key}) raise_for_status(response) data = response.json() cls.models = [ @@ -50,7 +52,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): cls.models.sort() except Exception as e: debug.log(e) - cls.models = cls.fallback_models + return cls.fallback_models return cls.models @classmethod diff --git a/g4f/Provider/needs_auth/OpenaiAPI.py b/g4f/Provider/needs_auth/OpenaiAPI.py index fadf5f53..1b2e7ae4 100644 --- a/g4f/Provider/needs_auth/OpenaiAPI.py +++ b/g4f/Provider/needs_auth/OpenaiAPI.py @@ -108,7 +108,8 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin): if api_endpoint is None: api_endpoint = f"{api_base.rstrip('/')}/chat/completions" async with session.post(api_endpoint, json=data) as response: - if response.headers.get("content-type", None if stream else "application/json") == "application/json": + content_type = response.headers.get("content-type", "text/event-stream" if stream else "application/json") + if content_type.startswith("application/json"): data = await response.json() cls.raise_error(data) await raise_for_status(response) @@ -122,7 +123,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin): if "finish_reason" in choice and choice["finish_reason"] is not None: yield FinishReason(choice["finish_reason"]) return - elif response.headers.get("content-type", "text/event-stream" if stream else None) == "text/event-stream": + elif content_type.startswith("text/event-stream"): await raise_for_status(response) first = True async for line in response.iter_lines(): @@ -147,7 +148,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin): break else: await raise_for_status(response) - raise ResponseError(f"Not supported content-type: {response.headers.get('content-type')}") + raise ResponseError(f"Not supported content-type: {content_type}") @classmethod def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict: -- cgit v1.2.3