summaryrefslogtreecommitdiffstats
path: root/g4f/base_provider.py
blob: 03ae64d6afedd9afa8fca0348c788808ffd0a792 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from abc import ABC, abstractmethod
from typing import Union, List, Dict, Type
from .typing import Messages, CreateResult

class BaseProvider(ABC):
    """
    Abstract base class for a provider.

    Attributes:
        url (str): URL of the provider.
        working (bool): Indicates if the provider is currently working.
        needs_auth (bool): Indicates if the provider needs authentication.
        supports_stream (bool): Indicates if the provider supports streaming.
        supports_gpt_35_turbo (bool): Indicates if the provider supports GPT-3.5 Turbo.
        supports_gpt_4 (bool): Indicates if the provider supports GPT-4.
        supports_message_history (bool): Indicates if the provider supports message history.
        params (str): List parameters for the provider.
    """

    url: str = None
    working: bool = False
    needs_auth: bool = False
    supports_stream: bool = False
    supports_gpt_35_turbo: bool = False
    supports_gpt_4: bool = False
    supports_message_history: bool = False
    params: str

    @classmethod
    @abstractmethod
    def create_completion(
        cls,
        model: str,
        messages: Messages,
        stream: bool,
        **kwargs
    ) -> CreateResult:
        """
        Create a completion with the given parameters.

        Args:
            model (str): The model to use.
            messages (Messages): The messages to process.
            stream (bool): Whether to use streaming.
            **kwargs: Additional keyword arguments.

        Returns:
            CreateResult: The result of the creation process.
        """
        raise NotImplementedError()

    @classmethod
    @abstractmethod
    async def create_async(
        cls,
        model: str,
        messages: Messages,
        **kwargs
    ) -> str:
        """
        Asynchronously create a completion with the given parameters.

        Args:
            model (str): The model to use.
            messages (Messages): The messages to process.
            **kwargs: Additional keyword arguments.

        Returns:
            str: The result of the creation process.
        """
        raise NotImplementedError()
    
    @classmethod
    def get_dict(cls) -> Dict[str, str]:
        """
        Get a dictionary representation of the provider.

        Returns:
            Dict[str, str]: A dictionary with provider's details.
        """
        return {'name': cls.__name__, 'url': cls.url} 
    
class BaseRetryProvider(BaseProvider):
    """
    Base class for a provider that implements retry logic.

    Attributes:
        providers (List[Type[BaseProvider]]): List of providers to use for retries.
        shuffle (bool): Whether to shuffle the providers list.
        exceptions (Dict[str, Exception]): Dictionary of exceptions encountered.
        last_provider (Type[BaseProvider]): The last provider used.
    """

    __name__: str = "RetryProvider"
    supports_stream: bool = True

    def __init__(
        self,
        providers: List[Type[BaseProvider]],
        shuffle: bool = True
    ) -> None:
        """
        Initialize the BaseRetryProvider.

        Args:
            providers (List[Type[BaseProvider]]): List of providers to use.
            shuffle (bool): Whether to shuffle the providers list.
        """
        self.providers = providers
        self.shuffle = shuffle
        self.working = True
        self.exceptions: Dict[str, Exception] = {}
        self.last_provider: Type[BaseProvider] = None
        
ProviderType = Union[Type[BaseProvider], BaseRetryProvider]