diff options
author | Heiner Lohaus <hlohaus@users.noreply.github.com> | 2024-04-07 10:36:13 +0200 |
---|---|---|
committer | Heiner Lohaus <hlohaus@users.noreply.github.com> | 2024-04-07 10:36:13 +0200 |
commit | b35dfcd1b01c575b65e0299ef71d285dc8f41459 (patch) | |
tree | cfe5f4a390af62fafefd1d27ca2c82a23cdcab49 /g4f/Provider/needs_auth/Openai.py | |
parent | Update Gemini.py (diff) | |
download | gpt4free-b35dfcd1b01c575b65e0299ef71d285dc8f41459.tar gpt4free-b35dfcd1b01c575b65e0299ef71d285dc8f41459.tar.gz gpt4free-b35dfcd1b01c575b65e0299ef71d285dc8f41459.tar.bz2 gpt4free-b35dfcd1b01c575b65e0299ef71d285dc8f41459.tar.lz gpt4free-b35dfcd1b01c575b65e0299ef71d285dc8f41459.tar.xz gpt4free-b35dfcd1b01c575b65e0299ef71d285dc8f41459.tar.zst gpt4free-b35dfcd1b01c575b65e0299ef71d285dc8f41459.zip |
Diffstat (limited to 'g4f/Provider/needs_auth/Openai.py')
-rw-r--r-- | g4f/Provider/needs_auth/Openai.py | 96 |
1 files changed, 65 insertions, 31 deletions
diff --git a/g4f/Provider/needs_auth/Openai.py b/g4f/Provider/needs_auth/Openai.py index b876cd0b..6cd2cf86 100644 --- a/g4f/Provider/needs_auth/Openai.py +++ b/g4f/Provider/needs_auth/Openai.py @@ -3,10 +3,10 @@ from __future__ import annotations import json from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason -from ...typing import AsyncResult, Messages +from ...typing import Union, Optional, AsyncResult, Messages from ...requests.raise_for_status import raise_for_status from ...requests import StreamSession -from ...errors import MissingAuthError +from ...errors import MissingAuthError, ResponseError class Openai(AsyncGeneratorProvider, ProviderModelMixin): url = "https://openai.com" @@ -27,48 +27,82 @@ class Openai(AsyncGeneratorProvider, ProviderModelMixin): temperature: float = None, max_tokens: int = None, top_p: float = None, - stop: str = None, + stop: Union[str, list[str]] = None, stream: bool = False, + headers: dict = None, + extra_data: dict = {}, **kwargs ) -> AsyncResult: - if api_key is None: + if cls.needs_auth and api_key is None: raise MissingAuthError('Add a "api_key"') async with StreamSession( proxies={"all": proxy}, - headers=cls.get_headers(api_key), + headers=cls.get_headers(stream, api_key, headers), timeout=timeout ) as session: - data = { - "messages": messages, - "model": cls.get_model(model), - "temperature": temperature, - "max_tokens": max_tokens, - "top_p": top_p, - "stop": stop, - "stream": stream, - } + data = filter_none( + messages=messages, + model=cls.get_model(model), + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + stop=stop, + stream=stream, + **extra_data + ) async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response: await raise_for_status(response) - async for line in response.iter_lines(): - if line.startswith(b"data: ") or not stream: - async for chunk in cls.read_line(line[6:] if stream else line, stream): - yield chunk + if not stream: + data = await response.json() + choice = data["choices"][0] + if "content" in choice["message"]: + yield choice["message"]["content"].strip() + finish = cls.read_finish_reason(choice) + if finish is not None: + yield finish + else: + first = True + async for line in response.iter_lines(): + if line.startswith(b"data: "): + chunk = line[6:] + if chunk == b"[DONE]": + break + data = json.loads(chunk) + if "error_message" in data: + raise ResponseError(data["error_message"]) + choice = data["choices"][0] + if "content" in choice["delta"] and choice["delta"]["content"]: + delta = choice["delta"]["content"] + if first: + delta = delta.lstrip() + if delta: + first = False + yield delta + finish = cls.read_finish_reason(choice) + if finish is not None: + yield finish @staticmethod - async def read_line(line: str, stream: bool): - if line == b"[DONE]": - return - choice = json.loads(line)["choices"][0] - if stream and "content" in choice["delta"] and choice["delta"]["content"]: - yield choice["delta"]["content"] - elif not stream and "content" in choice["message"]: - yield choice["message"]["content"] + def read_finish_reason(choice: dict) -> Optional[FinishReason]: if "finish_reason" in choice and choice["finish_reason"] is not None: - yield FinishReason(choice["finish_reason"]) + return FinishReason(choice["finish_reason"]) - @staticmethod - def get_headers(api_key: str) -> dict: + @classmethod + def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict: return { - "Authorization": f"Bearer {api_key}", + "Accept": "text/event-stream" if stream else "application/json", "Content-Type": "application/json", - }
\ No newline at end of file + **( + {"Authorization": f"Bearer {api_key}"} + if cls.needs_auth and api_key is not None + else {} + ), + **({} if headers is None else headers) + } + +def filter_none(**kwargs) -> dict: + return { + key: value + for key, value in kwargs.items() + if value is not None + }
\ No newline at end of file |