summaryrefslogtreecommitdiffstats
path: root/g4f/providers/base_provider.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/providers/base_provider.py')
-rw-r--r--g4f/providers/base_provider.py35
1 files changed, 18 insertions, 17 deletions
diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py
index ee5bcbb8..f3483fc2 100644
--- a/g4f/providers/base_provider.py
+++ b/g4f/providers/base_provider.py
@@ -6,8 +6,9 @@ from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
from abc import abstractmethod
from inspect import signature, Parameter
-from ..typing import CreateResult, AsyncResult, Messages, Union
-from .types import BaseProvider
+from typing import Callable, Union
+from ..typing import CreateResult, AsyncResult, Messages
+from .types import BaseProvider, FinishReason
from ..errors import NestAsyncioError, ModelNotSupportedError
from .. import debug
@@ -21,17 +22,23 @@ if sys.platform == 'win32':
if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
-def get_running_loop() -> Union[AbstractEventLoop, None]:
+def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]:
try:
loop = asyncio.get_running_loop()
- if not hasattr(loop.__class__, "_nest_patched"):
- raise NestAsyncioError(
- 'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.'
- )
+ if check_nested and not hasattr(loop.__class__, "_nest_patched"):
+ try:
+ import nest_asyncio
+ nest_asyncio.apply(loop)
+ except ImportError:
+ raise NestAsyncioError('Install "nest_asyncio" package')
return loop
except RuntimeError:
pass
+# Fix for RuntimeError: async generator ignored GeneratorExit
+async def await_callback(callback: Callable):
+ return await callback()
+
class AbstractProvider(BaseProvider):
"""
Abstract class for providing asynchronous functionality to derived classes.
@@ -132,7 +139,7 @@ class AsyncProvider(AbstractProvider):
Returns:
CreateResult: The result of the completion creation.
"""
- get_running_loop()
+ get_running_loop(check_nested=True)
yield asyncio.run(cls.create_async(model, messages, **kwargs))
@staticmethod
@@ -158,7 +165,6 @@ class AsyncProvider(AbstractProvider):
"""
raise NotImplementedError()
-
class AsyncGeneratorProvider(AsyncProvider):
"""
Provides asynchronous generator functionality for streaming results.
@@ -187,9 +193,9 @@ class AsyncGeneratorProvider(AsyncProvider):
Returns:
CreateResult: The result of the streaming completion creation.
"""
- loop = get_running_loop()
+ loop = get_running_loop(check_nested=True)
new_loop = False
- if not loop:
+ if loop is None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
new_loop = True
@@ -197,16 +203,11 @@ class AsyncGeneratorProvider(AsyncProvider):
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
gen = generator.__aiter__()
- # Fix for RuntimeError: async generator ignored GeneratorExit
- async def await_callback(callback):
- return await callback()
-
try:
while True:
yield loop.run_until_complete(await_callback(gen.__anext__))
except StopAsyncIteration:
...
- # Fix for: ResourceWarning: unclosed event loop
finally:
if new_loop:
loop.close()
@@ -233,7 +234,7 @@ class AsyncGeneratorProvider(AsyncProvider):
"""
return "".join([
chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
- if not isinstance(chunk, Exception)
+ if not isinstance(chunk, (Exception, FinishReason))
])
@staticmethod