From 53192b86b129380660f7454170fa987faf2da3c5 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Tue, 10 Oct 2023 09:49:29 +0200 Subject: Some small fixes --- g4f/Provider/Acytoo.py | 4 ++-- g4f/Provider/GptGo.py | 6 +++--- g4f/Provider/H2o.py | 6 +++--- g4f/Provider/Myshell.py | 6 +++--- g4f/Provider/Phind.py | 6 +++--- g4f/Provider/base_provider.py | 18 +++++++++--------- g4f/Provider/deprecated/AiService.py | 4 ++-- g4f/Provider/helper.py | 4 ++-- g4f/Provider/retry_provider.py | 13 +++++-------- 9 files changed, 32 insertions(+), 35 deletions(-) (limited to 'g4f') diff --git a/g4f/Provider/Acytoo.py b/g4f/Provider/Acytoo.py index 0ac3425c..cefdd1ac 100644 --- a/g4f/Provider/Acytoo.py +++ b/g4f/Provider/Acytoo.py @@ -23,7 +23,7 @@ class Acytoo(AsyncGeneratorProvider): headers=_create_header() ) as session: async with session.post( - cls.url + '/api/completions', + f'{cls.url}/api/completions', proxy=proxy, json=_create_payload(messages, **kwargs) ) as response: @@ -40,7 +40,7 @@ def _create_header(): } -def _create_payload(messages: list[dict[str, str]], temperature: float = 0.5, **kwargs): +def _create_payload(messages: Messages, temperature: float = 0.5, **kwargs): return { 'key' : '', 'model' : 'gpt-3.5-turbo', diff --git a/g4f/Provider/GptGo.py b/g4f/Provider/GptGo.py index 5f6cc362..f9b94a5c 100644 --- a/g4f/Provider/GptGo.py +++ b/g4f/Provider/GptGo.py @@ -3,7 +3,7 @@ from __future__ import annotations from aiohttp import ClientSession import json -from ..typing import AsyncGenerator +from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider, format_prompt @@ -16,10 +16,10 @@ class GptGo(AsyncGeneratorProvider): async def create_async_generator( cls, model: str, - messages: list[dict[str, str]], + messages: Messages, proxy: str = None, **kwargs - ) -> AsyncGenerator: + ) -> AsyncResult: headers = { "User-Agent" : "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36", "Accept" : "*/*", diff --git a/g4f/Provider/H2o.py b/g4f/Provider/H2o.py index d92bd6d1..65429a28 100644 --- a/g4f/Provider/H2o.py +++ b/g4f/Provider/H2o.py @@ -5,7 +5,7 @@ import uuid from aiohttp import ClientSession -from ..typing import AsyncGenerator +from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider, format_prompt @@ -18,10 +18,10 @@ class H2o(AsyncGeneratorProvider): async def create_async_generator( cls, model: str, - messages: list[dict[str, str]], + messages: Messages, proxy: str = None, **kwargs - ) -> AsyncGenerator: + ) -> AsyncResult: model = model if model else cls.model headers = {"Referer": cls.url + "/"} diff --git a/g4f/Provider/Myshell.py b/g4f/Provider/Myshell.py index 6ed4fd7a..847bac2f 100644 --- a/g4f/Provider/Myshell.py +++ b/g4f/Provider/Myshell.py @@ -6,7 +6,7 @@ from aiohttp import ClientSession from aiohttp.http import WSMsgType import asyncio -from ..typing import AsyncGenerator +from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider, format_prompt @@ -27,11 +27,11 @@ class Myshell(AsyncGeneratorProvider): async def create_async_generator( cls, model: str, - messages: list[dict[str, str]], + messages: Messages, proxy: str = None, timeout: int = 90, **kwargs - ) -> AsyncGenerator: + ) -> AsyncResult: if not model: bot_id = models["samantha"] elif model in models: diff --git a/g4f/Provider/Phind.py b/g4f/Provider/Phind.py index ae4de686..d7c6f7c7 100644 --- a/g4f/Provider/Phind.py +++ b/g4f/Provider/Phind.py @@ -3,7 +3,7 @@ from __future__ import annotations import random from datetime import datetime -from ..typing import AsyncGenerator +from ..typing import AsyncResult, Messages from ..requests import StreamSession from .base_provider import AsyncGeneratorProvider, format_prompt @@ -17,11 +17,11 @@ class Phind(AsyncGeneratorProvider): async def create_async_generator( cls, model: str, - messages: list[dict[str, str]], + messages: Messages, proxy: str = None, timeout: int = 120, **kwargs - ) -> AsyncGenerator: + ) -> AsyncResult: chars = 'abcdefghijklmnopqrstuvwxyz0123456789' user_id = ''.join(random.choice(chars) for _ in range(24)) data = { diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index 35764081..c54b98e5 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor from abc import ABC, abstractmethod from .helper import get_event_loop, get_cookies, format_prompt -from ..typing import AsyncGenerator, CreateResult +from ..typing import CreateResult, AsyncResult, Messages class BaseProvider(ABC): @@ -20,7 +20,7 @@ class BaseProvider(ABC): @abstractmethod def create_completion( model: str, - messages: list[dict[str, str]], + messages: Messages, stream: bool, **kwargs ) -> CreateResult: @@ -30,7 +30,7 @@ class BaseProvider(ABC): async def create_async( cls, model: str, - messages: list[dict[str, str]], + messages: Messages, *, loop: AbstractEventLoop = None, executor: ThreadPoolExecutor = None, @@ -69,7 +69,7 @@ class AsyncProvider(BaseProvider): def create_completion( cls, model: str, - messages: list[dict[str, str]], + messages: Messages, stream: bool = False, **kwargs ) -> CreateResult: @@ -81,7 +81,7 @@ class AsyncProvider(BaseProvider): @abstractmethod async def create_async( model: str, - messages: list[dict[str, str]], + messages: Messages, **kwargs ) -> str: raise NotImplementedError() @@ -94,7 +94,7 @@ class AsyncGeneratorProvider(AsyncProvider): def create_completion( cls, model: str, - messages: list[dict[str, str]], + messages: Messages, stream: bool = True, **kwargs ) -> CreateResult: @@ -116,7 +116,7 @@ class AsyncGeneratorProvider(AsyncProvider): async def create_async( cls, model: str, - messages: list[dict[str, str]], + messages: Messages, **kwargs ) -> str: return "".join([ @@ -132,7 +132,7 @@ class AsyncGeneratorProvider(AsyncProvider): @abstractmethod def create_async_generator( model: str, - messages: list[dict[str, str]], + messages: Messages, **kwargs - ) -> AsyncGenerator: + ) -> AsyncResult: raise NotImplementedError() \ No newline at end of file diff --git a/g4f/Provider/deprecated/AiService.py b/g4f/Provider/deprecated/AiService.py index 9b41e3c8..d1d15859 100644 --- a/g4f/Provider/deprecated/AiService.py +++ b/g4f/Provider/deprecated/AiService.py @@ -2,7 +2,7 @@ from __future__ import annotations import requests -from ...typing import Any, CreateResult +from ...typing import Any, CreateResult, Messages from ..base_provider import BaseProvider @@ -14,7 +14,7 @@ class AiService(BaseProvider): @staticmethod def create_completion( model: str, - messages: list[dict[str, str]], + messages: Messages, stream: bool, **kwargs: Any, ) -> CreateResult: diff --git a/g4f/Provider/helper.py b/g4f/Provider/helper.py index 5a9a9329..db19adc1 100644 --- a/g4f/Provider/helper.py +++ b/g4f/Provider/helper.py @@ -4,7 +4,7 @@ import asyncio import sys from asyncio import AbstractEventLoop from os import path -from typing import Dict, List +from ..typing import Dict, List, Messages import browser_cookie3 # Change event loop policy on windows @@ -53,7 +53,7 @@ def get_cookies(cookie_domain: str) -> Dict[str, str]: return _cookies[cookie_domain] -def format_prompt(messages: List[Dict[str, str]], add_special_tokens=False) -> str: +def format_prompt(messages: Messages, add_special_tokens=False) -> str: if add_special_tokens or len(messages) > 1: formatted = "\n".join( [ diff --git a/g4f/Provider/retry_provider.py b/g4f/Provider/retry_provider.py index b49020b2..94b9b90a 100644 --- a/g4f/Provider/retry_provider.py +++ b/g4f/Provider/retry_provider.py @@ -2,7 +2,7 @@ from __future__ import annotations import random from typing import List, Type, Dict -from ..typing import CreateResult +from ..typing import CreateResult, Messages from .base_provider import BaseProvider, AsyncProvider from ..debug import logging @@ -10,10 +10,7 @@ from ..debug import logging class RetryProvider(AsyncProvider): __name__: str = "RetryProvider" working: bool = True - needs_auth: bool = False supports_stream: bool = True - supports_gpt_35_turbo: bool = False - supports_gpt_4: bool = False def __init__( self, @@ -27,7 +24,7 @@ class RetryProvider(AsyncProvider): def create_completion( self, model: str, - messages: List[Dict[str, str]], + messages: Messages, stream: bool = False, **kwargs ) -> CreateResult: @@ -54,17 +51,17 @@ class RetryProvider(AsyncProvider): if logging: print(f"{provider.__name__}: {e.__class__.__name__}: {e}") if started: - break + raise e self.raise_exceptions() async def create_async( self, model: str, - messages: List[Dict[str, str]], + messages: Messages, **kwargs ) -> str: - providers = [provider for provider in self.providers] + providers = self.providers if self.shuffle: random.shuffle(providers) -- cgit v1.2.3