summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/base_provider.py
diff options
context:
space:
mode:
authorTekky <98614666+xtekky@users.noreply.github.com>2023-08-28 22:08:23 +0200
committerGitHub <noreply@github.com>2023-08-28 22:08:23 +0200
commit7e687b3d178c00a27d7e5ae2613fe88ee7844639 (patch)
tree4034e8ae9fc7ca9af295f04358bb00516b464e0b /g4f/Provider/base_provider.py
parentMerge pull request #851 from Luneye/patch-1 (diff)
parentMerge branch 'main' into hugging (diff)
downloadgpt4free-0.0.2.6.tar
gpt4free-0.0.2.6.tar.gz
gpt4free-0.0.2.6.tar.bz2
gpt4free-0.0.2.6.tar.lz
gpt4free-0.0.2.6.tar.xz
gpt4free-0.0.2.6.tar.zst
gpt4free-0.0.2.6.zip
Diffstat (limited to 'g4f/Provider/base_provider.py')
-rw-r--r--g4f/Provider/base_provider.py36
1 files changed, 23 insertions, 13 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index d5f23931..def2cd6d 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -4,8 +4,7 @@ from ..typing import Any, CreateResult, AsyncGenerator, Union
import browser_cookie3
import asyncio
-from time import time
-import math
+
class BaseProvider(ABC):
url: str
@@ -48,6 +47,17 @@ def get_cookies(cookie_domain: str) -> dict:
return _cookies[cookie_domain]
+def format_prompt(messages: list[dict[str, str]], add_special_tokens=False):
+ if add_special_tokens or len(messages) > 1:
+ formatted = "\n".join(
+ ["%s: %s" % ((message["role"]).capitalize(), message["content"]) for message in messages]
+ )
+ return f"{formatted}\nAssistant:"
+ else:
+ return messages.pop()["content"]
+
+
+
class AsyncProvider(BaseProvider):
@classmethod
def create_completion(
@@ -72,20 +82,19 @@ class AsyncGeneratorProvider(AsyncProvider):
cls,
model: str,
messages: list[dict[str, str]],
- stream: bool = True, **kwargs: Any) -> CreateResult:
-
- if stream:
- yield from run_generator(cls.create_async_generator(model, messages, **kwargs))
- else:
- yield from AsyncProvider.create_completion(cls=cls, model=model, messages=messages, **kwargs)
+ stream: bool = True,
+ **kwargs
+ ) -> CreateResult:
+ yield from run_generator(cls.create_async_generator(model, messages, stream=stream, **kwargs))
@classmethod
async def create_async(
cls,
model: str,
- messages: list[dict[str, str]], **kwargs: Any) -> str:
-
- chunks = [chunk async for chunk in cls.create_async_generator(model, messages, **kwargs)]
+ messages: list[dict[str, str]],
+ **kwargs
+ ) -> str:
+ chunks = [chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)]
if chunks:
return "".join(chunks)
@@ -93,8 +102,9 @@ class AsyncGeneratorProvider(AsyncProvider):
@abstractmethod
def create_async_generator(
model: str,
- messages: list[dict[str, str]]) -> AsyncGenerator:
-
+ messages: list[dict[str, str]],
+ **kwargs
+ ) -> AsyncGenerator:
raise NotImplementedError()