summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/base_provider.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/base_provider.py')
-rw-r--r--g4f/Provider/base_provider.py40
1 files changed, 18 insertions, 22 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index 56d79ee6..d5f23931 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -9,20 +9,19 @@ import math
class BaseProvider(ABC):
url: str
- working = False
- needs_auth = False
- supports_stream = False
+ working = False
+ needs_auth = False
+ supports_stream = False
supports_gpt_35_turbo = False
- supports_gpt_4 = False
+ supports_gpt_4 = False
@staticmethod
@abstractmethod
def create_completion(
model: str,
messages: list[dict[str, str]],
- stream: bool,
- **kwargs: Any,
- ) -> CreateResult:
+ stream: bool, **kwargs: Any) -> CreateResult:
+
raise NotImplementedError()
@classmethod
@@ -42,8 +41,10 @@ _cookies = {}
def get_cookies(cookie_domain: str) -> dict:
if cookie_domain not in _cookies:
_cookies[cookie_domain] = {}
+
for cookie in browser_cookie3.load(cookie_domain):
_cookies[cookie_domain][cookie.name] = cookie.value
+
return _cookies[cookie_domain]
@@ -53,18 +54,15 @@ class AsyncProvider(BaseProvider):
cls,
model: str,
messages: list[dict[str, str]],
- stream: bool = False,
- **kwargs: Any
- ) -> CreateResult:
+ stream: bool = False, **kwargs: Any) -> CreateResult:
+
yield asyncio.run(cls.create_async(model, messages, **kwargs))
@staticmethod
@abstractmethod
async def create_async(
model: str,
- messages: list[dict[str, str]],
- **kwargs: Any,
- ) -> str:
+ messages: list[dict[str, str]], **kwargs: Any) -> str:
raise NotImplementedError()
@@ -74,9 +72,8 @@ class AsyncGeneratorProvider(AsyncProvider):
cls,
model: str,
messages: list[dict[str, str]],
- stream: bool = True,
- **kwargs: Any
- ) -> CreateResult:
+ stream: bool = True, **kwargs: Any) -> CreateResult:
+
if stream:
yield from run_generator(cls.create_async_generator(model, messages, **kwargs))
else:
@@ -86,9 +83,8 @@ class AsyncGeneratorProvider(AsyncProvider):
async def create_async(
cls,
model: str,
- messages: list[dict[str, str]],
- **kwargs: Any,
- ) -> str:
+ messages: list[dict[str, str]], **kwargs: Any) -> str:
+
chunks = [chunk async for chunk in cls.create_async_generator(model, messages, **kwargs)]
if chunks:
return "".join(chunks)
@@ -97,14 +93,14 @@ class AsyncGeneratorProvider(AsyncProvider):
@abstractmethod
def create_async_generator(
model: str,
- messages: list[dict[str, str]],
- ) -> AsyncGenerator:
+ messages: list[dict[str, str]]) -> AsyncGenerator:
+
raise NotImplementedError()
def run_generator(generator: AsyncGenerator[Union[Any, str], Any]):
loop = asyncio.new_event_loop()
- gen = generator.__aiter__()
+ gen = generator.__aiter__()
while True:
try: