summaryrefslogtreecommitdiffstats
path: root/g4f/Provider
diff options
context:
space:
mode:
authorHeiner Lohaus <hlohaus@users.noreply.github.com>2024-04-21 22:39:00 +0200
committerHeiner Lohaus <hlohaus@users.noreply.github.com>2024-04-21 22:39:00 +0200
commit3a23e81de93c4c9a83aa22b70ea13066f06541e3 (patch)
treecaa0c5f892c3ab8df393c1821bbdab780c5d83de /g4f/Provider
parentAdd image model list (diff)
downloadgpt4free-3a23e81de93c4c9a83aa22b70ea13066f06541e3.tar
gpt4free-3a23e81de93c4c9a83aa22b70ea13066f06541e3.tar.gz
gpt4free-3a23e81de93c4c9a83aa22b70ea13066f06541e3.tar.bz2
gpt4free-3a23e81de93c4c9a83aa22b70ea13066f06541e3.tar.lz
gpt4free-3a23e81de93c4c9a83aa22b70ea13066f06541e3.tar.xz
gpt4free-3a23e81de93c4c9a83aa22b70ea13066f06541e3.tar.zst
gpt4free-3a23e81de93c4c9a83aa22b70ea13066f06541e3.zip
Diffstat (limited to '')
-rw-r--r--g4f/Provider/MetaAI.py2
-rw-r--r--g4f/Provider/Replicate.py84
-rw-r--r--g4f/Provider/__init__.py2
-rw-r--r--g4f/Provider/needs_auth/OpenaiChat.py15
-rw-r--r--g4f/Provider/unfinished/Replicate.py78
5 files changed, 94 insertions, 87 deletions
diff --git a/g4f/Provider/MetaAI.py b/g4f/Provider/MetaAI.py
index 045255e7..caed7778 100644
--- a/g4f/Provider/MetaAI.py
+++ b/g4f/Provider/MetaAI.py
@@ -89,7 +89,7 @@ class MetaAI(AsyncGeneratorProvider):
headers = {}
headers = {
'content-type': 'application/x-www-form-urlencoded',
- 'cookie': format_cookies(cookies),
+ 'cookie': format_cookies(self.cookies),
'origin': 'https://www.meta.ai',
'referer': 'https://www.meta.ai/',
'x-asbd-id': '129477',
diff --git a/g4f/Provider/Replicate.py b/g4f/Provider/Replicate.py
new file mode 100644
index 00000000..593fd04d
--- /dev/null
+++ b/g4f/Provider/Replicate.py
@@ -0,0 +1,84 @@
+from __future__ import annotations
+
+from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
+from .helper import format_prompt, filter_none
+from ..typing import AsyncResult, Messages
+from ..requests import raise_for_status
+from ..requests.aiohttp import StreamSession
+from ..errors import ResponseError, MissingAuthError
+
+class Replicate(AsyncGeneratorProvider, ProviderModelMixin):
+ url = "https://replicate.com"
+ working = True
+ default_model = "meta/meta-llama-3-70b-instruct"
+
+ @classmethod
+ async def create_async_generator(
+ cls,
+ model: str,
+ messages: Messages,
+ api_key: str = None,
+ proxy: str = None,
+ timeout: int = 180,
+ system_prompt: str = None,
+ max_new_tokens: int = None,
+ temperature: float = None,
+ top_p: float = None,
+ top_k: float = None,
+ stop: list = None,
+ extra_data: dict = {},
+ headers: dict = {
+ "accept": "application/json",
+ },
+ **kwargs
+ ) -> AsyncResult:
+ model = cls.get_model(model)
+ if cls.needs_auth and api_key is None:
+ raise MissingAuthError("api_key is missing")
+ if api_key is not None:
+ headers["Authorization"] = f"Bearer {api_key}"
+ api_base = "https://api.replicate.com/v1/models/"
+ else:
+ api_base = "https://replicate.com/api/models/"
+ async with StreamSession(
+ proxy=proxy,
+ headers=headers,
+ timeout=timeout
+ ) as session:
+ data = {
+ "stream": True,
+ "input": {
+ "prompt": format_prompt(messages),
+ **filter_none(
+ system_prompt=system_prompt,
+ max_new_tokens=max_new_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ stop_sequences=",".join(stop) if stop else None
+ ),
+ **extra_data
+ },
+ }
+ url = f"{api_base.rstrip('/')}/{model}/predictions"
+ async with session.post(url, json=data) as response:
+ message = "Model not found" if response.status == 404 else None
+ await raise_for_status(response, message)
+ result = await response.json()
+ if "id" not in result:
+ raise ResponseError(f"Invalid response: {result}")
+ async with session.get(result["urls"]["stream"], headers={"Accept": "text/event-stream"}) as response:
+ await raise_for_status(response)
+ event = None
+ async for line in response.iter_lines():
+ if line.startswith(b"event: "):
+ event = line[7:]
+ if event == b"done":
+ break
+ elif event == b"output":
+ if line.startswith(b"data: "):
+ new_text = line[6:].decode()
+ if new_text:
+ yield new_text
+ else:
+ yield "\n" \ No newline at end of file
diff --git a/g4f/Provider/__init__.py b/g4f/Provider/__init__.py
index 27c14672..d2d9bfda 100644
--- a/g4f/Provider/__init__.py
+++ b/g4f/Provider/__init__.py
@@ -9,7 +9,6 @@ from .deprecated import *
from .not_working import *
from .selenium import *
from .needs_auth import *
-from .unfinished import *
from .Aichatos import Aichatos
from .Aura import Aura
@@ -46,6 +45,7 @@ from .MetaAI import MetaAI
from .MetaAIAccount import MetaAIAccount
from .PerplexityLabs import PerplexityLabs
from .Pi import Pi
+from .Replicate import Replicate
from .ReplicateImage import ReplicateImage
from .Vercel import Vercel
from .WhiteRabbitNeo import WhiteRabbitNeo
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index 7952d606..3d6e9858 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -340,9 +340,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
Raises:
RuntimeError: If an error occurs during processing.
"""
-
async with StreamSession(
- proxies={"all": proxy},
+ proxy=proxy,
impersonate="chrome",
timeout=timeout
) as session:
@@ -364,26 +363,27 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
api_key = cls._api_key = None
cls._create_request_args()
if debug.logging:
- print("OpenaiChat: Load default_model failed")
+ print("OpenaiChat: Load default model failed")
print(f"{e.__class__.__name__}: {e}")
arkose_token = None
if cls.default_model is None:
+ error = None
try:
arkose_token, api_key, cookies, headers = await getArkoseAndAccessToken(proxy)
cls._create_request_args(cookies, headers)
cls._set_api_key(api_key)
except NoValidHarFileError as e:
- ...
+ error = e
if cls._api_key is None:
await cls.nodriver_access_token()
if cls._api_key is None and cls.needs_auth:
- raise e
+ raise error
cls.default_model = cls.get_model(await cls.get_default_model(session, cls._headers))
async with session.post(
f"{cls.url}/backend-anon/sentinel/chat-requirements"
- if not cls._api_key else
+ if cls._api_key is None else
f"{cls.url}/backend-api/sentinel/chat-requirements",
json={"conversation_mode_kind": "primary_assistant"},
headers=cls._headers
@@ -412,7 +412,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
print("OpenaiChat: Upload image failed")
print(f"{e.__class__.__name__}: {e}")
- model = cls.get_model(model).replace("gpt-3.5-turbo", "text-davinci-002-render-sha")
+ model = cls.get_model(model)
+ model = "text-davinci-002-render-sha" if model == "gpt-3.5-turbo" else model
if conversation is None:
conversation = Conversation(conversation_id, str(uuid.uuid4()) if parent_id is None else parent_id)
else:
diff --git a/g4f/Provider/unfinished/Replicate.py b/g4f/Provider/unfinished/Replicate.py
deleted file mode 100644
index aaaf31b3..00000000
--- a/g4f/Provider/unfinished/Replicate.py
+++ /dev/null
@@ -1,78 +0,0 @@
-from __future__ import annotations
-
-import asyncio
-
-from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
-from ..helper import format_prompt, filter_none
-from ...typing import AsyncResult, Messages
-from ...requests import StreamSession, raise_for_status
-from ...image import ImageResponse
-from ...errors import ResponseError, MissingAuthError
-
-class Replicate(AsyncGeneratorProvider, ProviderModelMixin):
- url = "https://replicate.com"
- working = True
- default_model = "mistralai/mixtral-8x7b-instruct-v0.1"
- api_base = "https://api.replicate.com/v1/models/"
-
- @classmethod
- async def create_async_generator(
- cls,
- model: str,
- messages: Messages,
- api_key: str = None,
- proxy: str = None,
- timeout: int = 180,
- system_prompt: str = None,
- max_new_tokens: int = None,
- temperature: float = None,
- top_p: float = None,
- top_k: float = None,
- stop: list = None,
- extra_data: dict = {},
- headers: dict = {},
- **kwargs
- ) -> AsyncResult:
- model = cls.get_model(model)
- if api_key is None:
- raise MissingAuthError("api_key is missing")
- headers["Authorization"] = f"Bearer {api_key}"
- async with StreamSession(
- proxies={"all": proxy},
- headers=headers,
- timeout=timeout
- ) as session:
- data = {
- "stream": True,
- "input": {
- "prompt": format_prompt(messages),
- **filter_none(
- system_prompt=system_prompt,
- max_new_tokens=max_new_tokens,
- temperature=temperature,
- top_p=top_p,
- top_k=top_k,
- stop_sequences=",".join(stop) if stop else None
- ),
- **extra_data
- },
- }
- url = f"{cls.api_base.rstrip('/')}/{model}/predictions"
- async with session.post(url, json=data) as response:
- await raise_for_status(response)
- result = await response.json()
- if "id" not in result:
- raise ResponseError(f"Invalid response: {result}")
- async with session.get(result["urls"]["stream"], headers={"Accept": "text/event-stream"}) as response:
- await raise_for_status(response)
- event = None
- async for line in response.iter_lines():
- if line.startswith(b"event: "):
- event = line[7:]
- elif event == b"output":
- if line.startswith(b"data: "):
- yield line[6:].decode()
- elif not line.startswith(b"id: "):
- continue#yield "+"+line.decode()
- elif event == b"done":
- break \ No newline at end of file