summaryrefslogtreecommitdiffstats
path: root/g4f
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/Bard.py22
-rw-r--r--g4f/Provider/Bing.py466
-rw-r--r--g4f/Provider/Hugchat.py67
-rw-r--r--g4f/Provider/OpenaiChat.py74
-rw-r--r--g4f/Provider/__init__.py4
-rw-r--r--g4f/Provider/base_provider.py85
6 files changed, 430 insertions, 288 deletions
diff --git a/g4f/Provider/Bard.py b/g4f/Provider/Bard.py
index cbe728cd..a8c7d13f 100644
--- a/g4f/Provider/Bard.py
+++ b/g4f/Provider/Bard.py
@@ -2,42 +2,26 @@ import json
import random
import re
-import browser_cookie3
from aiohttp import ClientSession
import asyncio
from ..typing import Any, CreateResult
-from .base_provider import BaseProvider
+from .base_provider import AsyncProvider, get_cookies
-class Bard(BaseProvider):
+class Bard(AsyncProvider):
url = "https://bard.google.com"
needs_auth = True
working = True
@classmethod
- def create_completion(
- cls,
- model: str,
- messages: list[dict[str, str]],
- stream: bool,
- proxy: str = None,
- cookies: dict = {},
- **kwargs: Any,
- ) -> CreateResult:
- yield asyncio.run(cls.create_async(str, messages, proxy, cookies))
-
- @classmethod
async def create_async(
cls,
model: str,
messages: list[dict[str, str]],
proxy: str = None,
- cookies: dict = {},
+ cookies: dict = get_cookies(".google.com"),
**kwargs: Any,
) -> str:
- if not cookies:
- for cookie in browser_cookie3.load(domain_name='.google.com'):
- cookies[cookie.name] = cookie.value
formatted = "\n".join(
["%s: %s" % (message["role"], message["content"]) for message in messages]
diff --git a/g4f/Provider/Bing.py b/g4f/Provider/Bing.py
index 48b5477d..2c2e60ad 100644
--- a/g4f/Provider/Bing.py
+++ b/g4f/Provider/Bing.py
@@ -2,42 +2,39 @@ import asyncio
import json
import os
import random
-import ssl
-import uuid
import aiohttp
-import certifi
-import requests
-
-from ..typing import Any, AsyncGenerator, CreateResult, Tuple, Union
-from .base_provider import BaseProvider
+import asyncio
+from aiohttp import ClientSession
+from ..typing import Any, AsyncGenerator, CreateResult, Union
+from .base_provider import AsyncGeneratorProvider, get_cookies
-class Bing(BaseProvider):
+class Bing(AsyncGeneratorProvider):
url = "https://bing.com/chat"
+ needs_auth = True
+ working = True
supports_gpt_4 = True
-
+ supports_stream=True
+
@staticmethod
- def create_completion(
- model: str,
- messages: list[dict[str, str]],
- stream: bool,
- **kwargs: Any,
- ) -> CreateResult:
+ def create_async_generator(
+ model: str,
+ messages: list[dict[str, str]],
+ cookies: dict = get_cookies(".bing.com"),
+ **kwargs
+ ) -> AsyncGenerator:
if len(messages) < 2:
prompt = messages[0]["content"]
- context = False
+ context = None
else:
prompt = messages[-1]["content"]
- context = convert(messages[:-1])
+ context = create_context(messages[:-1])
- response = run(stream_generate(prompt, jailbreak, context))
- for token in response:
- yield token
+ return stream_generate(prompt, context, cookies)
-
-def convert(messages: list[dict[str, str]]):
+def create_context(messages: list[dict[str, str]]):
context = ""
for message in messages:
@@ -45,250 +42,43 @@ def convert(messages: list[dict[str, str]]):
return context
-
-jailbreak = {
- "optionsSets": [
- "saharasugg",
- "enablenewsfc",
- "clgalileo",
- "gencontentv3",
- "nlu_direct_response_filter",
- "deepleo",
- "disable_emoji_spoken_text",
- "responsible_ai_policy_235",
- "enablemm",
- "h3precise"
- # "harmonyv3",
- "dtappid",
- "cricinfo",
- "cricinfov2",
- "dv3sugg",
- "nojbfedge",
- ]
-}
-
-
-ssl_context = ssl.create_default_context()
-ssl_context.load_verify_locations(certifi.where())
-
-
-def _format(msg: dict[str, Any]) -> str:
- return json.dumps(msg, ensure_ascii=False) + Defaults.delimiter
-
-
-async def stream_generate(
- prompt: str,
- mode: dict[str, list[str]] = jailbreak,
- context: Union[bool, str] = False,
-):
- timeout = aiohttp.ClientTimeout(total=900)
- session = aiohttp.ClientSession(timeout=timeout)
-
- conversationId, clientId, conversationSignature = await create_conversation()
-
- wss = await session.ws_connect(
- "wss://sydney.bing.com/sydney/ChatHub",
- ssl=ssl_context,
- autoping=False,
- headers={
- "accept": "application/json",
- "accept-language": "en-US,en;q=0.9",
- "content-type": "application/json",
- "sec-ch-ua": '"Not_A Brand";v="99", "Microsoft Edge";v="110", "Chromium";v="110"',
- "sec-ch-ua-arch": '"x86"',
- "sec-ch-ua-bitness": '"64"',
- "sec-ch-ua-full-version": '"109.0.1518.78"',
- "sec-ch-ua-full-version-list": '"Chromium";v="110.0.5481.192", "Not A(Brand";v="24.0.0.0", "Microsoft Edge";v="110.0.1587.69"',
- "sec-ch-ua-mobile": "?0",
- "sec-ch-ua-model": "",
- "sec-ch-ua-platform": '"Windows"',
- "sec-ch-ua-platform-version": '"15.0.0"',
- "sec-fetch-dest": "empty",
- "sec-fetch-mode": "cors",
- "sec-fetch-site": "same-origin",
- "x-ms-client-request-id": str(uuid.uuid4()),
- "x-ms-useragent": "azsdk-js-api-client-factory/1.0.0-beta.1 core-rest-pipeline/1.10.0 OS/Win32",
- "Referer": "https://www.bing.com/search?q=Bing+AI&showconv=1&FORM=hpcodx",
- "Referrer-Policy": "origin-when-cross-origin",
- "x-forwarded-for": Defaults.ip_address,
- },
- )
-
- await wss.send_str(_format({"protocol": "json", "version": 1}))
- await wss.receive(timeout=900)
-
- argument: dict[str, Any] = {
- **mode,
+class Conversation():
+ def __init__(self, conversationId: str, clientId: str, conversationSignature: str) -> None:
+ self.conversationId = conversationId
+ self.clientId = clientId
+ self.conversationSignature = conversationSignature
+
+async def create_conversation(session: ClientSession) -> Conversation:
+ url = 'https://www.bing.com/turing/conversation/create'
+ async with await session.get(url) as response:
+ response = await response.json()
+ conversationId = response.get('conversationId')
+ clientId = response.get('clientId')
+ conversationSignature = response.get('conversationSignature')
+
+ if not conversationId or not clientId or not conversationSignature:
+ raise Exception('Failed to create conversation.')
+
+ return Conversation(conversationId, clientId, conversationSignature)
+
+async def list_conversations(session: ClientSession) -> list:
+ url = "https://www.bing.com/turing/conversation/chats"
+ async with session.get(url) as response:
+ response = await response.json()
+ return response["chats"]
+
+async def delete_conversation(session: ClientSession, conversation: Conversation) -> list:
+ url = "https://sydney.bing.com/sydney/DeleteSingleConversation"
+ json = {
+ "conversationId": conversation.conversationId,
+ "conversationSignature": conversation.conversationSignature,
+ "participant": {"id": conversation.clientId},
"source": "cib",
- "allowedMessageTypes": Defaults.allowedMessageTypes,
- "sliceIds": Defaults.sliceIds,
- "traceId": os.urandom(16).hex(),
- "isStartOfSession": True,
- "message": Defaults.location
- | {
- "author": "user",
- "inputMethod": "Keyboard",
- "text": prompt,
- "messageType": "Chat",
- },
- "conversationSignature": conversationSignature,
- "participant": {"id": clientId},
- "conversationId": conversationId,
- }
-
- if context:
- argument["previousMessages"] = [
- {
- "author": "user",
- "description": context,
- "contextType": "WebPage",
- "messageType": "Context",
- "messageId": "discover-web--page-ping-mriduna-----",
- }
- ]
-
- struct: dict[str, list[dict[str, Any]] | str | int] = {
- "arguments": [argument],
- "invocationId": "0",
- "target": "chat",
- "type": 4,
+ "optionsSets": ["autosave"]
}
-
- await wss.send_str(_format(struct))
-
- final = False
- draw = False
- resp_txt = ""
- result_text = ""
- resp_txt_no_link = ""
- cache_text = ""
-
- while not final:
- msg = await wss.receive(timeout=900)
- objects = msg.data.split(Defaults.delimiter) # type: ignore
-
- for obj in objects: # type: ignore
- if obj is None or not obj:
- continue
-
- response = json.loads(obj) # type: ignore
- if response.get("type") == 1 and response["arguments"][0].get(
- "messages",
- ):
- if not draw:
- if (
- response["arguments"][0]["messages"][0]["contentOrigin"]
- != "Apology"
- ) and not draw:
- resp_txt = result_text + response["arguments"][0]["messages"][
- 0
- ]["adaptiveCards"][0]["body"][0].get("text", "")
- resp_txt_no_link = result_text + response["arguments"][0][
- "messages"
- ][0].get("text", "")
-
- if response["arguments"][0]["messages"][0].get(
- "messageType",
- ):
- resp_txt = (
- resp_txt
- + response["arguments"][0]["messages"][0][
- "adaptiveCards"
- ][0]["body"][0]["inlines"][0].get("text")
- + "\n"
- )
- result_text = (
- result_text
- + response["arguments"][0]["messages"][0][
- "adaptiveCards"
- ][0]["body"][0]["inlines"][0].get("text")
- + "\n"
- )
-
- if cache_text.endswith(" "):
- final = True
- if wss and not wss.closed:
- await wss.close()
- if session and not session.closed:
- await session.close()
-
- yield (resp_txt.replace(cache_text, ""))
- cache_text = resp_txt
-
- elif response.get("type") == 2:
- if response["item"]["result"].get("error"):
- if wss and not wss.closed:
- await wss.close()
- if session and not session.closed:
- await session.close()
-
- raise Exception(
- f"{response['item']['result']['value']}: {response['item']['result']['message']}"
- )
-
- if draw:
- cache = response["item"]["messages"][1]["adaptiveCards"][0]["body"][
- 0
- ]["text"]
- response["item"]["messages"][1]["adaptiveCards"][0]["body"][0][
- "text"
- ] = (cache + resp_txt)
-
- if (
- response["item"]["messages"][-1]["contentOrigin"] == "Apology"
- and resp_txt
- ):
- response["item"]["messages"][-1]["text"] = resp_txt_no_link
- response["item"]["messages"][-1]["adaptiveCards"][0]["body"][0][
- "text"
- ] = resp_txt
-
- # print('Preserved the message from being deleted', file=sys.stderr)
-
- final = True
- if wss and not wss.closed:
- await wss.close()
- if session and not session.closed:
- await session.close()
-
-
-async def create_conversation() -> Tuple[str, str, str]:
- create = requests.get(
- "https://www.bing.com/turing/conversation/create",
- headers={
- "authority": "edgeservices.bing.com",
- "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
- "accept-language": "en-US,en;q=0.9",
- "cache-control": "max-age=0",
- "sec-ch-ua": '"Chromium";v="110", "Not A(Brand";v="24", "Microsoft Edge";v="110"',
- "sec-ch-ua-arch": '"x86"',
- "sec-ch-ua-bitness": '"64"',
- "sec-ch-ua-full-version": '"110.0.1587.69"',
- "sec-ch-ua-full-version-list": '"Chromium";v="110.0.5481.192", "Not A(Brand";v="24.0.0.0", "Microsoft Edge";v="110.0.1587.69"',
- "sec-ch-ua-mobile": "?0",
- "sec-ch-ua-model": '""',
- "sec-ch-ua-platform": '"Windows"',
- "sec-ch-ua-platform-version": '"15.0.0"',
- "sec-fetch-dest": "document",
- "sec-fetch-mode": "navigate",
- "sec-fetch-site": "none",
- "sec-fetch-user": "?1",
- "upgrade-insecure-requests": "1",
- "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36 Edg/110.0.1587.69",
- "x-edge-shopping-flag": "1",
- "x-forwarded-for": Defaults.ip_address,
- },
- )
-
- conversationId = create.json().get("conversationId")
- clientId = create.json().get("clientId")
- conversationSignature = create.json().get("conversationSignature")
-
- if not conversationId or not clientId or not conversationSignature:
- raise Exception("Failed to create conversation.")
-
- return conversationId, clientId, conversationSignature
-
+ async with session.post(url, json=json) as response:
+ response = await response.json()
+ return response["result"]["value"] == "Success"
class Defaults:
delimiter = "\x1e"
@@ -309,9 +99,6 @@ class Defaults:
]
sliceIds = [
- # "222dtappid",
- # "225cricinfo",
- # "224locals0"
"winmuid3tf",
"osbsdusgreccf",
"ttstmout",
@@ -349,6 +136,151 @@ class Defaults:
],
}
+ headers = {
+ 'accept': '*/*',
+ 'accept-language': 'en-US,en;q=0.9',
+ 'cache-control': 'max-age=0',
+ 'sec-ch-ua': '"Chromium";v="110", "Not A(Brand";v="24", "Microsoft Edge";v="110"',
+ 'sec-ch-ua-arch': '"x86"',
+ 'sec-ch-ua-bitness': '"64"',
+ 'sec-ch-ua-full-version': '"110.0.1587.69"',
+ 'sec-ch-ua-full-version-list': '"Chromium";v="110.0.5481.192", "Not A(Brand";v="24.0.0.0", "Microsoft Edge";v="110.0.1587.69"',
+ 'sec-ch-ua-mobile': '?0',
+ 'sec-ch-ua-model': '""',
+ 'sec-ch-ua-platform': '"Windows"',
+ 'sec-ch-ua-platform-version': '"15.0.0"',
+ 'sec-fetch-dest': 'document',
+ 'sec-fetch-mode': 'navigate',
+ 'sec-fetch-site': 'none',
+ 'sec-fetch-user': '?1',
+ 'upgrade-insecure-requests': '1',
+ 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36 Edg/110.0.1587.69',
+ 'x-edge-shopping-flag': '1',
+ 'x-forwarded-for': ip_address,
+ }
+
+ optionsSets = [
+ 'saharasugg',
+ 'enablenewsfc',
+ 'clgalileo',
+ 'gencontentv3',
+ "nlu_direct_response_filter",
+ "deepleo",
+ "disable_emoji_spoken_text",
+ "responsible_ai_policy_235",
+ "enablemm",
+ "h3precise"
+ "dtappid",
+ "cricinfo",
+ "cricinfov2",
+ "dv3sugg",
+ "nojbfedge"
+ ]
+
+def format_message(message: dict) -> str:
+ return json.dumps(message, ensure_ascii=False) + Defaults.delimiter
+
+def create_message(conversation: Conversation, prompt: str, context: str=None) -> str:
+ struct = {
+ 'arguments': [
+ {
+ 'optionsSets': Defaults.optionsSets,
+ 'source': 'cib',
+ 'allowedMessageTypes': Defaults.allowedMessageTypes,
+ 'sliceIds': Defaults.sliceIds,
+ 'traceId': os.urandom(16).hex(),
+ 'isStartOfSession': True,
+ 'message': Defaults.location | {
+ 'author': 'user',
+ 'inputMethod': 'Keyboard',
+ 'text': prompt,
+ 'messageType': 'Chat'
+ },
+ 'conversationSignature': conversation.conversationSignature,
+ 'participant': {
+ 'id': conversation.clientId
+ },
+ 'conversationId': conversation.conversationId
+ }
+ ],
+ 'invocationId': '0',
+ 'target': 'chat',
+ 'type': 4
+ }
+
+ if context:
+ struct['arguments'][0]['previousMessages'] = [{
+ "author": "user",
+ "description": context,
+ "contextType": "WebPage",
+ "messageType": "Context",
+ "messageId": "discover-web--page-ping-mriduna-----"
+ }]
+ return format_message(struct)
+
+async def stream_generate(
+ prompt: str,
+ context: str=None,
+ cookies: dict=None
+ ):
+ async with ClientSession(
+ timeout=aiohttp.ClientTimeout(total=900),
+ cookies=cookies,
+ headers=Defaults.headers,
+ ) as session:
+ conversation = await create_conversation(session)
+ try:
+ async with session.ws_connect(
+ 'wss://sydney.bing.com/sydney/ChatHub',
+ autoping=False,
+ ) as wss:
+
+ await wss.send_str(format_message({'protocol': 'json', 'version': 1}))
+ msg = await wss.receive(timeout=900)
+
+ await wss.send_str(create_message(conversation, prompt, context))
+
+ response_txt = ''
+ result_text = ''
+ returned_text = ''
+ final = False
+
+ while not final:
+ msg = await wss.receive(timeout=900)
+ objects = msg.data.split(Defaults.delimiter)
+ for obj in objects:
+ if obj is None or not obj:
+ continue
+
+ response = json.loads(obj)
+ if response.get('type') == 1 and response['arguments'][0].get('messages'):
+ message = response['arguments'][0]['messages'][0]
+ if (message['contentOrigin'] != 'Apology'):
+ response_txt = result_text + \
+ message['adaptiveCards'][0]['body'][0].get('text', '')
+
+ if message.get('messageType'):
+ inline_txt = message['adaptiveCards'][0]['body'][0]['inlines'][0].get('text')
+ response_txt += inline_txt + '\n'
+ result_text += inline_txt + '\n'
+
+ if returned_text.endswith(' '):
+ final = True
+ break
+
+ if response_txt.startswith(returned_text):
+ new = response_txt[len(returned_text):]
+ if new != "\n":
+ yield new
+ returned_text = response_txt
+ elif response.get('type') == 2:
+ result = response['item']['result']
+ if result.get('error'):
+ raise Exception(f"{result['value']}: {result['message']}")
+ final = True
+ break
+ finally:
+ await delete_conversation(session, conversation)
def run(generator: AsyncGenerator[Union[Any, str], Any]):
loop = asyncio.get_event_loop()
diff --git a/g4f/Provider/Hugchat.py b/g4f/Provider/Hugchat.py
new file mode 100644
index 00000000..cedf8402
--- /dev/null
+++ b/g4f/Provider/Hugchat.py
@@ -0,0 +1,67 @@
+has_module = False
+try:
+ from hugchat.hugchat import ChatBot
+except ImportError:
+ has_module = False
+
+from .base_provider import BaseProvider, get_cookies
+from g4f.typing import CreateResult
+
+class Hugchat(BaseProvider):
+ url = "https://huggingface.co/chat/"
+ needs_auth = True
+ working = has_module
+ llms = ['OpenAssistant/oasst-sft-6-llama-30b-xor', 'meta-llama/Llama-2-70b-chat-hf']
+
+ @classmethod
+ def create_completion(
+ cls,
+ model: str,
+ messages: list[dict[str, str]],
+ stream: bool = False,
+ proxy: str = None,
+ cookies: str = get_cookies(".huggingface.co"),
+ **kwargs
+ ) -> CreateResult:
+ bot = ChatBot(
+ cookies=cookies
+ )
+
+ if proxy and "://" not in proxy:
+ proxy = f"http://{proxy}"
+ bot.session.proxies = {"http": proxy, "https": proxy}
+
+ if model:
+ try:
+ if not isinstance(model, int):
+ model = cls.llms.index(model)
+ bot.switch_llm(model)
+ except:
+ raise RuntimeError(f"Model are not supported: {model}")
+
+ if len(messages) > 1:
+ formatted = "\n".join(
+ ["%s: %s" % (message["role"], message["content"]) for message in messages]
+ )
+ prompt = f"{formatted}\nAssistant:"
+ else:
+ prompt = messages.pop()["content"]
+
+ try:
+ yield bot.chat(prompt, **kwargs)
+ finally:
+ bot.delete_conversation(bot.current_conversation)
+ bot.current_conversation = ""
+ pass
+
+ @classmethod
+ @property
+ def params(cls):
+ params = [
+ ("model", "str"),
+ ("messages", "list[dict[str, str]]"),
+ ("stream", "bool"),
+ ("proxy", "str"),
+ ]
+ param = ", ".join([": ".join(p) for p in params])
+ return f"g4f.provider.{cls.__name__} supports: ({param})"
diff --git a/g4f/Provider/OpenaiChat.py b/g4f/Provider/OpenaiChat.py
new file mode 100644
index 00000000..cca258b3
--- /dev/null
+++ b/g4f/Provider/OpenaiChat.py
@@ -0,0 +1,74 @@
+has_module = True
+try:
+ from revChatGPT.V1 import AsyncChatbot
+except ImportError:
+ has_module = False
+from .base_provider import AsyncGeneratorProvider, get_cookies
+from ..typing import AsyncGenerator
+
+class OpenaiChat(AsyncGeneratorProvider):
+ url = "https://chat.openai.com"
+ needs_auth = True
+ working = has_module
+ supports_gpt_35_turbo = True
+ supports_gpt_4 = True
+ supports_stream = True
+
+ @classmethod
+ async def create_async_generator(
+ cls,
+ model: str,
+ messages: list[dict[str, str]],
+ proxy: str = None,
+ access_token: str = None,
+ cookies: dict = None,
+ **kwargs
+ ) -> AsyncGenerator:
+
+ config = {"access_token": access_token, "model": model}
+ if proxy:
+ if "://" not in proxy:
+ proxy = f"http://{proxy}"
+ config["proxy"] = proxy
+
+ bot = AsyncChatbot(
+ config=config
+ )
+
+ if not access_token:
+ cookies = cookies if cookies else get_cookies("chat.openai.com")
+ response = await bot.session.get("https://chat.openai.com/api/auth/session", cookies=cookies)
+ access_token = response.json()["accessToken"]
+ bot.set_access_token(access_token)
+
+ if len(messages) > 1:
+ formatted = "\n".join(
+ ["%s: %s" % ((message["role"]).capitalize(), message["content"]) for message in messages]
+ )
+ prompt = f"{formatted}\nAssistant:"
+ else:
+ prompt = messages.pop()["content"]
+
+ returned = None
+ async for message in bot.ask(prompt):
+ message = message["message"]
+ if returned:
+ if message.startswith(returned):
+ new = message[len(returned):]
+ if new:
+ yield new
+ else:
+ yield message
+ returned = message
+
+ @classmethod
+ @property
+ def params(cls):
+ params = [
+ ("model", "str"),
+ ("messages", "list[dict[str, str]]"),
+ ("stream", "bool"),
+ ("proxy", "str"),
+ ]
+ param = ", ".join([": ".join(p) for p in params])
+ return f"g4f.provider.{cls.__name__} supports: ({param})"
diff --git a/g4f/Provider/__init__.py b/g4f/Provider/__init__.py
index e27dee5d..5ad9f156 100644
--- a/g4f/Provider/__init__.py
+++ b/g4f/Provider/__init__.py
@@ -14,9 +14,11 @@ from .EasyChat import EasyChat
from .Forefront import Forefront
from .GetGpt import GetGpt
from .H2o import H2o
+from .Hugchat import Hugchat
from .Liaobots import Liaobots
from .Lockchat import Lockchat
from .Opchatgpts import Opchatgpts
+from .OpenaiChat import OpenaiChat
from .Raycast import Raycast
from .Theb import Theb
from .Vercel import Vercel
@@ -44,10 +46,12 @@ __all__ = [
"Forefront",
"GetGpt",
"H2o",
+ "Hugchat",
"Liaobots",
"Lockchat",
"Opchatgpts",
"Raycast",
+ "OpenaiChat",
"Theb",
"Vercel",
"Wewordle",
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index 98ad3514..56d79ee6 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -1,7 +1,11 @@
from abc import ABC, abstractmethod
-from ..typing import Any, CreateResult
+from ..typing import Any, CreateResult, AsyncGenerator, Union
+import browser_cookie3
+import asyncio
+from time import time
+import math
class BaseProvider(ABC):
url: str
@@ -30,4 +34,81 @@ class BaseProvider(ABC):
("stream", "bool"),
]
param = ", ".join([": ".join(p) for p in params])
- return f"g4f.provider.{cls.__name__} supports: ({param})" \ No newline at end of file
+ return f"g4f.provider.{cls.__name__} supports: ({param})"
+
+
+_cookies = {}
+
+def get_cookies(cookie_domain: str) -> dict:
+ if cookie_domain not in _cookies:
+ _cookies[cookie_domain] = {}
+ for cookie in browser_cookie3.load(cookie_domain):
+ _cookies[cookie_domain][cookie.name] = cookie.value
+ return _cookies[cookie_domain]
+
+
+class AsyncProvider(BaseProvider):
+ @classmethod
+ def create_completion(
+ cls,
+ model: str,
+ messages: list[dict[str, str]],
+ stream: bool = False,
+ **kwargs: Any
+ ) -> CreateResult:
+ yield asyncio.run(cls.create_async(model, messages, **kwargs))
+
+ @staticmethod
+ @abstractmethod
+ async def create_async(
+ model: str,
+ messages: list[dict[str, str]],
+ **kwargs: Any,
+ ) -> str:
+ raise NotImplementedError()
+
+
+class AsyncGeneratorProvider(AsyncProvider):
+ @classmethod
+ def create_completion(
+ cls,
+ model: str,
+ messages: list[dict[str, str]],
+ stream: bool = True,
+ **kwargs: Any
+ ) -> CreateResult:
+ if stream:
+ yield from run_generator(cls.create_async_generator(model, messages, **kwargs))
+ else:
+ yield from AsyncProvider.create_completion(cls=cls, model=model, messages=messages, **kwargs)
+
+ @classmethod
+ async def create_async(
+ cls,
+ model: str,
+ messages: list[dict[str, str]],
+ **kwargs: Any,
+ ) -> str:
+ chunks = [chunk async for chunk in cls.create_async_generator(model, messages, **kwargs)]
+ if chunks:
+ return "".join(chunks)
+
+ @staticmethod
+ @abstractmethod
+ def create_async_generator(
+ model: str,
+ messages: list[dict[str, str]],
+ ) -> AsyncGenerator:
+ raise NotImplementedError()
+
+
+def run_generator(generator: AsyncGenerator[Union[Any, str], Any]):
+ loop = asyncio.new_event_loop()
+ gen = generator.__aiter__()
+
+ while True:
+ try:
+ yield loop.run_until_complete(gen.__anext__())
+
+ except StopAsyncIteration:
+ break