diff options
Diffstat (limited to 'g4f/Provider/needs_auth/Theb.py')
-rw-r--r-- | g4f/Provider/needs_auth/Theb.py | 217 |
1 files changed, 137 insertions, 80 deletions
diff --git a/g4f/Provider/needs_auth/Theb.py b/g4f/Provider/needs_auth/Theb.py index b3c9019d..89c69727 100644 --- a/g4f/Provider/needs_auth/Theb.py +++ b/g4f/Provider/needs_auth/Theb.py @@ -1,101 +1,158 @@ from __future__ import annotations -import json -import random -import requests +import time -from ...typing import Any, CreateResult, Messages +from ...typing import CreateResult, Messages from ..base_provider import BaseProvider -from ..helper import format_prompt +from ..helper import WebDriver, format_prompt, get_browser +models = { + "theb-ai": "TheB.AI", + "theb-ai-free": "TheB.AI Free", + "gpt-3.5-turbo": "GPT-3.5 Turbo (New)", + "gpt-3.5-turbo-16k": "GPT-3.5-16K", + "gpt-4-turbo": "GPT-4 Turbo", + "gpt-4": "GPT-4", + "gpt-4-32k": "GPT-4 32K", + "claude-2": "Claude 2", + "claude-instant-1": "Claude Instant 1.2", + "palm-2": "PaLM 2", + "palm-2-32k": "PaLM 2 32K", + "palm-2-codey": "Codey", + "palm-2-codey-32k": "Codey 32K", + "vicuna-13b-v1.5": "Vicuna v1.5 13B", + "llama-2-7b-chat": "Llama 2 7B", + "llama-2-13b-chat": "Llama 2 13B", + "llama-2-70b-chat": "Llama 2 70B", + "code-llama-7b": "Code Llama 7B", + "code-llama-13b": "Code Llama 13B", + "code-llama-34b": "Code Llama 34B", + "qwen-7b-chat": "Qwen 7B" +} class Theb(BaseProvider): - url = "https://theb.ai" - working = True - supports_stream = True - supports_gpt_35_turbo = True - needs_auth = True + url = "https://beta.theb.ai" + working = True + supports_gpt_35_turbo = True + supports_gpt_4 = True + supports_stream = True - @staticmethod + @classmethod def create_completion( + cls, model: str, messages: Messages, stream: bool, proxy: str = None, + browser: WebDriver = None, + headless: bool = True, **kwargs ) -> CreateResult: - auth = kwargs.get("auth", { - "bearer_token":"free", - "org_id":"theb", - }) - - bearer_token = auth["bearer_token"] - org_id = auth["org_id"] + if model in models: + model = models[model] + prompt = format_prompt(messages) + driver = browser if browser else get_browser(None, headless, proxy) - headers = { - 'authority': 'beta.theb.ai', - 'accept': 'text/event-stream', - 'accept-language': 'id-ID,id;q=0.9,en-US;q=0.8,en;q=0.7', - 'authorization': f'Bearer {bearer_token}', - 'content-type': 'application/json', - 'origin': 'https://beta.theb.ai', - 'referer': 'https://beta.theb.ai/home', - 'sec-ch-ua': '"Chromium";v="116", "Not)A;Brand";v="24", "Google Chrome";v="116"', - 'sec-ch-ua-mobile': '?0', - 'sec-ch-ua-platform': '"Windows"', - 'sec-fetch-dest': 'empty', - 'sec-fetch-mode': 'cors', - 'sec-fetch-site': 'same-origin', - 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36', - 'x-ai-model': 'ee8d4f29cb7047f78cbe84313ed6ace8', - } + from selenium.webdriver.common.by import By + from selenium.webdriver.support.ui import WebDriverWait + from selenium.webdriver.support import expected_conditions as EC + from selenium.webdriver.common.keys import Keys - req_rand = random.randint(100000000, 9999999999) + + try: + driver.get(f"{cls.url}/home") + wait = WebDriverWait(driver, 10 if headless else 240) + wait.until(EC.visibility_of_element_located((By.TAG_NAME, "body"))) + time.sleep(0.1) + try: + driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click() + driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click() + except: + pass + if model: + # Load model panel + wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "#SelectModel svg"))) + time.sleep(0.1) + driver.find_element(By.CSS_SELECTOR, "#SelectModel svg").click() + try: + driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click() + driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click() + except: + pass + # Select model + selector = f"div.flex-col div.items-center span[title='{model}']" + wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, selector))) + span = driver.find_element(By.CSS_SELECTOR, selector) + container = span.find_element(By.XPATH, "//div/../..") + button = container.find_element(By.CSS_SELECTOR, "button.btn-blue.btn-small.border") + button.click() - json_data: dict[str, Any] = { - "text" : format_prompt(messages), - "category" : "04f58f64a4aa4191a957b47290fee864", - "model" : "ee8d4f29cb7047f78cbe84313ed6ace8", - "model_params": { - "system_prompt" : "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-3.5 architecture.\nKnowledge cutoff: 2021-09\nCurrent date: {{YYYY-MM-DD}}", - "temperature" : kwargs.get("temperature", 1), - "top_p" : kwargs.get("top_p", 1), - "frequency_penalty" : kwargs.get("frequency_penalty", 0), - "presence_penalty" : kwargs.get("presence_penalty", 0), - "long_term_memory" : "auto" - } + # Register fetch hook + script = """ +window._fetch = window.fetch; +window.fetch = (url, options) => { + // Call parent fetch method + const result = window._fetch(url, options); + if (!url.startsWith("/api/conversation")) { + return result; + } + // Load response reader + result.then((response) => { + if (!response.body.locked) { + window._reader = response.body.getReader(); } + }); + // Return dummy response + return new Promise((resolve, reject) => { + resolve(new Response(new ReadableStream())) + }); +} +window._last_message = ""; +""" + driver.execute_script(script) - response = requests.post( - f"https://beta.theb.ai/api/conversation?org_id={org_id}&req_rand={req_rand}", - headers=headers, - json=json_data, - stream=True, - proxies={"https": proxy} - ) - - response.raise_for_status() - content = "" - next_content = "" - for chunk in response.iter_lines(): - if b"content" in chunk: - next_content = content - data = json.loads(chunk.decode().split("data: ")[1]) - content = data["content"] - yield content.replace(next_content, "") + # Submit prompt + wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize"))) + driver.find_element(By.ID, "textareaAutosize").send_keys(prompt) + driver.find_element(By.ID, "textareaAutosize").send_keys(Keys.ENTER) - @classmethod - @property - def params(cls): - params = [ - ("model", "str"), - ("messages", "list[dict[str, str]]"), - ("auth", "list[dict[str, str]]"), - ("stream", "bool"), - ("temperature", "float"), - ("presence_penalty", "int"), - ("frequency_penalty", "int"), - ("top_p", "int") - ] - param = ", ".join([": ".join(p) for p in params]) - return f"g4f.provider.{cls.__name__} supports: ({param})"
\ No newline at end of file + # Read response with reader + script = """ +if(window._reader) { + chunk = await window._reader.read(); + if (chunk['done']) { + return null; + } + text = (new TextDecoder()).decode(chunk['value']); + message = ''; + text.split('\\r\\n').forEach((line, index) => { + if (line.startsWith('data: ')) { + try { + line = JSON.parse(line.substring('data: '.length)); + message = line["args"]["content"]; + } catch(e) { } + } + }); + if (message) { + try { + return message.substring(window._last_message.length); + } finally { + window._last_message = message; + } + } +} +return ''; +""" + while True: + chunk = driver.execute_script(script) + if chunk: + yield chunk + elif chunk != "": + break + else: + time.sleep(0.1) + finally: + if not browser: + driver.close() + time.sleep(0.1) + driver.quit()
\ No newline at end of file |