summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/base_provider.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/base_provider.py')
-rw-r--r--g4f/Provider/base_provider.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index 35764081..c54b98e5 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor
from abc import ABC, abstractmethod
from .helper import get_event_loop, get_cookies, format_prompt
-from ..typing import AsyncGenerator, CreateResult
+from ..typing import CreateResult, AsyncResult, Messages
class BaseProvider(ABC):
@@ -20,7 +20,7 @@ class BaseProvider(ABC):
@abstractmethod
def create_completion(
model: str,
- messages: list[dict[str, str]],
+ messages: Messages,
stream: bool,
**kwargs
) -> CreateResult:
@@ -30,7 +30,7 @@ class BaseProvider(ABC):
async def create_async(
cls,
model: str,
- messages: list[dict[str, str]],
+ messages: Messages,
*,
loop: AbstractEventLoop = None,
executor: ThreadPoolExecutor = None,
@@ -69,7 +69,7 @@ class AsyncProvider(BaseProvider):
def create_completion(
cls,
model: str,
- messages: list[dict[str, str]],
+ messages: Messages,
stream: bool = False,
**kwargs
) -> CreateResult:
@@ -81,7 +81,7 @@ class AsyncProvider(BaseProvider):
@abstractmethod
async def create_async(
model: str,
- messages: list[dict[str, str]],
+ messages: Messages,
**kwargs
) -> str:
raise NotImplementedError()
@@ -94,7 +94,7 @@ class AsyncGeneratorProvider(AsyncProvider):
def create_completion(
cls,
model: str,
- messages: list[dict[str, str]],
+ messages: Messages,
stream: bool = True,
**kwargs
) -> CreateResult:
@@ -116,7 +116,7 @@ class AsyncGeneratorProvider(AsyncProvider):
async def create_async(
cls,
model: str,
- messages: list[dict[str, str]],
+ messages: Messages,
**kwargs
) -> str:
return "".join([
@@ -132,7 +132,7 @@ class AsyncGeneratorProvider(AsyncProvider):
@abstractmethod
def create_async_generator(
model: str,
- messages: list[dict[str, str]],
+ messages: Messages,
**kwargs
- ) -> AsyncGenerator:
+ ) -> AsyncResult:
raise NotImplementedError() \ No newline at end of file