summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/Blackbox.py12
-rw-r--r--g4f/Provider/Copilot.py23
-rw-r--r--g4f/Provider/RubiksAI.py124
-rw-r--r--g4f/Provider/needs_auth/Cerebras.py65
-rw-r--r--g4f/Provider/needs_auth/CopilotAccount.py7
-rw-r--r--g4f/Provider/needs_auth/HuggingFace2.py28
-rw-r--r--g4f/Provider/needs_auth/OpenaiAPI.py4
-rw-r--r--g4f/Provider/needs_auth/__init__.py2
-rw-r--r--g4f/client/__init__.py4
-rw-r--r--g4f/gui/client/index.html8
-rw-r--r--g4f/gui/client/static/css/style.css4
-rw-r--r--g4f/gui/client/static/js/chat.v1.js87
-rw-r--r--g4f/gui/server/api.py5
-rw-r--r--g4f/gui/server/backend.py3
-rw-r--r--g4f/requests/raise_for_status.py4
15 files changed, 242 insertions, 138 deletions
diff --git a/g4f/Provider/Blackbox.py b/g4f/Provider/Blackbox.py
index 97466c04..ba58a511 100644
--- a/g4f/Provider/Blackbox.py
+++ b/g4f/Provider/Blackbox.py
@@ -28,6 +28,9 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
image_models = [default_image_model, 'repomap']
text_models = [default_model, 'gpt-4o', 'gemini-pro', 'claude-sonnet-3.5', 'blackboxai-pro']
vision_models = [default_model, 'gpt-4o', 'gemini-pro', 'blackboxai-pro']
+ model_aliases = {
+ "claude-3.5-sonnet": "claude-sonnet-3.5",
+ }
agentMode = {
default_image_model: {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"},
}
@@ -198,6 +201,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
async with ClientSession(headers=headers) as session:
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
response.raise_for_status()
+ is_first = False
async for chunk in response.content.iter_any():
text_chunk = chunk.decode(errors="ignore")
if model in cls.image_models:
@@ -217,5 +221,9 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
for i, result in enumerate(search_results, 1):
formatted_response += f"\n{i}. {result['title']}: {result['link']}"
yield formatted_response
- else:
- yield text_chunk.strip()
+ elif text_chunk:
+ if is_first:
+ is_first = False
+ yield text_chunk.lstrip()
+ else:
+ yield text_chunk \ No newline at end of file
diff --git a/g4f/Provider/Copilot.py b/g4f/Provider/Copilot.py
index e40278c7..e8eea0a5 100644
--- a/g4f/Provider/Copilot.py
+++ b/g4f/Provider/Copilot.py
@@ -21,8 +21,9 @@ from .helper import format_prompt
from ..typing import CreateResult, Messages, ImageType
from ..errors import MissingRequirementsError
from ..requests.raise_for_status import raise_for_status
+from ..providers.helper import format_cookies
from ..requests import get_nodriver
-from ..image import to_bytes, is_accepted_format
+from ..image import ImageResponse, to_bytes, is_accepted_format
from .. import debug
class Conversation(BaseConversation):
@@ -70,18 +71,21 @@ class Copilot(AbstractProvider):
access_token, cookies = asyncio.run(cls.get_access_token_and_cookies(proxy))
else:
access_token = conversation.access_token
- websocket_url = f"{websocket_url}&acessToken={quote(access_token)}"
- headers = {"Authorization": f"Bearer {access_token}"}
+ debug.log(f"Copilot: Access token: {access_token[:7]}...{access_token[-5:]}")
+ debug.log(f"Copilot: Cookies: {';'.join([*cookies])}")
+ websocket_url = f"{websocket_url}&accessToken={quote(access_token)}"
+ headers = {"authorization": f"Bearer {access_token}", "cookie": format_cookies(cookies)}
with Session(
timeout=timeout,
proxy=proxy,
impersonate="chrome",
headers=headers,
- cookies=cookies
+ cookies=cookies,
) as session:
- response = session.get(f"{cls.url}/")
+ response = session.get("https://copilot.microsoft.com/c/api/user")
raise_for_status(response)
+ debug.log(f"Copilot: User: {response.json().get('firstName', 'null')}")
if conversation is None:
response = session.post(cls.conversation_url)
raise_for_status(response)
@@ -119,6 +123,7 @@ class Copilot(AbstractProvider):
is_started = False
msg = None
+ image_prompt: str = None
while True:
try:
msg = wss.recv()[0]
@@ -128,7 +133,11 @@ class Copilot(AbstractProvider):
if msg.get("event") == "appendText":
is_started = True
yield msg.get("text")
- elif msg.get("event") in ["done", "partCompleted"]:
+ elif msg.get("event") == "generatingImage":
+ image_prompt = msg.get("prompt")
+ elif msg.get("event") == "imageGenerated":
+ yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")})
+ elif msg.get("event") == "done":
break
if not is_started:
raise RuntimeError(f"Last message: {msg}")
@@ -152,7 +161,7 @@ class Copilot(AbstractProvider):
})()
""")
if access_token is None:
- asyncio.sleep(1)
+ await asyncio.sleep(1)
cookies = {}
for c in await page.send(nodriver.cdp.network.get_cookies([cls.url])):
cookies[c.name] = c.value
diff --git a/g4f/Provider/RubiksAI.py b/g4f/Provider/RubiksAI.py
index 7e76d558..c06e6c3d 100644
--- a/g4f/Provider/RubiksAI.py
+++ b/g4f/Provider/RubiksAI.py
@@ -1,7 +1,6 @@
+
from __future__ import annotations
-import asyncio
-import aiohttp
import random
import string
import json
@@ -11,34 +10,24 @@ from aiohttp import ClientSession
from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
-from .helper import format_prompt
-
+from ..requests.raise_for_status import raise_for_status
class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
label = "Rubiks AI"
url = "https://rubiks.ai"
- api_endpoint = "https://rubiks.ai/search/api.php"
+ api_endpoint = "https://rubiks.ai/search/api/"
working = True
supports_stream = True
supports_system_message = True
supports_message_history = True
- default_model = 'llama-3.1-70b-versatile'
- models = [default_model, 'gpt-4o-mini']
+ default_model = 'gpt-4o-mini'
+ models = [default_model, 'gpt-4o', 'o1-mini', 'claude-3.5-sonnet', 'grok-beta', 'gemini-1.5-pro', 'nova-pro']
model_aliases = {
"llama-3.1-70b": "llama-3.1-70b-versatile",
}
- @classmethod
- def get_model(cls, model: str) -> str:
- if model in cls.models:
- return model
- elif model in cls.model_aliases:
- return cls.model_aliases[model]
- else:
- return cls.default_model
-
@staticmethod
def generate_mid() -> str:
"""
@@ -70,7 +59,8 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
model: str,
messages: Messages,
proxy: str = None,
- websearch: bool = False,
+ web_search: bool = False,
+ temperature: float = 0.6,
**kwargs
) -> AsyncResult:
"""
@@ -80,20 +70,18 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
- model (str): The model to use in the request.
- messages (Messages): The messages to send as a prompt.
- proxy (str, optional): Proxy URL, if needed.
- - websearch (bool, optional): Indicates whether to include search sources in the response. Defaults to False.
+ - web_search (bool, optional): Indicates whether to include search sources in the response. Defaults to False.
"""
model = cls.get_model(model)
- prompt = format_prompt(messages)
- q_value = prompt
mid_value = cls.generate_mid()
- referer = cls.create_referer(q=q_value, mid=mid_value, model=model)
-
- url = cls.api_endpoint
- params = {
- 'q': q_value,
- 'model': model,
- 'id': '',
- 'mid': mid_value
+ referer = cls.create_referer(q=messages[-1]["content"], mid=mid_value, model=model)
+
+ data = {
+ "messages": messages,
+ "model": model,
+ "search": web_search,
+ "stream": True,
+ "temperature": temperature
}
headers = {
@@ -111,52 +99,34 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"Linux"'
}
-
- try:
- timeout = aiohttp.ClientTimeout(total=None)
- async with ClientSession(timeout=timeout) as session:
- async with session.get(url, headers=headers, params=params, proxy=proxy) as response:
- if response.status != 200:
- yield f"Request ended with status code {response.status}"
- return
-
- assistant_text = ''
- sources = []
-
- async for line in response.content:
- decoded_line = line.decode('utf-8').strip()
- if not decoded_line.startswith('data: '):
- continue
- data = decoded_line[6:]
- if data in ('[DONE]', '{"done": ""}'):
- break
- try:
- json_data = json.loads(data)
- except json.JSONDecodeError:
- continue
-
- if 'url' in json_data and 'title' in json_data:
- if websearch:
- sources.append({'title': json_data['title'], 'url': json_data['url']})
-
- elif 'choices' in json_data:
- for choice in json_data['choices']:
- delta = choice.get('delta', {})
- content = delta.get('content', '')
- role = delta.get('role', '')
- if role == 'assistant':
- continue
- assistant_text += content
-
- if websearch and sources:
- sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)])
- assistant_text += f"\n\n**Source:**\n{sources_text}"
-
- yield assistant_text
-
- except asyncio.CancelledError:
- yield "The request was cancelled."
- except aiohttp.ClientError as e:
- yield f"An error occurred during the request: {e}"
- except Exception as e:
- yield f"An unexpected error occurred: {e}"
+ async with ClientSession() as session:
+ async with session.post(cls.api_endpoint, headers=headers, json=data, proxy=proxy) as response:
+ await raise_for_status(response)
+
+ sources = []
+ async for line in response.content:
+ decoded_line = line.decode('utf-8').strip()
+ if not decoded_line.startswith('data: '):
+ continue
+ data = decoded_line[6:]
+ if data in ('[DONE]', '{"done": ""}'):
+ break
+ try:
+ json_data = json.loads(data)
+ except json.JSONDecodeError:
+ continue
+
+ if 'url' in json_data and 'title' in json_data:
+ if web_search:
+ sources.append({'title': json_data['title'], 'url': json_data['url']})
+
+ elif 'choices' in json_data:
+ for choice in json_data['choices']:
+ delta = choice.get('delta', {})
+ content = delta.get('content', '')
+ if content:
+ yield content
+
+ if web_search and sources:
+ sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)])
+ yield f"\n\n**Source:**\n{sources_text}" \ No newline at end of file
diff --git a/g4f/Provider/needs_auth/Cerebras.py b/g4f/Provider/needs_auth/Cerebras.py
new file mode 100644
index 00000000..0f94c476
--- /dev/null
+++ b/g4f/Provider/needs_auth/Cerebras.py
@@ -0,0 +1,65 @@
+from __future__ import annotations
+
+import requests
+from aiohttp import ClientSession
+
+from .OpenaiAPI import OpenaiAPI
+from ...typing import AsyncResult, Messages, Cookies
+from ...requests.raise_for_status import raise_for_status
+from ...cookies import get_cookies
+
+class Cerebras(OpenaiAPI):
+ label = "Cerebras Inference"
+ url = "https://inference.cerebras.ai/"
+ working = True
+ default_model = "llama3.1-70b"
+ fallback_models = [
+ "llama3.1-70b",
+ "llama3.1-8b",
+ ]
+ model_aliases = {"llama-3.1-70b": "llama3.1-70b", "llama-3.1-8b": "llama3.1-8b"}
+
+ @classmethod
+ def get_models(cls, api_key: str = None):
+ if not cls.models:
+ try:
+ headers = {}
+ if api_key:
+ headers["authorization"] = f"Bearer ${api_key}"
+ response = requests.get(f"https://api.cerebras.ai/v1/models", headers=headers)
+ raise_for_status(response)
+ data = response.json()
+ cls.models = [model.get("model") for model in data.get("models")]
+ except Exception:
+ cls.models = cls.fallback_models
+ return cls.models
+
+ @classmethod
+ async def create_async_generator(
+ cls,
+ model: str,
+ messages: Messages,
+ api_base: str = "https://api.cerebras.ai/v1",
+ api_key: str = None,
+ cookies: Cookies = None,
+ **kwargs
+ ) -> AsyncResult:
+ if api_key is None and cookies is None:
+ cookies = get_cookies(".cerebras.ai")
+ async with ClientSession(cookies=cookies) as session:
+ async with session.get("https://inference.cerebras.ai/api/auth/session") as response:
+ raise_for_status(response)
+ data = await response.json()
+ if data:
+ api_key = data.get("user", {}).get("demoApiKey")
+ async for chunk in super().create_async_generator(
+ model, messages,
+ api_base=api_base,
+ impersonate="chrome",
+ api_key=api_key,
+ headers={
+ "User-Agent": "ex/JS 1.5.0",
+ },
+ **kwargs
+ ):
+ yield chunk
diff --git a/g4f/Provider/needs_auth/CopilotAccount.py b/g4f/Provider/needs_auth/CopilotAccount.py
index 76e51278..497aab98 100644
--- a/g4f/Provider/needs_auth/CopilotAccount.py
+++ b/g4f/Provider/needs_auth/CopilotAccount.py
@@ -1,9 +1,12 @@
from __future__ import annotations
+from ..base_provider import ProviderModelMixin
from ..Copilot import Copilot
-class CopilotAccount(Copilot):
+class CopilotAccount(Copilot, ProviderModelMixin):
needs_auth = True
parent = "Copilot"
default_model = "Copilot"
- default_vision_model = default_model \ No newline at end of file
+ default_vision_model = default_model
+ models = [default_model]
+ image_models = models \ No newline at end of file
diff --git a/g4f/Provider/needs_auth/HuggingFace2.py b/g4f/Provider/needs_auth/HuggingFace2.py
new file mode 100644
index 00000000..847d459b
--- /dev/null
+++ b/g4f/Provider/needs_auth/HuggingFace2.py
@@ -0,0 +1,28 @@
+from __future__ import annotations
+
+from .OpenaiAPI import OpenaiAPI
+from ..HuggingChat import HuggingChat
+from ...typing import AsyncResult, Messages
+
+class HuggingFace2(OpenaiAPI):
+ label = "HuggingFace (Inference API)"
+ url = "https://huggingface.co"
+ working = True
+ default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
+ default_vision_model = default_model
+ models = [
+ *HuggingChat.models
+ ]
+
+ @classmethod
+ def create_async_generator(
+ cls,
+ model: str,
+ messages: Messages,
+ api_base: str = "https://api-inference.huggingface.co/v1",
+ max_tokens: int = 500,
+ **kwargs
+ ) -> AsyncResult:
+ return super().create_async_generator(
+ model, messages, api_base=api_base, max_tokens=max_tokens, **kwargs
+ ) \ No newline at end of file
diff --git a/g4f/Provider/needs_auth/OpenaiAPI.py b/g4f/Provider/needs_auth/OpenaiAPI.py
index 116b5f6f..83268b6d 100644
--- a/g4f/Provider/needs_auth/OpenaiAPI.py
+++ b/g4f/Provider/needs_auth/OpenaiAPI.py
@@ -34,6 +34,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
stop: Union[str, list[str]] = None,
stream: bool = False,
headers: dict = None,
+ impersonate: str = None,
extra_data: dict = {},
**kwargs
) -> AsyncResult:
@@ -55,7 +56,8 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
async with StreamSession(
proxies={"all": proxy},
headers=cls.get_headers(stream, api_key, headers),
- timeout=timeout
+ timeout=timeout,
+ impersonate=impersonate,
) as session:
data = filter_none(
messages=messages,
diff --git a/g4f/Provider/needs_auth/__init__.py b/g4f/Provider/needs_auth/__init__.py
index 0f430ab5..1c7fe7c5 100644
--- a/g4f/Provider/needs_auth/__init__.py
+++ b/g4f/Provider/needs_auth/__init__.py
@@ -1,6 +1,7 @@
from .gigachat import *
from .BingCreateImages import BingCreateImages
+from .Cerebras import Cerebras
from .CopilotAccount import CopilotAccount
from .DeepInfra import DeepInfra
from .DeepInfraImage import DeepInfraImage
@@ -8,6 +9,7 @@ from .Gemini import Gemini
from .GeminiPro import GeminiPro
from .Groq import Groq
from .HuggingFace import HuggingFace
+from .HuggingFace2 import HuggingFace2
from .MetaAI import MetaAI
from .MetaAIAccount import MetaAIAccount
from .OpenaiAPI import OpenaiAPI
diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py
index 1f3cdab1..549a244b 100644
--- a/g4f/client/__init__.py
+++ b/g4f/client/__init__.py
@@ -12,7 +12,7 @@ from typing import Union, AsyncIterator, Iterator, Coroutine
from ..providers.base_provider import AsyncGeneratorProvider
from ..image import ImageResponse, to_image, to_data_uri, is_accepted_format, EXTENSIONS_MAP
-from ..typing import Messages, Cookies, Image
+from ..typing import Messages, Image
from ..providers.types import ProviderType, FinishReason, BaseConversation
from ..errors import NoImageResponseError
from ..providers.retry_provider import IterListProvider
@@ -254,6 +254,8 @@ class Images:
provider_handler = self.models.get(model, provider or self.provider or BingCreateImages)
elif isinstance(provider, str):
provider_handler = convert_to_provider(provider)
+ else:
+ provider_handler = provider
if provider_handler is None:
raise ValueError(f"Unknown model: {model}")
if proxy is None:
diff --git a/g4f/gui/client/index.html b/g4f/gui/client/index.html
index 3a2197de..48214093 100644
--- a/g4f/gui/client/index.html
+++ b/g4f/gui/client/index.html
@@ -129,6 +129,10 @@
<textarea id="BingCreateImages-api_key" name="BingCreateImages[api_key]" placeholder="&quot;_U&quot; cookie"></textarea>
</div>
<div class="field box">
+ <label for="Cerebras-api_key" class="label" title="">Cerebras Inference:</label>
+ <textarea id="Cerebras-api_key" name="Cerebras[api_key]" placeholder="api_key"></textarea>
+ </div>
+ <div class="field box">
<label for="DeepInfra-api_key" class="label" title="">DeepInfra:</label>
<textarea id="DeepInfra-api_key" name="DeepInfra[api_key]" class="DeepInfraImage-api_key" placeholder="api_key"></textarea>
</div>
@@ -142,7 +146,7 @@
</div>
<div class="field box">
<label for="HuggingFace-api_key" class="label" title="">HuggingFace:</label>
- <textarea id="HuggingFace-api_key" name="HuggingFace[api_key]" placeholder="api_key"></textarea>
+ <textarea id="HuggingFace-api_key" name="HuggingFace[api_key]" class="HuggingFace2-api_key" placeholder="api_key"></textarea>
</div>
<div class="field box">
<label for="Openai-api_key" class="label" title="">OpenAI API:</label>
@@ -192,7 +196,7 @@
<div class="stop_generating stop_generating-hidden">
<button id="cancelButton">
<span>Stop Generating</span>
- <i class="fa-regular fa-stop"></i>
+ <i class="fa-solid fa-stop"></i>
</button>
</div>
<div class="regenerate">
diff --git a/g4f/gui/client/static/css/style.css b/g4f/gui/client/static/css/style.css
index e435094f..76399703 100644
--- a/g4f/gui/client/static/css/style.css
+++ b/g4f/gui/client/static/css/style.css
@@ -512,9 +512,7 @@ body {
@media only screen and (min-width: 40em) {
.stop_generating {
- left: 50%;
- transform: translateX(-50%);
- right: auto;
+ right: 4px;
}
.toolbar .regenerate span {
display: block;
diff --git a/g4f/gui/client/static/js/chat.v1.js b/g4f/gui/client/static/js/chat.v1.js
index 51bf8b81..a3e94ee2 100644
--- a/g4f/gui/client/static/js/chat.v1.js
+++ b/g4f/gui/client/static/js/chat.v1.js
@@ -215,7 +215,6 @@ const register_message_buttons = async () => {
const message_el = el.parentElement.parentElement.parentElement;
el.classList.add("clicked");
setTimeout(() => el.classList.remove("clicked"), 1000);
- await hide_message(window.conversation_id, message_el.dataset.index);
await ask_gpt(message_el.dataset.index, get_message_id());
})
}
@@ -317,6 +316,7 @@ async function remove_cancel_button() {
regenerate.addEventListener("click", async () => {
regenerate.classList.add("regenerate-hidden");
+ setTimeout(()=>regenerate.classList.remove("regenerate-hidden"), 3000);
stop_generating.classList.remove("stop_generating-hidden");
await hide_message(window.conversation_id);
await ask_gpt(-1, get_message_id());
@@ -383,12 +383,12 @@ const prepare_messages = (messages, message_index = -1) => {
return new_messages;
}
-async function add_message_chunk(message, message_index) {
- content_map = content_storage[message_index];
+async function add_message_chunk(message, message_id) {
+ content_map = content_storage[message_id];
if (message.type == "conversation") {
console.info("Conversation used:", message.conversation)
} else if (message.type == "provider") {
- provider_storage[message_index] = message.provider;
+ provider_storage[message_id] = message.provider;
content_map.content.querySelector('.provider').innerHTML = `
<a href="${message.provider.url}" target="_blank">
${message.provider.label ? message.provider.label : message.provider.name}
@@ -398,7 +398,7 @@ async function add_message_chunk(message, message_index) {
} else if (message.type == "message") {
console.error(message.message)
} else if (message.type == "error") {
- error_storage[message_index] = message.error
+ error_storage[message_id] = message.error
console.error(message.error);
content_map.inner.innerHTML += `<p><strong>An error occured:</strong> ${message.error}</p>`;
let p = document.createElement("p");
@@ -407,8 +407,8 @@ async function add_message_chunk(message, message_index) {
} else if (message.type == "preview") {
content_map.inner.innerHTML = markdown_render(message.preview);
} else if (message.type == "content") {
- message_storage[message_index] += message.content;
- html = markdown_render(message_storage[message_index]);
+ message_storage[message_id] += message.content;
+ html = markdown_render(message_storage[message_id]);
let lastElement, lastIndex = null;
for (element of ['</p>', '</code></pre>', '</p>\n</li>\n</ol>', '</li>\n</ol>', '</li>\n</ul>']) {
const index = html.lastIndexOf(element)
@@ -421,7 +421,7 @@ async function add_message_chunk(message, message_index) {
html = html.substring(0, lastIndex) + '<span class="cursor"></span>' + lastElement;
}
content_map.inner.innerHTML = html;
- content_map.count.innerText = count_words_and_tokens(message_storage[message_index], provider_storage[message_index]?.model);
+ content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model);
highlight(content_map.inner);
} else if (message.type == "log") {
let p = document.createElement("p");
@@ -453,7 +453,7 @@ const ask_gpt = async (message_index = -1, message_id) => {
let total_messages = messages.length;
messages = prepare_messages(messages, message_index);
message_index = total_messages
- message_storage[message_index] = "";
+ message_storage[message_id] = "";
stop_generating.classList.remove(".stop_generating-hidden");
message_box.scrollTop = message_box.scrollHeight;
@@ -477,10 +477,10 @@ const ask_gpt = async (message_index = -1, message_id) => {
</div>
`;
- controller_storage[message_index] = new AbortController();
+ controller_storage[message_id] = new AbortController();
let content_el = document.getElementById(`gpt_${message_id}`)
- let content_map = content_storage[message_index] = {
+ let content_map = content_storage[message_id] = {
content: content_el,
inner: content_el.querySelector('.content_inner'),
count: content_el.querySelector('.count'),
@@ -492,12 +492,7 @@ const ask_gpt = async (message_index = -1, message_id) => {
const file = input && input.files.length > 0 ? input.files[0] : null;
const provider = providerSelect.options[providerSelect.selectedIndex].value;
const auto_continue = document.getElementById("auto_continue")?.checked;
- let api_key = null;
- if (provider) {
- api_key = document.getElementById(`${provider}-api_key`)?.value || null;
- if (api_key == null)
- api_key = document.querySelector(`.${provider}-api_key`)?.value || null;
- }
+ let api_key = get_api_key_by_provider(provider);
await api("conversation", {
id: message_id,
conversation_id: window.conversation_id,
@@ -506,10 +501,10 @@ const ask_gpt = async (message_index = -1, message_id) => {
provider: provider,
messages: messages,
auto_continue: auto_continue,
- api_key: api_key
- }, file, message_index);
- if (!error_storage[message_index]) {
- html = markdown_render(message_storage[message_index]);
+ api_key: api_key,
+ }, file, message_id);
+ if (!error_storage[message_id]) {
+ html = markdown_render(message_storage[message_id]);
content_map.inner.innerHTML = html;
highlight(content_map.inner);
@@ -520,14 +515,14 @@ const ask_gpt = async (message_index = -1, message_id) => {
} catch (e) {
console.error(e);
if (e.name != "AbortError") {
- error_storage[message_index] = true;
+ error_storage[message_id] = true;
content_map.inner.innerHTML += `<p><strong>An error occured:</strong> ${e}</p>`;
}
}
- delete controller_storage[message_index];
- if (!error_storage[message_index] && message_storage[message_index]) {
- const message_provider = message_index in provider_storage ? provider_storage[message_index] : null;
- await add_message(window.conversation_id, "assistant", message_storage[message_index], message_provider);
+ delete controller_storage[message_id];
+ if (!error_storage[message_id] && message_storage[message_id]) {
+ const message_provider = message_id in provider_storage ? provider_storage[message_id] : null;
+ await add_message(window.conversation_id, "assistant", message_storage[message_id], message_provider);
await safe_load_conversation(window.conversation_id);
} else {
let cursorDiv = message_box.querySelector(".cursor");
@@ -1156,7 +1151,7 @@ async function on_api() {
evt.preventDefault();
console.log("pressed enter");
prompt_lock = true;
- setTimeout(()=>prompt_lock=false, 3);
+ setTimeout(()=>prompt_lock=false, 3000);
await handle_ask();
} else {
messageInput.style.removeProperty("height");
@@ -1167,7 +1162,7 @@ async function on_api() {
console.log("clicked send");
if (prompt_lock) return;
prompt_lock = true;
- setTimeout(()=>prompt_lock=false, 3);
+ setTimeout(()=>prompt_lock=false, 3000);
await handle_ask();
});
messageInput.focus();
@@ -1189,8 +1184,8 @@ async function on_api() {
providerSelect.appendChild(option);
})
- await load_provider_models(appStorage.getItem("provider"));
await load_settings_storage()
+ await load_provider_models(appStorage.getItem("provider"));
const hide_systemPrompt = document.getElementById("hide-systemPrompt")
const slide_systemPrompt_icon = document.querySelector(".slide-systemPrompt i");
@@ -1316,7 +1311,7 @@ function get_selected_model() {
}
}
-async function api(ressource, args=null, file=null, message_index=null) {
+async function api(ressource, args=null, file=null, message_id=null) {
if (window?.pywebview) {
if (args !== null) {
if (ressource == "models") {
@@ -1326,15 +1321,19 @@ async function api(ressource, args=null, file=null, message_index=null) {
}
return pywebview.api[`get_${ressource}`]();
}
+ let api_key;
if (ressource == "models" && args) {
+ api_key = get_api_key_by_provider(args);
ressource = `${ressource}/${args}`;
}
const url = `/backend-api/v2/${ressource}`;
+ const headers = {};
+ if (api_key) {
+ headers.authorization = `Bearer ${api_key}`;
+ }
if (ressource == "conversation") {
let body = JSON.stringify(args);
- const headers = {
- accept: 'text/event-stream'
- }
+ headers.accept = 'text/event-stream';
if (file !== null) {
const formData = new FormData();
formData.append('file', file);
@@ -1345,17 +1344,17 @@ async function api(ressource, args=null, file=null, message_index=null) {
}
response = await fetch(url, {
method: 'POST',
- signal: controller_storage[message_index].signal,
+ signal: controller_storage[message_id].signal,
headers: headers,
- body: body
+ body: body,
});
- return read_response(response, message_index);
+ return read_response(response, message_id);
}
- response = await fetch(url);
+ response = await fetch(url, {headers: headers});
return await response.json();
}
-async function read_response(response, message_index) {
+async function read_response(response, message_id) {
const reader = response.body.pipeThrough(new TextDecoderStream()).getReader();
let buffer = ""
while (true) {
@@ -1368,7 +1367,7 @@ async function read_response(response, message_index) {
continue;
}
try {
- add_message_chunk(JSON.parse(buffer + line), message_index);
+ add_message_chunk(JSON.parse(buffer + line), message_id);
buffer = "";
} catch {
buffer += line
@@ -1377,6 +1376,16 @@ async function read_response(response, message_index) {
}
}
+function get_api_key_by_provider(provider) {
+ let api_key = null;
+ if (provider) {
+ api_key = document.getElementById(`${provider}-api_key`)?.value || null;
+ if (api_key == null)
+ api_key = document.querySelector(`.${provider}-api_key`)?.value || null;
+ }
+ return api_key;
+}
+
async function load_provider_models(providerIndex=null) {
if (!providerIndex) {
providerIndex = providerSelect.selectedIndex;
diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py
index 6be77d09..2d871ff3 100644
--- a/g4f/gui/server/api.py
+++ b/g4f/gui/server/api.py
@@ -38,10 +38,11 @@ class Api:
return models._all_models
@staticmethod
- def get_provider_models(provider: str) -> list[dict]:
+ def get_provider_models(provider: str, api_key: str = None) -> list[dict]:
if provider in __map__:
provider: ProviderType = __map__[provider]
if issubclass(provider, ProviderModelMixin):
+ models = provider.get_models() if api_key is None else provider.get_models(api_key=api_key)
return [
{
"model": model,
@@ -49,7 +50,7 @@ class Api:
"vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []),
"image": model in getattr(provider, "image_models", []),
}
- for model in provider.get_models()
+ for model in models
]
return []
diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py
index dc1b1080..020e49ef 100644
--- a/g4f/gui/server/backend.py
+++ b/g4f/gui/server/backend.py
@@ -94,7 +94,8 @@ class Backend_Api(Api):
)
def get_provider_models(self, provider: str):
- models = super().get_provider_models(provider)
+ api_key = None if request.authorization is None else request.authorization.token
+ models = super().get_provider_models(provider, api_key)
if models is None:
return 404, "Provider not found"
return models
diff --git a/g4f/requests/raise_for_status.py b/g4f/requests/raise_for_status.py
index fe262f34..8625f552 100644
--- a/g4f/requests/raise_for_status.py
+++ b/g4f/requests/raise_for_status.py
@@ -11,7 +11,9 @@ class CloudflareError(ResponseStatusError):
...
def is_cloudflare(text: str) -> bool:
- if "<title>Attention Required! | Cloudflare</title>" in text or 'id="cf-cloudflare-status"' in text:
+ if "Generated by cloudfront" in text:
+ return True
+ elif "<title>Attention Required! | Cloudflare</title>" in text or 'id="cf-cloudflare-status"' in text:
return True
return '<div id="cf-please-wait">' in text or "<title>Just a moment...</title>" in text