diff options
Diffstat (limited to 'g4f/Provider/base_provider.py')
-rw-r--r-- | g4f/Provider/base_provider.py | 50 |
1 files changed, 40 insertions, 10 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index 79f8f617..0cceb220 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -1,11 +1,12 @@ from __future__ import annotations import asyncio +from asyncio import SelectorEventLoop from abc import ABC, abstractmethod import browser_cookie3 -from ..typing import Any, AsyncGenerator, CreateResult, Union +from ..typing import Any, AsyncGenerator, CreateResult class BaseProvider(ABC): @@ -21,10 +22,13 @@ class BaseProvider(ABC): def create_completion( model: str, messages: list[dict[str, str]], - stream: bool, **kwargs: Any) -> CreateResult: + stream: bool, + **kwargs + ) -> CreateResult: raise NotImplementedError() + @classmethod @property def params(cls): @@ -46,13 +50,19 @@ class AsyncProvider(BaseProvider): stream: bool = False, **kwargs ) -> CreateResult: - yield asyncio.run(cls.create_async(model, messages, **kwargs)) + loop = create_event_loop() + try: + yield loop.run_until_complete(cls.create_async(model, messages, **kwargs)) + finally: + loop.close() @staticmethod @abstractmethod async def create_async( model: str, - messages: list[dict[str, str]], **kwargs: Any) -> str: + messages: list[dict[str, str]], + **kwargs + ) -> str: raise NotImplementedError() @@ -67,10 +77,14 @@ class AsyncGeneratorProvider(AsyncProvider): stream: bool = True, **kwargs ) -> CreateResult: - loop = asyncio.new_event_loop() + loop = create_event_loop() try: - asyncio.set_event_loop(loop) - generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) + generator = cls.create_async_generator( + model, + messages, + stream=stream, + **kwargs + ) gen = generator.__aiter__() while True: try: @@ -78,10 +92,8 @@ class AsyncGeneratorProvider(AsyncProvider): except StopAsyncIteration: break finally: - asyncio.set_event_loop(None) loop.close() - @classmethod async def create_async( cls, @@ -89,7 +101,14 @@ class AsyncGeneratorProvider(AsyncProvider): messages: list[dict[str, str]], **kwargs ) -> str: - return "".join([chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)]) + return "".join([ + chunk async for chunk in cls.create_async_generator( + model, + messages, + stream=False, + **kwargs + ) + ]) @staticmethod @abstractmethod @@ -101,6 +120,17 @@ class AsyncGeneratorProvider(AsyncProvider): raise NotImplementedError() +# Don't create a new event loop in a running async loop. +# Force use selector event loop on windows and linux use it anyway. +def create_event_loop() -> SelectorEventLoop: + try: + asyncio.get_running_loop() + except RuntimeError: + return SelectorEventLoop() + raise RuntimeError( + 'Use "create_async" instead of "create" function in a async loop.') + + _cookies = {} def get_cookies(cookie_domain: str) -> dict: |