diff options
Diffstat (limited to 'g4f/Provider/needs_auth/Poe.py')
-rw-r--r-- | g4f/Provider/needs_auth/Poe.py | 93 |
1 files changed, 43 insertions, 50 deletions
diff --git a/g4f/Provider/needs_auth/Poe.py b/g4f/Provider/needs_auth/Poe.py index a894bcb1..1c8c97d7 100644 --- a/g4f/Provider/needs_auth/Poe.py +++ b/g4f/Provider/needs_auth/Poe.py @@ -4,7 +4,7 @@ import time from ...typing import CreateResult, Messages from ..base_provider import BaseProvider -from ..helper import WebDriver, format_prompt, get_browser +from ..helper import WebDriver, WebDriverSession, format_prompt models = { "meta-llama/Llama-2-7b-chat-hf": {"name": "Llama-2-7b"}, @@ -33,7 +33,7 @@ class Poe(BaseProvider): messages: Messages, stream: bool, proxy: str = None, - browser: WebDriver = None, + web_driver: WebDriver = None, user_data_dir: str = None, headless: bool = True, **kwargs @@ -43,56 +43,54 @@ class Poe(BaseProvider): elif model not in models: raise ValueError(f"Model are not supported: {model}") prompt = format_prompt(messages) - driver = browser if browser else get_browser(user_data_dir, headless, proxy) - script = """ -window._message = window._last_message = ""; -window._message_finished = false; -class ProxiedWebSocket extends WebSocket { - constructor(url, options) { - super(url, options); - this.addEventListener("message", (e) => { - const data = JSON.parse(JSON.parse(e.data)["messages"][0])["payload"]["data"]; - if ("messageAdded" in data) { - if (data["messageAdded"]["author"] != "human") { - window._message = data["messageAdded"]["text"]; - if (data["messageAdded"]["state"] == "complete") { - window._message_finished = true; + session = WebDriverSession(web_driver, user_data_dir, headless, proxy=proxy) + with session as driver: + from selenium.webdriver.common.by import By + from selenium.webdriver.support.ui import WebDriverWait + from selenium.webdriver.support import expected_conditions as EC + + driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", { + "source": """ + window._message = window._last_message = ""; + window._message_finished = false; + class ProxiedWebSocket extends WebSocket { + constructor(url, options) { + super(url, options); + this.addEventListener("message", (e) => { + const data = JSON.parse(JSON.parse(e.data)["messages"][0])["payload"]["data"]; + if ("messageAdded" in data) { + if (data["messageAdded"]["author"] != "human") { + window._message = data["messageAdded"]["text"]; + if (data["messageAdded"]["state"] == "complete") { + window._message_finished = true; + } } } - } - }); - } -} -window.WebSocket = ProxiedWebSocket; -""" - driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", { - "source": script - }) - - from selenium.webdriver.common.by import By - from selenium.webdriver.support.ui import WebDriverWait - from selenium.webdriver.support import expected_conditions as EC + }); + } + } + window.WebSocket = ProxiedWebSocket; + """ + }) - try: - driver.get(f"{cls.url}/{models[model]['name']}") - wait = WebDriverWait(driver, 10 if headless else 240) - wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']"))) - except: - # Reopen browser for login - if not browser: - driver.quit() - driver = get_browser(None, False, proxy) + try: driver.get(f"{cls.url}/{models[model]['name']}") - wait = WebDriverWait(driver, 240) + wait = WebDriverWait(driver, 10 if headless else 240) wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']"))) - else: - raise RuntimeError("Prompt textarea not found. You may not be logged in.") + except: + # Reopen browser for login + if not web_driver: + driver = session.reopen(headless=False) + driver.get(f"{cls.url}/{models[model]['name']}") + wait = WebDriverWait(driver, 240) + wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']"))) + else: + raise RuntimeError("Prompt textarea not found. You may not be logged in.") - driver.find_element(By.CSS_SELECTOR, "footer textarea[class^='GrowingTextArea']").send_keys(prompt) - driver.find_element(By.CSS_SELECTOR, "footer button[class*='ChatMessageSendButton']").click() + driver.find_element(By.CSS_SELECTOR, "footer textarea[class^='GrowingTextArea']").send_keys(prompt) + driver.find_element(By.CSS_SELECTOR, "footer button[class*='ChatMessageSendButton']").click() - try: script = """ if(window._message && window._message != window._last_message) { try { @@ -113,9 +111,4 @@ if(window._message && window._message != window._last_message) { elif chunk != "": break else: - time.sleep(0.1) - finally: - if not browser: - driver.close() - time.sleep(0.1) - driver.quit()
\ No newline at end of file + time.sleep(0.1)
\ No newline at end of file |