summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHeiner Lohaus <heiner@lohaus.eu>2023-09-20 15:01:33 +0200
committerHeiner Lohaus <heiner@lohaus.eu>2023-09-20 15:01:33 +0200
commit587f4ad2c9c1daaa246cabb02dbec7ec013de754 (patch)
treebcb7db3a2d2d99dfa9b599a89a7c00e389b91856
parentAdd check_running_loop requirement (diff)
downloadgpt4free-587f4ad2c9c1daaa246cabb02dbec7ec013de754.tar
gpt4free-587f4ad2c9c1daaa246cabb02dbec7ec013de754.tar.gz
gpt4free-587f4ad2c9c1daaa246cabb02dbec7ec013de754.tar.bz2
gpt4free-587f4ad2c9c1daaa246cabb02dbec7ec013de754.tar.lz
gpt4free-587f4ad2c9c1daaa246cabb02dbec7ec013de754.tar.xz
gpt4free-587f4ad2c9c1daaa246cabb02dbec7ec013de754.tar.zst
gpt4free-587f4ad2c9c1daaa246cabb02dbec7ec013de754.zip
-rw-r--r--g4f/Provider/base_provider.py22
1 files changed, 13 insertions, 9 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index 9d45aa44..1e2d4c64 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -47,9 +47,11 @@ class AsyncProvider(BaseProvider):
stream: bool = False,
**kwargs
) -> CreateResult:
- check_running_loop()
-
- yield asyncio.run(cls.create_async(model, messages, **kwargs))
+ loop = create_event_loop()
+ try:
+ yield loop.run_until_complete(cls.create_async(model, messages, **kwargs))
+ finally:
+ loop.close()
@staticmethod
@abstractmethod
@@ -70,10 +72,7 @@ class AsyncGeneratorProvider(AsyncProvider):
stream: bool = True,
**kwargs
) -> CreateResult:
- check_running_loop()
-
- # Force use selector event loop on windows
- loop = asyncio.SelectorEventLoop()
+ loop = get_new_event_loop()
try:
generator = cls.create_async_generator(
model,
@@ -108,12 +107,17 @@ class AsyncGeneratorProvider(AsyncProvider):
) -> AsyncGenerator:
raise NotImplementedError()
-# Don't create a new loop in a running loop
-def check_running_loop():
+
+def create_event_loop():
+ # Don't create a new loop in a running loop
if asyncio.events._get_running_loop() is not None:
raise RuntimeError(
'Use "create_async" instead of "create" function in a async loop.')
+ # Force use selector event loop on windows
+ return asyncio.SelectorEventLoop()
+
+
_cookies = {}
def get_cookies(cookie_domain: str) -> dict: