diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/providers/base_provider.py | 37 |
1 files changed, 19 insertions, 18 deletions
diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py index ee5bcbb8..37f4af15 100644 --- a/g4f/providers/base_provider.py +++ b/g4f/providers/base_provider.py @@ -6,9 +6,10 @@ from asyncio import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor from abc import abstractmethod from inspect import signature, Parameter -from ..typing import CreateResult, AsyncResult, Messages, Union -from .types import BaseProvider -from ..errors import NestAsyncioError, ModelNotSupportedError +from typing import Callable, Union +from ..typing import CreateResult, AsyncResult, Messages +from .types import BaseProvider, FinishReason +from ..errors import NestAsyncioError, ModelNotSupportedError, MissingRequirementsError from .. import debug if sys.version_info < (3, 10): @@ -21,17 +22,23 @@ if sys.platform == 'win32': if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) -def get_running_loop() -> Union[AbstractEventLoop, None]: +def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]: try: loop = asyncio.get_running_loop() - if not hasattr(loop.__class__, "_nest_patched"): - raise NestAsyncioError( - 'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.' - ) + if check_nested and not hasattr(loop.__class__, "_nest_patched"): + try: + import nest_asyncio + nest_asyncio.apply(loop) + except ImportError: + raise MissingRequirementsError('Install "nest_asyncio" package') return loop except RuntimeError: pass +# Fix for RuntimeError: async generator ignored GeneratorExit +async def await_callback(callback: Callable): + return await callback() + class AbstractProvider(BaseProvider): """ Abstract class for providing asynchronous functionality to derived classes. @@ -132,7 +139,7 @@ class AsyncProvider(AbstractProvider): Returns: CreateResult: The result of the completion creation. """ - get_running_loop() + get_running_loop(check_nested=True) yield asyncio.run(cls.create_async(model, messages, **kwargs)) @staticmethod @@ -158,7 +165,6 @@ class AsyncProvider(AbstractProvider): """ raise NotImplementedError() - class AsyncGeneratorProvider(AsyncProvider): """ Provides asynchronous generator functionality for streaming results. @@ -187,9 +193,9 @@ class AsyncGeneratorProvider(AsyncProvider): Returns: CreateResult: The result of the streaming completion creation. """ - loop = get_running_loop() + loop = get_running_loop(check_nested=True) new_loop = False - if not loop: + if loop is None: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) new_loop = True @@ -197,16 +203,11 @@ class AsyncGeneratorProvider(AsyncProvider): generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) gen = generator.__aiter__() - # Fix for RuntimeError: async generator ignored GeneratorExit - async def await_callback(callback): - return await callback() - try: while True: yield loop.run_until_complete(await_callback(gen.__anext__)) except StopAsyncIteration: ... - # Fix for: ResourceWarning: unclosed event loop finally: if new_loop: loop.close() @@ -233,7 +234,7 @@ class AsyncGeneratorProvider(AsyncProvider): """ return "".join([ chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs) - if not isinstance(chunk, Exception) + if not isinstance(chunk, (Exception, FinishReason)) ]) @staticmethod |