summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/needs_auth/Poe.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/needs_auth/Poe.py')
-rw-r--r--g4f/Provider/needs_auth/Poe.py93
1 files changed, 43 insertions, 50 deletions
diff --git a/g4f/Provider/needs_auth/Poe.py b/g4f/Provider/needs_auth/Poe.py
index a894bcb1..1c8c97d7 100644
--- a/g4f/Provider/needs_auth/Poe.py
+++ b/g4f/Provider/needs_auth/Poe.py
@@ -4,7 +4,7 @@ import time
from ...typing import CreateResult, Messages
from ..base_provider import BaseProvider
-from ..helper import WebDriver, format_prompt, get_browser
+from ..helper import WebDriver, WebDriverSession, format_prompt
models = {
"meta-llama/Llama-2-7b-chat-hf": {"name": "Llama-2-7b"},
@@ -33,7 +33,7 @@ class Poe(BaseProvider):
messages: Messages,
stream: bool,
proxy: str = None,
- browser: WebDriver = None,
+ web_driver: WebDriver = None,
user_data_dir: str = None,
headless: bool = True,
**kwargs
@@ -43,56 +43,54 @@ class Poe(BaseProvider):
elif model not in models:
raise ValueError(f"Model are not supported: {model}")
prompt = format_prompt(messages)
- driver = browser if browser else get_browser(user_data_dir, headless, proxy)
- script = """
-window._message = window._last_message = "";
-window._message_finished = false;
-class ProxiedWebSocket extends WebSocket {
- constructor(url, options) {
- super(url, options);
- this.addEventListener("message", (e) => {
- const data = JSON.parse(JSON.parse(e.data)["messages"][0])["payload"]["data"];
- if ("messageAdded" in data) {
- if (data["messageAdded"]["author"] != "human") {
- window._message = data["messageAdded"]["text"];
- if (data["messageAdded"]["state"] == "complete") {
- window._message_finished = true;
+ session = WebDriverSession(web_driver, user_data_dir, headless, proxy=proxy)
+ with session as driver:
+ from selenium.webdriver.common.by import By
+ from selenium.webdriver.support.ui import WebDriverWait
+ from selenium.webdriver.support import expected_conditions as EC
+
+ driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
+ "source": """
+ window._message = window._last_message = "";
+ window._message_finished = false;
+ class ProxiedWebSocket extends WebSocket {
+ constructor(url, options) {
+ super(url, options);
+ this.addEventListener("message", (e) => {
+ const data = JSON.parse(JSON.parse(e.data)["messages"][0])["payload"]["data"];
+ if ("messageAdded" in data) {
+ if (data["messageAdded"]["author"] != "human") {
+ window._message = data["messageAdded"]["text"];
+ if (data["messageAdded"]["state"] == "complete") {
+ window._message_finished = true;
+ }
}
}
- }
- });
- }
-}
-window.WebSocket = ProxiedWebSocket;
-"""
- driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
- "source": script
- })
-
- from selenium.webdriver.common.by import By
- from selenium.webdriver.support.ui import WebDriverWait
- from selenium.webdriver.support import expected_conditions as EC
+ });
+ }
+ }
+ window.WebSocket = ProxiedWebSocket;
+ """
+ })
- try:
- driver.get(f"{cls.url}/{models[model]['name']}")
- wait = WebDriverWait(driver, 10 if headless else 240)
- wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']")))
- except:
- # Reopen browser for login
- if not browser:
- driver.quit()
- driver = get_browser(None, False, proxy)
+ try:
driver.get(f"{cls.url}/{models[model]['name']}")
- wait = WebDriverWait(driver, 240)
+ wait = WebDriverWait(driver, 10 if headless else 240)
wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']")))
- else:
- raise RuntimeError("Prompt textarea not found. You may not be logged in.")
+ except:
+ # Reopen browser for login
+ if not web_driver:
+ driver = session.reopen(headless=False)
+ driver.get(f"{cls.url}/{models[model]['name']}")
+ wait = WebDriverWait(driver, 240)
+ wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']")))
+ else:
+ raise RuntimeError("Prompt textarea not found. You may not be logged in.")
- driver.find_element(By.CSS_SELECTOR, "footer textarea[class^='GrowingTextArea']").send_keys(prompt)
- driver.find_element(By.CSS_SELECTOR, "footer button[class*='ChatMessageSendButton']").click()
+ driver.find_element(By.CSS_SELECTOR, "footer textarea[class^='GrowingTextArea']").send_keys(prompt)
+ driver.find_element(By.CSS_SELECTOR, "footer button[class*='ChatMessageSendButton']").click()
- try:
script = """
if(window._message && window._message != window._last_message) {
try {
@@ -113,9 +111,4 @@ if(window._message && window._message != window._last_message) {
elif chunk != "":
break
else:
- time.sleep(0.1)
- finally:
- if not browser:
- driver.close()
- time.sleep(0.1)
- driver.quit() \ No newline at end of file
+ time.sleep(0.1) \ No newline at end of file