diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/api/__init__.py | 92 |
1 files changed, 51 insertions, 41 deletions
diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index e7f87260..02ba5260 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -14,11 +14,12 @@ from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_401_UNAUTHORIZE from fastapi.encoders import jsonable_encoder from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel -from typing import Union, Optional, Iterator +from typing import Union, Optional import g4f import g4f.debug -from g4f.client import Client, ChatCompletion, ChatCompletionChunk, ImagesResponse +from g4f.client import AsyncClient, ChatCompletion +from g4f.client.helper import filter_none from g4f.typing import Messages from g4f.cookies import read_cookie_files @@ -47,6 +48,10 @@ def create_app(g4f_api_key: str = None): return app +def create_app_debug(g4f_api_key: str = None): + g4f.debug.logging = True + return create_app(g4f_api_key) + class ChatCompletionsConfig(BaseModel): messages: Messages model: str @@ -62,13 +67,19 @@ class ChatCompletionsConfig(BaseModel): class ImageGenerationConfig(BaseModel): prompt: str model: Optional[str] = None + provider: Optional[str] = None response_format: str = "url" + api_key: Optional[str] = None + proxy: Optional[str] = None class AppConfig: ignored_providers: Optional[list[str]] = None g4f_api_key: Optional[str] = None ignore_cookie_files: bool = False - defaults: dict = {} + model: str = None, + provider: str = None + image_provider: str = None + proxy: str = None @classmethod def set_config(cls, **data): @@ -84,7 +95,7 @@ def set_list_ignored_providers(ignored: list[str]): class Api: def __init__(self, app: FastAPI, g4f_api_key=None) -> None: self.app = app - self.client = Client() + self.client = AsyncClient() self.g4f_api_key = g4f_api_key self.get_g4f_api_key = APIKeyHeader(name="g4f-api-key") @@ -133,8 +144,8 @@ class Api: @self.app.get("/v1") async def read_root_v1(): return HTMLResponse('g4f API: Go to ' - '<a href="/v1/chat/completions">chat/completions</a>, ' - '<a href="/v1/models">models</a>, or ' + '<a href="/v1/models">models</a>, ' + '<a href="/v1/chat/completions">chat/completions</a>, or ' '<a href="/v1/images/generate">images/generate</a>.') @self.app.get("/v1/models") @@ -177,31 +188,24 @@ class Api: # Create the completion response response = self.client.chat.completions.create( - **{ - **AppConfig.defaults, - **config.dict(exclude_none=True), - }, - ignored=AppConfig.ignored_providers + **filter_none( + **{ + "model": AppConfig.model, + "provider": AppConfig.provider, + "proxy": AppConfig.proxy, + **config.dict(exclude_none=True), + }, + ignored=AppConfig.ignored_providers + ), ) - # Check if the response is synchronous or asynchronous - if isinstance(response, ChatCompletion): - # Synchronous response - return JSONResponse(response.to_json()) - if not config.stream: - # If the response is an iterator but not streaming, collect the result - response_list = list(response) if isinstance(response, Iterator) else [response] - return JSONResponse(response_list[0].to_json()) - - # Streaming response - async def async_generator(sync_gen): - for item in sync_gen: - yield item + response: ChatCompletion = await response + return JSONResponse(response.to_json()) async def streaming(): try: - async for chunk in async_generator(response): + async for chunk in response: yield f"data: {json.dumps(chunk.to_json())}\n\n" except GeneratorExit: pass @@ -217,30 +221,38 @@ class Api: return Response(content=format_exception(e, config), status_code=500, media_type="application/json") @self.app.post("/v1/images/generate") + @self.app.post("/v1/images/generations") async def generate_image(config: ImageGenerationConfig): try: - response: ImagesResponse = await self.client.images.async_generate( + response = await self.client.images.generate( prompt=config.prompt, model=config.model, - response_format=config.response_format + provider=AppConfig.image_provider if config.provider is None else config.provider, + **filter_none( + response_format = config.response_format, + api_key = config.api_key, + proxy = config.proxy + ) ) - # Convert Image objects to dictionaries - response_data = [{"url": image.url, "b64_json": image.b64_json} for image in response.data] - return JSONResponse({"data": response_data}) + return JSONResponse(response.to_json()) except Exception as e: logger.exception(e) - return Response(content=format_exception(e, config), status_code=500, media_type="application/json") + return Response(content=format_exception(e, config, True), status_code=500, media_type="application/json") @self.app.post("/v1/completions") async def completions(): return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json") -def format_exception(e: Exception, config: Union[ChatCompletionsConfig, ImageGenerationConfig]) -> str: - last_provider = g4f.get_last_provider(True) +def format_exception(e: Exception, config: Union[ChatCompletionsConfig, ImageGenerationConfig], image: bool = False) -> str: + last_provider = {} if not image else g4f.get_last_provider(True) + provider = (AppConfig.image_provider if image else AppConfig.provider) if config.provider is None else config.provider + model = AppConfig.model if config.model is None else config.model return json.dumps({ "error": {"message": f"{e.__class__.__name__}: {e}"}, - "model": last_provider.get("model") if last_provider else getattr(config, 'model', None), - "provider": last_provider.get("name") if last_provider else getattr(config, 'provider', None) + "model": last_provider.get("model") if model is None else model, + **filter_none( + provider=last_provider.get("name") if provider is None else provider + ) }) def run_api( @@ -250,21 +262,19 @@ def run_api( debug: bool = False, workers: int = None, use_colors: bool = None, - g4f_api_key: str = None + reload: bool = False ) -> None: print(f'Starting server... [g4f v-{g4f.version.utils.current_version}]' + (" (debug)" if debug else "")) if use_colors is None: use_colors = debug if bind is not None: host, port = bind.split(":") - if debug: - g4f.debug.logging = True uvicorn.run( - "g4f.api:create_app", + f"g4f.api:create_app{'_debug' if debug else ''}", host=host, port=int(port), workers=workers, use_colors=use_colors, factory=True, - reload=debug - ) + reload=reload + )
\ No newline at end of file |