summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/needs_auth
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/deprecated/OpenAssistant.py (renamed from g4f/Provider/needs_auth/OpenAssistant.py)1
-rw-r--r--g4f/Provider/needs_auth/Gemini.py2
-rw-r--r--g4f/Provider/needs_auth/Openai.py96
-rw-r--r--g4f/Provider/needs_auth/ThebApi.py57
-rw-r--r--g4f/Provider/needs_auth/__init__.py1
5 files changed, 84 insertions, 73 deletions
diff --git a/g4f/Provider/needs_auth/OpenAssistant.py b/g4f/Provider/deprecated/OpenAssistant.py
index e549b517..80cae3c2 100644
--- a/g4f/Provider/needs_auth/OpenAssistant.py
+++ b/g4f/Provider/deprecated/OpenAssistant.py
@@ -8,7 +8,6 @@ from ...typing import AsyncResult, Messages
from ..base_provider import AsyncGeneratorProvider
from ..helper import format_prompt, get_cookies
-
class OpenAssistant(AsyncGeneratorProvider):
url = "https://open-assistant.io/chat"
needs_auth = True
diff --git a/g4f/Provider/needs_auth/Gemini.py b/g4f/Provider/needs_auth/Gemini.py
index 9013a4f8..fc9d9575 100644
--- a/g4f/Provider/needs_auth/Gemini.py
+++ b/g4f/Provider/needs_auth/Gemini.py
@@ -19,7 +19,7 @@ except ImportError:
from ...typing import Messages, Cookies, ImageType, AsyncResult
from ..base_provider import AsyncGeneratorProvider
from ..helper import format_prompt, get_cookies
-from requests.raise_for_status import raise_for_status
+from ...requests.raise_for_status import raise_for_status
from ...errors import MissingAuthError, MissingRequirementsError
from ...image import to_bytes, ImageResponse
from ...webdriver import get_browser, get_driver_cookies
diff --git a/g4f/Provider/needs_auth/Openai.py b/g4f/Provider/needs_auth/Openai.py
index b876cd0b..6cd2cf86 100644
--- a/g4f/Provider/needs_auth/Openai.py
+++ b/g4f/Provider/needs_auth/Openai.py
@@ -3,10 +3,10 @@ from __future__ import annotations
import json
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason
-from ...typing import AsyncResult, Messages
+from ...typing import Union, Optional, AsyncResult, Messages
from ...requests.raise_for_status import raise_for_status
from ...requests import StreamSession
-from ...errors import MissingAuthError
+from ...errors import MissingAuthError, ResponseError
class Openai(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://openai.com"
@@ -27,48 +27,82 @@ class Openai(AsyncGeneratorProvider, ProviderModelMixin):
temperature: float = None,
max_tokens: int = None,
top_p: float = None,
- stop: str = None,
+ stop: Union[str, list[str]] = None,
stream: bool = False,
+ headers: dict = None,
+ extra_data: dict = {},
**kwargs
) -> AsyncResult:
- if api_key is None:
+ if cls.needs_auth and api_key is None:
raise MissingAuthError('Add a "api_key"')
async with StreamSession(
proxies={"all": proxy},
- headers=cls.get_headers(api_key),
+ headers=cls.get_headers(stream, api_key, headers),
timeout=timeout
) as session:
- data = {
- "messages": messages,
- "model": cls.get_model(model),
- "temperature": temperature,
- "max_tokens": max_tokens,
- "top_p": top_p,
- "stop": stop,
- "stream": stream,
- }
+ data = filter_none(
+ messages=messages,
+ model=cls.get_model(model),
+ temperature=temperature,
+ max_tokens=max_tokens,
+ top_p=top_p,
+ stop=stop,
+ stream=stream,
+ **extra_data
+ )
async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response:
await raise_for_status(response)
- async for line in response.iter_lines():
- if line.startswith(b"data: ") or not stream:
- async for chunk in cls.read_line(line[6:] if stream else line, stream):
- yield chunk
+ if not stream:
+ data = await response.json()
+ choice = data["choices"][0]
+ if "content" in choice["message"]:
+ yield choice["message"]["content"].strip()
+ finish = cls.read_finish_reason(choice)
+ if finish is not None:
+ yield finish
+ else:
+ first = True
+ async for line in response.iter_lines():
+ if line.startswith(b"data: "):
+ chunk = line[6:]
+ if chunk == b"[DONE]":
+ break
+ data = json.loads(chunk)
+ if "error_message" in data:
+ raise ResponseError(data["error_message"])
+ choice = data["choices"][0]
+ if "content" in choice["delta"] and choice["delta"]["content"]:
+ delta = choice["delta"]["content"]
+ if first:
+ delta = delta.lstrip()
+ if delta:
+ first = False
+ yield delta
+ finish = cls.read_finish_reason(choice)
+ if finish is not None:
+ yield finish
@staticmethod
- async def read_line(line: str, stream: bool):
- if line == b"[DONE]":
- return
- choice = json.loads(line)["choices"][0]
- if stream and "content" in choice["delta"] and choice["delta"]["content"]:
- yield choice["delta"]["content"]
- elif not stream and "content" in choice["message"]:
- yield choice["message"]["content"]
+ def read_finish_reason(choice: dict) -> Optional[FinishReason]:
if "finish_reason" in choice and choice["finish_reason"] is not None:
- yield FinishReason(choice["finish_reason"])
+ return FinishReason(choice["finish_reason"])
- @staticmethod
- def get_headers(api_key: str) -> dict:
+ @classmethod
+ def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
return {
- "Authorization": f"Bearer {api_key}",
+ "Accept": "text/event-stream" if stream else "application/json",
"Content-Type": "application/json",
- } \ No newline at end of file
+ **(
+ {"Authorization": f"Bearer {api_key}"}
+ if cls.needs_auth and api_key is not None
+ else {}
+ ),
+ **({} if headers is None else headers)
+ }
+
+def filter_none(**kwargs) -> dict:
+ return {
+ key: value
+ for key, value in kwargs.items()
+ if value is not None
+ } \ No newline at end of file
diff --git a/g4f/Provider/needs_auth/ThebApi.py b/g4f/Provider/needs_auth/ThebApi.py
index 1c7baf8d..48879bcb 100644
--- a/g4f/Provider/needs_auth/ThebApi.py
+++ b/g4f/Provider/needs_auth/ThebApi.py
@@ -1,10 +1,7 @@
from __future__ import annotations
-import requests
-
-from ...typing import Any, CreateResult, Messages
-from ..base_provider import AbstractProvider, ProviderModelMixin
-from ...errors import MissingAuthError
+from ...typing import CreateResult, Messages
+from .Openai import Openai
models = {
"theb-ai": "TheB.AI",
@@ -30,7 +27,7 @@ models = {
"qwen-7b-chat": "Qwen 7B"
}
-class ThebApi(AbstractProvider, ProviderModelMixin):
+class ThebApi(Openai):
url = "https://theb.ai"
working = True
needs_auth = True
@@ -38,44 +35,26 @@ class ThebApi(AbstractProvider, ProviderModelMixin):
models = list(models)
@classmethod
- def create_completion(
+ def create_async_generator(
cls,
model: str,
messages: Messages,
- stream: bool,
- auth: str = None,
- proxy: str = None,
+ api_base: str = "https://api.theb.ai/v1",
+ temperature: float = 1,
+ top_p: float = 1,
**kwargs
) -> CreateResult:
- if not auth:
- raise MissingAuthError("Missing auth")
- headers = {
- 'accept': 'application/json',
- 'authorization': f'Bearer {auth}',
- 'content-type': 'application/json',
- }
- # response = requests.get("https://api.baizhi.ai/v1/models", headers=headers).json()["data"]
- # models = dict([(m["id"], m["name"]) for m in response])
- # print(json.dumps(models, indent=4))
- data: dict[str, Any] = {
- "model": cls.get_model(model),
- "messages": messages,
- "stream": False,
+ if "auth" in kwargs:
+ kwargs["api_key"] = kwargs["auth"]
+ system_message = "\n".join([message["content"] for message in messages if message["role"] == "system"])
+ if not system_message:
+ system_message = "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-3.5 architecture."
+ messages = [message for message in messages if message["role"] != "system"]
+ data = {
"model_params": {
- "system_prompt": kwargs.get("system_message", "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-3.5 architecture."),
- "temperature": 1,
- "top_p": 1,
- **kwargs
+ "system_prompt": system_message,
+ "temperature": temperature,
+ "top_p": top_p,
}
}
- response = requests.post(
- "https://api.theb.ai/v1/chat/completions",
- headers=headers,
- json=data,
- proxies={"https": proxy}
- )
- try:
- response.raise_for_status()
- yield response.json()["choices"][0]["message"]["content"]
- except:
- raise RuntimeError(f"Response: {next(response.iter_lines()).decode()}") \ No newline at end of file
+ return super().create_async_generator(model, messages, api_base=api_base, extra_data=data, **kwargs) \ No newline at end of file
diff --git a/g4f/Provider/needs_auth/__init__.py b/g4f/Provider/needs_auth/__init__.py
index 92fa165b..581335e1 100644
--- a/g4f/Provider/needs_auth/__init__.py
+++ b/g4f/Provider/needs_auth/__init__.py
@@ -3,7 +3,6 @@ from .Raycast import Raycast
from .Theb import Theb
from .ThebApi import ThebApi
from .OpenaiChat import OpenaiChat
-from .OpenAssistant import OpenAssistant
from .Poe import Poe
from .Openai import Openai
from .Groq import Groq \ No newline at end of file