From 58fa409eefcc8ae0233967dc807b046ad77bf6fa Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Wed, 20 Nov 2024 02:34:47 +0100 Subject: Add Cerebras and HuggingFace2 provider, Fix RubiksAI provider Add support for image generation in Copilot provider --- g4f/Provider/Blackbox.py | 12 ++- g4f/Provider/Copilot.py | 23 ++++-- g4f/Provider/RubiksAI.py | 124 +++++++++++------------------- g4f/Provider/needs_auth/Cerebras.py | 65 ++++++++++++++++ g4f/Provider/needs_auth/CopilotAccount.py | 7 +- g4f/Provider/needs_auth/HuggingFace2.py | 28 +++++++ g4f/Provider/needs_auth/OpenaiAPI.py | 4 +- g4f/Provider/needs_auth/__init__.py | 2 + g4f/gui/client/index.html | 8 +- g4f/gui/client/static/css/style.css | 4 +- g4f/gui/client/static/js/chat.v1.js | 87 +++++++++++---------- g4f/gui/server/api.py | 5 +- g4f/gui/server/backend.py | 3 +- g4f/requests/raise_for_status.py | 4 +- 14 files changed, 239 insertions(+), 137 deletions(-) create mode 100644 g4f/Provider/needs_auth/Cerebras.py create mode 100644 g4f/Provider/needs_auth/HuggingFace2.py diff --git a/g4f/Provider/Blackbox.py b/g4f/Provider/Blackbox.py index 97466c04..ba58a511 100644 --- a/g4f/Provider/Blackbox.py +++ b/g4f/Provider/Blackbox.py @@ -28,6 +28,9 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin): image_models = [default_image_model, 'repomap'] text_models = [default_model, 'gpt-4o', 'gemini-pro', 'claude-sonnet-3.5', 'blackboxai-pro'] vision_models = [default_model, 'gpt-4o', 'gemini-pro', 'blackboxai-pro'] + model_aliases = { + "claude-3.5-sonnet": "claude-sonnet-3.5", + } agentMode = { default_image_model: {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"}, } @@ -198,6 +201,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin): async with ClientSession(headers=headers) as session: async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response: response.raise_for_status() + is_first = False async for chunk in response.content.iter_any(): text_chunk = chunk.decode(errors="ignore") if model in cls.image_models: @@ -217,5 +221,9 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin): for i, result in enumerate(search_results, 1): formatted_response += f"\n{i}. {result['title']}: {result['link']}" yield formatted_response - else: - yield text_chunk.strip() + elif text_chunk: + if is_first: + is_first = False + yield text_chunk.lstrip() + else: + yield text_chunk \ No newline at end of file diff --git a/g4f/Provider/Copilot.py b/g4f/Provider/Copilot.py index e40278c7..e8eea0a5 100644 --- a/g4f/Provider/Copilot.py +++ b/g4f/Provider/Copilot.py @@ -21,8 +21,9 @@ from .helper import format_prompt from ..typing import CreateResult, Messages, ImageType from ..errors import MissingRequirementsError from ..requests.raise_for_status import raise_for_status +from ..providers.helper import format_cookies from ..requests import get_nodriver -from ..image import to_bytes, is_accepted_format +from ..image import ImageResponse, to_bytes, is_accepted_format from .. import debug class Conversation(BaseConversation): @@ -70,18 +71,21 @@ class Copilot(AbstractProvider): access_token, cookies = asyncio.run(cls.get_access_token_and_cookies(proxy)) else: access_token = conversation.access_token - websocket_url = f"{websocket_url}&acessToken={quote(access_token)}" - headers = {"Authorization": f"Bearer {access_token}"} + debug.log(f"Copilot: Access token: {access_token[:7]}...{access_token[-5:]}") + debug.log(f"Copilot: Cookies: {';'.join([*cookies])}") + websocket_url = f"{websocket_url}&accessToken={quote(access_token)}" + headers = {"authorization": f"Bearer {access_token}", "cookie": format_cookies(cookies)} with Session( timeout=timeout, proxy=proxy, impersonate="chrome", headers=headers, - cookies=cookies + cookies=cookies, ) as session: - response = session.get(f"{cls.url}/") + response = session.get("https://copilot.microsoft.com/c/api/user") raise_for_status(response) + debug.log(f"Copilot: User: {response.json().get('firstName', 'null')}") if conversation is None: response = session.post(cls.conversation_url) raise_for_status(response) @@ -119,6 +123,7 @@ class Copilot(AbstractProvider): is_started = False msg = None + image_prompt: str = None while True: try: msg = wss.recv()[0] @@ -128,7 +133,11 @@ class Copilot(AbstractProvider): if msg.get("event") == "appendText": is_started = True yield msg.get("text") - elif msg.get("event") in ["done", "partCompleted"]: + elif msg.get("event") == "generatingImage": + image_prompt = msg.get("prompt") + elif msg.get("event") == "imageGenerated": + yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")}) + elif msg.get("event") == "done": break if not is_started: raise RuntimeError(f"Last message: {msg}") @@ -152,7 +161,7 @@ class Copilot(AbstractProvider): })() """) if access_token is None: - asyncio.sleep(1) + await asyncio.sleep(1) cookies = {} for c in await page.send(nodriver.cdp.network.get_cookies([cls.url])): cookies[c.name] = c.value diff --git a/g4f/Provider/RubiksAI.py b/g4f/Provider/RubiksAI.py index 7e76d558..c06e6c3d 100644 --- a/g4f/Provider/RubiksAI.py +++ b/g4f/Provider/RubiksAI.py @@ -1,7 +1,6 @@ + from __future__ import annotations -import asyncio -import aiohttp import random import string import json @@ -11,34 +10,24 @@ from aiohttp import ClientSession from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider, ProviderModelMixin -from .helper import format_prompt - +from ..requests.raise_for_status import raise_for_status class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin): label = "Rubiks AI" url = "https://rubiks.ai" - api_endpoint = "https://rubiks.ai/search/api.php" + api_endpoint = "https://rubiks.ai/search/api/" working = True supports_stream = True supports_system_message = True supports_message_history = True - default_model = 'llama-3.1-70b-versatile' - models = [default_model, 'gpt-4o-mini'] + default_model = 'gpt-4o-mini' + models = [default_model, 'gpt-4o', 'o1-mini', 'claude-3.5-sonnet', 'grok-beta', 'gemini-1.5-pro', 'nova-pro'] model_aliases = { "llama-3.1-70b": "llama-3.1-70b-versatile", } - @classmethod - def get_model(cls, model: str) -> str: - if model in cls.models: - return model - elif model in cls.model_aliases: - return cls.model_aliases[model] - else: - return cls.default_model - @staticmethod def generate_mid() -> str: """ @@ -70,7 +59,8 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin): model: str, messages: Messages, proxy: str = None, - websearch: bool = False, + web_search: bool = False, + temperature: float = 0.6, **kwargs ) -> AsyncResult: """ @@ -80,20 +70,18 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin): - model (str): The model to use in the request. - messages (Messages): The messages to send as a prompt. - proxy (str, optional): Proxy URL, if needed. - - websearch (bool, optional): Indicates whether to include search sources in the response. Defaults to False. + - web_search (bool, optional): Indicates whether to include search sources in the response. Defaults to False. """ model = cls.get_model(model) - prompt = format_prompt(messages) - q_value = prompt mid_value = cls.generate_mid() - referer = cls.create_referer(q=q_value, mid=mid_value, model=model) - - url = cls.api_endpoint - params = { - 'q': q_value, - 'model': model, - 'id': '', - 'mid': mid_value + referer = cls.create_referer(q=messages[-1]["content"], mid=mid_value, model=model) + + data = { + "messages": messages, + "model": model, + "search": web_search, + "stream": True, + "temperature": temperature } headers = { @@ -111,52 +99,34 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin): 'sec-ch-ua-mobile': '?0', 'sec-ch-ua-platform': '"Linux"' } - - try: - timeout = aiohttp.ClientTimeout(total=None) - async with ClientSession(timeout=timeout) as session: - async with session.get(url, headers=headers, params=params, proxy=proxy) as response: - if response.status != 200: - yield f"Request ended with status code {response.status}" - return - - assistant_text = '' - sources = [] - - async for line in response.content: - decoded_line = line.decode('utf-8').strip() - if not decoded_line.startswith('data: '): - continue - data = decoded_line[6:] - if data in ('[DONE]', '{"done": ""}'): - break - try: - json_data = json.loads(data) - except json.JSONDecodeError: - continue - - if 'url' in json_data and 'title' in json_data: - if websearch: - sources.append({'title': json_data['title'], 'url': json_data['url']}) - - elif 'choices' in json_data: - for choice in json_data['choices']: - delta = choice.get('delta', {}) - content = delta.get('content', '') - role = delta.get('role', '') - if role == 'assistant': - continue - assistant_text += content - - if websearch and sources: - sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)]) - assistant_text += f"\n\n**Source:**\n{sources_text}" - - yield assistant_text - - except asyncio.CancelledError: - yield "The request was cancelled." - except aiohttp.ClientError as e: - yield f"An error occurred during the request: {e}" - except Exception as e: - yield f"An unexpected error occurred: {e}" + async with ClientSession() as session: + async with session.post(cls.api_endpoint, headers=headers, json=data, proxy=proxy) as response: + await raise_for_status(response) + + sources = [] + async for line in response.content: + decoded_line = line.decode('utf-8').strip() + if not decoded_line.startswith('data: '): + continue + data = decoded_line[6:] + if data in ('[DONE]', '{"done": ""}'): + break + try: + json_data = json.loads(data) + except json.JSONDecodeError: + continue + + if 'url' in json_data and 'title' in json_data: + if web_search: + sources.append({'title': json_data['title'], 'url': json_data['url']}) + + elif 'choices' in json_data: + for choice in json_data['choices']: + delta = choice.get('delta', {}) + content = delta.get('content', '') + if content: + yield content + + if web_search and sources: + sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)]) + yield f"\n\n**Source:**\n{sources_text}" \ No newline at end of file diff --git a/g4f/Provider/needs_auth/Cerebras.py b/g4f/Provider/needs_auth/Cerebras.py new file mode 100644 index 00000000..0f94c476 --- /dev/null +++ b/g4f/Provider/needs_auth/Cerebras.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import requests +from aiohttp import ClientSession + +from .OpenaiAPI import OpenaiAPI +from ...typing import AsyncResult, Messages, Cookies +from ...requests.raise_for_status import raise_for_status +from ...cookies import get_cookies + +class Cerebras(OpenaiAPI): + label = "Cerebras Inference" + url = "https://inference.cerebras.ai/" + working = True + default_model = "llama3.1-70b" + fallback_models = [ + "llama3.1-70b", + "llama3.1-8b", + ] + model_aliases = {"llama-3.1-70b": "llama3.1-70b", "llama-3.1-8b": "llama3.1-8b"} + + @classmethod + def get_models(cls, api_key: str = None): + if not cls.models: + try: + headers = {} + if api_key: + headers["authorization"] = f"Bearer ${api_key}" + response = requests.get(f"https://api.cerebras.ai/v1/models", headers=headers) + raise_for_status(response) + data = response.json() + cls.models = [model.get("model") for model in data.get("models")] + except Exception: + cls.models = cls.fallback_models + return cls.models + + @classmethod + async def create_async_generator( + cls, + model: str, + messages: Messages, + api_base: str = "https://api.cerebras.ai/v1", + api_key: str = None, + cookies: Cookies = None, + **kwargs + ) -> AsyncResult: + if api_key is None and cookies is None: + cookies = get_cookies(".cerebras.ai") + async with ClientSession(cookies=cookies) as session: + async with session.get("https://inference.cerebras.ai/api/auth/session") as response: + raise_for_status(response) + data = await response.json() + if data: + api_key = data.get("user", {}).get("demoApiKey") + async for chunk in super().create_async_generator( + model, messages, + api_base=api_base, + impersonate="chrome", + api_key=api_key, + headers={ + "User-Agent": "ex/JS 1.5.0", + }, + **kwargs + ): + yield chunk diff --git a/g4f/Provider/needs_auth/CopilotAccount.py b/g4f/Provider/needs_auth/CopilotAccount.py index 76e51278..497aab98 100644 --- a/g4f/Provider/needs_auth/CopilotAccount.py +++ b/g4f/Provider/needs_auth/CopilotAccount.py @@ -1,9 +1,12 @@ from __future__ import annotations +from ..base_provider import ProviderModelMixin from ..Copilot import Copilot -class CopilotAccount(Copilot): +class CopilotAccount(Copilot, ProviderModelMixin): needs_auth = True parent = "Copilot" default_model = "Copilot" - default_vision_model = default_model \ No newline at end of file + default_vision_model = default_model + models = [default_model] + image_models = models \ No newline at end of file diff --git a/g4f/Provider/needs_auth/HuggingFace2.py b/g4f/Provider/needs_auth/HuggingFace2.py new file mode 100644 index 00000000..847d459b --- /dev/null +++ b/g4f/Provider/needs_auth/HuggingFace2.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from .OpenaiAPI import OpenaiAPI +from ..HuggingChat import HuggingChat +from ...typing import AsyncResult, Messages + +class HuggingFace2(OpenaiAPI): + label = "HuggingFace (Inference API)" + url = "https://huggingface.co" + working = True + default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct" + default_vision_model = default_model + models = [ + *HuggingChat.models + ] + + @classmethod + def create_async_generator( + cls, + model: str, + messages: Messages, + api_base: str = "https://api-inference.huggingface.co/v1", + max_tokens: int = 500, + **kwargs + ) -> AsyncResult: + return super().create_async_generator( + model, messages, api_base=api_base, max_tokens=max_tokens, **kwargs + ) \ No newline at end of file diff --git a/g4f/Provider/needs_auth/OpenaiAPI.py b/g4f/Provider/needs_auth/OpenaiAPI.py index 116b5f6f..83268b6d 100644 --- a/g4f/Provider/needs_auth/OpenaiAPI.py +++ b/g4f/Provider/needs_auth/OpenaiAPI.py @@ -34,6 +34,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin): stop: Union[str, list[str]] = None, stream: bool = False, headers: dict = None, + impersonate: str = None, extra_data: dict = {}, **kwargs ) -> AsyncResult: @@ -55,7 +56,8 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin): async with StreamSession( proxies={"all": proxy}, headers=cls.get_headers(stream, api_key, headers), - timeout=timeout + timeout=timeout, + impersonate=impersonate, ) as session: data = filter_none( messages=messages, diff --git a/g4f/Provider/needs_auth/__init__.py b/g4f/Provider/needs_auth/__init__.py index 0f430ab5..1c7fe7c5 100644 --- a/g4f/Provider/needs_auth/__init__.py +++ b/g4f/Provider/needs_auth/__init__.py @@ -1,6 +1,7 @@ from .gigachat import * from .BingCreateImages import BingCreateImages +from .Cerebras import Cerebras from .CopilotAccount import CopilotAccount from .DeepInfra import DeepInfra from .DeepInfraImage import DeepInfraImage @@ -8,6 +9,7 @@ from .Gemini import Gemini from .GeminiPro import GeminiPro from .Groq import Groq from .HuggingFace import HuggingFace +from .HuggingFace2 import HuggingFace2 from .MetaAI import MetaAI from .MetaAIAccount import MetaAIAccount from .OpenaiAPI import OpenaiAPI diff --git a/g4f/gui/client/index.html b/g4f/gui/client/index.html index 3a2197de..48214093 100644 --- a/g4f/gui/client/index.html +++ b/g4f/gui/client/index.html @@ -128,6 +128,10 @@ +
+ + +
@@ -142,7 +146,7 @@
- +
@@ -192,7 +196,7 @@
diff --git a/g4f/gui/client/static/css/style.css b/g4f/gui/client/static/css/style.css index e435094f..76399703 100644 --- a/g4f/gui/client/static/css/style.css +++ b/g4f/gui/client/static/css/style.css @@ -512,9 +512,7 @@ body { @media only screen and (min-width: 40em) { .stop_generating { - left: 50%; - transform: translateX(-50%); - right: auto; + right: 4px; } .toolbar .regenerate span { display: block; diff --git a/g4f/gui/client/static/js/chat.v1.js b/g4f/gui/client/static/js/chat.v1.js index 51bf8b81..a3e94ee2 100644 --- a/g4f/gui/client/static/js/chat.v1.js +++ b/g4f/gui/client/static/js/chat.v1.js @@ -215,7 +215,6 @@ const register_message_buttons = async () => { const message_el = el.parentElement.parentElement.parentElement; el.classList.add("clicked"); setTimeout(() => el.classList.remove("clicked"), 1000); - await hide_message(window.conversation_id, message_el.dataset.index); await ask_gpt(message_el.dataset.index, get_message_id()); }) } @@ -317,6 +316,7 @@ async function remove_cancel_button() { regenerate.addEventListener("click", async () => { regenerate.classList.add("regenerate-hidden"); + setTimeout(()=>regenerate.classList.remove("regenerate-hidden"), 3000); stop_generating.classList.remove("stop_generating-hidden"); await hide_message(window.conversation_id); await ask_gpt(-1, get_message_id()); @@ -383,12 +383,12 @@ const prepare_messages = (messages, message_index = -1) => { return new_messages; } -async function add_message_chunk(message, message_index) { - content_map = content_storage[message_index]; +async function add_message_chunk(message, message_id) { + content_map = content_storage[message_id]; if (message.type == "conversation") { console.info("Conversation used:", message.conversation) } else if (message.type == "provider") { - provider_storage[message_index] = message.provider; + provider_storage[message_id] = message.provider; content_map.content.querySelector('.provider').innerHTML = ` ${message.provider.label ? message.provider.label : message.provider.name} @@ -398,7 +398,7 @@ async function add_message_chunk(message, message_index) { } else if (message.type == "message") { console.error(message.message) } else if (message.type == "error") { - error_storage[message_index] = message.error + error_storage[message_id] = message.error console.error(message.error); content_map.inner.innerHTML += `

An error occured: ${message.error}

`; let p = document.createElement("p"); @@ -407,8 +407,8 @@ async function add_message_chunk(message, message_index) { } else if (message.type == "preview") { content_map.inner.innerHTML = markdown_render(message.preview); } else if (message.type == "content") { - message_storage[message_index] += message.content; - html = markdown_render(message_storage[message_index]); + message_storage[message_id] += message.content; + html = markdown_render(message_storage[message_id]); let lastElement, lastIndex = null; for (element of ['

', '', '

\n\n', '\n', '\n']) { const index = html.lastIndexOf(element) @@ -421,7 +421,7 @@ async function add_message_chunk(message, message_index) { html = html.substring(0, lastIndex) + '' + lastElement; } content_map.inner.innerHTML = html; - content_map.count.innerText = count_words_and_tokens(message_storage[message_index], provider_storage[message_index]?.model); + content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model); highlight(content_map.inner); } else if (message.type == "log") { let p = document.createElement("p"); @@ -453,7 +453,7 @@ const ask_gpt = async (message_index = -1, message_id) => { let total_messages = messages.length; messages = prepare_messages(messages, message_index); message_index = total_messages - message_storage[message_index] = ""; + message_storage[message_id] = ""; stop_generating.classList.remove(".stop_generating-hidden"); message_box.scrollTop = message_box.scrollHeight; @@ -477,10 +477,10 @@ const ask_gpt = async (message_index = -1, message_id) => {
`; - controller_storage[message_index] = new AbortController(); + controller_storage[message_id] = new AbortController(); let content_el = document.getElementById(`gpt_${message_id}`) - let content_map = content_storage[message_index] = { + let content_map = content_storage[message_id] = { content: content_el, inner: content_el.querySelector('.content_inner'), count: content_el.querySelector('.count'), @@ -492,12 +492,7 @@ const ask_gpt = async (message_index = -1, message_id) => { const file = input && input.files.length > 0 ? input.files[0] : null; const provider = providerSelect.options[providerSelect.selectedIndex].value; const auto_continue = document.getElementById("auto_continue")?.checked; - let api_key = null; - if (provider) { - api_key = document.getElementById(`${provider}-api_key`)?.value || null; - if (api_key == null) - api_key = document.querySelector(`.${provider}-api_key`)?.value || null; - } + let api_key = get_api_key_by_provider(provider); await api("conversation", { id: message_id, conversation_id: window.conversation_id, @@ -506,10 +501,10 @@ const ask_gpt = async (message_index = -1, message_id) => { provider: provider, messages: messages, auto_continue: auto_continue, - api_key: api_key - }, file, message_index); - if (!error_storage[message_index]) { - html = markdown_render(message_storage[message_index]); + api_key: api_key, + }, file, message_id); + if (!error_storage[message_id]) { + html = markdown_render(message_storage[message_id]); content_map.inner.innerHTML = html; highlight(content_map.inner); @@ -520,14 +515,14 @@ const ask_gpt = async (message_index = -1, message_id) => { } catch (e) { console.error(e); if (e.name != "AbortError") { - error_storage[message_index] = true; + error_storage[message_id] = true; content_map.inner.innerHTML += `

An error occured: ${e}

`; } } - delete controller_storage[message_index]; - if (!error_storage[message_index] && message_storage[message_index]) { - const message_provider = message_index in provider_storage ? provider_storage[message_index] : null; - await add_message(window.conversation_id, "assistant", message_storage[message_index], message_provider); + delete controller_storage[message_id]; + if (!error_storage[message_id] && message_storage[message_id]) { + const message_provider = message_id in provider_storage ? provider_storage[message_id] : null; + await add_message(window.conversation_id, "assistant", message_storage[message_id], message_provider); await safe_load_conversation(window.conversation_id); } else { let cursorDiv = message_box.querySelector(".cursor"); @@ -1156,7 +1151,7 @@ async function on_api() { evt.preventDefault(); console.log("pressed enter"); prompt_lock = true; - setTimeout(()=>prompt_lock=false, 3); + setTimeout(()=>prompt_lock=false, 3000); await handle_ask(); } else { messageInput.style.removeProperty("height"); @@ -1167,7 +1162,7 @@ async function on_api() { console.log("clicked send"); if (prompt_lock) return; prompt_lock = true; - setTimeout(()=>prompt_lock=false, 3); + setTimeout(()=>prompt_lock=false, 3000); await handle_ask(); }); messageInput.focus(); @@ -1189,8 +1184,8 @@ async function on_api() { providerSelect.appendChild(option); }) - await load_provider_models(appStorage.getItem("provider")); await load_settings_storage() + await load_provider_models(appStorage.getItem("provider")); const hide_systemPrompt = document.getElementById("hide-systemPrompt") const slide_systemPrompt_icon = document.querySelector(".slide-systemPrompt i"); @@ -1316,7 +1311,7 @@ function get_selected_model() { } } -async function api(ressource, args=null, file=null, message_index=null) { +async function api(ressource, args=null, file=null, message_id=null) { if (window?.pywebview) { if (args !== null) { if (ressource == "models") { @@ -1326,15 +1321,19 @@ async function api(ressource, args=null, file=null, message_index=null) { } return pywebview.api[`get_${ressource}`](); } + let api_key; if (ressource == "models" && args) { + api_key = get_api_key_by_provider(args); ressource = `${ressource}/${args}`; } const url = `/backend-api/v2/${ressource}`; + const headers = {}; + if (api_key) { + headers.authorization = `Bearer ${api_key}`; + } if (ressource == "conversation") { let body = JSON.stringify(args); - const headers = { - accept: 'text/event-stream' - } + headers.accept = 'text/event-stream'; if (file !== null) { const formData = new FormData(); formData.append('file', file); @@ -1345,17 +1344,17 @@ async function api(ressource, args=null, file=null, message_index=null) { } response = await fetch(url, { method: 'POST', - signal: controller_storage[message_index].signal, + signal: controller_storage[message_id].signal, headers: headers, - body: body + body: body, }); - return read_response(response, message_index); + return read_response(response, message_id); } - response = await fetch(url); + response = await fetch(url, {headers: headers}); return await response.json(); } -async function read_response(response, message_index) { +async function read_response(response, message_id) { const reader = response.body.pipeThrough(new TextDecoderStream()).getReader(); let buffer = "" while (true) { @@ -1368,7 +1367,7 @@ async function read_response(response, message_index) { continue; } try { - add_message_chunk(JSON.parse(buffer + line), message_index); + add_message_chunk(JSON.parse(buffer + line), message_id); buffer = ""; } catch { buffer += line @@ -1377,6 +1376,16 @@ async function read_response(response, message_index) { } } +function get_api_key_by_provider(provider) { + let api_key = null; + if (provider) { + api_key = document.getElementById(`${provider}-api_key`)?.value || null; + if (api_key == null) + api_key = document.querySelector(`.${provider}-api_key`)?.value || null; + } + return api_key; +} + async function load_provider_models(providerIndex=null) { if (!providerIndex) { providerIndex = providerSelect.selectedIndex; diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py index 6be77d09..2d871ff3 100644 --- a/g4f/gui/server/api.py +++ b/g4f/gui/server/api.py @@ -38,10 +38,11 @@ class Api: return models._all_models @staticmethod - def get_provider_models(provider: str) -> list[dict]: + def get_provider_models(provider: str, api_key: str = None) -> list[dict]: if provider in __map__: provider: ProviderType = __map__[provider] if issubclass(provider, ProviderModelMixin): + models = provider.get_models() if api_key is None else provider.get_models(api_key=api_key) return [ { "model": model, @@ -49,7 +50,7 @@ class Api: "vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []), "image": model in getattr(provider, "image_models", []), } - for model in provider.get_models() + for model in models ] return [] diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py index dc1b1080..020e49ef 100644 --- a/g4f/gui/server/backend.py +++ b/g4f/gui/server/backend.py @@ -94,7 +94,8 @@ class Backend_Api(Api): ) def get_provider_models(self, provider: str): - models = super().get_provider_models(provider) + api_key = None if request.authorization is None else request.authorization.token + models = super().get_provider_models(provider, api_key) if models is None: return 404, "Provider not found" return models diff --git a/g4f/requests/raise_for_status.py b/g4f/requests/raise_for_status.py index fe262f34..8625f552 100644 --- a/g4f/requests/raise_for_status.py +++ b/g4f/requests/raise_for_status.py @@ -11,7 +11,9 @@ class CloudflareError(ResponseStatusError): ... def is_cloudflare(text: str) -> bool: - if "Attention Required! | Cloudflare" in text or 'id="cf-cloudflare-status"' in text: + if "Generated by cloudfront" in text: + return True + elif "Attention Required! | Cloudflare" in text or 'id="cf-cloudflare-status"' in text: return True return '
' in text or "Just a moment..." in text -- cgit v1.2.3 From 5e9e56ed5388250237479ee59acdbc1c7017ea39 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Wed, 20 Nov 2024 02:35:35 +0100 Subject: Fix missing provider_handler in client --- g4f/client/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index 1f3cdab1..549a244b 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -12,7 +12,7 @@ from typing import Union, AsyncIterator, Iterator, Coroutine from ..providers.base_provider import AsyncGeneratorProvider from ..image import ImageResponse, to_image, to_data_uri, is_accepted_format, EXTENSIONS_MAP -from ..typing import Messages, Cookies, Image +from ..typing import Messages, Image from ..providers.types import ProviderType, FinishReason, BaseConversation from ..errors import NoImageResponseError from ..providers.retry_provider import IterListProvider @@ -254,6 +254,8 @@ class Images: provider_handler = self.models.get(model, provider or self.provider or BingCreateImages) elif isinstance(provider, str): provider_handler = convert_to_provider(provider) + else: + provider_handler = provider if provider_handler is None: raise ValueError(f"Unknown model: {model}") if proxy is None: -- cgit v1.2.3