summaryrefslogtreecommitdiffstats
path: root/g4f/api
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/api/__init__.py103
1 files changed, 85 insertions, 18 deletions
diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py
index 292164fa..628d7512 100644
--- a/g4f/api/__init__.py
+++ b/g4f/api/__init__.py
@@ -5,8 +5,10 @@ import json
import uvicorn
import secrets
import os
+import shutil
-from fastapi import FastAPI, Response, Request
+import os.path
+from fastapi import FastAPI, Response, Request, UploadFile
from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse
from fastapi.exceptions import RequestValidationError
from fastapi.security import APIKeyHeader
@@ -16,16 +18,17 @@ from fastapi.encoders import jsonable_encoder
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import FileResponse
from pydantic import BaseModel
-from typing import Union, Optional
+from typing import Union, Optional, List
import g4f
import g4f.debug
-from g4f.client import AsyncClient, ChatCompletion
+from g4f.client import AsyncClient, ChatCompletion, convert_to_provider
from g4f.providers.response import BaseConversation
from g4f.client.helper import filter_none
from g4f.image import is_accepted_format, images_dir
from g4f.typing import Messages
-from g4f.cookies import read_cookie_files
+from g4f.errors import ProviderNotFoundError
+from g4f.cookies import read_cookie_files, get_cookies_dir
from g4f.Provider import ProviderType, ProviderUtils, __providers__
logger = logging.getLogger(__name__)
@@ -78,6 +81,18 @@ class ImageGenerationConfig(BaseModel):
api_key: Optional[str] = None
proxy: Optional[str] = None
+class ProviderResponseModel(BaseModel):
+ id: str
+ object: str = "provider"
+ created: int
+ owned_by: Optional[str]
+
+class ModelResponseModel(BaseModel):
+ id: str
+ object: str = "model"
+ created: int
+ owned_by: Optional[str]
+
class AppConfig:
ignored_providers: Optional[list[str]] = None
g4f_api_key: Optional[str] = None
@@ -109,7 +124,7 @@ class Api:
def register_authorization(self):
@self.app.middleware("http")
async def authorization(request: Request, call_next):
- if self.g4f_api_key and request.url.path in ["/v1/chat/completions", "/v1/completions", "/v1/images/generate"]:
+ if self.g4f_api_key and request.url.path not in ("/", "/v1"):
try:
user_g4f_api_key = await self.get_g4f_api_key(request)
except HTTPException as e:
@@ -123,9 +138,7 @@ class Api:
status_code=HTTP_403_FORBIDDEN,
content=jsonable_encoder({"detail": "Invalid G4F API key"}),
)
-
- response = await call_next(request)
- return response
+ return await call_next(request)
def register_validation_exception_handler(self):
@self.app.exception_handler(RequestValidationError)
@@ -158,22 +171,21 @@ class Api:
'<a href="/docs">/docs</a>')
@self.app.get("/v1/models")
- async def models():
+ async def models() -> list[ModelResponseModel]:
model_list = dict(
(model, g4f.models.ModelUtils.convert[model])
for model in g4f.Model.__all__()
)
- model_list = [{
+ return [{
'id': model_id,
'object': 'model',
'created': 0,
'owned_by': model.base_provider
} for model_id, model in model_list.items()]
- return JSONResponse(model_list)
@self.app.get("/v1/models/{model_name}")
async def model_info(model_name: str):
- try:
+ if model_name in g4f.models.ModelUtils.convert:
model_info = g4f.models.ModelUtils.convert[model_name]
return JSONResponse({
'id': model_name,
@@ -181,8 +193,7 @@ class Api:
'created': 0,
'owned_by': model_info.base_provider
})
- except:
- return JSONResponse({"error": "The model does not exist."})
+ return JSONResponse({"error": "The model does not exist."}, 404)
@self.app.post("/v1/chat/completions")
async def chat_completions(config: ChatCompletionsConfig, request: Request = None, provider: str = None):
@@ -277,12 +288,68 @@ class Api:
logger.exception(e)
return Response(content=format_exception(e, config, True), status_code=500, media_type="application/json")
- @self.app.post("/v1/completions")
- async def completions():
- return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json")
+ @self.app.get("/v1/providers")
+ async def providers() -> list[ProviderResponseModel]:
+ return [{
+ 'id': provider.__name__,
+ 'object': 'provider',
+ 'created': 0,
+ 'url': provider.url,
+ 'label': getattr(provider, "label", None),
+ } for provider in __providers__ if provider.working]
+
+ @self.app.get("/v1/providers/{provider}")
+ async def providers_info(provider: str) -> ProviderResponseModel:
+ if provider not in ProviderUtils.convert:
+ return JSONResponse({"error": "The provider does not exist."}, 404)
+ provider: ProviderType = ProviderUtils.convert[provider]
+ def safe_get_models(provider: ProviderType) -> list[str]:
+ try:
+ return provider.get_models() if hasattr(provider, "get_models") else []
+ except:
+ return []
+ return {
+ 'id': provider.__name__,
+ 'object': 'provider',
+ 'created': 0,
+ 'url': provider.url,
+ 'label': getattr(provider, "label", None),
+ 'models': safe_get_models(provider),
+ 'image_models': getattr(provider, "image_models", []) or [],
+ 'vision_models': [model for model in [getattr(provider, "default_vision_model", None)] if model],
+ 'params': [*provider.get_parameters()] if hasattr(provider, "get_parameters") else []
+ }
+
+ @self.app.post("/v1/upload_cookies")
+ def upload_cookies(files: List[UploadFile]):
+ response_data = []
+ for file in files:
+ try:
+ if file and file.filename.endswith(".json") or file.filename.endswith(".har"):
+ filename = os.path.basename(file.filename)
+ with open(os.path.join(get_cookies_dir(), filename), 'wb') as f:
+ shutil.copyfileobj(file.file, f)
+ response_data.append({"filename": filename})
+ finally:
+ file.file.close()
+ return response_data
+
+ @self.app.get("/v1/synthesize/{provider}")
+ async def synthesize(request: Request, provider: str):
+ try:
+ provider_handler = convert_to_provider(provider)
+ except ProviderNotFoundError:
+ return Response("Provider not found", 404)
+ if not hasattr(provider_handler, "synthesize"):
+ return Response("Provider doesn't support synthesize", 500)
+ if len(request.query_params) == 0:
+ return Response("Missing query params", 500)
+ response_data = provider_handler.synthesize({**request.query_params})
+ content_type = getattr(provider_handler, "synthesize_content_type", "application/octet-stream")
+ return StreamingResponse(response_data, media_type=content_type)
@self.app.get("/images/{filename}")
- async def get_image(filename):
+ async def get_image(filename) -> FileResponse:
target = os.path.join(images_dir, filename)
if not os.path.isfile(target):