summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorH Lohaus <hlohaus@users.noreply.github.com>2024-01-14 15:32:51 +0100
committerGitHub <noreply@github.com>2024-01-14 15:32:51 +0100
commit1ca80ed48b55d6462b4bd445e66d4f7de7442c2b (patch)
tree05a94b53b83461b8249de965e093b4fd3722e2d1
parentMerge pull request #1466 from hlohaus/upp (diff)
parentChange doctypes style to Google (diff)
downloadgpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar
gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar.gz
gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar.bz2
gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar.lz
gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar.xz
gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar.zst
gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.zip
-rw-r--r--.github/workflows/unittest.yml19
-rw-r--r--etc/unittest/main.py73
-rw-r--r--g4f/Provider/Bing.py230
-rw-r--r--g4f/Provider/FreeChatgpt.py15
-rw-r--r--g4f/Provider/Phind.py8
-rw-r--r--g4f/Provider/base_provider.py198
-rw-r--r--g4f/Provider/bing/conversation.py44
-rw-r--r--g4f/Provider/bing/create_images.py224
-rw-r--r--g4f/Provider/bing/upload_image.py188
-rw-r--r--g4f/Provider/create_images.py61
-rw-r--r--g4f/Provider/helper.py143
-rw-r--r--g4f/Provider/needs_auth/OpenaiChat.py339
-rw-r--r--g4f/Provider/retry_provider.py58
-rw-r--r--g4f/__init__.py91
-rw-r--r--g4f/base_provider.py81
-rw-r--r--g4f/gui/client/css/style.css13
-rw-r--r--g4f/gui/client/html/index.html24
-rw-r--r--g4f/gui/client/js/chat.v1.js80
-rw-r--r--g4f/gui/server/backend.py211
-rw-r--r--g4f/image.py105
-rw-r--r--g4f/models.py15
-rw-r--r--g4f/requests.py48
-rw-r--r--g4f/version.py93
-rw-r--r--g4f/webdriver.py111
24 files changed, 1841 insertions, 631 deletions
diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml
new file mode 100644
index 00000000..e895e969
--- /dev/null
+++ b/.github/workflows/unittest.yml
@@ -0,0 +1,19 @@
+name: Unittest
+
+on: [push]
+
+jobs:
+ build:
+ name: Build unittest
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.x"
+ cache: 'pip'
+ - name: Install requirements
+ - run: pip install -r requirements.txt
+ - name: Run tests
+ run: python -m etc.unittest.main \ No newline at end of file
diff --git a/etc/unittest/main.py b/etc/unittest/main.py
new file mode 100644
index 00000000..61f4ffda
--- /dev/null
+++ b/etc/unittest/main.py
@@ -0,0 +1,73 @@
+import sys
+import pathlib
+import unittest
+from unittest.mock import MagicMock
+
+sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
+
+import g4f
+from g4f import ChatCompletion, get_last_provider
+from g4f.gui.server.backend import Backend_Api, get_error_message
+from g4f.base_provider import BaseProvider
+
+g4f.debug.logging = False
+
+class MockProvider(BaseProvider):
+ working = True
+
+ def create_completion(
+ model, messages, stream, **kwargs
+ ):
+ yield "Mock"
+
+ async def create_async(
+ model, messages, **kwargs
+ ):
+ return "Mock"
+
+class TestBackendApi(unittest.TestCase):
+
+ def setUp(self):
+ self.app = MagicMock()
+ self.api = Backend_Api(self.app)
+
+ def test_version(self):
+ response = self.api.get_version()
+ self.assertIn("version", response)
+ self.assertIn("latest_version", response)
+
+class TestChatCompletion(unittest.TestCase):
+
+ def test_create(self):
+ messages = [{'role': 'user', 'content': 'Hello'}]
+ result = ChatCompletion.create(g4f.models.default, messages)
+ self.assertTrue("Hello" in result or "Good" in result)
+
+ def test_get_last_provider(self):
+ messages = [{'role': 'user', 'content': 'Hello'}]
+ ChatCompletion.create(g4f.models.default, messages, MockProvider)
+ self.assertEqual(get_last_provider(), MockProvider)
+
+ def test_bing_provider(self):
+ messages = [{'role': 'user', 'content': 'Hello'}]
+ provider = g4f.Provider.Bing
+ result = ChatCompletion.create(g4f.models.default, messages, provider)
+ self.assertTrue("Bing" in result)
+
+class TestChatCompletionAsync(unittest.IsolatedAsyncioTestCase):
+
+ async def test_async(self):
+ messages = [{'role': 'user', 'content': 'Hello'}]
+ result = await ChatCompletion.create_async(g4f.models.default, messages, MockProvider)
+ self.assertTrue("Mock" in result)
+
+class TestUtilityFunctions(unittest.TestCase):
+
+ def test_get_error_message(self):
+ g4f.debug.last_provider = g4f.Provider.Bing
+ exception = Exception("Message")
+ result = get_error_message(exception)
+ self.assertEqual("Bing: Exception: Message", result)
+
+if __name__ == '__main__':
+ unittest.main() \ No newline at end of file
diff --git a/g4f/Provider/Bing.py b/g4f/Provider/Bing.py
index 50e29d23..34687866 100644
--- a/g4f/Provider/Bing.py
+++ b/g4f/Provider/Bing.py
@@ -15,12 +15,18 @@ from .bing.upload_image import upload_image
from .bing.create_images import create_images
from .bing.conversation import Conversation, create_conversation, delete_conversation
-class Tones():
+class Tones:
+ """
+ Defines the different tone options for the Bing provider.
+ """
creative = "Creative"
balanced = "Balanced"
precise = "Precise"
class Bing(AsyncGeneratorProvider):
+ """
+ Bing provider for generating responses using the Bing API.
+ """
url = "https://bing.com/chat"
working = True
supports_message_history = True
@@ -38,6 +44,19 @@ class Bing(AsyncGeneratorProvider):
web_search: bool = False,
**kwargs
) -> AsyncResult:
+ """
+ Creates an asynchronous generator for producing responses from Bing.
+
+ :param model: The model to use.
+ :param messages: Messages to process.
+ :param proxy: Proxy to use for requests.
+ :param timeout: Timeout for requests.
+ :param cookies: Cookies for the session.
+ :param tone: The tone of the response.
+ :param image: The image type to be used.
+ :param web_search: Flag to enable or disable web search.
+ :return: An asynchronous result object.
+ """
if len(messages) < 2:
prompt = messages[0]["content"]
context = None
@@ -56,65 +75,48 @@ class Bing(AsyncGeneratorProvider):
return stream_generate(prompt, tone, image, context, proxy, cookies, web_search, gpt4_turbo, timeout)
-def create_context(messages: Messages):
+def create_context(messages: Messages) -> str:
+ """
+ Creates a context string from a list of messages.
+
+ :param messages: A list of message dictionaries.
+ :return: A string representing the context created from the messages.
+ """
return "".join(
- f"[{message['role']}]" + ("(#message)" if message['role']!="system" else "(#additional_instructions)") + f"\n{message['content']}\n\n"
+ f"[{message['role']}]" + ("(#message)" if message['role'] != "system" else "(#additional_instructions)") + f"\n{message['content']}\n\n"
for message in messages
)
class Defaults:
+ """
+ Default settings and configurations for the Bing provider.
+ """
delimiter = "\x1e"
ip_address = f"13.{random.randint(104, 107)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
+ # List of allowed message types for Bing responses
allowedMessageTypes = [
- "ActionRequest",
- "Chat",
- "Context",
- # "Disengaged", unwanted
- "Progress",
- # "AdsQuery", unwanted
- "SemanticSerp",
- "GenerateContentQuery",
- "SearchQuery",
- # The following message types should not be added so that it does not flood with
- # useless messages (such as "Analyzing images" or "Searching the web") while it's retrieving the AI response
- # "InternalSearchQuery",
- # "InternalSearchResult",
- "RenderCardRequest",
- # "RenderContentRequest"
+ "ActionRequest", "Chat", "Context", "Progress", "SemanticSerp",
+ "GenerateContentQuery", "SearchQuery", "RenderCardRequest"
]
sliceIds = [
- 'abv2',
- 'srdicton',
- 'convcssclick',
- 'stylewv2',
- 'contctxp2tf',
- '802fluxv1pc_a',
- '806log2sphs0',
- '727savemem',
- '277teditgnds0',
- '207hlthgrds0',
+ 'abv2', 'srdicton', 'convcssclick', 'stylewv2', 'contctxp2tf',
+ '802fluxv1pc_a', '806log2sphs0', '727savemem', '277teditgnds0', '207hlthgrds0'
]
+ # Default location settings
location = {
- "locale": "en-US",
- "market": "en-US",
- "region": "US",
- "locationHints": [
- {
- "country": "United States",
- "state": "California",
- "city": "Los Angeles",
- "timezoneoffset": 8,
- "countryConfidence": 8,
- "Center": {"Latitude": 34.0536909, "Longitude": -118.242766},
- "RegionType": 2,
- "SourceType": 1,
- }
- ],
+ "locale": "en-US", "market": "en-US", "region": "US",
+ "locationHints": [{
+ "country": "United States", "state": "California", "city": "Los Angeles",
+ "timezoneoffset": 8, "countryConfidence": 8,
+ "Center": {"Latitude": 34.0536909, "Longitude": -118.242766},
+ "RegionType": 2, "SourceType": 1
+ }],
}
+ # Default headers for requests
headers = {
'accept': '*/*',
'accept-language': 'en-US,en;q=0.9',
@@ -139,23 +141,13 @@ class Defaults:
}
optionsSets = [
- 'nlu_direct_response_filter',
- 'deepleo',
- 'disable_emoji_spoken_text',
- 'responsible_ai_policy_235',
- 'enablemm',
- 'iyxapbing',
- 'iycapbing',
- 'gencontentv3',
- 'fluxsrtrunc',
- 'fluxtrunc',
- 'fluxv1',
- 'rai278',
- 'replaceurl',
- 'eredirecturl',
- 'nojbfedge'
+ 'nlu_direct_response_filter', 'deepleo', 'disable_emoji_spoken_text',
+ 'responsible_ai_policy_235', 'enablemm', 'iyxapbing', 'iycapbing',
+ 'gencontentv3', 'fluxsrtrunc', 'fluxtrunc', 'fluxv1', 'rai278',
+ 'replaceurl', 'eredirecturl', 'nojbfedge'
]
+ # Default cookies
cookies = {
'SRCHD' : 'AF=NOFORM',
'PPLState' : '1',
@@ -166,6 +158,12 @@ class Defaults:
}
def format_message(msg: dict) -> str:
+ """
+ Formats a message dictionary into a JSON string with a delimiter.
+
+ :param msg: The message dictionary to format.
+ :return: A formatted string representation of the message.
+ """
return json.dumps(msg, ensure_ascii=False) + Defaults.delimiter
def create_message(
@@ -177,7 +175,20 @@ def create_message(
web_search: bool = False,
gpt4_turbo: bool = False
) -> str:
+ """
+ Creates a message for the Bing API with specified parameters.
+
+ :param conversation: The current conversation object.
+ :param prompt: The user's input prompt.
+ :param tone: The desired tone for the response.
+ :param context: Additional context for the prompt.
+ :param image_response: The response if an image is involved.
+ :param web_search: Flag to enable web search.
+ :param gpt4_turbo: Flag to enable GPT-4 Turbo.
+ :return: A formatted string message for the Bing API.
+ """
options_sets = Defaults.optionsSets
+ # Append tone-specific options
if tone == Tones.creative:
options_sets.append("h3imaginative")
elif tone == Tones.precise:
@@ -186,54 +197,49 @@ def create_message(
options_sets.append("galileo")
else:
options_sets.append("harmonyv3")
-
+
+ # Additional configurations based on parameters
if not web_search:
options_sets.append("nosearchall")
-
if gpt4_turbo:
options_sets.append("dlgpt4t")
-
+
request_id = str(uuid.uuid4())
struct = {
- 'arguments': [
- {
- 'source': 'cib',
- 'optionsSets': options_sets,
- 'allowedMessageTypes': Defaults.allowedMessageTypes,
- 'sliceIds': Defaults.sliceIds,
- 'traceId': os.urandom(16).hex(),
- 'isStartOfSession': True,
+ 'arguments': [{
+ 'source': 'cib', 'optionsSets': options_sets,
+ 'allowedMessageTypes': Defaults.allowedMessageTypes,
+ 'sliceIds': Defaults.sliceIds,
+ 'traceId': os.urandom(16).hex(), 'isStartOfSession': True,
+ 'requestId': request_id,
+ 'message': {
+ **Defaults.location,
+ 'author': 'user',
+ 'inputMethod': 'Keyboard',
+ 'text': prompt,
+ 'messageType': 'Chat',
'requestId': request_id,
- 'message': {**Defaults.location, **{
- 'author': 'user',
- 'inputMethod': 'Keyboard',
- 'text': prompt,
- 'messageType': 'Chat',
- 'requestId': request_id,
- 'messageId': request_id,
- }},
- "verbosity": "verbose",
- "scenario": "SERP",
- "plugins":[
- {"id":"c310c353-b9f0-4d76-ab0d-1dd5e979cf68", "category": 1}
- ] if web_search else [],
- 'tone': tone,
- 'spokenTextMode': 'None',
- 'conversationId': conversation.conversationId,
- 'participant': {
- 'id': conversation.clientId
- },
- }
- ],
+ 'messageId': request_id
+ },
+ "verbosity": "verbose",
+ "scenario": "SERP",
+ "plugins": [{"id": "c310c353-b9f0-4d76-ab0d-1dd5e979cf68", "category": 1}] if web_search else [],
+ 'tone': tone,
+ 'spokenTextMode': 'None',
+ 'conversationId': conversation.conversationId,
+ 'participant': {'id': conversation.clientId},
+ }],
'invocationId': '1',
'target': 'chat',
'type': 4
}
- if image_response.get('imageUrl') and image_response.get('originalImageUrl'):
+
+ if image_response and image_response.get('imageUrl') and image_response.get('originalImageUrl'):
struct['arguments'][0]['message']['originalImageUrl'] = image_response.get('originalImageUrl')
struct['arguments'][0]['message']['imageUrl'] = image_response.get('imageUrl')
struct['arguments'][0]['experienceType'] = None
struct['arguments'][0]['attachedFileInfo'] = {"fileName": None, "fileType": None}
+
if context:
struct['arguments'][0]['previousMessages'] = [{
"author": "user",
@@ -242,30 +248,46 @@ def create_message(
"messageType": "Context",
"messageId": "discover-web--page-ping-mriduna-----"
}]
+
return format_message(struct)
async def stream_generate(
- prompt: str,
- tone: str,
- image: ImageType = None,
- context: str = None,
- proxy: str = None,
- cookies: dict = None,
- web_search: bool = False,
- gpt4_turbo: bool = False,
- timeout: int = 900
- ):
+ prompt: str,
+ tone: str,
+ image: ImageType = None,
+ context: str = None,
+ proxy: str = None,
+ cookies: dict = None,
+ web_search: bool = False,
+ gpt4_turbo: bool = False,
+ timeout: int = 900
+):
+ """
+ Asynchronously streams generated responses from the Bing API.
+
+ :param prompt: The user's input prompt.
+ :param tone: The desired tone for the response.
+ :param image: The image type involved in the response.
+ :param context: Additional context for the prompt.
+ :param proxy: Proxy settings for the request.
+ :param cookies: Cookies for the session.
+ :param web_search: Flag to enable web search.
+ :param gpt4_turbo: Flag to enable GPT-4 Turbo.
+ :param timeout: Timeout for the request.
+ :return: An asynchronous generator yielding responses.
+ """
headers = Defaults.headers
if cookies:
headers["Cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items())
+
async with ClientSession(
- timeout=ClientTimeout(total=timeout),
- headers=headers
+ timeout=ClientTimeout(total=timeout), headers=headers
) as session:
conversation = await create_conversation(session, proxy)
image_response = await upload_image(session, image, tone, proxy) if image else None
if image_response:
yield image_response
+
try:
async with session.ws_connect(
'wss://sydney.bing.com/sydney/ChatHub',
@@ -289,7 +311,7 @@ async def stream_generate(
if obj is None or not obj:
continue
response = json.loads(obj)
- if response.get('type') == 1 and response['arguments'][0].get('messages'):
+ if response and response.get('type') == 1 and response['arguments'][0].get('messages'):
message = response['arguments'][0]['messages'][0]
image_response = None
if (message['contentOrigin'] != 'Apology'):
diff --git a/g4f/Provider/FreeChatgpt.py b/g4f/Provider/FreeChatgpt.py
index 75514118..0f993690 100644
--- a/g4f/Provider/FreeChatgpt.py
+++ b/g4f/Provider/FreeChatgpt.py
@@ -1,16 +1,20 @@
from __future__ import annotations
-import json
+import json, random
from aiohttp import ClientSession
from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider
-
models = {
- "claude-v2": "claude-2.0",
- "gemini-pro": "google-gemini-pro"
+ "claude-v2": "claude-2.0",
+ "claude-v2.1":"claude-2.1",
+ "gemini-pro": "google-gemini-pro"
}
+urls = [
+ "https://free.chatgpt.org.uk",
+ "https://ai.chatgpt.org.uk"
+]
class FreeChatgpt(AsyncGeneratorProvider):
url = "https://free.chatgpt.org.uk"
@@ -31,6 +35,7 @@ class FreeChatgpt(AsyncGeneratorProvider):
model = models[model]
elif not model:
model = "gpt-3.5-turbo"
+ url = random.choice(urls)
headers = {
"Accept": "application/json, text/event-stream",
"Content-Type":"application/json",
@@ -55,7 +60,7 @@ class FreeChatgpt(AsyncGeneratorProvider):
"top_p":1,
**kwargs
}
- async with session.post(f'{cls.url}/api/openai/v1/chat/completions', json=data, proxy=proxy) as response:
+ async with session.post(f'{url}/api/openai/v1/chat/completions', json=data, proxy=proxy) as response:
response.raise_for_status()
started = False
async for line in response.content:
diff --git a/g4f/Provider/Phind.py b/g4f/Provider/Phind.py
index bb216989..9e80baa9 100644
--- a/g4f/Provider/Phind.py
+++ b/g4f/Provider/Phind.py
@@ -59,12 +59,16 @@ class Phind(AsyncGeneratorProvider):
"rewrittenQuestion": prompt,
"challenge": 0.21132115912208504
}
- async with session.post(f"{cls.url}/api/infer/followup/answer", headers=headers, json=data) as response:
+ async with session.post(f"https://https.api.phind.com/infer/", headers=headers, json=data) as response:
new_line = False
async for line in response.iter_lines():
if line.startswith(b"data: "):
chunk = line[6:]
- if chunk.startswith(b"<PHIND_METADATA>") or chunk.startswith(b"<PHIND_INDICATOR>"):
+ if chunk.startswith(b'<PHIND_DONE/>'):
+ break
+ if chunk.startswith(b'<PHIND_WEBRESULTS>') or chunk.startswith(b'<PHIND_FOLLOWUP>'):
+ pass
+ elif chunk.startswith(b"<PHIND_METADATA>") or chunk.startswith(b"<PHIND_INDICATOR>"):
pass
elif chunk:
yield chunk.decode()
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index e7e88841..fd92d17a 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -1,28 +1,29 @@
from __future__ import annotations
-
import sys
import asyncio
-from asyncio import AbstractEventLoop
+from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
-from abc import abstractmethod
-from inspect import signature, Parameter
-from .helper import get_event_loop, get_cookies, format_prompt
-from ..typing import CreateResult, AsyncResult, Messages
-from ..base_provider import BaseProvider
+from abc import abstractmethod
+from inspect import signature, Parameter
+from .helper import get_event_loop, get_cookies, format_prompt
+from ..typing import CreateResult, AsyncResult, Messages
+from ..base_provider import BaseProvider
if sys.version_info < (3, 10):
NoneType = type(None)
else:
from types import NoneType
-# Change event loop policy on windows for curl_cffi
+# Set Windows event loop policy for better compatibility with asyncio and curl_cffi
if sys.platform == 'win32':
- if isinstance(
- asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy
- ):
+ if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
class AbstractProvider(BaseProvider):
+ """
+ Abstract class for providing asynchronous functionality to derived classes.
+ """
+
@classmethod
async def create_async(
cls,
@@ -33,62 +34,67 @@ class AbstractProvider(BaseProvider):
executor: ThreadPoolExecutor = None,
**kwargs
) -> str:
- if not loop:
- loop = get_event_loop()
+ """
+ Asynchronously creates a result based on the given model and messages.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ str: The created result as a string.
+ """
+ loop = loop or get_event_loop()
def create_func() -> str:
- return "".join(cls.create_completion(
- model,
- messages,
- False,
- **kwargs
- ))
+ return "".join(cls.create_completion(model, messages, False, **kwargs))
return await asyncio.wait_for(
- loop.run_in_executor(
- executor,
- create_func
- ),
+ loop.run_in_executor(executor, create_func),
timeout=kwargs.get("timeout", 0)
)
@classmethod
@property
def params(cls) -> str:
- if issubclass(cls, AsyncGeneratorProvider):
- sig = signature(cls.create_async_generator)
- elif issubclass(cls, AsyncProvider):
- sig = signature(cls.create_async)
- else:
- sig = signature(cls.create_completion)
+ """
+ Returns the parameters supported by the provider.
+
+ Args:
+ cls (type): The class on which this property is called.
+
+ Returns:
+ str: A string listing the supported parameters.
+ """
+ sig = signature(
+ cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
+ cls.create_async if issubclass(cls, AsyncProvider) else
+ cls.create_completion
+ )
def get_type_name(annotation: type) -> str:
- if hasattr(annotation, "__name__"):
- annotation = annotation.__name__
- elif isinstance(annotation, NoneType):
- annotation = "None"
- return str(annotation)
-
+ return annotation.__name__ if hasattr(annotation, "__name__") else str(annotation)
+
args = ""
for name, param in sig.parameters.items():
- if name in ("self", "kwargs"):
- continue
- if name == "stream" and not cls.supports_stream:
+ if name in ("self", "kwargs") or (name == "stream" and not cls.supports_stream):
continue
- if args:
- args += ", "
- args += "\n " + name
- if name != "model" and param.annotation is not Parameter.empty:
- args += f": {get_type_name(param.annotation)}"
- if param.default == "":
- args += ' = ""'
- elif param.default is not Parameter.empty:
- args += f" = {param.default}"
+ args += f"\n {name}"
+ args += f": {get_type_name(param.annotation)}" if param.annotation is not Parameter.empty else ""
+ args += f' = "{param.default}"' if param.default == "" else f" = {param.default}" if param.default is not Parameter.empty else ""
return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
class AsyncProvider(AbstractProvider):
+ """
+ Provides asynchronous functionality for creating completions.
+ """
+
@classmethod
def create_completion(
cls,
@@ -99,8 +105,21 @@ class AsyncProvider(AbstractProvider):
loop: AbstractEventLoop = None,
**kwargs
) -> CreateResult:
- if not loop:
- loop = get_event_loop()
+ """
+ Creates a completion result synchronously.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to False.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ CreateResult: The result of the completion creation.
+ """
+ loop = loop or get_event_loop()
coro = cls.create_async(model, messages, **kwargs)
yield loop.run_until_complete(coro)
@@ -111,10 +130,27 @@ class AsyncProvider(AbstractProvider):
messages: Messages,
**kwargs
) -> str:
+ """
+ Abstract method for creating asynchronous results.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ NotImplementedError: If this method is not overridden in derived classes.
+
+ Returns:
+ str: The created result as a string.
+ """
raise NotImplementedError()
class AsyncGeneratorProvider(AsyncProvider):
+ """
+ Provides asynchronous generator functionality for streaming results.
+ """
supports_stream = True
@classmethod
@@ -127,15 +163,24 @@ class AsyncGeneratorProvider(AsyncProvider):
loop: AbstractEventLoop = None,
**kwargs
) -> CreateResult:
- if not loop:
- loop = get_event_loop()
- generator = cls.create_async_generator(
- model,
- messages,
- stream=stream,
- **kwargs
- )
+ """
+ Creates a streaming completion result synchronously.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to True.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ CreateResult: The result of the streaming completion creation.
+ """
+ loop = loop or get_event_loop()
+ generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
gen = generator.__aiter__()
+
while True:
try:
yield loop.run_until_complete(gen.__anext__())
@@ -149,21 +194,44 @@ class AsyncGeneratorProvider(AsyncProvider):
messages: Messages,
**kwargs
) -> str:
+ """
+ Asynchronously creates a result from a generator.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ str: The created result as a string.
+ """
return "".join([
- chunk async for chunk in cls.create_async_generator(
- model,
- messages,
- stream=False,
- **kwargs
- ) if not isinstance(chunk, Exception)
+ chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
+ if not isinstance(chunk, Exception)
])
@staticmethod
@abstractmethod
- def create_async_generator(
+ async def create_async_generator(
model: str,
messages: Messages,
stream: bool = True,
**kwargs
) -> AsyncResult:
- raise NotImplementedError()
+ """
+ Abstract method for creating an asynchronous generator.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to True.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ NotImplementedError: If this method is not overridden in derived classes.
+
+ Returns:
+ AsyncResult: An asynchronous generator yielding results.
+ """
+ raise NotImplementedError() \ No newline at end of file
diff --git a/g4f/Provider/bing/conversation.py b/g4f/Provider/bing/conversation.py
index 9e011c26..36ada3b0 100644
--- a/g4f/Provider/bing/conversation.py
+++ b/g4f/Provider/bing/conversation.py
@@ -1,13 +1,33 @@
from aiohttp import ClientSession
-
-class Conversation():
+class Conversation:
+ """
+ Represents a conversation with specific attributes.
+ """
def __init__(self, conversationId: str, clientId: str, conversationSignature: str) -> None:
+ """
+ Initialize a new conversation instance.
+
+ Args:
+ conversationId (str): Unique identifier for the conversation.
+ clientId (str): Client identifier.
+ conversationSignature (str): Signature for the conversation.
+ """
self.conversationId = conversationId
self.clientId = clientId
self.conversationSignature = conversationSignature
async def create_conversation(session: ClientSession, proxy: str = None) -> Conversation:
+ """
+ Create a new conversation asynchronously.
+
+ Args:
+ session (ClientSession): An instance of aiohttp's ClientSession.
+ proxy (str, optional): Proxy URL. Defaults to None.
+
+ Returns:
+ Conversation: An instance representing the created conversation.
+ """
url = 'https://www.bing.com/turing/conversation/create?bundleVersion=1.1199.4'
async with session.get(url, proxy=proxy) as response:
try:
@@ -24,12 +44,32 @@ async def create_conversation(session: ClientSession, proxy: str = None) -> Conv
return Conversation(conversationId, clientId, conversationSignature)
async def list_conversations(session: ClientSession) -> list:
+ """
+ List all conversations asynchronously.
+
+ Args:
+ session (ClientSession): An instance of aiohttp's ClientSession.
+
+ Returns:
+ list: A list of conversations.
+ """
url = "https://www.bing.com/turing/conversation/chats"
async with session.get(url) as response:
response = await response.json()
return response["chats"]
async def delete_conversation(session: ClientSession, conversation: Conversation, proxy: str = None) -> bool:
+ """
+ Delete a conversation asynchronously.
+
+ Args:
+ session (ClientSession): An instance of aiohttp's ClientSession.
+ conversation (Conversation): The conversation to delete.
+ proxy (str, optional): Proxy URL. Defaults to None.
+
+ Returns:
+ bool: True if deletion was successful, False otherwise.
+ """
url = "https://sydney.bing.com/sydney/DeleteSingleConversation"
json = {
"conversationId": conversation.conversationId,
diff --git a/g4f/Provider/bing/create_images.py b/g4f/Provider/bing/create_images.py
index a1ecace3..060cd184 100644
--- a/g4f/Provider/bing/create_images.py
+++ b/g4f/Provider/bing/create_images.py
@@ -1,9 +1,16 @@
+"""
+This module provides functionalities for creating and managing images using Bing's service.
+It includes functions for user login, session creation, image creation, and processing.
+"""
+
import asyncio
-import time, json, os
+import time
+import json
+import os
from aiohttp import ClientSession
from bs4 import BeautifulSoup
from urllib.parse import quote
-from typing import Generator
+from typing import Generator, List, Dict
from ..create_images import CreateImagesProvider
from ..helper import get_cookies, get_event_loop
@@ -12,23 +19,47 @@ from ...base_provider import ProviderType
from ...image import format_images_markdown
BING_URL = "https://www.bing.com"
+TIMEOUT_LOGIN = 1200
+TIMEOUT_IMAGE_CREATION = 300
+ERRORS = [
+ "this prompt is being reviewed",
+ "this prompt has been blocked",
+ "we're working hard to offer image creator in more languages",
+ "we can't create your images right now"
+]
+BAD_IMAGES = [
+ "https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png",
+ "https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
+]
+
+def wait_for_login(driver: WebDriver, timeout: int = TIMEOUT_LOGIN) -> None:
+ """
+ Waits for the user to log in within a given timeout period.
-def wait_for_login(driver: WebDriver, timeout: int = 1200) -> None:
+ Args:
+ driver (WebDriver): Webdriver for browser automation.
+ timeout (int): Maximum waiting time in seconds.
+
+ Raises:
+ RuntimeError: If the login process exceeds the timeout.
+ """
driver.get(f"{BING_URL}/")
- value = driver.get_cookie("_U")
- if value:
- return
start_time = time.time()
- while True:
+ while not driver.get_cookie("_U"):
if time.time() - start_time > timeout:
raise RuntimeError("Timeout error")
- value = driver.get_cookie("_U")
- if value:
- time.sleep(1)
- return
time.sleep(0.5)
-def create_session(cookies: dict) -> ClientSession:
+def create_session(cookies: Dict[str, str]) -> ClientSession:
+ """
+ Creates a new client session with specified cookies and headers.
+
+ Args:
+ cookies (Dict[str, str]): Cookies to be used for the session.
+
+ Returns:
+ ClientSession: The created client session.
+ """
headers = {
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"accept-encoding": "gzip, deflate, br",
@@ -47,28 +78,32 @@ def create_session(cookies: dict) -> ClientSession:
"upgrade-insecure-requests": "1",
}
if cookies:
- headers["cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items())
+ headers["Cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items())
return ClientSession(headers=headers)
-async def create_images(session: ClientSession, prompt: str, proxy: str = None, timeout: int = 300) -> list:
- url_encoded_prompt = quote(prompt)
+async def create_images(session: ClientSession, prompt: str, proxy: str = None, timeout: int = TIMEOUT_IMAGE_CREATION) -> List[str]:
+ """
+ Creates images based on a given prompt using Bing's service.
+
+ Args:
+ session (ClientSession): Active client session.
+ prompt (str): Prompt to generate images.
+ proxy (str, optional): Proxy configuration.
+ timeout (int): Timeout for the request.
+
+ Returns:
+ List[str]: A list of URLs to the created images.
+
+ Raises:
+ RuntimeError: If image creation fails or times out.
+ """
+ url_encoded_prompt = quote(prompt)
payload = f"q={url_encoded_prompt}&rt=4&FORM=GENCRE"
url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=4&FORM=GENCRE"
- async with session.post(
- url,
- allow_redirects=False,
- data=payload,
- timeout=timeout,
- ) as response:
+ async with session.post(url, allow_redirects=False, data=payload, timeout=timeout) as response:
response.raise_for_status()
- errors = [
- "this prompt is being reviewed",
- "this prompt has been blocked",
- "we're working hard to offer image creator in more languages",
- "we can't create your images right now"
- ]
text = (await response.text()).lower()
- for error in errors:
+ for error in ERRORS:
if error in text:
raise RuntimeError(f"Create images failed: {error}")
if response.status != 302:
@@ -107,54 +142,109 @@ async def create_images(session: ClientSession, prompt: str, proxy: str = None,
raise RuntimeError(error)
return read_images(text)
-def read_images(text: str) -> list:
- html_soup = BeautifulSoup(text, "html.parser")
- tags = html_soup.find_all("img")
- image_links = [img["src"] for img in tags if "mimg" in img["class"]]
- images = [link.split("?w=")[0] for link in image_links]
- bad_images = [
- "https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png",
- "https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
- ]
- if any(im in bad_images for im in images):
+def read_images(html_content: str) -> List[str]:
+ """
+ Extracts image URLs from the HTML content.
+
+ Args:
+ html_content (str): HTML content containing image URLs.
+
+ Returns:
+ List[str]: A list of image URLs.
+ """
+ soup = BeautifulSoup(html_content, "html.parser")
+ tags = soup.find_all("img", class_="mimg")
+ images = [img["src"].split("?w=")[0] for img in tags]
+ if any(im in BAD_IMAGES for im in images):
raise RuntimeError("Bad images found")
if not images:
raise RuntimeError("No images found")
return images
-async def create_images_markdown(cookies: dict, prompt: str, proxy: str = None) -> str:
- session = create_session(cookies)
- try:
+async def create_images_markdown(cookies: Dict[str, str], prompt: str, proxy: str = None) -> str:
+ """
+ Creates markdown formatted string with images based on the prompt.
+
+ Args:
+ cookies (Dict[str, str]): Cookies to be used for the session.
+ prompt (str): Prompt to generate images.
+ proxy (str, optional): Proxy configuration.
+
+ Returns:
+ str: Markdown formatted string with images.
+ """
+ async with create_session(cookies) as session:
images = await create_images(session, prompt, proxy)
return format_images_markdown(images, prompt)
- finally:
- await session.close()
-def get_cookies_from_browser(proxy: str = None) -> dict:
- driver = get_browser(proxy=proxy)
- try:
+def get_cookies_from_browser(proxy: str = None) -> Dict[str, str]:
+ """
+ Retrieves cookies from the browser using webdriver.
+
+ Args:
+ proxy (str, optional): Proxy configuration.
+
+ Returns:
+ Dict[str, str]: Retrieved cookies.
+ """
+ with get_browser(proxy=proxy) as driver:
wait_for_login(driver)
+ time.sleep(1)
return get_driver_cookies(driver)
- finally:
- driver.quit()
-
-def create_completion(prompt: str, cookies: dict = None, proxy: str = None) -> Generator:
- loop = get_event_loop()
- if not cookies:
- cookies = get_cookies(".bing.com")
- if "_U" not in cookies:
- login_url = os.environ.get("G4F_LOGIN_URL")
- if login_url:
- yield f"Please login: [Bing]({login_url})\n\n"
- cookies = get_cookies_from_browser(proxy)
- yield loop.run_until_complete(create_images_markdown(cookies, prompt, proxy))
-
-async def create_async(prompt: str, cookies: dict = None, proxy: str = None) -> str:
- if not cookies:
- cookies = get_cookies(".bing.com")
- if "_U" not in cookies:
- cookies = get_cookies_from_browser(proxy)
- return await create_images_markdown(cookies, prompt, proxy)
+
+class CreateImagesBing:
+ """A class for creating images using Bing."""
+
+ _cookies: Dict[str, str] = {}
+
+ @classmethod
+ def create_completion(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> Generator[str, None, None]:
+ """
+ Generator for creating imagecompletion based on a prompt.
+
+ Args:
+ prompt (str): Prompt to generate images.
+ cookies (Dict[str, str], optional): Cookies for the session. If None, cookies are retrieved automatically.
+ proxy (str, optional): Proxy configuration.
+
+ Yields:
+ Generator[str, None, None]: The final output as markdown formatted string with images.
+ """
+ loop = get_event_loop()
+ cookies = cookies or cls._cookies or get_cookies(".bing.com")
+ if "_U" not in cookies:
+ login_url = os.environ.get("G4F_LOGIN_URL")
+ if login_url:
+ yield f"Please login: [Bing]({login_url})\n\n"
+ cls._cookies = cookies = get_cookies_from_browser(proxy)
+ yield loop.run_until_complete(create_images_markdown(cookies, prompt, proxy))
+
+ @classmethod
+ async def create_async(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> str:
+ """
+ Asynchronously creates a markdown formatted string with images based on the prompt.
+
+ Args:
+ prompt (str): Prompt to generate images.
+ cookies (Dict[str, str], optional): Cookies for the session. If None, cookies are retrieved automatically.
+ proxy (str, optional): Proxy configuration.
+
+ Returns:
+ str: Markdown formatted string with images.
+ """
+ cookies = cookies or cls._cookies or get_cookies(".bing.com")
+ if "_U" not in cookies:
+ cls._cookies = cookies = get_cookies_from_browser(proxy)
+ return await create_images_markdown(cookies, prompt, proxy)
def patch_provider(provider: ProviderType) -> CreateImagesProvider:
- return CreateImagesProvider(provider, create_completion, create_async) \ No newline at end of file
+ """
+ Patches a provider to include image creation capabilities.
+
+ Args:
+ provider (ProviderType): The provider to be patched.
+
+ Returns:
+ CreateImagesProvider: The patched provider with image creation capabilities.
+ """
+ return CreateImagesProvider(provider, CreateImagesBing.create_completion, CreateImagesBing.create_async) \ No newline at end of file
diff --git a/g4f/Provider/bing/upload_image.py b/g4f/Provider/bing/upload_image.py
index 1af902ef..4d70659f 100644
--- a/g4f/Provider/bing/upload_image.py
+++ b/g4f/Provider/bing/upload_image.py
@@ -1,64 +1,107 @@
-from __future__ import annotations
+"""
+Module to handle image uploading and processing for Bing AI integrations.
+"""
+from __future__ import annotations
import string
import random
import json
import math
-from ...typing import ImageType
from aiohttp import ClientSession
+from PIL import Image
+
+from ...typing import ImageType, Tuple
from ...image import to_image, process_image, to_base64, ImageResponse
-image_config = {
+IMAGE_CONFIG = {
"maxImagePixels": 360000,
"imageCompressionRate": 0.7,
- "enableFaceBlurDebug": 0,
+ "enableFaceBlurDebug": False,
}
async def upload_image(
- session: ClientSession,
- image: ImageType,
- tone: str,
+ session: ClientSession,
+ image_data: ImageType,
+ tone: str,
proxy: str = None
) -> ImageResponse:
- image = to_image(image)
- width, height = image.size
- max_image_pixels = image_config['maxImagePixels']
- if max_image_pixels / (width * height) < 1:
- new_width = int(width * math.sqrt(max_image_pixels / (width * height)))
- new_height = int(height * math.sqrt(max_image_pixels / (width * height)))
- else:
- new_width = width
- new_height = height
- new_img = process_image(image, new_width, new_height)
- new_img_binary_data = to_base64(new_img, image_config['imageCompressionRate'])
- data, boundary = build_image_upload_api_payload(new_img_binary_data, tone)
- headers = session.headers.copy()
- headers["content-type"] = f'multipart/form-data; boundary={boundary}'
- headers["referer"] = 'https://www.bing.com/search?q=Bing+AI&showconv=1&FORM=hpcodx'
- headers["origin"] = 'https://www.bing.com'
+ """
+ Uploads an image to Bing's AI service and returns the image response.
+
+ Args:
+ session (ClientSession): The active session.
+ image_data (bytes): The image data to be uploaded.
+ tone (str): The tone of the conversation.
+ proxy (str, optional): Proxy if any. Defaults to None.
+
+ Raises:
+ RuntimeError: If the image upload fails.
+
+ Returns:
+ ImageResponse: The response from the image upload.
+ """
+ image = to_image(image_data)
+ new_width, new_height = calculate_new_dimensions(image)
+ processed_img = process_image(image, new_width, new_height)
+ img_binary_data = to_base64(processed_img, IMAGE_CONFIG['imageCompressionRate'])
+
+ data, boundary = build_image_upload_payload(img_binary_data, tone)
+ headers = prepare_headers(session, boundary)
+
async with session.post("https://www.bing.com/images/kblob", data=data, headers=headers, proxy=proxy) as response:
if response.status != 200:
raise RuntimeError("Failed to upload image.")
- image_info = await response.json()
- if not image_info.get('blobId'):
- raise RuntimeError("Failed to parse image info.")
- result = {'bcid': image_info.get('blobId', "")}
- result['blurredBcid'] = image_info.get('processedBlobId', "")
- if result['blurredBcid'] != "":
- result["imageUrl"] = "https://www.bing.com/images/blob?bcid=" + result['blurredBcid']
- elif result['bcid'] != "":
- result["imageUrl"] = "https://www.bing.com/images/blob?bcid=" + result['bcid']
- result['originalImageUrl'] = (
- "https://www.bing.com/images/blob?bcid="
- + result['blurredBcid']
- if image_config["enableFaceBlurDebug"]
- else "https://www.bing.com/images/blob?bcid="
- + result['bcid']
- )
- return ImageResponse(result["imageUrl"], "", result)
-
-def build_image_upload_api_payload(image_bin: str, tone: str):
- payload = {
+ return parse_image_response(await response.json())
+
+def calculate_new_dimensions(image: Image.Image) -> Tuple[int, int]:
+ """
+ Calculates the new dimensions for the image based on the maximum allowed pixels.
+
+ Args:
+ image (Image): The PIL Image object.
+
+ Returns:
+ Tuple[int, int]: The new width and height for the image.
+ """
+ width, height = image.size
+ max_image_pixels = IMAGE_CONFIG['maxImagePixels']
+ if max_image_pixels / (width * height) < 1:
+ scale_factor = math.sqrt(max_image_pixels / (width * height))
+ return int(width * scale_factor), int(height * scale_factor)
+ return width, height
+
+def build_image_upload_payload(image_bin: str, tone: str) -> Tuple[str, str]:
+ """
+ Builds the payload for image uploading.
+
+ Args:
+ image_bin (str): Base64 encoded image binary data.
+ tone (str): The tone of the conversation.
+
+ Returns:
+ Tuple[str, str]: The data and boundary for the payload.
+ """
+ boundary = "----WebKitFormBoundary" + ''.join(random.choices(string.ascii_letters + string.digits, k=16))
+ data = f"--{boundary}\r\n" \
+ f"Content-Disposition: form-data; name=\"knowledgeRequest\"\r\n\r\n" \
+ f"{json.dumps(build_knowledge_request(tone), ensure_ascii=False)}\r\n" \
+ f"--{boundary}\r\n" \
+ f"Content-Disposition: form-data; name=\"imageBase64\"\r\n\r\n" \
+ f"{image_bin}\r\n" \
+ f"--{boundary}--\r\n"
+ return data, boundary
+
+def build_knowledge_request(tone: str) -> dict:
+ """
+ Builds the knowledge request payload.
+
+ Args:
+ tone (str): The tone of the conversation.
+
+ Returns:
+ dict: The knowledge request payload.
+ """
+ return {
'invokedSkills': ["ImageById"],
'subscriptionId': "Bing.Chat.Multimodal",
'invokedSkillsRequestData': {
@@ -69,21 +112,46 @@ def build_image_upload_api_payload(image_bin: str, tone: str):
'convotone': tone
}
}
- knowledge_request = {
- 'imageInfo': {},
- 'knowledgeRequest': payload
- }
- boundary="----WebKitFormBoundary" + ''.join(random.choices(string.ascii_letters + string.digits, k=16))
- data = (
- f'--{boundary}'
- + '\r\nContent-Disposition: form-data; name="knowledgeRequest"\r\n\r\n'
- + json.dumps(knowledge_request, ensure_ascii=False)
- + "\r\n--"
- + boundary
- + '\r\nContent-Disposition: form-data; name="imageBase64"\r\n\r\n'
- + image_bin
- + "\r\n--"
- + boundary
- + "--\r\n"
+
+def prepare_headers(session: ClientSession, boundary: str) -> dict:
+ """
+ Prepares the headers for the image upload request.
+
+ Args:
+ session (ClientSession): The active session.
+ boundary (str): The boundary string for the multipart/form-data.
+
+ Returns:
+ dict: The headers for the request.
+ """
+ headers = session.headers.copy()
+ headers["Content-Type"] = f'multipart/form-data; boundary={boundary}'
+ headers["Referer"] = 'https://www.bing.com/search?q=Bing+AI&showconv=1&FORM=hpcodx'
+ headers["Origin"] = 'https://www.bing.com'
+ return headers
+
+def parse_image_response(response: dict) -> ImageResponse:
+ """
+ Parses the response from the image upload.
+
+ Args:
+ response (dict): The response dictionary.
+
+ Raises:
+ RuntimeError: If parsing the image info fails.
+
+ Returns:
+ ImageResponse: The parsed image response.
+ """
+ if not response.get('blobId'):
+ raise RuntimeError("Failed to parse image info.")
+
+ result = {'bcid': response.get('blobId', ""), 'blurredBcid': response.get('processedBlobId', "")}
+ result["imageUrl"] = f"https://www.bing.com/images/blob?bcid={result['blurredBcid'] or result['bcid']}"
+
+ result['originalImageUrl'] = (
+ f"https://www.bing.com/images/blob?bcid={result['blurredBcid']}"
+ if IMAGE_CONFIG["enableFaceBlurDebug"] else
+ f"https://www.bing.com/images/blob?bcid={result['bcid']}"
)
- return data, boundary \ No newline at end of file
+ return ImageResponse(result["imageUrl"], "", result) \ No newline at end of file
diff --git a/g4f/Provider/create_images.py b/g4f/Provider/create_images.py
index f8a0442d..b8bcbde3 100644
--- a/g4f/Provider/create_images.py
+++ b/g4f/Provider/create_images.py
@@ -8,13 +8,31 @@ from ..base_provider import BaseProvider, ProviderType
system_message = """
You can generate custom images with the DALL-E 3 image generator.
-To generate a image with a prompt, do this:
+To generate an image with a prompt, do this:
<img data-prompt=\"keywords for the image\">
Don't use images with data uri. It is important to use a prompt instead.
<img data-prompt=\"image caption\">
"""
class CreateImagesProvider(BaseProvider):
+ """
+ Provider class for creating images based on text prompts.
+
+ This provider handles image creation requests embedded within message content,
+ using provided image creation functions.
+
+ Attributes:
+ provider (ProviderType): The underlying provider to handle non-image related tasks.
+ create_images (callable): A function to create images synchronously.
+ create_images_async (callable): A function to create images asynchronously.
+ system_message (str): A message that explains the image creation capability.
+ include_placeholder (bool): Flag to determine whether to include the image placeholder in the output.
+ __name__ (str): Name of the provider.
+ url (str): URL of the provider.
+ working (bool): Indicates if the provider is operational.
+ supports_stream (bool): Indicates if the provider supports streaming.
+ """
+
def __init__(
self,
provider: ProviderType,
@@ -23,6 +41,16 @@ class CreateImagesProvider(BaseProvider):
system_message: str = system_message,
include_placeholder: bool = True
) -> None:
+ """
+ Initializes the CreateImagesProvider.
+
+ Args:
+ provider (ProviderType): The underlying provider.
+ create_images (callable): Function to create images synchronously.
+ create_async (callable): Function to create images asynchronously.
+ system_message (str, optional): System message to be prefixed to messages. Defaults to a predefined message.
+ include_placeholder (bool, optional): Whether to include image placeholders in the output. Defaults to True.
+ """
self.provider = provider
self.create_images = create_images
self.create_images_async = create_async
@@ -40,6 +68,22 @@ class CreateImagesProvider(BaseProvider):
stream: bool = False,
**kwargs
) -> CreateResult:
+ """
+ Creates a completion result, processing any image creation prompts found within the messages.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process, which may contain image prompts.
+ stream (bool, optional): Indicates whether to stream the results. Defaults to False.
+ **kwargs: Additional keywordarguments for the provider.
+
+ Yields:
+ CreateResult: Yields chunks of the processed messages, including image data if applicable.
+
+ Note:
+ This method processes messages to detect image creation prompts. When such a prompt is found,
+ it calls the synchronous image creation function and includes the resulting image in the output.
+ """
messages.insert(0, {"role": "system", "content": self.system_message})
buffer = ""
for chunk in self.provider.create_completion(model, messages, stream, **kwargs):
@@ -71,6 +115,21 @@ class CreateImagesProvider(BaseProvider):
messages: Messages,
**kwargs
) -> str:
+ """
+ Asynchronously creates a response, processing any image creation prompts found within the messages.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process, which may contain image prompts.
+ **kwargs: Additional keyword arguments for the provider.
+
+ Returns:
+ str: The processed response string, including asynchronously generated image data if applicable.
+
+ Note:
+ This method processes messages to detect image creation prompts. When such a prompt is found,
+ it calls the asynchronous image creation function and includes the resulting image in the output.
+ """
messages.insert(0, {"role": "system", "content": self.system_message})
response = await self.provider.create_async(model, messages, **kwargs)
matches = re.findall(r'(<img data-prompt="(.*?)">)', response)
diff --git a/g4f/Provider/helper.py b/g4f/Provider/helper.py
index 81f417dd..fce1ee6f 100644
--- a/g4f/Provider/helper.py
+++ b/g4f/Provider/helper.py
@@ -1,36 +1,31 @@
from __future__ import annotations
import asyncio
-import webbrowser
+import os
import random
-import string
import secrets
-import os
-from os import path
+import string
from asyncio import AbstractEventLoop, BaseEventLoop
from platformdirs import user_config_dir
from browser_cookie3 import (
- chrome,
- chromium,
- opera,
- opera_gx,
- brave,
- edge,
- vivaldi,
- firefox,
- _LinuxPasswordManager
+ chrome, chromium, opera, opera_gx,
+ brave, edge, vivaldi, firefox,
+ _LinuxPasswordManager, BrowserCookieError
)
-
from ..typing import Dict, Messages
from .. import debug
-# Local Cookie Storage
+# Global variable to store cookies
_cookies: Dict[str, Dict[str, str]] = {}
-# If loop closed or not set, create new event loop.
-# If event loop is already running, handle nested event loops.
-# If "nest_asyncio" is installed, patch the event loop.
def get_event_loop() -> AbstractEventLoop:
+ """
+ Get the current asyncio event loop. If the loop is closed or not set, create a new event loop.
+ If a loop is running, handle nested event loops. Patch the loop if 'nest_asyncio' is installed.
+
+ Returns:
+ AbstractEventLoop: The current or new event loop.
+ """
try:
loop = asyncio.get_event_loop()
if isinstance(loop, BaseEventLoop):
@@ -39,61 +34,50 @@ def get_event_loop() -> AbstractEventLoop:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
- # Is running event loop
asyncio.get_running_loop()
if not hasattr(loop.__class__, "_nest_patched"):
import nest_asyncio
nest_asyncio.apply(loop)
except RuntimeError:
- # No running event loop
pass
except ImportError:
raise RuntimeError(
- 'Use "create_async" instead of "create" function in a running event loop. Or install the "nest_asyncio" package.'
+ 'Use "create_async" instead of "create" function in a running event loop. Or install "nest_asyncio" package.'
)
return loop
-def init_cookies():
- urls = [
- 'https://chat-gpt.org',
- 'https://www.aitianhu.com',
- 'https://chatgptfree.ai',
- 'https://gptchatly.com',
- 'https://bard.google.com',
- 'https://huggingface.co/chat',
- 'https://open-assistant.io/chat'
- ]
-
- browsers = ['google-chrome', 'chrome', 'firefox', 'safari']
-
- def open_urls_in_browser(browser):
- b = webbrowser.get(browser)
- for url in urls:
- b.open(url, new=0, autoraise=True)
-
- for browser in browsers:
- try:
- open_urls_in_browser(browser)
- break
- except webbrowser.Error:
- continue
-
-# Check for broken dbus address in docker image
if os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null":
_LinuxPasswordManager.get_password = lambda a, b: b"secret"
-
-# Load cookies for a domain from all supported browsers.
-# Cache the results in the "_cookies" variable.
-def get_cookies(domain_name=''):
+
+def get_cookies(domain_name: str = '') -> Dict[str, str]:
+ """
+ Load cookies for a given domain from all supported browsers and cache the results.
+
+ Args:
+ domain_name (str): The domain for which to load cookies.
+
+ Returns:
+ Dict[str, str]: A dictionary of cookie names and values.
+ """
if domain_name in _cookies:
return _cookies[domain_name]
- def g4f(domain_name):
- user_data_dir = user_config_dir("g4f")
- cookie_file = path.join(user_data_dir, "Default", "Cookies")
- return [] if not path.exists(cookie_file) else chrome(cookie_file, domain_name)
+
+ cookies = _load_cookies_from_browsers(domain_name)
+ _cookies[domain_name] = cookies
+ return cookies
+
+def _load_cookies_from_browsers(domain_name: str) -> Dict[str, str]:
+ """
+ Helper function to load cookies from various browsers.
+
+ Args:
+ domain_name (str): The domain for which to load cookies.
+ Returns:
+ Dict[str, str]: A dictionary of cookie names and values.
+ """
cookies = {}
- for cookie_fn in [g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]:
+ for cookie_fn in [_g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]:
try:
cookie_jar = cookie_fn(domain_name=domain_name)
if len(cookie_jar) and debug.logging:
@@ -101,13 +85,38 @@ def get_cookies(domain_name=''):
for cookie in cookie_jar:
if cookie.name not in cookies:
cookies[cookie.name] = cookie.value
- except:
+ except BrowserCookieError:
pass
- _cookies[domain_name] = cookies
- return _cookies[domain_name]
+ except Exception as e:
+ if debug.logging:
+ print(f"Error reading cookies from {cookie_fn.__name__} for {domain_name}: {e}")
+ return cookies
+
+def _g4f(domain_name: str) -> list:
+ """
+ Load cookies from the 'g4f' browser (if exists).
+
+ Args:
+ domain_name (str): The domain for which to load cookies.
+ Returns:
+ list: List of cookies.
+ """
+ user_data_dir = user_config_dir("g4f")
+ cookie_file = os.path.join(user_data_dir, "Default", "Cookies")
+ return [] if not os.path.exists(cookie_file) else chrome(cookie_file, domain_name)
def format_prompt(messages: Messages, add_special_tokens=False) -> str:
+ """
+ Format a series of messages into a single string, optionally adding special tokens.
+
+ Args:
+ messages (Messages): A list of message dictionaries, each containing 'role' and 'content'.
+ add_special_tokens (bool): Whether to add special formatting tokens.
+
+ Returns:
+ str: A formatted string containing all messages.
+ """
if not add_special_tokens and len(messages) <= 1:
return messages[0]["content"]
formatted = "\n".join([
@@ -116,12 +125,26 @@ def format_prompt(messages: Messages, add_special_tokens=False) -> str:
])
return f"{formatted}\nAssistant:"
-
def get_random_string(length: int = 10) -> str:
+ """
+ Generate a random string of specified length, containing lowercase letters and digits.
+
+ Args:
+ length (int, optional): Length of the random string to generate. Defaults to 10.
+
+ Returns:
+ str: A random string of the specified length.
+ """
return ''.join(
random.choice(string.ascii_lowercase + string.digits)
for _ in range(length)
)
def get_random_hex() -> str:
+ """
+ Generate a random hexadecimal string of a fixed length.
+
+ Returns:
+ str: A random hexadecimal string of 32 characters (16 bytes).
+ """
return secrets.token_hex(16).zfill(32) \ No newline at end of file
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index a790f0de..7d352a46 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -1,6 +1,9 @@
from __future__ import annotations
+import asyncio
+import uuid
+import json
+import os
-import uuid, json, asyncio, os
from py_arkose_generator.arkose import get_values_for_request
from async_property import async_cached_property
from selenium.webdriver.common.by import By
@@ -14,7 +17,8 @@ from ...typing import AsyncResult, Messages
from ...requests import StreamSession
from ...image import to_image, to_bytes, ImageType, ImageResponse
-models = {
+# Aliases for model names
+MODELS = {
"gpt-3.5": "text-davinci-002-render-sha",
"gpt-3.5-turbo": "text-davinci-002-render-sha",
"gpt-4": "gpt-4",
@@ -22,13 +26,15 @@ models = {
}
class OpenaiChat(AsyncGeneratorProvider):
- url = "https://chat.openai.com"
- working = True
- needs_auth = True
+ """A class for creating and managing conversations with OpenAI chat service"""
+
+ url = "https://chat.openai.com"
+ working = True
+ needs_auth = True
supports_gpt_35_turbo = True
- supports_gpt_4 = True
- _cookies: dict = {}
- _default_model: str = None
+ supports_gpt_4 = True
+ _cookies: dict = {}
+ _default_model: str = None
@classmethod
async def create(
@@ -43,6 +49,23 @@ class OpenaiChat(AsyncGeneratorProvider):
image: ImageType = None,
**kwargs
) -> Response:
+ """Create a new conversation or continue an existing one
+
+ Args:
+ prompt: The user input to start or continue the conversation
+ model: The name of the model to use for generating responses
+ messages: The list of previous messages in the conversation
+ history_disabled: A flag indicating if the history and training should be disabled
+ action: The type of action to perform, either "next", "continue", or "variant"
+ conversation_id: The ID of the existing conversation, if any
+ parent_id: The ID of the parent message, if any
+ image: The image to include in the user input, if any
+ **kwargs: Additional keyword arguments to pass to the generator
+
+ Returns:
+ A Response object that contains the generator, action, messages, and options
+ """
+ # Add the user input to the messages list
if prompt:
messages.append({
"role": "user",
@@ -67,20 +90,33 @@ class OpenaiChat(AsyncGeneratorProvider):
)
@classmethod
- async def upload_image(
+ async def _upload_image(
cls,
session: StreamSession,
headers: dict,
image: ImageType
) -> ImageResponse:
+ """Upload an image to the service and get the download URL
+
+ Args:
+ session: The StreamSession object to use for requests
+ headers: The headers to include in the requests
+ image: The image to upload, either a PIL Image object or a bytes object
+
+ Returns:
+ An ImageResponse object that contains the download URL, file name, and other data
+ """
+ # Convert the image to a PIL Image object and get the extension
image = to_image(image)
extension = image.format.lower()
+ # Convert the image to a bytes object and get the size
data_bytes = to_bytes(image)
data = {
"file_name": f"{image.width}x{image.height}.{extension}",
"file_size": len(data_bytes),
"use_case": "multimodal"
}
+ # Post the image data to the service and get the image data
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=headers) as response:
response.raise_for_status()
image_data = {
@@ -91,6 +127,7 @@ class OpenaiChat(AsyncGeneratorProvider):
"height": image.height,
"width": image.width
}
+ # Put the image bytes to the upload URL and check the status
async with session.put(
image_data["upload_url"],
data=data_bytes,
@@ -100,6 +137,7 @@ class OpenaiChat(AsyncGeneratorProvider):
}
) as response:
response.raise_for_status()
+ # Post the file ID to the service and get the download URL
async with session.post(
f"{cls.url}/backend-api/files/{image_data['file_id']}/uploaded",
json={},
@@ -110,24 +148,45 @@ class OpenaiChat(AsyncGeneratorProvider):
return ImageResponse(download_url, image_data["file_name"], image_data)
@classmethod
- async def get_default_model(cls, session: StreamSession, headers: dict):
+ async def _get_default_model(cls, session: StreamSession, headers: dict):
+ """Get the default model name from the service
+
+ Args:
+ session: The StreamSession object to use for requests
+ headers: The headers to include in the requests
+
+ Returns:
+ The default model name as a string
+ """
+ # Check the cache for the default model
if cls._default_model:
- model = cls._default_model
- else:
- async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response:
- data = await response.json()
- if "categories" in data:
- model = data["categories"][-1]["default_model"]
- else:
- RuntimeError(f"Response: {data}")
- cls._default_model = model
- return model
+ return cls._default_model
+ # Get the models data from the service
+ async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response:
+ data = await response.json()
+ if "categories" in data:
+ cls._default_model = data["categories"][-1]["default_model"]
+ else:
+ raise RuntimeError(f"Response: {data}")
+ return cls._default_model
@classmethod
- def create_messages(cls, prompt: str, image_response: ImageResponse = None):
+ def _create_messages(cls, prompt: str, image_response: ImageResponse = None):
+ """Create a list of messages for the user input
+
+ Args:
+ prompt: The user input as a string
+ image_response: The image response object, if any
+
+ Returns:
+ A list of messages with the user input and the image, if any
+ """
+ # Check if there is an image response
if not image_response:
+ # Create a content object with the text type and the prompt
content = {"content_type": "text", "parts": [prompt]}
else:
+ # Create a content object with the multimodal text type and the image and the prompt
content = {
"content_type": "multimodal_text",
"parts": [{
@@ -137,12 +196,15 @@ class OpenaiChat(AsyncGeneratorProvider):
"width": image_response.get("width"),
}, prompt]
}
+ # Create a message object with the user role and the content
messages = [{
"id": str(uuid.uuid4()),
"author": {"role": "user"},
"content": content,
}]
+ # Check if there is an image response
if image_response:
+ # Add the metadata object with the attachments
messages[0]["metadata"] = {
"attachments": [{
"height": image_response.get("height"),
@@ -156,19 +218,38 @@ class OpenaiChat(AsyncGeneratorProvider):
return messages
@classmethod
- async def get_image_response(cls, session: StreamSession, headers: dict, line: dict):
- if "parts" in line["message"]["content"]:
- part = line["message"]["content"]["parts"][0]
- if "asset_pointer" in part and part["metadata"]:
- file_id = part["asset_pointer"].split("file-service://", 1)[1]
- prompt = part["metadata"]["dalle"]["prompt"]
- async with session.get(
- f"{cls.url}/backend-api/files/{file_id}/download",
- headers=headers
- ) as response:
- response.raise_for_status()
- download_url = (await response.json())["download_url"]
- return ImageResponse(download_url, prompt)
+ async def _get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse:
+ """
+ Retrieves the image response based on the message content.
+
+ :param session: The StreamSession object.
+ :param headers: HTTP headers for the request.
+ :param line: The line of response containing image information.
+ :return: An ImageResponse object with the image details.
+ """
+ if "parts" not in line["message"]["content"]:
+ return
+ first_part = line["message"]["content"]["parts"][0]
+ if "asset_pointer" not in first_part or "metadata" not in first_part:
+ return
+ file_id = first_part["asset_pointer"].split("file-service://", 1)[1]
+ prompt = first_part["metadata"]["dalle"]["prompt"]
+ try:
+ async with session.get(f"{cls.url}/backend-api/files/{file_id}/download", headers=headers) as response:
+ response.raise_for_status()
+ download_url = (await response.json())["download_url"]
+ return ImageResponse(download_url, prompt)
+ except Exception as e:
+ raise RuntimeError(f"Error in downloading image: {e}")
+
+ @classmethod
+ async def _delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str):
+ async with session.patch(
+ f"{cls.url}/backend-api/conversation/{conversation_id}",
+ json={"is_visible": False},
+ headers=headers
+ ) as response:
+ response.raise_for_status()
@classmethod
async def create_async_generator(
@@ -188,26 +269,47 @@ class OpenaiChat(AsyncGeneratorProvider):
response_fields: bool = False,
**kwargs
) -> AsyncResult:
- if model in models:
- model = models[model]
+ """
+ Create an asynchronous generator for the conversation.
+
+ Args:
+ model (str): The model name.
+ messages (Messages): The list of previous messages.
+ proxy (str): Proxy to use for requests.
+ timeout (int): Timeout for requests.
+ access_token (str): Access token for authentication.
+ cookies (dict): Cookies to use for authentication.
+ auto_continue (bool): Flag to automatically continue the conversation.
+ history_disabled (bool): Flag to disable history and training.
+ action (str): Type of action ('next', 'continue', 'variant').
+ conversation_id (str): ID of the conversation.
+ parent_id (str): ID of the parent message.
+ image (ImageType): Image to include in the conversation.
+ response_fields (bool): Flag to include response fields in the output.
+ **kwargs: Additional keyword arguments.
+
+ Yields:
+ AsyncResult: Asynchronous results from the generator.
+
+ Raises:
+ RuntimeError: If an error occurs during processing.
+ """
+ model = MODELS.get(model, model)
if not parent_id:
parent_id = str(uuid.uuid4())
if not cookies:
- cookies = cls._cookies
- if not access_token:
- if not cookies:
- cls._cookies = cookies = get_cookies("chat.openai.com")
- if "access_token" in cookies:
- access_token = cookies["access_token"]
+ cookies = cls._cookies or get_cookies("chat.openai.com")
+ if not access_token and "access_token" in cookies:
+ access_token = cookies["access_token"]
if not access_token:
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
yield f"Please login: [ChatGPT]({login_url})\n\n"
- access_token, cookies = cls.browse_access_token(proxy)
+ access_token, cookies = cls._browse_access_token(proxy)
cls._cookies = cookies
- headers = {
- "Authorization": f"Bearer {access_token}",
- }
+
+ headers = {"Authorization": f"Bearer {access_token}"}
+
async with StreamSession(
proxies={"https": proxy},
impersonate="chrome110",
@@ -215,11 +317,11 @@ class OpenaiChat(AsyncGeneratorProvider):
cookies=dict([(name, value) for name, value in cookies.items() if name == "_puid"])
) as session:
if not model:
- model = await cls.get_default_model(session, headers)
+ model = await cls._get_default_model(session, headers)
try:
image_response = None
if image:
- image_response = await cls.upload_image(session, headers, image)
+ image_response = await cls._upload_image(session, headers, image)
yield image_response
except Exception as e:
yield e
@@ -227,7 +329,7 @@ class OpenaiChat(AsyncGeneratorProvider):
while not end_turn.is_end:
data = {
"action": action,
- "arkose_token": await cls.get_arkose_token(session),
+ "arkose_token": await cls._get_arkose_token(session),
"conversation_id": conversation_id,
"parent_message_id": parent_id,
"model": model,
@@ -235,7 +337,7 @@ class OpenaiChat(AsyncGeneratorProvider):
}
if action != "continue":
prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"]
- data["messages"] = cls.create_messages(prompt, image_response)
+ data["messages"] = cls._create_messages(prompt, image_response)
async with session.post(
f"{cls.url}/backend-api/conversation",
json=data,
@@ -261,62 +363,80 @@ class OpenaiChat(AsyncGeneratorProvider):
if "message_type" not in line["message"]["metadata"]:
continue
try:
- image_response = await cls.get_image_response(session, headers, line)
+ image_response = await cls._get_generated_image(session, headers, line)
if image_response:
yield image_response
except Exception as e:
yield e
if line["message"]["author"]["role"] != "assistant":
continue
- if line["message"]["metadata"]["message_type"] in ("next", "continue", "variant"):
- conversation_id = line["conversation_id"]
- parent_id = line["message"]["id"]
- if response_fields:
- response_fields = False
- yield ResponseFields(conversation_id, parent_id, end_turn)
- if "parts" in line["message"]["content"]:
- new_message = line["message"]["content"]["parts"][0]
- if len(new_message) > last_message:
- yield new_message[last_message:]
- last_message = len(new_message)
+ if line["message"]["content"]["content_type"] != "text":
+ continue
+ if line["message"]["metadata"]["message_type"] not in ("next", "continue", "variant"):
+ continue
+ conversation_id = line["conversation_id"]
+ parent_id = line["message"]["id"]
+ if response_fields:
+ response_fields = False
+ yield ResponseFields(conversation_id, parent_id, end_turn)
+ if "parts" in line["message"]["content"]:
+ new_message = line["message"]["content"]["parts"][0]
+ if len(new_message) > last_message:
+ yield new_message[last_message:]
+ last_message = len(new_message)
if "finish_details" in line["message"]["metadata"]:
if line["message"]["metadata"]["finish_details"]["type"] == "stop":
end_turn.end()
- break
except Exception as e:
- yield e
+ raise e
if not auto_continue:
break
action = "continue"
await asyncio.sleep(5)
- if history_disabled:
- async with session.patch(
- f"{cls.url}/backend-api/conversation/{conversation_id}",
- json={"is_visible": False},
- headers=headers
- ) as response:
- response.raise_for_status()
+ if history_disabled and auto_continue:
+ await cls._delete_conversation(session, headers, conversation_id)
@classmethod
- def browse_access_token(cls, proxy: str = None) -> tuple[str, dict]:
+ def _browse_access_token(cls, proxy: str = None) -> tuple[str, dict]:
+ """
+ Browse to obtain an access token.
+
+ Args:
+ proxy (str): Proxy to use for browsing.
+
+ Returns:
+ tuple[str, dict]: A tuple containing the access token and cookies.
+ """
driver = get_browser(proxy=proxy)
try:
driver.get(f"{cls.url}/")
- WebDriverWait(driver, 1200).until(
- EC.presence_of_element_located((By.ID, "prompt-textarea"))
+ WebDriverWait(driver, 1200).until(EC.presence_of_element_located((By.ID, "prompt-textarea")))
+ access_token = driver.execute_script(
+ "let session = await fetch('/api/auth/session');"
+ "let data = await session.json();"
+ "let accessToken = data['accessToken'];"
+ "let expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 24 * 7);"
+ "document.cookie = 'access_token=' + accessToken + ';expires=' + expires.toUTCString() + ';path=/';"
+ "return accessToken;"
)
- javascript = """
-access_token = (await (await fetch('/api/auth/session')).json())['accessToken'];
-expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 24 * 7); // One week
-document.cookie = 'access_token=' + access_token + ';expires=' + expires.toUTCString() + ';path=/';
-return access_token;
-"""
- return driver.execute_script(javascript), get_driver_cookies(driver)
+ return access_token, get_driver_cookies(driver)
finally:
driver.quit()
- @classmethod
- async def get_arkose_token(cls, session: StreamSession) -> str:
+ @classmethod
+ async def _get_arkose_token(cls, session: StreamSession) -> str:
+ """
+ Obtain an Arkose token for the session.
+
+ Args:
+ session (StreamSession): The session object.
+
+ Returns:
+ str: The Arkose token.
+
+ Raises:
+ RuntimeError: If unable to retrieve the token.
+ """
config = {
"pkey": "3D86FBBA-9D22-402A-B512-3420086BA6CC",
"surl": "https://tcr9i.chat.openai.com",
@@ -332,26 +452,30 @@ return access_token;
if "token" in decoded_json:
return decoded_json["token"]
raise RuntimeError(f"Response: {decoded_json}")
-
-class EndTurn():
+
+class EndTurn:
+ """
+ Class to represent the end of a conversation turn.
+ """
def __init__(self):
self.is_end = False
def end(self):
self.is_end = True
-class ResponseFields():
- def __init__(
- self,
- conversation_id: str,
- message_id: str,
- end_turn: EndTurn
- ):
+class ResponseFields:
+ """
+ Class to encapsulate response fields.
+ """
+ def __init__(self, conversation_id: str, message_id: str, end_turn: EndTurn):
self.conversation_id = conversation_id
self.message_id = message_id
self._end_turn = end_turn
class Response():
+ """
+ Class to encapsulate a response from the chat service.
+ """
def __init__(
self,
generator: AsyncResult,
@@ -360,13 +484,13 @@ class Response():
options: dict
):
self._generator = generator
- self.action: str = action
- self.is_end: bool = False
+ self.action = action
+ self.is_end = False
self._message = None
self._messages = messages
self._options = options
self._fields = None
-
+
async def generator(self):
if self._generator:
self._generator = None
@@ -384,19 +508,16 @@ class Response():
def __aiter__(self):
return self.generator()
-
+
@async_cached_property
async def message(self) -> str:
- [_ async for _ in self.generator()]
+ await self.generator()
return self._message
-
+
async def get_fields(self):
- [_ async for _ in self.generator()]
- return {
- "conversation_id": self._fields.conversation_id,
- "parent_id": self._fields.message_id,
- }
-
+ await self.generator()
+ return {"conversation_id": self._fields.conversation_id, "parent_id": self._fields.message_id}
+
async def next(self, prompt: str, **kwargs) -> Response:
return await OpenaiChat.create(
**self._options,
@@ -406,7 +527,7 @@ class Response():
**await self.get_fields(),
**kwargs
)
-
+
async def do_continue(self, **kwargs) -> Response:
fields = await self.get_fields()
if self.is_end:
@@ -418,7 +539,7 @@ class Response():
**fields,
**kwargs
)
-
+
async def variant(self, **kwargs) -> Response:
if self.action != "next":
raise RuntimeError("Can't create variant from continue or variant request.")
@@ -429,11 +550,9 @@ class Response():
**await self.get_fields(),
**kwargs
)
-
+
@async_cached_property
async def messages(self):
messages = self._messages
- messages.append({
- "role": "assistant", "content": await self.message
- })
+ messages.append({"role": "assistant", "content": await self.message})
return messages \ No newline at end of file
diff --git a/g4f/Provider/retry_provider.py b/g4f/Provider/retry_provider.py
index 4d3e77ac..9cc026fc 100644
--- a/g4f/Provider/retry_provider.py
+++ b/g4f/Provider/retry_provider.py
@@ -7,8 +7,17 @@ from ..base_provider import BaseRetryProvider
from .. import debug
from ..errors import RetryProviderError, RetryNoProviderError
-
class RetryProvider(BaseRetryProvider):
+ """
+ A provider class to handle retries for creating completions with different providers.
+
+ Attributes:
+ providers (list): A list of provider instances.
+ shuffle (bool): A flag indicating whether to shuffle providers before use.
+ exceptions (dict): A dictionary to store exceptions encountered during retries.
+ last_provider (BaseProvider): The last provider that was used.
+ """
+
def create_completion(
self,
model: str,
@@ -16,10 +25,21 @@ class RetryProvider(BaseRetryProvider):
stream: bool = False,
**kwargs
) -> CreateResult:
- if stream:
- providers = [provider for provider in self.providers if provider.supports_stream]
- else:
- providers = self.providers
+ """
+ Create a completion using available providers, with an option to stream the response.
+
+ Args:
+ model (str): The model to be used for completion.
+ messages (Messages): The messages to be used for generating completion.
+ stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
+
+ Yields:
+ CreateResult: Tokens or results from the completion.
+
+ Raises:
+ Exception: Any exception encountered during the completion process.
+ """
+ providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers
if self.shuffle:
random.shuffle(providers)
@@ -50,10 +70,23 @@ class RetryProvider(BaseRetryProvider):
messages: Messages,
**kwargs
) -> str:
+ """
+ Asynchronously create a completion using available providers.
+
+ Args:
+ model (str): The model to be used for completion.
+ messages (Messages): The messages to be used for generating completion.
+
+ Returns:
+ str: The result of the asynchronous completion.
+
+ Raises:
+ Exception: Any exception encountered during the asynchronous completion process.
+ """
providers = self.providers
if self.shuffle:
random.shuffle(providers)
-
+
self.exceptions = {}
for provider in providers:
self.last_provider = provider
@@ -66,13 +99,20 @@ class RetryProvider(BaseRetryProvider):
self.exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
-
+
self.raise_exceptions()
-
+
def raise_exceptions(self) -> None:
+ """
+ Raise a combined exception if any occurred during retries.
+
+ Raises:
+ RetryProviderError: If any provider encountered an exception.
+ RetryNoProviderError: If no provider is found.
+ """
if self.exceptions:
raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items()
]))
-
+
raise RetryNoProviderError("No provider found") \ No newline at end of file
diff --git a/g4f/__init__.py b/g4f/__init__.py
index 68f9ccf6..2b0e5b46 100644
--- a/g4f/__init__.py
+++ b/g4f/__init__.py
@@ -15,6 +15,26 @@ def get_model_and_provider(model : Union[Model, str],
ignored : list[str] = None,
ignore_working: bool = False,
ignore_stream: bool = False) -> tuple[str, ProviderType]:
+ """
+ Retrieves the model and provider based on input parameters.
+
+ Args:
+ model (Union[Model, str]): The model to use, either as an object or a string identifier.
+ provider (Union[ProviderType, str, None]): The provider to use, either as an object, a string identifier, or None.
+ stream (bool): Indicates if the operation should be performed as a stream.
+ ignored (list[str], optional): List of provider names to be ignored.
+ ignore_working (bool, optional): If True, ignores the working status of the provider.
+ ignore_stream (bool, optional): If True, ignores the streaming capability of the provider.
+
+ Returns:
+ tuple[str, ProviderType]: A tuple containing the model name and the provider type.
+
+ Raises:
+ ProviderNotFoundError: If the provider is not found.
+ ModelNotFoundError: If the model is not found.
+ ProviderNotWorkingError: If the provider is not working.
+ StreamNotSupportedError: If streaming is not supported by the provider.
+ """
if debug.version_check:
debug.version_check = False
version.utils.check_version()
@@ -70,7 +90,30 @@ class ChatCompletion:
ignore_stream_and_auth: bool = False,
patch_provider: callable = None,
**kwargs) -> Union[CreateResult, str]:
-
+ """
+ Creates a chat completion using the specified model, provider, and messages.
+
+ Args:
+ model (Union[Model, str]): The model to use, either as an object or a string identifier.
+ messages (Messages): The messages for which the completion is to be created.
+ provider (Union[ProviderType, str, None], optional): The provider to use, either as an object, a string identifier, or None.
+ stream (bool, optional): Indicates if the operation should be performed as a stream.
+ auth (Union[str, None], optional): Authentication token or credentials, if required.
+ ignored (list[str], optional): List of provider names to be ignored.
+ ignore_working (bool, optional): If True, ignores the working status of the provider.
+ ignore_stream_and_auth (bool, optional): If True, ignores the stream and authentication requirement checks.
+ patch_provider (callable, optional): Function to modify the provider.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ Union[CreateResult, str]: The result of the chat completion operation.
+
+ Raises:
+ AuthenticationRequiredError: If authentication is required but not provided.
+ ProviderNotFoundError, ModelNotFoundError: If the specified provider or model is not found.
+ ProviderNotWorkingError: If the provider is not operational.
+ StreamNotSupportedError: If streaming is requested but not supported by the provider.
+ """
model, provider = get_model_and_provider(model, provider, stream, ignored, ignore_working, ignore_stream_and_auth)
if not ignore_stream_and_auth and provider.needs_auth and not auth:
@@ -98,7 +141,24 @@ class ChatCompletion:
ignored : list[str] = None,
patch_provider: callable = None,
**kwargs) -> Union[AsyncResult, str]:
-
+ """
+ Asynchronously creates a completion using the specified model and provider.
+
+ Args:
+ model (Union[Model, str]): The model to use, either as an object or a string identifier.
+ messages (Messages): Messages to be processed.
+ provider (Union[ProviderType, str, None]): The provider to use, either as an object, a string identifier, or None.
+ stream (bool): Indicates if the operation should be performed as a stream.
+ ignored (list[str], optional): List of provider names to be ignored.
+ patch_provider (callable, optional): Function to modify the provider.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ Union[AsyncResult, str]: The result of the asynchronous chat completion operation.
+
+ Raises:
+ StreamNotSupportedError: If streaming is requested but not supported by the provider.
+ """
model, provider = get_model_and_provider(model, provider, False, ignored)
if stream:
@@ -118,7 +178,23 @@ class Completion:
provider : Union[ProviderType, None] = None,
stream : bool = False,
ignored : list[str] = None, **kwargs) -> Union[CreateResult, str]:
-
+ """
+ Creates a completion based on the provided model, prompt, and provider.
+
+ Args:
+ model (Union[Model, str]): The model to use, either as an object or a string identifier.
+ prompt (str): The prompt text for which the completion is to be created.
+ provider (Union[ProviderType, None], optional): The provider to use, either as an object or None.
+ stream (bool, optional): Indicates if the operation should be performed as a stream.
+ ignored (list[str], optional): List of provider names to be ignored.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ Union[CreateResult, str]: The result of the completion operation.
+
+ Raises:
+ ModelNotAllowedError: If the specified model is not allowed for use with this method.
+ """
allowed_models = [
'code-davinci-002',
'text-ada-001',
@@ -137,6 +213,15 @@ class Completion:
return result if stream else ''.join(result)
def get_last_provider(as_dict: bool = False) -> Union[ProviderType, dict[str, str]]:
+ """
+ Retrieves the last used provider.
+
+ Args:
+ as_dict (bool, optional): If True, returns the provider information as a dictionary.
+
+ Returns:
+ Union[ProviderType, dict[str, str]]: The last used provider, either as an object or a dictionary.
+ """
last = debug.last_provider
if isinstance(last, BaseRetryProvider):
last = last.last_provider
diff --git a/g4f/base_provider.py b/g4f/base_provider.py
index 1863f6bc..03ae64d6 100644
--- a/g4f/base_provider.py
+++ b/g4f/base_provider.py
@@ -1,7 +1,22 @@
from abc import ABC, abstractmethod
-from .typing import Messages, CreateResult, Union
-
+from typing import Union, List, Dict, Type
+from .typing import Messages, CreateResult
+
class BaseProvider(ABC):
+ """
+ Abstract base class for a provider.
+
+ Attributes:
+ url (str): URL of the provider.
+ working (bool): Indicates if the provider is currently working.
+ needs_auth (bool): Indicates if the provider needs authentication.
+ supports_stream (bool): Indicates if the provider supports streaming.
+ supports_gpt_35_turbo (bool): Indicates if the provider supports GPT-3.5 Turbo.
+ supports_gpt_4 (bool): Indicates if the provider supports GPT-4.
+ supports_message_history (bool): Indicates if the provider supports message history.
+ params (str): List parameters for the provider.
+ """
+
url: str = None
working: bool = False
needs_auth: bool = False
@@ -20,6 +35,18 @@ class BaseProvider(ABC):
stream: bool,
**kwargs
) -> CreateResult:
+ """
+ Create a completion with the given parameters.
+
+ Args:
+ model (str): The model to use.
+ messages (Messages): The messages to process.
+ stream (bool): Whether to use streaming.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ CreateResult: The result of the creation process.
+ """
raise NotImplementedError()
@classmethod
@@ -30,25 +57,59 @@ class BaseProvider(ABC):
messages: Messages,
**kwargs
) -> str:
+ """
+ Asynchronously create a completion with the given parameters.
+
+ Args:
+ model (str): The model to use.
+ messages (Messages): The messages to process.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ str: The result of the creation process.
+ """
raise NotImplementedError()
@classmethod
- def get_dict(cls):
+ def get_dict(cls) -> Dict[str, str]:
+ """
+ Get a dictionary representation of the provider.
+
+ Returns:
+ Dict[str, str]: A dictionary with provider's details.
+ """
return {'name': cls.__name__, 'url': cls.url}
class BaseRetryProvider(BaseProvider):
+ """
+ Base class for a provider that implements retry logic.
+
+ Attributes:
+ providers (List[Type[BaseProvider]]): List of providers to use for retries.
+ shuffle (bool): Whether to shuffle the providers list.
+ exceptions (Dict[str, Exception]): Dictionary of exceptions encountered.
+ last_provider (Type[BaseProvider]): The last provider used.
+ """
+
__name__: str = "RetryProvider"
supports_stream: bool = True
def __init__(
self,
- providers: list[type[BaseProvider]],
+ providers: List[Type[BaseProvider]],
shuffle: bool = True
) -> None:
- self.providers: list[type[BaseProvider]] = providers
- self.shuffle: bool = shuffle
- self.working: bool = True
- self.exceptions: dict[str, Exception] = {}
- self.last_provider: type[BaseProvider] = None
+ """
+ Initialize the BaseRetryProvider.
+
+ Args:
+ providers (List[Type[BaseProvider]]): List of providers to use.
+ shuffle (bool): Whether to shuffle the providers list.
+ """
+ self.providers = providers
+ self.shuffle = shuffle
+ self.working = True
+ self.exceptions: Dict[str, Exception] = {}
+ self.last_provider: Type[BaseProvider] = None
-ProviderType = Union[type[BaseProvider], BaseRetryProvider] \ No newline at end of file
+ProviderType = Union[Type[BaseProvider], BaseRetryProvider] \ No newline at end of file
diff --git a/g4f/gui/client/css/style.css b/g4f/gui/client/css/style.css
index 2d4c9857..e77410ab 100644
--- a/g4f/gui/client/css/style.css
+++ b/g4f/gui/client/css/style.css
@@ -404,7 +404,7 @@ body {
display: none;
}
-#image {
+#image, #file {
display: none;
}
@@ -412,13 +412,22 @@ label[for="image"]:has(> input:valid){
color: var(--accent);
}
-label[for="image"] {
+label[for="file"]:has(> input:valid){
+ color: var(--accent);
+}
+
+label[for="image"], label[for="file"] {
cursor: pointer;
position: absolute;
top: 10px;
left: 10px;
}
+label[for="file"] {
+ top: 32px;
+ left: 10px;
+}
+
.buttons input[type="checkbox"] {
height: 0;
width: 0;
diff --git a/g4f/gui/client/html/index.html b/g4f/gui/client/html/index.html
index 3f2bb0c0..95489ba4 100644
--- a/g4f/gui/client/html/index.html
+++ b/g4f/gui/client/html/index.html
@@ -118,6 +118,10 @@
<input type="file" id="image" name="image" accept="image/png, image/gif, image/jpeg" required/>
<i class="fa-regular fa-image"></i>
</label>
+ <label for="file">
+ <input type="file" id="file" name="file" accept="text/plain, text/html, text/xml, application/json, text/javascript, .sh, .py, .php, .css, .yaml, .sql, .svg, .log, .csv, .twig, .md" required/>
+ <i class="fa-solid fa-paperclip"></i>
+ </label>
<div id="send-button">
<i class="fa-solid fa-paper-plane-top"></i>
</div>
@@ -125,7 +129,14 @@
</div>
<div class="buttons">
<div class="field">
- <select name="model" id="model"></select>
+ <select name="model" id="model">
+ <option value="">Model: Default</option>
+ <option value="gpt-4">gpt-4</option>
+ <option value="gpt-3.5-turbo">gpt-3.5-turbo</option>
+ <option value="llama2-70b">llama2-70b</option>
+ <option value="gemini-pro">gemini-pro</option>
+ <option value="">----</option>
+ </select>
</div>
<div class="field">
<select name="jailbreak" id="jailbreak" style="display: none;">
@@ -138,7 +149,16 @@
<option value="gpt-evil-1.0">evil 1.0</option>
</select>
<div class="field">
- <select name="provider" id="provider"></select>
+ <select name="provider" id="provider">
+ <option value="">Provider: Auto</option>
+ <option value="Bing">Bing</option>
+ <option value="OpenaiChat">OpenaiChat</option>
+ <option value="HuggingChat">HuggingChat</option>
+ <option value="Bard">Bard</option>
+ <option value="Liaobots">Liaobots</option>
+ <option value="Phind">Phind</option>
+ <option value="">----</option>
+ </select>
</div>
</div>
<div class="field">
diff --git a/g4f/gui/client/js/chat.v1.js b/g4f/gui/client/js/chat.v1.js
index ccc9461b..8b9bc181 100644
--- a/g4f/gui/client/js/chat.v1.js
+++ b/g4f/gui/client/js/chat.v1.js
@@ -7,7 +7,9 @@ const spinner = box_conversations.querySelector(".spinner");
const stop_generating = document.querySelector(`.stop_generating`);
const regenerate = document.querySelector(`.regenerate`);
const send_button = document.querySelector(`#send-button`);
-const imageInput = document.querySelector('#image') ;
+const imageInput = document.querySelector('#image');
+const fileInput = document.querySelector('#file');
+
let prompt_lock = false;
hljs.addPlugin(new CopyButtonPlugin());
@@ -42,6 +44,11 @@ const handle_ask = async () => {
if (message.length > 0) {
message_input.value = '';
await add_conversation(window.conversation_id, message);
+ if ("text" in fileInput.dataset) {
+ message += '\n```' + fileInput.dataset.type + '\n';
+ message += fileInput.dataset.text;
+ message += '\n```'
+ }
await add_message(window.conversation_id, "user", message);
window.token = message_id();
message_box.innerHTML += `
@@ -55,6 +62,9 @@ const handle_ask = async () => {
</div>
</div>
`;
+ document.querySelectorAll('code:not(.hljs').forEach((el) => {
+ hljs.highlightElement(el);
+ });
await ask_gpt();
}
};
@@ -171,17 +181,30 @@ const ask_gpt = async () => {
content_inner.innerHTML += "<p>An error occured, please try again, if the problem persists, please use a other model or provider.</p>";
} else {
html = markdown_render(text);
- html = html.substring(0, html.lastIndexOf('</p>')) + '<span id="cursor"></span></p>';
+ let lastElement, lastIndex = null;
+ for (element of ['</p>', '</code></pre>', '</li>\n</ol>']) {
+ const index = html.lastIndexOf(element)
+ if (index > lastIndex) {
+ lastElement = element;
+ lastIndex = index;
+ }
+ }
+ if (lastIndex) {
+ html = html.substring(0, lastIndex) + '<span id="cursor"></span>' + lastElement;
+ }
content_inner.innerHTML = html;
- document.querySelectorAll('code').forEach((el) => {
+ document.querySelectorAll('code:not(.hljs').forEach((el) => {
hljs.highlightElement(el);
});
}
window.scrollTo(0, 0);
- message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
+ if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 100) {
+ message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
+ }
}
if (!error && imageInput) imageInput.value = "";
+ if (!error && fileInput) fileInput.value = "";
} catch (e) {
console.error(e);
@@ -305,7 +328,7 @@ const load_conversation = async (conversation_id) => {
`;
}
- document.querySelectorAll(`code`).forEach((el) => {
+ document.querySelectorAll('code:not(.hljs').forEach((el) => {
hljs.highlightElement(el);
});
@@ -400,7 +423,7 @@ const load_conversations = async (limit, offset, loader) => {
`;
}
- document.querySelectorAll(`code`).forEach((el) => {
+ document.querySelectorAll('code:not(.hljs').forEach((el) => {
hljs.highlightElement(el);
});
};
@@ -602,14 +625,7 @@ observer.observe(message_input, { attributes: true });
(async () => {
response = await fetch('/backend-api/v2/models')
models = await response.json()
-
let select = document.getElementById('model');
- select.textContent = '';
-
- let auto = document.createElement('option');
- auto.value = '';
- auto.text = 'Model: Default';
- select.appendChild(auto);
for (model of models) {
let option = document.createElement('option');
@@ -619,14 +635,7 @@ observer.observe(message_input, { attributes: true });
response = await fetch('/backend-api/v2/providers')
providers = await response.json()
-
select = document.getElementById('provider');
- select.textContent = '';
-
- auto = document.createElement('option');
- auto.value = '';
- auto.text = 'Provider: Auto';
- select.appendChild(auto);
for (provider of providers) {
let option = document.createElement('option');
@@ -643,11 +652,34 @@ observer.observe(message_input, { attributes: true });
document.title = 'g4f - gui - ' + versions["version"];
text = "version ~ "
- if (versions["version"] != versions["lastet_version"]) {
- release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["lastet_version"];
- text += '<a href="' + release_url +'" target="_blank" title="New version: ' + versions["lastet_version"] +'">' + versions["version"] + ' 🆕</a>';
+ if (versions["version"] != versions["latest_version"]) {
+ release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["latest_version"];
+ text += '<a href="' + release_url +'" target="_blank" title="New version: ' + versions["latest_version"] +'">' + versions["version"] + ' 🆕</a>';
} else {
text += versions["version"];
}
document.getElementById("version_text").innerHTML = text
-})() \ No newline at end of file
+})()
+
+fileInput.addEventListener('change', async (event) => {
+ if (fileInput.files.length) {
+ type = fileInput.files[0].type;
+ if (type && type.indexOf('/')) {
+ type = type.split('/').pop().replace('x-', '')
+ type = type.replace('plain', 'plaintext')
+ .replace('shellscript', 'sh')
+ .replace('svg+xml', 'svg')
+ .replace('vnd.trolltech.linguist', 'ts')
+ } else {
+ type = fileInput.files[0].name.split('.').pop()
+ }
+ fileInput.dataset.type = type
+ const reader = new FileReader();
+ reader.addEventListener('load', (event) => {
+ fileInput.dataset.text = event.target.result;
+ });
+ reader.readAsText(fileInput.files[0]);
+ } else {
+ delete fileInput.dataset.text;
+ }
+}); \ No newline at end of file
diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py
index 9d12bea5..4a5cafa8 100644
--- a/g4f/gui/server/backend.py
+++ b/g4f/gui/server/backend.py
@@ -1,6 +1,7 @@
import logging
import json
from flask import request, Flask
+from typing import Generator
from g4f import debug, version, models
from g4f import _all_models, get_last_provider, ChatCompletion
from g4f.image import is_allowed_extension, to_image
@@ -11,60 +12,123 @@ from .internet import get_search_message
debug.logging = True
class Backend_Api:
+ """
+ Handles various endpoints in a Flask application for backend operations.
+
+ This class provides methods to interact with models, providers, and to handle
+ various functionalities like conversations, error handling, and version management.
+
+ Attributes:
+ app (Flask): A Flask application instance.
+ routes (dict): A dictionary mapping API endpoints to their respective handlers.
+ """
def __init__(self, app: Flask) -> None:
+ """
+ Initialize the backend API with the given Flask application.
+
+ Args:
+ app (Flask): Flask application instance to attach routes to.
+ """
self.app: Flask = app
self.routes = {
'/backend-api/v2/models': {
- 'function': self.models,
- 'methods' : ['GET']
+ 'function': self.get_models,
+ 'methods': ['GET']
},
'/backend-api/v2/providers': {
- 'function': self.providers,
- 'methods' : ['GET']
+ 'function': self.get_providers,
+ 'methods': ['GET']
},
'/backend-api/v2/version': {
- 'function': self.version,
- 'methods' : ['GET']
+ 'function': self.get_version,
+ 'methods': ['GET']
},
'/backend-api/v2/conversation': {
- 'function': self._conversation,
+ 'function': self.handle_conversation,
'methods': ['POST']
},
'/backend-api/v2/gen.set.summarize:title': {
- 'function': self._gen_title,
+ 'function': self.generate_title,
'methods': ['POST']
},
'/backend-api/v2/error': {
- 'function': self.error,
+ 'function': self.handle_error,
'methods': ['POST']
}
}
- def error(self):
+ def handle_error(self):
+ """
+ Initialize the backend API with the given Flask application.
+
+ Args:
+ app (Flask): Flask application instance to attach routes to.
+ """
print(request.json)
-
return 'ok', 200
- def models(self):
+ def get_models(self):
+ """
+ Return a list of all models.
+
+ Fetches and returns a list of all available models in the system.
+
+ Returns:
+ List[str]: A list of model names.
+ """
return _all_models
- def providers(self):
- return [
- provider.__name__ for provider in __providers__ if provider.working
- ]
+ def get_providers(self):
+ """
+ Return a list of all working providers.
+ """
+ return [provider.__name__ for provider in __providers__ if provider.working]
- def version(self):
+ def get_version(self):
+ """
+ Returns the current and latest version of the application.
+
+ Returns:
+ dict: A dictionary containing the current and latest version.
+ """
return {
"version": version.utils.current_version,
- "lastet_version": version.get_latest_version(),
+ "latest_version": version.get_latest_version(),
}
- def _gen_title(self):
- return {
- 'title': ''
- }
+ def generate_title(self):
+ """
+ Generates and returns a title based on the request data.
+
+ Returns:
+ dict: A dictionary with the generated title.
+ """
+ return {'title': ''}
- def _conversation(self):
+ def handle_conversation(self):
+ """
+ Handles conversation requests and streams responses back.
+
+ Returns:
+ Response: A Flask response object for streaming.
+ """
+ kwargs = self._prepare_conversation_kwargs()
+
+ return self.app.response_class(
+ self._create_response_stream(kwargs),
+ mimetype='text/event-stream'
+ )
+
+ def _prepare_conversation_kwargs(self):
+ """
+ Prepares arguments for chat completion based on the request data.
+
+ Reads the request and prepares the necessary arguments for handling
+ a chat completion request.
+
+ Returns:
+ dict: Arguments prepared for chat completion.
+ """
kwargs = {}
if 'image' in request.files:
file = request.files['image']
@@ -87,47 +151,70 @@ class Backend_Api:
messages[-1]["content"] = get_search_message(messages[-1]["content"])
model = json_data.get('model')
model = model if model else models.default
- provider = json_data.get('provider', '').replace('g4f.Provider.', '')
- provider = provider if provider and provider != "Auto" else None
patch = patch_provider if json_data.get('patch_provider') else None
- def try_response():
- try:
- first = True
- for chunk in ChatCompletion.create(
- model=model,
- provider=provider,
- messages=messages,
- stream=True,
- ignore_stream_and_auth=True,
- patch_provider=patch,
- **kwargs
- ):
- if first:
- first = False
- yield json.dumps({
- 'type' : 'provider',
- 'provider': get_last_provider(True)
- }) + "\n"
- if isinstance(chunk, Exception):
- logging.exception(chunk)
- yield json.dumps({
- 'type' : 'message',
- 'message': get_error_message(chunk),
- }) + "\n"
- else:
- yield json.dumps({
- 'type' : 'content',
- 'content': str(chunk),
- }) + "\n"
- except Exception as e:
- logging.exception(e)
- yield json.dumps({
- 'type' : 'error',
- 'error': get_error_message(e)
- })
-
- return self.app.response_class(try_response(), mimetype='text/event-stream')
+ return {
+ "model": model,
+ "provider": provider,
+ "messages": messages,
+ "stream": True,
+ "ignore_stream_and_auth": True,
+ "patch_provider": patch,
+ **kwargs
+ }
+
+ def _create_response_stream(self, kwargs) -> Generator[str, None, None]:
+ """
+ Creates and returns a streaming response for the conversation.
+
+ Args:
+ kwargs (dict): Arguments for creating the chat completion.
+
+ Yields:
+ str: JSON formatted response chunks for the stream.
+
+ Raises:
+ Exception: If an error occurs during the streaming process.
+ """
+ try:
+ first = True
+ for chunk in ChatCompletion.create(**kwargs):
+ if first:
+ first = False
+ yield self._format_json('provider', get_last_provider(True))
+ if isinstance(chunk, Exception):
+ logging.exception(chunk)
+ yield self._format_json('message', get_error_message(chunk))
+ else:
+ yield self._format_json('content', str(chunk))
+ except Exception as e:
+ logging.exception(e)
+ yield self._format_json('error', get_error_message(e))
+
+ def _format_json(self, response_type: str, content) -> str:
+ """
+ Formats and returns a JSON response.
+
+ Args:
+ response_type (str): The type of the response.
+ content: The content to be included in the response.
+
+ Returns:
+ str: A JSON formatted string.
+ """
+ return json.dumps({
+ 'type': response_type,
+ response_type: content
+ }) + "\n"
def get_error_message(exception: Exception) -> str:
+ """
+ Generates a formatted error message from an exception.
+
+ Args:
+ exception (Exception): The exception to format.
+
+ Returns:
+ str: A formatted error message string.
+ """
return f"{get_last_provider().__name__}: {type(exception).__name__}: {exception}" \ No newline at end of file
diff --git a/g4f/image.py b/g4f/image.py
index 01664f4e..cfa22ab1 100644
--- a/g4f/image.py
+++ b/g4f/image.py
@@ -4,9 +4,18 @@ import base64
from .typing import ImageType, Union
from PIL import Image
-ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
+ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp'}
def to_image(image: ImageType) -> Image.Image:
+ """
+ Converts the input image to a PIL Image object.
+
+ Args:
+ image (Union[str, bytes, Image.Image]): The input image.
+
+ Returns:
+ Image.Image: The converted PIL Image object.
+ """
if isinstance(image, str):
is_data_uri_an_image(image)
image = extract_data_uri(image)
@@ -20,21 +29,48 @@ def to_image(image: ImageType) -> Image.Image:
image = copy
return image
-def is_allowed_extension(filename) -> bool:
+def is_allowed_extension(filename: str) -> bool:
+ """
+ Checks if the given filename has an allowed extension.
+
+ Args:
+ filename (str): The filename to check.
+
+ Returns:
+ bool: True if the extension is allowed, False otherwise.
+ """
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def is_data_uri_an_image(data_uri: str) -> bool:
+ """
+ Checks if the given data URI represents an image.
+
+ Args:
+ data_uri (str): The data URI to check.
+
+ Raises:
+ ValueError: If the data URI is invalid or the image format is not allowed.
+ """
# Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
if not re.match(r'data:image/(\w+);base64,', data_uri):
raise ValueError("Invalid data URI image.")
- # Extract the image format from the data URI
+ # Extract the image format from the data URI
image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1)
# Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
if image_format.lower() not in ALLOWED_EXTENSIONS:
raise ValueError("Invalid image format (from mime file type).")
def is_accepted_format(binary_data: bytes) -> bool:
+ """
+ Checks if the given binary data represents an image with an accepted format.
+
+ Args:
+ binary_data (bytes): The binary data to check.
+
+ Raises:
+ ValueError: If the image format is not allowed.
+ """
if binary_data.startswith(b'\xFF\xD8\xFF'):
pass # It's a JPEG image
elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
@@ -49,13 +85,31 @@ def is_accepted_format(binary_data: bytes) -> bool:
pass # It's a WebP image
else:
raise ValueError("Invalid image format (from magic code).")
-
+
def extract_data_uri(data_uri: str) -> bytes:
+ """
+ Extracts the binary data from the given data URI.
+
+ Args:
+ data_uri (str): The data URI.
+
+ Returns:
+ bytes: The extracted binary data.
+ """
data = data_uri.split(",")[1]
data = base64.b64decode(data)
return data
def get_orientation(image: Image.Image) -> int:
+ """
+ Gets the orientation of the given image.
+
+ Args:
+ image (Image.Image): The image.
+
+ Returns:
+ int: The orientation value.
+ """
exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif()
if exif_data is not None:
orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
@@ -63,6 +117,17 @@ def get_orientation(image: Image.Image) -> int:
return orientation
def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Image:
+ """
+ Processes the given image by adjusting its orientation and resizing it.
+
+ Args:
+ img (Image.Image): The image to process.
+ new_width (int): The new width of the image.
+ new_height (int): The new height of the image.
+
+ Returns:
+ Image.Image: The processed image.
+ """
orientation = get_orientation(img)
if orientation:
if orientation > 4:
@@ -75,13 +140,34 @@ def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Im
img = img.transpose(Image.ROTATE_90)
img.thumbnail((new_width, new_height))
return img
-
+
def to_base64(image: Image.Image, compression_rate: float) -> str:
+ """
+ Converts the given image to a base64-encoded string.
+
+ Args:
+ image (Image.Image): The image to convert.
+ compression_rate (float): The compression rate (0.0 to 1.0).
+
+ Returns:
+ str: The base64-encoded image.
+ """
output_buffer = BytesIO()
image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
return base64.b64encode(output_buffer.getvalue()).decode()
def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=200") -> str:
+ """
+ Formats the given images as a markdown string.
+
+ Args:
+ images: The images to format.
+ prompt (str): The prompt for the images.
+ preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200".
+
+ Returns:
+ str: The formatted markdown string.
+ """
if isinstance(images, list):
images = [f"[![#{idx+1} {prompt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)]
images = "\n".join(images)
@@ -92,6 +178,15 @@ def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=20
return f"\n{start_flag}{images}\n{end_flag}\n"
def to_bytes(image: Image.Image) -> bytes:
+ """
+ Converts the given image to bytes.
+
+ Args:
+ image (Image.Image): The image to convert.
+
+ Returns:
+ bytes: The image as bytes.
+ """
bytes_io = BytesIO()
image.save(bytes_io, image.format)
image.seek(0)
diff --git a/g4f/models.py b/g4f/models.py
index 03deebf8..dd6e0a2c 100644
--- a/g4f/models.py
+++ b/g4f/models.py
@@ -31,12 +31,21 @@ from .Provider import (
@dataclass(unsafe_hash=True)
class Model:
+ """
+ Represents a machine learning model configuration.
+
+ Attributes:
+ name (str): Name of the model.
+ base_provider (str): Default provider for the model.
+ best_provider (ProviderType): The preferred provider for the model, typically with retry logic.
+ """
name: str
base_provider: str
best_provider: ProviderType = None
@staticmethod
def __all__() -> list[str]:
+ """Returns a list of all model names."""
return _all_models
default = Model(
@@ -298,6 +307,12 @@ pi = Model(
)
class ModelUtils:
+ """
+ Utility class for mapping string identifiers to Model instances.
+
+ Attributes:
+ convert (dict[str, Model]): Dictionary mapping model string identifiers to Model instances.
+ """
convert: dict[str, Model] = {
# gpt-3.5
'gpt-3.5-turbo' : gpt_35_turbo,
diff --git a/g4f/requests.py b/g4f/requests.py
index 1a13dec9..466d5a2a 100644
--- a/g4f/requests.py
+++ b/g4f/requests.py
@@ -1,7 +1,6 @@
from __future__ import annotations
import json
-from contextlib import asynccontextmanager
from functools import partialmethod
from typing import AsyncGenerator
from urllib.parse import urlparse
@@ -9,27 +8,41 @@ from curl_cffi.requests import AsyncSession, Session, Response
from .webdriver import WebDriver, WebDriverSession, bypass_cloudflare, get_driver_cookies
class StreamResponse:
+ """
+ A wrapper class for handling asynchronous streaming responses.
+
+ Attributes:
+ inner (Response): The original Response object.
+ """
+
def __init__(self, inner: Response) -> None:
+ """Initialize the StreamResponse with the provided Response object."""
self.inner: Response = inner
async def text(self) -> str:
+ """Asynchronously get the response text."""
return await self.inner.atext()
def raise_for_status(self) -> None:
+ """Raise an HTTPError if one occurred."""
self.inner.raise_for_status()
async def json(self, **kwargs) -> dict:
+ """Asynchronously parse the JSON response content."""
return json.loads(await self.inner.acontent(), **kwargs)
async def iter_lines(self) -> AsyncGenerator[bytes, None]:
+ """Asynchronously iterate over the lines of the response."""
async for line in self.inner.aiter_lines():
yield line
async def iter_content(self) -> AsyncGenerator[bytes, None]:
+ """Asynchronously iterate over the response content."""
async for chunk in self.inner.aiter_content():
yield chunk
-
+
async def __aenter__(self):
+ """Asynchronously enter the runtime context for the response object."""
inner: Response = await self.inner
self.inner = inner
self.request = inner.request
@@ -39,24 +52,47 @@ class StreamResponse:
self.headers = inner.headers
self.cookies = inner.cookies
return self
-
+
async def __aexit__(self, *args):
+ """Asynchronously exit the runtime context for the response object."""
await self.inner.aclose()
+
class StreamSession(AsyncSession):
+ """
+ An asynchronous session class for handling HTTP requests with streaming.
+
+ Inherits from AsyncSession.
+ """
+
def request(
self, method: str, url: str, **kwargs
) -> StreamResponse:
+ """Create and return a StreamResponse object for the given HTTP request."""
return StreamResponse(super().request(method, url, stream=True, **kwargs))
+ # Defining HTTP methods as partial methods of the request method.
head = partialmethod(request, "HEAD")
get = partialmethod(request, "GET")
post = partialmethod(request, "POST")
put = partialmethod(request, "PUT")
patch = partialmethod(request, "PATCH")
delete = partialmethod(request, "DELETE")
-
-def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120):
+
+
+def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120) -> Session:
+ """
+ Create a Session object using a WebDriver to handle cookies and headers.
+
+ Args:
+ url (str): The URL to navigate to using the WebDriver.
+ webdriver (WebDriver, optional): The WebDriver instance to use.
+ proxy (str, optional): Proxy server to use for the Session.
+ timeout (int, optional): Timeout in seconds for the WebDriver.
+
+ Returns:
+ Session: A Session object configured with cookies and headers from the WebDriver.
+ """
with WebDriverSession(webdriver, "", proxy=proxy, virtual_display=True) as driver:
bypass_cloudflare(driver, url, timeout)
cookies = get_driver_cookies(driver)
@@ -78,4 +114,4 @@ def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str =
proxies={"https": proxy, "http": proxy},
timeout=timeout,
impersonate="chrome110"
- )
+ ) \ No newline at end of file
diff --git a/g4f/version.py b/g4f/version.py
index bb4b7f17..9201c75c 100644
--- a/g4f/version.py
+++ b/g4f/version.py
@@ -5,45 +5,120 @@ from importlib.metadata import version as get_package_version, PackageNotFoundEr
from subprocess import check_output, CalledProcessError, PIPE
from .errors import VersionNotFoundError
+def get_pypi_version(package_name: str) -> str:
+ """
+ Retrieves the latest version of a package from PyPI.
+
+ Args:
+ package_name (str): The name of the package for which to retrieve the version.
+
+ Returns:
+ str: The latest version of the specified package from PyPI.
+
+ Raises:
+ VersionNotFoundError: If there is an error in fetching the version from PyPI.
+ """
+ try:
+ response = requests.get(f"https://pypi.org/pypi/{package_name}/json").json()
+ return response["info"]["version"]
+ except requests.RequestException as e:
+ raise VersionNotFoundError(f"Failed to get PyPI version: {e}")
+
+def get_github_version(repo: str) -> str:
+ """
+ Retrieves the latest release version from a GitHub repository.
+
+ Args:
+ repo (str): The name of the GitHub repository.
+
+ Returns:
+ str: The latest release version from the specified GitHub repository.
+
+ Raises:
+ VersionNotFoundError: If there is an error in fetching the version from GitHub.
+ """
+ try:
+ response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest").json()
+ return response["tag_name"]
+ except requests.RequestException as e:
+ raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")
+
def get_latest_version() -> str:
+ """
+ Retrieves the latest release version of the 'g4f' package from PyPI or GitHub.
+
+ Returns:
+ str: The latest release version of 'g4f'.
+
+ Note:
+ The function first tries to fetch the version from PyPI. If the package is not found,
+ it retrieves the version from the GitHub repository.
+ """
try:
+ # Is installed via package manager?
get_package_version("g4f")
- response = requests.get("https://pypi.org/pypi/g4f/json").json()
- return response["info"]["version"]
+ return get_pypi_version("g4f")
except PackageNotFoundError:
- url = "https://api.github.com/repos/xtekky/gpt4free/releases/latest"
- response = requests.get(url).json()
- return response["tag_name"]
+ # Else use Github version:
+ return get_github_version("xtekky/gpt4free")
-class VersionUtils():
+class VersionUtils:
+ """
+ Utility class for managing and comparing package versions of 'g4f'.
+ """
@cached_property
def current_version(self) -> str:
+ """
+ Retrieves the current version of the 'g4f' package.
+
+ Returns:
+ str: The current version of 'g4f'.
+
+ Raises:
+ VersionNotFoundError: If the version cannot be determined from the package manager,
+ Docker environment, or git repository.
+ """
# Read from package manager
try:
return get_package_version("g4f")
except PackageNotFoundError:
pass
+
# Read from docker environment
version = environ.get("G4F_VERSION")
if version:
return version
+
# Read from git repository
try:
command = ["git", "describe", "--tags", "--abbrev=0"]
return check_output(command, text=True, stderr=PIPE).strip()
except CalledProcessError:
pass
+
raise VersionNotFoundError("Version not found")
-
+
@cached_property
def latest_version(self) -> str:
+ """
+ Retrieves the latest version of the 'g4f' package.
+
+ Returns:
+ str: The latest version of 'g4f'.
+ """
return get_latest_version()
-
+
def check_version(self) -> None:
+ """
+ Checks if the current version of 'g4f' is up to date with the latest version.
+
+ Note:
+ If a newer version is available, it prints a message with the new version and update instructions.
+ """
try:
if self.current_version != self.latest_version:
print(f'New g4f version: {self.latest_version} (current: {self.current_version}) | pip install -U g4f')
except Exception as e:
print(f'Failed to check g4f version: {e}')
-
+
utils = VersionUtils() \ No newline at end of file
diff --git a/g4f/webdriver.py b/g4f/webdriver.py
index da283409..9a83215f 100644
--- a/g4f/webdriver.py
+++ b/g4f/webdriver.py
@@ -1,5 +1,4 @@
from __future__ import annotations
-
from platformdirs import user_config_dir
from selenium.webdriver.remote.webdriver import WebDriver
from undetected_chromedriver import Chrome, ChromeOptions
@@ -21,7 +20,19 @@ def get_browser(
proxy: str = None,
options: ChromeOptions = None
) -> WebDriver:
- if user_data_dir == None:
+ """
+ Creates and returns a Chrome WebDriver with specified options.
+
+ Args:
+ user_data_dir (str, optional): Directory for user data. If None, uses default directory.
+ headless (bool, optional): Whether to run the browser in headless mode. Defaults to False.
+ proxy (str, optional): Proxy settings for the browser. Defaults to None.
+ options (ChromeOptions, optional): ChromeOptions object with specific browser options. Defaults to None.
+
+ Returns:
+ WebDriver: An instance of WebDriver configured with the specified options.
+ """
+ if user_data_dir is None:
user_data_dir = user_config_dir("g4f")
if user_data_dir and debug.logging:
print("Open browser with config dir:", user_data_dir)
@@ -39,36 +50,53 @@ def get_browser(
headless=headless
)
-def get_driver_cookies(driver: WebDriver):
- return dict([(cookie["name"], cookie["value"]) for cookie in driver.get_cookies()])
+def get_driver_cookies(driver: WebDriver) -> dict:
+ """
+ Retrieves cookies from the specified WebDriver.
+
+ Args:
+ driver (WebDriver): The WebDriver instance from which to retrieve cookies.
+
+ Returns:
+ dict: A dictionary containing cookies with their names as keys and values as cookie values.
+ """
+ return {cookie["name"]: cookie["value"] for cookie in driver.get_cookies()}
def bypass_cloudflare(driver: WebDriver, url: str, timeout: int) -> None:
- # Open website
+ """
+ Attempts to bypass Cloudflare protection when accessing a URL using the provided WebDriver.
+
+ Args:
+ driver (WebDriver): The WebDriver to use for accessing the URL.
+ url (str): The URL to access.
+ timeout (int): Time in seconds to wait for the page to load.
+
+ Raises:
+ Exception: If there is an error while bypassing Cloudflare or loading the page.
+ """
driver.get(url)
- # Is cloudflare protection
if driver.find_element(By.TAG_NAME, "body").get_attribute("class") == "no-js":
if debug.logging:
print("Cloudflare protection detected:", url)
try:
- # Click button in iframe
- WebDriverWait(driver, 5).until(
- EC.presence_of_element_located((By.CSS_SELECTOR, "#turnstile-wrapper iframe"))
- )
driver.switch_to.frame(driver.find_element(By.CSS_SELECTOR, "#turnstile-wrapper iframe"))
WebDriverWait(driver, 5).until(
EC.presence_of_element_located((By.CSS_SELECTOR, "#challenge-stage input"))
- )
- driver.find_element(By.CSS_SELECTOR, "#challenge-stage input").click()
- except:
- pass
+ ).click()
+ except Exception as e:
+ if debug.logging:
+ print(f"Error bypassing Cloudflare: {e}")
finally:
driver.switch_to.default_content()
- # No cloudflare protection
WebDriverWait(driver, timeout).until(
EC.presence_of_element_located((By.CSS_SELECTOR, "body:not(.no-js)"))
)
-class WebDriverSession():
+class WebDriverSession:
+ """
+ Manages a Selenium WebDriver session, including handling of virtual displays and proxies.
+ """
+
def __init__(
self,
webdriver: WebDriver = None,
@@ -78,12 +106,21 @@ class WebDriverSession():
proxy: str = None,
options: ChromeOptions = None
):
+ """
+ Initializes a new instance of the WebDriverSession.
+
+ Args:
+ webdriver (WebDriver, optional): A WebDriver instance for the session. Defaults to None.
+ user_data_dir (str, optional): Directory for user data. Defaults to None.
+ headless (bool, optional): Whether to run the browser in headless mode. Defaults to False.
+ virtual_display (bool, optional): Whether to use a virtual display. Defaults to False.
+ proxy (str, optional): Proxy settings for the browser. Defaults to None.
+ options (ChromeOptions, optional): ChromeOptions for the browser. Defaults to None.
+ """
self.webdriver = webdriver
self.user_data_dir = user_data_dir
self.headless = headless
- self.virtual_display = None
- if has_pyvirtualdisplay and virtual_display:
- self.virtual_display = Display(size=(1920, 1080))
+ self.virtual_display = Display(size=(1920, 1080)) if has_pyvirtualdisplay and virtual_display else None
self.proxy = proxy
self.options = options
self.default_driver = None
@@ -94,8 +131,18 @@ class WebDriverSession():
headless: bool = False,
virtual_display: bool = False
) -> WebDriver:
- if user_data_dir == None:
- user_data_dir = self.user_data_dir
+ """
+ Reopens the WebDriver session with new settings.
+
+ Args:
+ user_data_dir (str, optional): Directory for user data. Defaults to current value.
+ headless (bool, optional): Whether to run the browser in headless mode. Defaults to current value.
+ virtual_display (bool, optional): Whether to use a virtual display. Defaults to current value.
+
+ Returns:
+ WebDriver: The reopened WebDriver instance.
+ """
+ user_data_dir = user_data_data_dir or self.user_data_dir
if self.default_driver:
self.default_driver.quit()
if not virtual_display and self.virtual_display:
@@ -105,6 +152,12 @@ class WebDriverSession():
return self.default_driver
def __enter__(self) -> WebDriver:
+ """
+ Context management method for entering a session. Initializes and returns a WebDriver instance.
+
+ Returns:
+ WebDriver: An instance of WebDriver for this session.
+ """
if self.webdriver:
return self.webdriver
if self.virtual_display:
@@ -113,11 +166,23 @@ class WebDriverSession():
return self.default_driver
def __exit__(self, exc_type, exc_val, exc_tb):
+ """
+ Context management method for exiting a session. Closes and quits the WebDriver.
+
+ Args:
+ exc_type: Exception type.
+ exc_val: Exception value.
+ exc_tb: Exception traceback.
+
+ Note:
+ Closes the WebDriver and stops the virtual display if used.
+ """
if self.default_driver:
try:
self.default_driver.close()
- except:
- pass
+ except Exception as e:
+ if debug.logging:
+ print(f"Error closing WebDriver: {e}")
self.default_driver.quit()
if self.virtual_display:
self.virtual_display.stop() \ No newline at end of file