summaryrefslogblamecommitdiffstats
path: root/g4f/Provider/needs_auth/Openai.py
blob: f73c10117c58ff1b791e7f8b5b68df6a938d6465 (plain) (tree)
1
2
3
4
5
6
7
8
9



                                  
                                
                                                                                    
                                                            
                                                       
                                                     

                                                         
                        

















                                                    
                                           
                             

                              

                     
                                              


                                                     
                                                              

                           









                                           
            

                                                                                                       

                                                
                                         













                                                                  
                                                 










                                                                                           

                 
                                                                   
                                                                             
                                                        
 






                                                                                             

                                                                                          
                
                                                                            
                                               





                                                         
        
from __future__ import annotations

import json

from ..helper import filter_none
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason
from ...typing import Union, Optional, AsyncResult, Messages
from ...requests import StreamSession, raise_for_status
from ...errors import MissingAuthError, ResponseError

class Openai(AsyncGeneratorProvider, ProviderModelMixin):
    label = "OpenAI API"
    url = "https://openai.com"
    working = True
    needs_auth = True
    supports_message_history = True
    supports_system_message = True

    @classmethod
    async def create_async_generator(
        cls,
        model: str,
        messages: Messages,
        proxy: str = None,
        timeout: int = 120,
        api_key: str = None,
        api_base: str = "https://api.openai.com/v1",
        temperature: float = None,
        max_tokens: int = None,
        top_p: float = None,
        stop: Union[str, list[str]] = None,
        stream: bool = False,
        headers: dict = None,
        extra_data: dict = {},
        **kwargs
    ) -> AsyncResult:
        if cls.needs_auth and api_key is None:
            raise MissingAuthError('Add a "api_key"')
        async with StreamSession(
            proxies={"all": proxy},
            headers=cls.get_headers(stream, api_key, headers),
            timeout=timeout
        ) as session:
            data = filter_none(
                messages=messages,
                model=cls.get_model(model),
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                stop=stop,
                stream=stream,
                **extra_data
            )
            
            async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response:
                await raise_for_status(response)
                if not stream:
                    data = await response.json()
                    cls.raise_error(data)
                    choice = data["choices"][0]
                    if "content" in choice["message"]:
                        yield choice["message"]["content"].strip()
                    finish = cls.read_finish_reason(choice)
                    if finish is not None:
                        yield finish
                else:
                    first = True
                    async for line in response.iter_lines():
                        if line.startswith(b"data: "):
                            chunk = line[6:]
                            if chunk == b"[DONE]":
                                break
                            data = json.loads(chunk)
                            cls.raise_error(data)
                            choice = data["choices"][0]
                            if "content" in choice["delta"] and choice["delta"]["content"]:
                                delta = choice["delta"]["content"]
                                if first:
                                    delta = delta.lstrip()
                                if delta:
                                    first = False
                                    yield delta
                            finish = cls.read_finish_reason(choice)
                            if finish is not None:
                                yield finish

    @staticmethod
    def read_finish_reason(choice: dict) -> Optional[FinishReason]:
        if "finish_reason" in choice and choice["finish_reason"] is not None:
            return FinishReason(choice["finish_reason"])

    @staticmethod
    def raise_error(data: dict):
        if "error_message" in data:
            raise ResponseError(data["error_message"])
        elif "error" in data:
            raise ResponseError(f'Error {data["error"]["code"]}: {data["error"]["message"]}')

    @classmethod
    def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
        return {
            "Accept": "text/event-stream" if stream else "application/json",
            "Content-Type": "application/json",
            **(
                {"Authorization": f"Bearer {api_key}"}
                if cls.needs_auth and api_key is not None
                else {}
            ),
            **({} if headers is None else headers)
        }