diff options
Diffstat (limited to 'g4f/gui')
-rw-r--r-- | g4f/gui/client/index.html | 5 | ||||
-rw-r--r-- | g4f/gui/client/static/css/style.css | 1 | ||||
-rw-r--r-- | g4f/gui/client/static/js/chat.v1.js | 16 | ||||
-rw-r--r-- | g4f/gui/server/api.py | 99 | ||||
-rw-r--r-- | g4f/gui/server/backend.py | 7 | ||||
-rw-r--r-- | g4f/gui/server/internet.py | 20 |
6 files changed, 73 insertions, 75 deletions
diff --git a/g4f/gui/client/index.html b/g4f/gui/client/index.html index 48214093..6c2ad8b6 100644 --- a/g4f/gui/client/index.html +++ b/g4f/gui/client/index.html @@ -112,6 +112,11 @@ <label for="hide-systemPrompt" class="toogle" title="For more space on phones"></label> </div> <div class="field"> + <span class="label">Download generated images</span> + <input type="checkbox" id="download_images" checked/> + <label for="download_images" class="toogle" title="Download and save generated images to /generated_images"></label> + </div> + <div class="field"> <span class="label">Auto continue in ChatGPT</span> <input id="auto_continue" type="checkbox" name="auto_continue" checked/> <label for="auto_continue" class="toogle" title="Continue large responses in OpenaiChat"></label> diff --git a/g4f/gui/client/static/css/style.css b/g4f/gui/client/static/css/style.css index 76399703..c4b61d87 100644 --- a/g4f/gui/client/static/css/style.css +++ b/g4f/gui/client/static/css/style.css @@ -498,6 +498,7 @@ body { gap: 12px; cursor: pointer; animation: show_popup 0.4s; + height: 28px; } .toolbar .regenerate { diff --git a/g4f/gui/client/static/js/chat.v1.js b/g4f/gui/client/static/js/chat.v1.js index a3e94ee2..0136f9c4 100644 --- a/g4f/gui/client/static/js/chat.v1.js +++ b/g4f/gui/client/static/js/chat.v1.js @@ -428,10 +428,14 @@ async function add_message_chunk(message, message_id) { p.innerText = message.log; log_storage.appendChild(p); } - window.scrollTo(0, 0); - if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 100) { - message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" }); + let scroll_down = ()=>{ + if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 100) { + window.scrollTo(0, 0); + message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" }); + } } + setTimeout(scroll_down, 200); + setTimeout(scroll_down, 1000); } cameraInput?.addEventListener("click", (e) => { @@ -492,6 +496,7 @@ const ask_gpt = async (message_index = -1, message_id) => { const file = input && input.files.length > 0 ? input.files[0] : null; const provider = providerSelect.options[providerSelect.selectedIndex].value; const auto_continue = document.getElementById("auto_continue")?.checked; + const download_images = document.getElementById("download_images")?.checked; let api_key = get_api_key_by_provider(provider); await api("conversation", { id: message_id, @@ -501,13 +506,13 @@ const ask_gpt = async (message_index = -1, message_id) => { provider: provider, messages: messages, auto_continue: auto_continue, + download_images: download_images, api_key: api_key, }, file, message_id); if (!error_storage[message_id]) { html = markdown_render(message_storage[message_id]); content_map.inner.innerHTML = html; highlight(content_map.inner); - if (imageInput) imageInput.value = ""; if (cameraInput) cameraInput.value = ""; if (fileInput) fileInput.value = ""; @@ -1239,8 +1244,7 @@ async function load_version() { if (versions["version"] != versions["latest_version"]) { let release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["latest_version"]; let title = `New version: ${versions["latest_version"]}`; - text += `<a href="${release_url}" target="_blank" title="${title}">${versions["version"]}</a> `; - text += `<i class="fa-solid fa-rotate"></i>` + text += `<a href="${release_url}" target="_blank" title="${title}">${versions["version"]}</a> 🆕`; } else { text += versions["version"]; } diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py index 29fc34e2..00eb7182 100644 --- a/g4f/gui/server/api.py +++ b/g4f/gui/server/api.py @@ -2,34 +2,22 @@ from __future__ import annotations import logging import os -import uuid import asyncio -import time -from aiohttp import ClientSession -from typing import Iterator, Optional +from typing import Iterator from flask import send_from_directory +from inspect import signature from g4f import version, models from g4f import get_last_provider, ChatCompletion from g4f.errors import VersionNotFoundError -from g4f.typing import Cookies -from g4f.image import ImagePreview, ImageResponse, is_accepted_format, extract_data_uri -from g4f.requests.aiohttp import get_connector +from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir from g4f.Provider import ProviderType, __providers__, __map__ -from g4f.providers.base_provider import ProviderModelMixin, FinishReason -from g4f.providers.conversation import BaseConversation +from g4f.providers.base_provider import ProviderModelMixin +from g4f.providers.response import BaseConversation, FinishReason +from g4f.client.service import convert_to_provider from g4f import debug logger = logging.getLogger(__name__) - -# Define the directory for generated images -images_dir = "./generated_images" - -# Function to ensure the images directory exists -def ensure_images_dir(): - if not os.path.exists(images_dir): - os.makedirs(images_dir) - conversations: dict[dict[str, BaseConversation]] = {} class Api: @@ -42,7 +30,10 @@ class Api: if provider in __map__: provider: ProviderType = __map__[provider] if issubclass(provider, ProviderModelMixin): - models = provider.get_models() if api_key is None else provider.get_models(api_key=api_key) + if api_key is not None and "api_key" in signature(provider.get_models).parameters: + models = provider.get_models(api_key=api_key) + else: + models = provider.get_models() return [ { "model": model, @@ -90,7 +81,7 @@ class Api: def get_providers() -> list[str]: return { provider.__name__: (provider.label if hasattr(provider, "label") else provider.__name__) - + (" (Image Generation)" if hasattr(provider, "image_models") else "") + + (" (Image Generation)" if getattr(provider, "image_models", None) else "") + (" (Image Upload)" if getattr(provider, "default_vision_model", None) else "") + (" (WebDriver)" if "webdriver" in provider.get_parameters() else "") + (" (Auth)" if provider.needs_auth else "") @@ -120,16 +111,23 @@ class Api: api_key = json_data.get("api_key") if api_key is not None: kwargs["api_key"] = api_key - if json_data.get('web_search'): - if provider: - kwargs['web_search'] = True - else: - from .internet import get_search_message - messages[-1]["content"] = get_search_message(messages[-1]["content"]) + do_web_search = json_data.get('web_search') + if do_web_search and provider: + provider_handler = convert_to_provider(provider) + if hasattr(provider_handler, "get_parameters"): + if "web_search" in provider_handler.get_parameters(): + kwargs['web_search'] = True + do_web_search = False + if do_web_search: + from .internet import get_search_message + messages[-1]["content"] = get_search_message(messages[-1]["content"]) + if json_data.get("auto_continue"): + kwargs['auto_continue'] = True conversation_id = json_data.get("conversation_id") - if conversation_id and provider in conversations and conversation_id in conversations[provider]: - kwargs["conversation"] = conversations[provider][conversation_id] + if conversation_id and provider: + if provider in conversations and conversation_id in conversations[provider]: + kwargs["conversation"] = conversations[provider][conversation_id] return { "model": model, @@ -141,7 +139,7 @@ class Api: **kwargs } - def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str) -> Iterator: + def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator: if debug.logging: debug.logs = [] print_callback = debug.log_handler @@ -163,18 +161,22 @@ class Api: first = False yield self._format_json("provider", get_last_provider(True)) if isinstance(chunk, BaseConversation): - if provider not in conversations: - conversations[provider] = {} - conversations[provider][conversation_id] = chunk - yield self._format_json("conversation", conversation_id) + if provider: + if provider not in conversations: + conversations[provider] = {} + conversations[provider][conversation_id] = chunk + yield self._format_json("conversation", conversation_id) elif isinstance(chunk, Exception): logger.exception(chunk) yield self._format_json("message", get_error_message(chunk)) elif isinstance(chunk, ImagePreview): yield self._format_json("preview", chunk.to_string()) elif isinstance(chunk, ImageResponse): - images = asyncio.run(self._copy_images(chunk.get_list(), chunk.options.get("cookies"))) - yield self._format_json("content", str(ImageResponse(images, chunk.alt))) + images = chunk + if download_images: + images = asyncio.run(copy_images(chunk.get_list(), chunk.options.get("cookies"))) + images = ImageResponse(images, chunk.alt) + yield self._format_json("content", str(images)) elif not isinstance(chunk, FinishReason): yield self._format_json("content", str(chunk)) if debug.logs: @@ -185,31 +187,6 @@ class Api: logger.exception(e) yield self._format_json('error', get_error_message(e)) - async def _copy_images(self, images: list[str], cookies: Optional[Cookies] = None): - ensure_images_dir() - async with ClientSession( - connector=get_connector(None, os.environ.get("G4F_PROXY")), - cookies=cookies - ) as session: - async def copy_image(image: str) -> str: - target = os.path.join(images_dir, f"{int(time.time())}_{str(uuid.uuid4())}") - if image.startswith("data:"): - with open(target, "wb") as f: - f.write(extract_data_uri(image)) - else: - async with session.get(image) as response: - with open(target, "wb") as f: - async for chunk in response.content.iter_any(): - f.write(chunk) - with open(target, "rb") as f: - extension = is_accepted_format(f.read(12)).split("/")[-1] - extension = "jpg" if extension == "jpeg" else extension - new_target = f"{target}.{extension}" - os.rename(target, new_target) - return f"/images/{os.path.basename(new_target)}" - - return await asyncio.gather(*[copy_image(image) for image in images]) - def _format_json(self, response_type: str, content): return { 'type': response_type, @@ -221,4 +198,4 @@ def get_error_message(exception: Exception) -> str: provider = get_last_provider() if provider is None: return message - return f"{provider.__name__}: {message}" + return f"{provider.__name__}: {message}"
\ No newline at end of file diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py index 020e49ef..917d779e 100644 --- a/g4f/gui/server/backend.py +++ b/g4f/gui/server/backend.py @@ -89,7 +89,12 @@ class Backend_Api(Api): kwargs = self._prepare_conversation_kwargs(json_data, kwargs) return self.app.response_class( - self._create_response_stream(kwargs, json_data.get("conversation_id"), json_data.get("provider")), + self._create_response_stream( + kwargs, + json_data.get("conversation_id"), + json_data.get("provider"), + json_data.get("download_images", True), + ), mimetype='text/event-stream' ) diff --git a/g4f/gui/server/internet.py b/g4f/gui/server/internet.py index b41b5eae..bafa3af7 100644 --- a/g4f/gui/server/internet.py +++ b/g4f/gui/server/internet.py @@ -8,12 +8,14 @@ try: except ImportError: has_requirements = False from ...errors import MissingRequirementsError - +from ... import debug + import asyncio class SearchResults(): - def __init__(self, results: list): + def __init__(self, results: list, used_words: int): self.results = results + self.used_words = used_words def __iter__(self): yield from self.results @@ -104,7 +106,8 @@ async def search(query: str, n_results: int = 5, max_words: int = 2500, add_text region="wt-wt", safesearch="moderate", timelimit="y", - max_results=n_results + max_results=n_results, + backend="html" ): results.append(SearchResultEntry( result["title"], @@ -120,6 +123,7 @@ async def search(query: str, n_results: int = 5, max_words: int = 2500, add_text texts = await asyncio.gather(*requests) formatted_results = [] + used_words = 0 left_words = max_words for i, entry in enumerate(results): if add_text: @@ -132,13 +136,14 @@ async def search(query: str, n_results: int = 5, max_words: int = 2500, add_text left_words -= entry.snippet.count(" ") if 0 > left_words: break + used_words = max_words - left_words formatted_results.append(entry) - return SearchResults(formatted_results) + return SearchResults(formatted_results, used_words) -def get_search_message(prompt) -> str: +def get_search_message(prompt, n_results: int = 5, max_words: int = 2500) -> str: try: - search_results = asyncio.run(search(prompt)) + search_results = asyncio.run(search(prompt, n_results, max_words)) message = f""" {search_results} @@ -149,7 +154,8 @@ Make sure to add the sources of cites using [[Number]](Url) notation after the r User request: {prompt} """ + debug.log(f"Web search: '{prompt.strip()[:50]}...' {search_results.used_words} Words") return message except Exception as e: - print("Couldn't do web search:", e) + debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}") return prompt
\ No newline at end of file |