summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/base_provider.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/base_provider.py')
-rw-r--r--g4f/Provider/base_provider.py102
1 files changed, 29 insertions, 73 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index e8a54f78..a21dc871 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -1,13 +1,10 @@
from __future__ import annotations
-import asyncio
-import functools
-from asyncio import SelectorEventLoop, AbstractEventLoop
+from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
from abc import ABC, abstractmethod
-import browser_cookie3
-
+from .helper import get_event_loop, get_cookies, format_prompt
from ..typing import AsyncGenerator, CreateResult
@@ -40,20 +37,18 @@ class BaseProvider(ABC):
**kwargs
) -> str:
if not loop:
- loop = asyncio.get_event_loop()
-
- partial_func = functools.partial(
- cls.create_completion,
- model,
- messages,
- False,
- **kwargs
- )
- response = await loop.run_in_executor(
+ loop = get_event_loop()
+ def create_func():
+ return "".join(cls.create_completion(
+ model,
+ messages,
+ False,
+ **kwargs
+ ))
+ return await loop.run_in_executor(
executor,
- partial_func
+ create_func
)
- return "".join(response)
@classmethod
@property
@@ -76,11 +71,9 @@ class AsyncProvider(BaseProvider):
stream: bool = False,
**kwargs
) -> CreateResult:
- loop = create_event_loop()
- try:
- yield loop.run_until_complete(cls.create_async(model, messages, **kwargs))
- finally:
- loop.close()
+ loop = get_event_loop()
+ coro = cls.create_async(model, messages, **kwargs)
+ yield loop.run_until_complete(coro)
@staticmethod
@abstractmethod
@@ -103,22 +96,19 @@ class AsyncGeneratorProvider(AsyncProvider):
stream: bool = True,
**kwargs
) -> CreateResult:
- loop = create_event_loop()
- try:
- generator = cls.create_async_generator(
- model,
- messages,
- stream=stream,
- **kwargs
- )
- gen = generator.__aiter__()
- while True:
- try:
- yield loop.run_until_complete(gen.__anext__())
- except StopAsyncIteration:
- break
- finally:
- loop.close()
+ loop = get_event_loop()
+ generator = cls.create_async_generator(
+ model,
+ messages,
+ stream=stream,
+ **kwargs
+ )
+ gen = generator.__aiter__()
+ while True:
+ try:
+ yield loop.run_until_complete(gen.__anext__())
+ except StopAsyncIteration:
+ break
@classmethod
async def create_async(
@@ -143,38 +133,4 @@ class AsyncGeneratorProvider(AsyncProvider):
messages: list[dict[str, str]],
**kwargs
) -> AsyncGenerator:
- raise NotImplementedError()
-
-
-# Don't create a new event loop in a running async loop.
-# Force use selector event loop on windows and linux use it anyway.
-def create_event_loop() -> SelectorEventLoop:
- try:
- asyncio.get_running_loop()
- except RuntimeError:
- return SelectorEventLoop()
- raise RuntimeError(
- 'Use "create_async" instead of "create" function in a running event loop.')
-
-
-_cookies = {}
-
-def get_cookies(cookie_domain: str) -> dict:
- if cookie_domain not in _cookies:
- _cookies[cookie_domain] = {}
- try:
- for cookie in browser_cookie3.load(cookie_domain):
- _cookies[cookie_domain][cookie.name] = cookie.value
- except:
- pass
- return _cookies[cookie_domain]
-
-
-def format_prompt(messages: list[dict[str, str]], add_special_tokens=False):
- if add_special_tokens or len(messages) > 1:
- formatted = "\n".join(
- ["%s: %s" % ((message["role"]).capitalize(), message["content"]) for message in messages]
- )
- return f"{formatted}\nAssistant:"
- else:
- return messages[0]["content"] \ No newline at end of file
+ raise NotImplementedError() \ No newline at end of file