diff options
author | Heiner Lohaus <hlohaus@users.noreply.github.com> | 2024-01-01 17:48:57 +0100 |
---|---|---|
committer | Heiner Lohaus <hlohaus@users.noreply.github.com> | 2024-01-01 17:48:57 +0100 |
commit | c617b18d12c2f9d82ce7c73aae46d353b83f625a (patch) | |
tree | 898f5090865a8aea64fb87e56f9ebfc979a6b706 /g4f/Provider/base_provider.py | |
parent | Patch event loop on win, Check event loop closed (diff) | |
download | gpt4free-c617b18d12c2f9d82ce7c73aae46d353b83f625a.tar gpt4free-c617b18d12c2f9d82ce7c73aae46d353b83f625a.tar.gz gpt4free-c617b18d12c2f9d82ce7c73aae46d353b83f625a.tar.bz2 gpt4free-c617b18d12c2f9d82ce7c73aae46d353b83f625a.tar.lz gpt4free-c617b18d12c2f9d82ce7c73aae46d353b83f625a.tar.xz gpt4free-c617b18d12c2f9d82ce7c73aae46d353b83f625a.tar.zst gpt4free-c617b18d12c2f9d82ce7c73aae46d353b83f625a.zip |
Diffstat (limited to 'g4f/Provider/base_provider.py')
-rw-r--r-- | g4f/Provider/base_provider.py | 48 |
1 files changed, 21 insertions, 27 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index 62029f5d..6da7f6c6 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -1,12 +1,14 @@ from __future__ import annotations import sys +import asyncio from asyncio import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor -from abc import ABC, abstractmethod +from abc import abstractmethod from inspect import signature, Parameter from .helper import get_event_loop, get_cookies, format_prompt -from ..typing import CreateResult, AsyncResult, Messages +from ..typing import CreateResult, AsyncResult, Messages, Union +from ..base_provider import BaseProvider if sys.version_info < (3, 10): NoneType = type(None) @@ -20,25 +22,7 @@ if sys.platform == 'win32': ): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) -class BaseProvider(ABC): - url: str - working: bool = False - needs_auth: bool = False - supports_stream: bool = False - supports_gpt_35_turbo: bool = False - supports_gpt_4: bool = False - supports_message_history: bool = False - - @staticmethod - @abstractmethod - def create_completion( - model: str, - messages: Messages, - stream: bool, - **kwargs - ) -> CreateResult: - raise NotImplementedError() - +class AbstractProvider(BaseProvider): @classmethod async def create_async( cls, @@ -60,9 +44,12 @@ class BaseProvider(ABC): **kwargs )) - return await loop.run_in_executor( - executor, - create_func + return await asyncio.wait_for( + loop.run_in_executor( + executor, + create_func + ), + timeout=kwargs.get("timeout", 0) ) @classmethod @@ -102,16 +89,19 @@ class BaseProvider(ABC): return f"g4f.Provider.{cls.__name__} supports: ({args}\n)" -class AsyncProvider(BaseProvider): +class AsyncProvider(AbstractProvider): @classmethod def create_completion( cls, model: str, messages: Messages, stream: bool = False, + *, + loop: AbstractEventLoop = None, **kwargs ) -> CreateResult: - loop = get_event_loop() + if not loop: + loop = get_event_loop() coro = cls.create_async(model, messages, **kwargs) yield loop.run_until_complete(coro) @@ -134,9 +124,12 @@ class AsyncGeneratorProvider(AsyncProvider): model: str, messages: Messages, stream: bool = True, + *, + loop: AbstractEventLoop = None, **kwargs ) -> CreateResult: - loop = get_event_loop() + if not loop: + loop = get_event_loop() generator = cls.create_async_generator( model, messages, @@ -171,6 +164,7 @@ class AsyncGeneratorProvider(AsyncProvider): def create_async_generator( model: str, messages: Messages, + stream: bool = True, **kwargs ) -> AsyncResult: raise NotImplementedError() |