diff options
Diffstat (limited to 'g4f/__init__.py')
-rw-r--r-- | g4f/__init__.py | 29 |
1 files changed, 18 insertions, 11 deletions
diff --git a/g4f/__init__.py b/g4f/__init__.py index 4b1e4b80..89b3e6d1 100644 --- a/g4f/__init__.py +++ b/g4f/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations from requests import get from g4f.models import Model, ModelUtils from .Provider import BaseProvider -from .typing import CreateResult, Union +from .typing import Messages, CreateResult, Union from .debug import logging version = '0.1.5.6' @@ -27,19 +27,19 @@ def get_model_and_provider(model : Union[Model, str], if model in ModelUtils.convert: model = ModelUtils.convert[model] else: - raise Exception(f'The model: {model} does not exist') + raise ValueError(f'The model: {model} does not exist') if not provider: provider = model.best_provider if not provider: - raise Exception(f'No provider found for model: {model}') + raise RuntimeError(f'No provider found for model: {model}') if not provider.working: - raise Exception(f'{provider.__name__} is not working') + raise RuntimeError(f'{provider.__name__} is not working') if not provider.supports_stream and stream: - raise Exception(f'ValueError: {provider.__name__} does not support "stream" argument') + raise ValueError(f'{provider.__name__} does not support "stream" argument') if logging: print(f'Using {provider.__name__} provider') @@ -48,17 +48,20 @@ def get_model_and_provider(model : Union[Model, str], class ChatCompletion: @staticmethod - def create(model: Union[Model, str], - messages : list[dict[str, str]], + def create( + model: Union[Model, str], + messages : Messages, provider : Union[type[BaseProvider], None] = None, stream : bool = False, - auth : Union[str, None] = None, **kwargs) -> Union[CreateResult, str]: + auth : Union[str, None] = None, + **kwargs + ) -> Union[CreateResult, str]: model, provider = get_model_and_provider(model, provider, stream) if provider.needs_auth and not auth: - raise Exception( - f'ValueError: {provider.__name__} requires authentication (use auth=\'cookie or token or jwt ...\' param)') + raise ValueError( + f'{provider.__name__} requires authentication (use auth=\'cookie or token or jwt ...\' param)') if provider.needs_auth: kwargs['auth'] = auth @@ -69,10 +72,14 @@ class ChatCompletion: @staticmethod async def create_async( model: Union[Model, str], - messages: list[dict[str, str]], + messages: Messages, provider: Union[type[BaseProvider], None] = None, + stream: bool = False, **kwargs ) -> str: + if stream: + raise ValueError(f'"create_async" does not support "stream" argument') + model, provider = get_model_and_provider(model, provider, False) return await provider.create_async(model.name, messages, **kwargs) |