diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/api/__init__.py | 57 |
1 files changed, 52 insertions, 5 deletions
diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index 02ba5260..21e69388 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -4,6 +4,7 @@ import logging import json import uvicorn import secrets +import os from fastapi import FastAPI, Response, Request from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse @@ -13,13 +14,16 @@ from starlette.exceptions import HTTPException from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from fastapi.encoders import jsonable_encoder from fastapi.middleware.cors import CORSMiddleware +from starlette.responses import FileResponse from pydantic import BaseModel from typing import Union, Optional import g4f import g4f.debug from g4f.client import AsyncClient, ChatCompletion +from g4f.providers.response import BaseConversation from g4f.client.helper import filter_none +from g4f.image import is_accepted_format, images_dir from g4f.typing import Messages from g4f.cookies import read_cookie_files @@ -63,6 +67,7 @@ class ChatCompletionsConfig(BaseModel): api_key: Optional[str] = None web_search: Optional[bool] = None proxy: Optional[str] = None + conversation_id: str = None class ImageGenerationConfig(BaseModel): prompt: str @@ -98,6 +103,7 @@ class Api: self.client = AsyncClient() self.g4f_api_key = g4f_api_key self.get_g4f_api_key = APIKeyHeader(name="g4f-api-key") + self.conversations: dict[str, dict[str, BaseConversation]] = {} def register_authorization(self): @self.app.middleware("http") @@ -179,12 +185,21 @@ class Api: async def chat_completions(config: ChatCompletionsConfig, request: Request = None, provider: str = None): try: config.provider = provider if config.provider is None else config.provider + if config.provider is None: + config.provider = AppConfig.provider if config.api_key is None and request is not None: auth_header = request.headers.get("Authorization") if auth_header is not None: - auth_header = auth_header.split(None, 1)[-1] - if auth_header and auth_header != "Bearer": - config.api_key = auth_header + api_key = auth_header.split(None, 1)[-1] + if api_key and api_key != "Bearer": + config.api_key = api_key + + conversation = return_conversation = None + if config.conversation_id is not None and config.provider is not None: + return_conversation = True + if config.conversation_id in self.conversations: + if config.provider in self.conversations[config.conversation_id]: + conversation = self.conversations[config.conversation_id][config.provider] # Create the completion response response = self.client.chat.completions.create( @@ -194,6 +209,11 @@ class Api: "provider": AppConfig.provider, "proxy": AppConfig.proxy, **config.dict(exclude_none=True), + **{ + "conversation_id": None, + "return_conversation": return_conversation, + "conversation": conversation + } }, ignored=AppConfig.ignored_providers ), @@ -206,7 +226,13 @@ class Api: async def streaming(): try: async for chunk in response: - yield f"data: {json.dumps(chunk.to_json())}\n\n" + if isinstance(chunk, BaseConversation): + if config.conversation_id is not None and config.provider is not None: + if config.conversation_id not in self.conversations: + self.conversations[config.conversation_id] = {} + self.conversations[config.conversation_id][config.provider] = chunk + else: + yield f"data: {json.dumps(chunk.to_json())}\n\n" except GeneratorExit: pass except Exception as e: @@ -222,7 +248,13 @@ class Api: @self.app.post("/v1/images/generate") @self.app.post("/v1/images/generations") - async def generate_image(config: ImageGenerationConfig): + async def generate_image(config: ImageGenerationConfig, request: Request): + if config.api_key is None: + auth_header = request.headers.get("Authorization") + if auth_header is not None: + api_key = auth_header.split(None, 1)[-1] + if api_key and api_key != "Bearer": + config.api_key = api_key try: response = await self.client.images.generate( prompt=config.prompt, @@ -234,6 +266,9 @@ class Api: proxy = config.proxy ) ) + for image in response.data: + if hasattr(image, "url") and image.url.startswith("/"): + image.url = f"{request.base_url}{image.url.lstrip('/')}" return JSONResponse(response.to_json()) except Exception as e: logger.exception(e) @@ -243,6 +278,18 @@ class Api: async def completions(): return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json") + @self.app.get("/images/{filename}") + async def get_image(filename): + target = os.path.join(images_dir, filename) + + if not os.path.isfile(target): + return Response(status_code=404) + + with open(target, "rb") as f: + content_type = is_accepted_format(f.read(12)) + + return FileResponse(target, media_type=content_type) + def format_exception(e: Exception, config: Union[ChatCompletionsConfig, ImageGenerationConfig], image: bool = False) -> str: last_provider = {} if not image else g4f.get_last_provider(True) provider = (AppConfig.image_provider if image else AppConfig.provider) if config.provider is None else config.provider |