summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTekky <98614666+xtekky@users.noreply.github.com>2023-09-22 21:40:59 +0200
committerGitHub <noreply@github.com>2023-09-22 21:40:59 +0200
commitba287e89b55118965ff0e151e54636b1f50d3b38 (patch)
treedc69218fecae4971c90ae391ff6919c032b93540
parent~ | gpt-3.5-turbo-0613 (diff)
parentAdd RetryProvider (diff)
downloadgpt4free-ba287e89b55118965ff0e151e54636b1f50d3b38.tar
gpt4free-ba287e89b55118965ff0e151e54636b1f50d3b38.tar.gz
gpt4free-ba287e89b55118965ff0e151e54636b1f50d3b38.tar.bz2
gpt4free-ba287e89b55118965ff0e151e54636b1f50d3b38.tar.lz
gpt4free-ba287e89b55118965ff0e151e54636b1f50d3b38.tar.xz
gpt4free-ba287e89b55118965ff0e151e54636b1f50d3b38.tar.zst
gpt4free-ba287e89b55118965ff0e151e54636b1f50d3b38.zip
-rw-r--r--README.md65
-rw-r--r--g4f/Provider/Aivvm.py2
-rw-r--r--g4f/Provider/Bard.py17
-rw-r--r--g4f/Provider/ChatgptLogin.py9
-rw-r--r--g4f/Provider/CodeLinkAva.py7
-rw-r--r--g4f/Provider/H2o.py16
-rw-r--r--g4f/Provider/HuggingChat.py12
-rw-r--r--g4f/Provider/PerplexityAi.py87
-rw-r--r--g4f/Provider/Vitalentum.py4
-rw-r--r--g4f/Provider/__init__.py8
-rw-r--r--g4f/Provider/base_provider.py50
-rw-r--r--g4f/Provider/retry_provider.py81
-rw-r--r--g4f/__init__.py74
-rw-r--r--g4f/models.py31
-rw-r--r--testing/test_chat_completion.py23
-rw-r--r--testing/test_providers.py22
16 files changed, 381 insertions, 127 deletions
diff --git a/README.md b/README.md
index 763417ba..ed8dc577 100644
--- a/README.md
+++ b/README.md
@@ -238,43 +238,42 @@ response = g4f.ChatCompletion.create(
##### Async Support:
-To enhance speed and overall performance, execute providers asynchronously. The total execution time will be determined by the duration of the slowest provider's execution.
+To enhance speed and overall performance, execute providers asynchronously.
+The total execution time will be determined by the duration of the slowest provider's execution.
```py
import g4f, asyncio
-async def run_async():
- _providers = [
- g4f.Provider.AItianhu,
- g4f.Provider.Acytoo,
- g4f.Provider.Aichat,
- g4f.Provider.Ails,
- g4f.Provider.Aivvm,
- g4f.Provider.ChatBase,
- g4f.Provider.ChatgptAi,
- g4f.Provider.ChatgptLogin,
- g4f.Provider.CodeLinkAva,
- g4f.Provider.DeepAi,
- g4f.Provider.Opchatgpts,
- g4f.Provider.Vercel,
- g4f.Provider.Vitalentum,
- g4f.Provider.Wewordle,
- g4f.Provider.Ylokh,
- g4f.Provider.You,
- g4f.Provider.Yqcloud,
- ]
- responses = [
- provider.create_async(
- model=g4f.models.default,
- messages=[{"role": "user", "content": "Hello"}],
- )
- for provider in _providers
- ]
- responses = await asyncio.gather(*responses)
- for idx, provider in enumerate(_providers):
- print(f"{provider.__name__}:", responses[idx])
-
-asyncio.run(run_async())
+_providers = [
+ g4f.Provider.Aichat,
+ g4f.Provider.Aivvm,
+ g4f.Provider.ChatBase,
+ g4f.Provider.Bing,
+ g4f.Provider.CodeLinkAva,
+ g4f.Provider.DeepAi,
+ g4f.Provider.GptGo,
+ g4f.Provider.Wewordle,
+ g4f.Provider.You,
+ g4f.Provider.Yqcloud,
+]
+
+async def run_provider(provider: g4f.Provider.AsyncProvider):
+ try:
+ response = await provider.create_async(
+ model=g4f.models.default.name,
+ messages=[{"role": "user", "content": "Hello"}],
+ )
+ print(f"{provider.__name__}:", response)
+ except Exception as e:
+ print(f"{provider.__name__}:", e)
+
+async def run_all():
+ calls = [
+ run_provider(provider) for provider in _providers
+ ]
+ await asyncio.gather(*calls)
+
+asyncio.run(run_all())
```
### interference openai-proxy api (use with openai python package)
diff --git a/g4f/Provider/Aivvm.py b/g4f/Provider/Aivvm.py
index dbfc588d..b2d7c139 100644
--- a/g4f/Provider/Aivvm.py
+++ b/g4f/Provider/Aivvm.py
@@ -41,7 +41,7 @@ class Aivvm(AsyncGeneratorProvider):
headers = {
"User-Agent" : "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36",
"Accept" : "*/*",
- "Accept-language" : "en,fr-FR;q=0.9,fr;q=0.8,es-ES;q=0.7,es;q=0.6,en-US;q=0.5,am;q=0.4,de;q=0.3",
+ "Accept-Language" : "en,fr-FR;q=0.9,fr;q=0.8,es-ES;q=0.7,es;q=0.6,en-US;q=0.5,am;q=0.4,de;q=0.3",
"Origin" : cls.url,
"Referer" : cls.url + "/",
"Sec-Fetch-Dest" : "empty",
diff --git a/g4f/Provider/Bard.py b/g4f/Provider/Bard.py
index 2137d820..4e076378 100644
--- a/g4f/Provider/Bard.py
+++ b/g4f/Provider/Bard.py
@@ -13,6 +13,7 @@ class Bard(AsyncProvider):
url = "https://bard.google.com"
needs_auth = True
working = True
+ _snlm0e = None
@classmethod
async def create_async(
@@ -31,7 +32,6 @@ class Bard(AsyncProvider):
headers = {
'authority': 'bard.google.com',
- 'content-type': 'application/x-www-form-urlencoded;charset=UTF-8',
'origin': 'https://bard.google.com',
'referer': 'https://bard.google.com/',
'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
@@ -42,13 +42,14 @@ class Bard(AsyncProvider):
cookies=cookies,
headers=headers
) as session:
- async with session.get(cls.url, proxy=proxy) as response:
- text = await response.text()
+ if not cls._snlm0e:
+ async with session.get(cls.url, proxy=proxy) as response:
+ text = await response.text()
- match = re.search(r'SNlM0e\":\"(.*?)\"', text)
- if not match:
- raise RuntimeError("No snlm0e value.")
- snlm0e = match.group(1)
+ match = re.search(r'SNlM0e\":\"(.*?)\"', text)
+ if not match:
+ raise RuntimeError("No snlm0e value.")
+ cls._snlm0e = match.group(1)
params = {
'bl': 'boq_assistant-bard-web-server_20230326.21_p0',
@@ -57,7 +58,7 @@ class Bard(AsyncProvider):
}
data = {
- 'at': snlm0e,
+ 'at': cls._snlm0e,
'f.req': json.dumps([None, json.dumps([[prompt]])])
}
diff --git a/g4f/Provider/ChatgptLogin.py b/g4f/Provider/ChatgptLogin.py
index 8b868f8e..3eb55a64 100644
--- a/g4f/Provider/ChatgptLogin.py
+++ b/g4f/Provider/ChatgptLogin.py
@@ -52,7 +52,14 @@ class ChatgptLogin(AsyncProvider):
}
async with session.post("https://opchatgpts.net/wp-admin/admin-ajax.php", data=data) as response:
response.raise_for_status()
- return (await response.json())["data"]
+ data = await response.json()
+ if "data" in data:
+ return data["data"]
+ elif "msg" in data:
+ raise RuntimeError(data["msg"])
+ else:
+ raise RuntimeError(f"Response: {data}")
+
@classmethod
@property
diff --git a/g4f/Provider/CodeLinkAva.py b/g4f/Provider/CodeLinkAva.py
index 3ab4e264..e3b3eb3e 100644
--- a/g4f/Provider/CodeLinkAva.py
+++ b/g4f/Provider/CodeLinkAva.py
@@ -40,11 +40,12 @@ class CodeLinkAva(AsyncGeneratorProvider):
}
async with session.post("https://ava-alpha-api.codelink.io/api/chat", json=data) as response:
response.raise_for_status()
- start = "data: "
async for line in response.content:
line = line.decode()
- if line.startswith("data: ") and not line.startswith("data: [DONE]"):
- line = json.loads(line[len(start):-1])
+ if line.startswith("data: "):
+ if line.startswith("data: [DONE]"):
+ break
+ line = json.loads(line[6:-1])
content = line["choices"][0]["delta"].get("content")
if content:
yield content
diff --git a/g4f/Provider/H2o.py b/g4f/Provider/H2o.py
index 30090a58..d92bd6d1 100644
--- a/g4f/Provider/H2o.py
+++ b/g4f/Provider/H2o.py
@@ -23,7 +23,7 @@ class H2o(AsyncGeneratorProvider):
**kwargs
) -> AsyncGenerator:
model = model if model else cls.model
- headers = {"Referer": "https://gpt-gm.h2o.ai/"}
+ headers = {"Referer": cls.url + "/"}
async with ClientSession(
headers=headers
@@ -36,14 +36,14 @@ class H2o(AsyncGeneratorProvider):
"searchEnabled": "true",
}
async with session.post(
- "https://gpt-gm.h2o.ai/settings",
+ f"{cls.url}/settings",
proxy=proxy,
data=data
) as response:
response.raise_for_status()
async with session.post(
- "https://gpt-gm.h2o.ai/conversation",
+ f"{cls.url}/conversation",
proxy=proxy,
json={"model": model},
) as response:
@@ -71,7 +71,7 @@ class H2o(AsyncGeneratorProvider):
},
}
async with session.post(
- f"https://gpt-gm.h2o.ai/conversation/{conversationId}",
+ f"{cls.url}/conversation/{conversationId}",
proxy=proxy,
json=data
) as response:
@@ -83,6 +83,14 @@ class H2o(AsyncGeneratorProvider):
if not line["token"]["special"]:
yield line["token"]["text"]
+ async with session.delete(
+ f"{cls.url}/conversation/{conversationId}",
+ proxy=proxy,
+ json=data
+ ) as response:
+ response.raise_for_status()
+
+
@classmethod
@property
def params(cls):
diff --git a/g4f/Provider/HuggingChat.py b/g4f/Provider/HuggingChat.py
index 85f879f3..7702c9dd 100644
--- a/g4f/Provider/HuggingChat.py
+++ b/g4f/Provider/HuggingChat.py
@@ -25,10 +25,10 @@ class HuggingChat(AsyncGeneratorProvider):
**kwargs
) -> AsyncGenerator:
model = model if model else cls.model
- if not cookies:
- cookies = get_cookies(".huggingface.co")
if proxy and "://" not in proxy:
proxy = f"http://{proxy}"
+ if not cookies:
+ cookies = get_cookies(".huggingface.co")
headers = {
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
@@ -37,7 +37,7 @@ class HuggingChat(AsyncGeneratorProvider):
cookies=cookies,
headers=headers
) as session:
- async with session.post("https://huggingface.co/chat/conversation", proxy=proxy, json={"model": model}) as response:
+ async with session.post(f"{cls.url}/conversation", proxy=proxy, json={"model": model}) as response:
conversation_id = (await response.json())["conversationId"]
send = {
@@ -62,7 +62,7 @@ class HuggingChat(AsyncGeneratorProvider):
"web_search_id": ""
}
}
- async with session.post(f"https://huggingface.co/chat/conversation/{conversation_id}", proxy=proxy, json=send) as response:
+ async with session.post(f"{cls.url}/conversation/{conversation_id}", proxy=proxy, json=send) as response:
if not stream:
data = await response.json()
if "error" in data:
@@ -76,8 +76,6 @@ class HuggingChat(AsyncGeneratorProvider):
first = True
async for line in response.content:
line = line.decode("utf-8")
- if not line:
- continue
if line.startswith(start):
line = json.loads(line[len(start):-1])
if "token" not in line:
@@ -89,7 +87,7 @@ class HuggingChat(AsyncGeneratorProvider):
else:
yield line["token"]["text"]
- async with session.delete(f"https://huggingface.co/chat/conversation/{conversation_id}", proxy=proxy) as response:
+ async with session.delete(f"{cls.url}/conversation/{conversation_id}", proxy=proxy) as response:
response.raise_for_status()
diff --git a/g4f/Provider/PerplexityAi.py b/g4f/Provider/PerplexityAi.py
new file mode 100644
index 00000000..269cdafd
--- /dev/null
+++ b/g4f/Provider/PerplexityAi.py
@@ -0,0 +1,87 @@
+from __future__ import annotations
+
+import json
+import time
+import base64
+from curl_cffi.requests import AsyncSession
+
+from .base_provider import AsyncProvider, format_prompt
+
+
+class PerplexityAi(AsyncProvider):
+ url = "https://www.perplexity.ai"
+ working = True
+ supports_gpt_35_turbo = True
+ _sources = []
+
+ @classmethod
+ async def create_async(
+ cls,
+ model: str,
+ messages: list[dict[str, str]],
+ proxy: str = None,
+ **kwargs
+ ) -> str:
+ url = cls.url + "/socket.io/?EIO=4&transport=polling"
+ async with AsyncSession(proxies={"https": proxy}, impersonate="chrome107") as session:
+ url_session = "https://www.perplexity.ai/api/auth/session"
+ response = await session.get(url_session)
+
+ response = await session.get(url, params={"t": timestamp()})
+ response.raise_for_status()
+ sid = json.loads(response.text[1:])["sid"]
+
+ data = '40{"jwt":"anonymous-ask-user"}'
+ response = await session.post(url, params={"t": timestamp(), "sid": sid}, data=data)
+ response.raise_for_status()
+
+ data = "424" + json.dumps([
+ "perplexity_ask",
+ format_prompt(messages),
+ {
+ "version":"2.1",
+ "source":"default",
+ "language":"en",
+ "timezone": time.tzname[0],
+ "search_focus":"internet",
+ "mode":"concise"
+ }
+ ])
+ response = await session.post(url, params={"t": timestamp(), "sid": sid}, data=data)
+ response.raise_for_status()
+
+ while True:
+ response = await session.get(url, params={"t": timestamp(), "sid": sid})
+ response.raise_for_status()
+ for line in response.text.splitlines():
+ if line.startswith("434"):
+ result = json.loads(json.loads(line[3:])[0]["text"])
+
+ cls._sources = [{
+ "name": source["name"],
+ "url": source["url"],
+ "snippet": source["snippet"]
+ } for source in result["web_results"]]
+
+ return result["answer"]
+
+ @classmethod
+ def get_sources(cls):
+ return cls._sources
+
+
+ @classmethod
+ @property
+ def params(cls):
+ params = [
+ ("model", "str"),
+ ("messages", "list[dict[str, str]]"),
+ ("stream", "bool"),
+ ("proxy", "str"),
+ ]
+ param = ", ".join([": ".join(p) for p in params])
+ return f"g4f.provider.{cls.__name__} supports: ({param})"
+
+
+def timestamp() -> str:
+ return base64.urlsafe_b64encode(int(time.time()-1407782612).to_bytes(4, 'big')).decode() \ No newline at end of file
diff --git a/g4f/Provider/Vitalentum.py b/g4f/Provider/Vitalentum.py
index 31ad8b80..d5265428 100644
--- a/g4f/Provider/Vitalentum.py
+++ b/g4f/Provider/Vitalentum.py
@@ -46,7 +46,9 @@ class Vitalentum(AsyncGeneratorProvider):
response.raise_for_status()
async for line in response.content:
line = line.decode()
- if line.startswith("data: ") and not line.startswith("data: [DONE]"):
+ if line.startswith("data: "):
+ if line.startswith("data: [DONE]"):
+ break
line = json.loads(line[6:-1])
content = line["choices"][0]["delta"].get("content")
if content:
diff --git a/g4f/Provider/__init__.py b/g4f/Provider/__init__.py
index c36782b4..b9ee2544 100644
--- a/g4f/Provider/__init__.py
+++ b/g4f/Provider/__init__.py
@@ -24,6 +24,7 @@ from .Lockchat import Lockchat
from .Opchatgpts import Opchatgpts
from .OpenaiChat import OpenaiChat
from .OpenAssistant import OpenAssistant
+from .PerplexityAi import PerplexityAi
from .Raycast import Raycast
from .Theb import Theb
from .Vercel import Vercel
@@ -37,10 +38,14 @@ from .FastGpt import FastGpt
from .V50 import V50
from .Wuguokai import Wuguokai
-from .base_provider import BaseProvider, AsyncProvider, AsyncGeneratorProvider
+from .base_provider import BaseProvider, AsyncProvider, AsyncGeneratorProvider
+from .retry_provider import RetryProvider
__all__ = [
'BaseProvider',
+ 'AsyncProvider',
+ 'AsyncGeneratorProvider',
+ 'RetryProvider',
'Acytoo',
'Aichat',
'Ails',
@@ -67,6 +72,7 @@ __all__ = [
'Raycast',
'OpenaiChat',
'OpenAssistant',
+ 'PerplexityAi',
'Theb',
'Vercel',
'Vitalentum',
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index 79f8f617..0cceb220 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -1,11 +1,12 @@
from __future__ import annotations
import asyncio
+from asyncio import SelectorEventLoop
from abc import ABC, abstractmethod
import browser_cookie3
-from ..typing import Any, AsyncGenerator, CreateResult, Union
+from ..typing import Any, AsyncGenerator, CreateResult
class BaseProvider(ABC):
@@ -21,10 +22,13 @@ class BaseProvider(ABC):
def create_completion(
model: str,
messages: list[dict[str, str]],
- stream: bool, **kwargs: Any) -> CreateResult:
+ stream: bool,
+ **kwargs
+ ) -> CreateResult:
raise NotImplementedError()
+
@classmethod
@property
def params(cls):
@@ -46,13 +50,19 @@ class AsyncProvider(BaseProvider):
stream: bool = False,
**kwargs
) -> CreateResult:
- yield asyncio.run(cls.create_async(model, messages, **kwargs))
+ loop = create_event_loop()
+ try:
+ yield loop.run_until_complete(cls.create_async(model, messages, **kwargs))
+ finally:
+ loop.close()
@staticmethod
@abstractmethod
async def create_async(
model: str,
- messages: list[dict[str, str]], **kwargs: Any) -> str:
+ messages: list[dict[str, str]],
+ **kwargs
+ ) -> str:
raise NotImplementedError()
@@ -67,10 +77,14 @@ class AsyncGeneratorProvider(AsyncProvider):
stream: bool = True,
**kwargs
) -> CreateResult:
- loop = asyncio.new_event_loop()
+ loop = create_event_loop()
try:
- asyncio.set_event_loop(loop)
- generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
+ generator = cls.create_async_generator(
+ model,
+ messages,
+ stream=stream,
+ **kwargs
+ )
gen = generator.__aiter__()
while True:
try:
@@ -78,10 +92,8 @@ class AsyncGeneratorProvider(AsyncProvider):
except StopAsyncIteration:
break
finally:
- asyncio.set_event_loop(None)
loop.close()
-
@classmethod
async def create_async(
cls,
@@ -89,7 +101,14 @@ class AsyncGeneratorProvider(AsyncProvider):
messages: list[dict[str, str]],
**kwargs
) -> str:
- return "".join([chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)])
+ return "".join([
+ chunk async for chunk in cls.create_async_generator(
+ model,
+ messages,
+ stream=False,
+ **kwargs
+ )
+ ])
@staticmethod
@abstractmethod
@@ -101,6 +120,17 @@ class AsyncGeneratorProvider(AsyncProvider):
raise NotImplementedError()
+# Don't create a new event loop in a running async loop.
+# Force use selector event loop on windows and linux use it anyway.
+def create_event_loop() -> SelectorEventLoop:
+ try:
+ asyncio.get_running_loop()
+ except RuntimeError:
+ return SelectorEventLoop()
+ raise RuntimeError(
+ 'Use "create_async" instead of "create" function in a async loop.')
+
+
_cookies = {}
def get_cookies(cookie_domain: str) -> dict:
diff --git a/g4f/Provider/retry_provider.py b/g4f/Provider/retry_provider.py
new file mode 100644
index 00000000..e1a9cd1f
--- /dev/null
+++ b/g4f/Provider/retry_provider.py
@@ -0,0 +1,81 @@
+from __future__ import annotations
+
+import random
+
+from ..typing import CreateResult
+from .base_provider import BaseProvider, AsyncProvider
+
+
+class RetryProvider(AsyncProvider):
+ __name__ = "RetryProvider"
+ working = True
+ needs_auth = False
+ supports_stream = True
+ supports_gpt_35_turbo = False
+ supports_gpt_4 = False
+
+ def __init__(
+ self,
+ providers: list[type[BaseProvider]],
+ shuffle: bool = True
+ ) -> None:
+ self.providers = providers
+ self.shuffle = shuffle
+
+
+ def create_completion(
+ self,
+ model: str,
+ messages: list[dict[str, str]],
+ stream: bool = False,
+ **kwargs
+ ) -> CreateResult:
+ if stream:
+ providers = [provider for provider in self.providers if provider.supports_stream]
+ else:
+ providers = self.providers
+ if self.shuffle:
+ random.shuffle(providers)
+
+ self.exceptions = {}
+ started = False
+ for provider in providers:
+ try:
+ for token in provider.create_completion(model, messages, stream, **kwargs):
+ yield token
+ started = True
+ if started:
+ return
+ except Exception as e:
+ self.exceptions[provider.__name__] = e
+ if started:
+ break
+
+ self.raise_exceptions()
+
+ async def create_async(
+ self,
+ model: str,
+ messages: list[dict[str, str]],
+ **kwargs
+ ) -> str:
+ providers = [provider for provider in self.providers if issubclass(provider, AsyncProvider)]
+ if self.shuffle:
+ random.shuffle(providers)
+
+ self.exceptions = {}
+ for provider in providers:
+ try:
+ return await provider.create_async(model, messages, **kwargs)
+ except Exception as e:
+ self.exceptions[provider.__name__] = e
+
+ self.raise_exceptions()
+
+ def raise_exceptions(self):
+ if self.exceptions:
+ raise RuntimeError("\n".join(["All providers failed:"] + [
+ f"{p}: {self.exceptions[p].__class__.__name__}: {self.exceptions[p]}" for p in self.exceptions
+ ]))
+
+ raise RuntimeError("No provider found") \ No newline at end of file
diff --git a/g4f/__init__.py b/g4f/__init__.py
index a49e60e9..8fdfe21f 100644
--- a/g4f/__init__.py
+++ b/g4f/__init__.py
@@ -1,11 +1,36 @@
from __future__ import annotations
from g4f import models
-from .Provider import BaseProvider
+from .Provider import BaseProvider, AsyncProvider
from .typing import Any, CreateResult, Union
import random
logging = False
+def get_model_and_provider(model: Union[models.Model, str], provider: type[BaseProvider], stream: bool):
+ if isinstance(model, str):
+ if model in models.ModelUtils.convert:
+ model = models.ModelUtils.convert[model]
+ else:
+ raise Exception(f'The model: {model} does not exist')
+
+ if not provider:
+ provider = model.best_provider
+
+ if not provider:
+ raise Exception(f'No provider found for model: {model}')
+
+ if not provider.working:
+ raise Exception(f'{provider.__name__} is not working')
+
+ if not provider.supports_stream and stream:
+ raise Exception(
+ f'ValueError: {provider.__name__} does not support "stream" argument')
+
+ if logging:
+ print(f'Using {provider.__name__} provider')
+
+ return model, provider
+
class ChatCompletion:
@staticmethod
def create(
@@ -13,28 +38,11 @@ class ChatCompletion:
messages : list[dict[str, str]],
provider : Union[type[BaseProvider], None] = None,
stream : bool = False,
- auth : Union[str, None] = None, **kwargs: Any) -> Union[CreateResult, str]:
+ auth : Union[str, None] = None,
+ **kwargs
+ ) -> Union[CreateResult, str]:
- if isinstance(model, str):
- if model in models.ModelUtils.convert:
- model = models.ModelUtils.convert[model]
- else:
- raise Exception(f'The model: {model} does not exist')
-
- if not provider:
- if isinstance(model.best_provider, list):
- if stream:
- provider = random.choice([p for p in model.best_provider if p.supports_stream])
- else:
- provider = random.choice(model.best_provider)
- else:
- provider = model.best_provider
-
- if not provider:
- raise Exception(f'No provider found')
-
- if not provider.working:
- raise Exception(f'{provider.__name__} is not working')
+ model, provider = get_model_and_provider(model, provider, stream)
if provider.needs_auth and not auth:
raise Exception(
@@ -43,12 +51,20 @@ class ChatCompletion:
if provider.needs_auth:
kwargs['auth'] = auth
- if not provider.supports_stream and stream:
- raise Exception(
- f'ValueError: {provider.__name__} does not support "stream" argument')
-
- if logging:
- print(f'Using {provider.__name__} provider')
-
result = provider.create_completion(model.name, messages, stream, **kwargs)
return result if stream else ''.join(result)
+
+ @staticmethod
+ async def create_async(
+ model : Union[models.Model, str],
+ messages : list[dict[str, str]],
+ provider : Union[type[BaseProvider], None] = None,
+ **kwargs
+ ) -> str:
+
+ model, provider = get_model_and_provider(model, provider, False)
+
+ if not issubclass(type(provider), AsyncProvider):
+ raise Exception(f"Provider: {provider.__name__} doesn't support create_async")
+
+ return await provider.create_async(model.name, messages, **kwargs)
diff --git a/g4f/models.py b/g4f/models.py
index 22311910..9889f0d5 100644
--- a/g4f/models.py
+++ b/g4f/models.py
@@ -1,20 +1,26 @@
from __future__ import annotations
from dataclasses import dataclass
from .typing import Union
-from .Provider import BaseProvider
+from .Provider import BaseProvider, RetryProvider
from .Provider import (
ChatgptLogin,
- CodeLinkAva,
ChatgptAi,
ChatBase,
Vercel,
DeepAi,
Aivvm,
Bard,
- H2o
+ H2o,
+ GptGo,
+ Bing,
+ PerplexityAi,
+ Wewordle,
+ Yqcloud,
+ AItianhu,
+ Aichat,
)
-@dataclass
+@dataclass(unsafe_hash=True)
class Model:
name: str
base_provider: str
@@ -24,15 +30,24 @@ class Model:
# Works for Liaobots, H2o, OpenaiChat, Yqcloud, You
default = Model(
name = "",
- base_provider = "huggingface")
+ base_provider = "",
+ best_provider = RetryProvider([
+ Bing, # Not fully GPT 3 or 4
+ PerplexityAi, # Adds references to sources
+ Wewordle, # Responds with markdown
+ Yqcloud, # Answers short questions in chinese
+ ChatBase, # Don't want to answer creatively
+ DeepAi, ChatgptLogin, ChatgptAi, Aivvm, GptGo, AItianhu, Aichat,
+ ])
+)
# GPT-3.5 / GPT-4
gpt_35_turbo = Model(
name = 'gpt-3.5-turbo',
base_provider = 'openai',
- best_provider = [
- DeepAi, CodeLinkAva, ChatgptLogin, ChatgptAi, ChatBase, Aivvm
- ]
+ best_provider = RetryProvider([
+ DeepAi, ChatgptLogin, ChatgptAi, Aivvm, GptGo, AItianhu, Aichat,
+ ])
)
gpt_4 = Model(
diff --git a/testing/test_chat_completion.py b/testing/test_chat_completion.py
index fbaa3169..d901e697 100644
--- a/testing/test_chat_completion.py
+++ b/testing/test_chat_completion.py
@@ -3,10 +3,23 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
-import g4f
+import g4f, asyncio
-response = g4f.ChatCompletion.create(
+print("create:", end=" ", flush=True)
+for response in g4f.ChatCompletion.create(
model=g4f.models.gpt_35_turbo,
- messages=[{"role": "user", "content": "hello, are you GPT 4?"}]
-)
-print(response) \ No newline at end of file
+ provider=g4f.Provider.GptGo,
+ messages=[{"role": "user", "content": "hello!"}],
+):
+ print(response, end="", flush=True)
+print()
+
+async def run_async():
+ response = await g4f.ChatCompletion.create_async(
+ model=g4f.models.gpt_35_turbo,
+ provider=g4f.Provider.GptGo,
+ messages=[{"role": "user", "content": "hello!"}],
+ )
+ print("create_async:", response)
+
+asyncio.run(run_async())
diff --git a/testing/test_providers.py b/testing/test_providers.py
index be04e7a3..5240119b 100644
--- a/testing/test_providers.py
+++ b/testing/test_providers.py
@@ -1,6 +1,6 @@
import sys
from pathlib import Path
-from colorama import Fore
+from colorama import Fore, Style
sys.path.append(str(Path(__file__).parent.parent))
@@ -8,10 +8,6 @@ from g4f import BaseProvider, models, Provider
logging = False
-class Styles:
- ENDC = "\033[0m"
- BOLD = "\033[1m"
- UNDERLINE = "\033[4m"
def main():
providers = get_providers()
@@ -29,11 +25,11 @@ def main():
print()
if failed_providers:
- print(f"{Fore.RED + Styles.BOLD}Failed providers:{Styles.ENDC}")
+ print(f"{Fore.RED + Style.BRIGHT}Failed providers:{Style.RESET_ALL}")
for _provider in failed_providers:
print(f"{Fore.RED}{_provider.__name__}")
else:
- print(f"{Fore.GREEN + Styles.BOLD}All providers are working")
+ print(f"{Fore.GREEN + Style.BRIGHT}All providers are working")
def get_providers() -> list[type[BaseProvider]]:
@@ -45,21 +41,15 @@ def get_providers() -> list[type[BaseProvider]]:
"AsyncProvider",
"AsyncGeneratorProvider"
]
- provider_names = [
- provider_name
+ return [
+ getattr(Provider, provider_name)
for provider_name in provider_names
if not provider_name.startswith("__") and provider_name not in ignore_names
]
- return [getattr(Provider, provider_name) for provider_name in provider_names]
def create_response(_provider: type[BaseProvider]) -> str:
- if _provider.supports_gpt_35_turbo:
- model = models.gpt_35_turbo.name
- elif _provider.supports_gpt_4:
- model = models.gpt_4.name
- else:
- model = models.default.name
+ model = models.gpt_35_turbo.name if _provider.supports_gpt_35_turbo else models.default.name
response = _provider.create_completion(
model=model,
messages=[{"role": "user", "content": "Hello, who are you? Answer in detail much as possible."}],