diff options
Diffstat (limited to 'g4f')
-rw-r--r-- | g4f/api/__init__.py | 340 | ||||
-rw-r--r-- | g4f/api/run.py | 2 |
2 files changed, 141 insertions, 201 deletions
diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index fec5606f..d86364d1 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -1,227 +1,167 @@ -import typing -from .. import BaseProvider -import g4f; g4f.debug.logging = True +from fastapi import FastAPI, Response, Request +from fastapi.middleware.cors import CORSMiddleware +from typing import List, Union, Any, Dict, AnyStr +from ._tokenizer import tokenize +import g4f import time import json import random import string -import logging - -from typing import Union -from loguru import logger -from waitress import serve -from ._logging import hook_logging -from ._tokenizer import tokenize -from flask_cors import CORS -from werkzeug.serving import WSGIRequestHandler -from werkzeug.exceptions import default_exceptions -from werkzeug.middleware.proxy_fix import ProxyFix - -from flask import ( - Flask, - jsonify, - make_response, - request, +import uvicorn +import nest_asyncio + +app = FastAPI() +nest_asyncio.apply() + +origins = [ + "http://localhost", + "http://localhost:1337", +] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], ) -class Api: - __default_ip = '127.0.0.1' - __default_port = 1337 - - def __init__(self, engine: g4f, debug: bool = True, sentry: bool = False, - list_ignored_providers:typing.List[typing.Union[str, BaseProvider]]=None) -> None: - self.engine = engine - self.debug = debug - self.sentry = sentry - self.list_ignored_providers = list_ignored_providers - self.log_level = logging.DEBUG if debug else logging.WARN - - hook_logging(level=self.log_level, format='[%(asctime)s] %(levelname)s in %(module)s: %(message)s') - self.logger = logging.getLogger('waitress') - - self.app = Flask(__name__) - self.app.wsgi_app = ProxyFix(self.app.wsgi_app, x_port=1) - self.app.after_request(self.__after_request) - - def run(self, bind_str, threads=8): - host, port = self.__parse_bind(bind_str) - - CORS(self.app, resources={r'/v1/*': {'supports_credentials': True, 'expose_headers': [ - 'Content-Type', - 'Authorization', - 'X-Requested-With', - 'Accept', - 'Origin', - 'Access-Control-Request-Method', - 'Access-Control-Request-Headers', - 'Content-Disposition'], 'max_age': 600}}) - - self.app.route('/v1/models', methods=['GET'])(self.models) - self.app.route('/v1/models/<model_id>', methods=['GET'])(self.model_info) - - self.app.route('/v1/chat/completions', methods=['POST'])(self.chat_completions) - self.app.route('/v1/completions', methods=['POST'])(self.completions) - - for ex in default_exceptions: - self.app.register_error_handler(ex, self.__handle_error) - - if not self.debug: - self.logger.warning(f'Serving on http://{host}:{port}') - - WSGIRequestHandler.protocol_version = 'HTTP/1.1' - serve(self.app, host=host, port=port, ident=None, threads=threads) - - def __handle_error(self, e: Exception): - self.logger.error(e) - - return make_response(jsonify({ - 'code': e.code, - 'message': str(e.original_exception if self.debug and hasattr(e, 'original_exception') else e.name)}), 500) - - @staticmethod - def __after_request(resp): - resp.headers['X-Server'] = f'g4f/{g4f.version}' +JSONObject = Dict[AnyStr, Any] +JSONArray = List[Any] +JSONStructure = Union[JSONArray, JSONObject] + +@app.get("/") +async def read_root(): + return Response(content=json.dumps({"info": "G4F API"}, indent=4), media_type="application/json") + +@app.get("/v1") +async def read_root_v1(): + return Response(content=json.dumps({"info": "Go to /v1/chat/completions or /v1/models."}, indent=4), media_type="application/json") + +@app.get("/v1/models") +async def models(): + model_list = [{ + 'id': model, + 'object': 'model', + 'created': 0, + 'owned_by': 'g4f'} for model in g4f.Model.__all__()] + + return Response(content=json.dumps({ + 'object': 'list', + 'data': model_list}, indent=4), media_type="application/json") + +@app.get("/v1/models/{model_name}") +async def model_info(model_name: str): + try: + model_info = (g4f.ModelUtils.convert[model_name]) - return resp - - def __parse_bind(self, bind_str): - sections = bind_str.split(':', 2) - if len(sections) < 2: - try: - port = int(sections[0]) - return self.__default_ip, port - except ValueError: - return sections[0], self.__default_port - - return sections[0], int(sections[1]) - - async def home(self): - return 'Hello world | https://127.0.0.1:1337/v1' + return Response(content=json.dumps({ + 'id': model_name, + 'object': 'model', + 'created': 0, + 'owned_by': model_info.base_provider + }, indent=4), media_type="application/json") + except: + return Response(content=json.dumps({"error": "The model does not exist."}, indent=4), media_type="application/json") + +@app.post("/v1/chat/completions") +async def chat_completions(request: Request, item: JSONStructure = None): + + item_data = { + 'model': 'gpt-3.5-turbo', + 'stream': False, + } - async def chat_completions(self): - model = request.json.get('model', 'gpt-3.5-turbo') - stream = request.json.get('stream', False) - messages = request.json.get('messages') - - logger.info(f'model: {model}, stream: {stream}, request: {messages[-1]["content"]}') + item_data.update(item or {}) + model = item_data.get('model') + stream = item_data.get('stream') + messages = item_data.get('messages') + + try: + response = g4f.ChatCompletion.create(model=model, stream=stream, messages=messages) + except: + return Response(content=json.dumps({"error": "An error occurred while generating the response."}, indent=4), media_type="application/json") + + completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28)) + completion_timestamp = int(time.time()) + + if not stream: + prompt_tokens, _ = tokenize(''.join([message['content'] for message in messages])) + completion_tokens, _ = tokenize(response) + + json_data = { + 'id': f'chatcmpl-{completion_id}', + 'object': 'chat.completion', + 'created': completion_timestamp, + 'model': model, + 'choices': [ + { + 'index': 0, + 'message': { + 'role': 'assistant', + 'content': response, + }, + 'finish_reason': 'stop', + } + ], + 'usage': { + 'prompt_tokens': prompt_tokens, + 'completion_tokens': completion_tokens, + 'total_tokens': prompt_tokens + completion_tokens, + }, + } - config = None - proxy = None + return Response(content=json.dumps(json_data, indent=4), media_type="application/json") + def streaming(): try: - config = json.load(open("config.json","r",encoding="utf-8")) - proxy = config["proxy"] + for chunk in response: + completion_data = { + 'id': f'chatcmpl-{completion_id}', + 'object': 'chat.completion.chunk', + 'created': completion_timestamp, + 'model': model, + 'choices': [ + { + 'index': 0, + 'delta': { + 'content': chunk, + }, + 'finish_reason': None, + } + ], + } - except Exception: - pass + content = json.dumps(completion_data, separators=(',', ':')) + yield f'data: {content}\n\n' + time.sleep(0.03) - if proxy != None: - response = self.engine.ChatCompletion.create(model=model, - stream=stream, messages=messages, - ignored=self.list_ignored_providers, - proxy=proxy) - else: - response = self.engine.ChatCompletion.create(model=model, - stream=stream, messages=messages, - ignored=self.list_ignored_providers) - - completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28)) - completion_timestamp = int(time.time()) - - if not stream: - prompt_tokens, _ = tokenize(''.join([message['content'] for message in messages])) - completion_tokens, _ = tokenize(response) - - return { + end_completion_data = { 'id': f'chatcmpl-{completion_id}', - 'object': 'chat.completion', + 'object': 'chat.completion.chunk', 'created': completion_timestamp, 'model': model, 'choices': [ { 'index': 0, - 'message': { - 'role': 'assistant', - 'content': response, - }, + 'delta': {}, 'finish_reason': 'stop', } ], - 'usage': { - 'prompt_tokens': prompt_tokens, - 'completion_tokens': completion_tokens, - 'total_tokens': prompt_tokens + completion_tokens, - }, } - def streaming(): - try: - for chunk in response: - completion_data = { - 'id': f'chatcmpl-{completion_id}', - 'object': 'chat.completion.chunk', - 'created': completion_timestamp, - 'model': model, - 'choices': [ - { - 'index': 0, - 'delta': { - 'content': chunk, - }, - 'finish_reason': None, - } - ], - } + content = json.dumps(end_completion_data, separators=(',', ':')) + yield f'data: {content}\n\n' - content = json.dumps(completion_data, separators=(',', ':')) - yield f'data: {content}\n\n' - time.sleep(0.03) + except GeneratorExit: + pass - end_completion_data = { - 'id': f'chatcmpl-{completion_id}', - 'object': 'chat.completion.chunk', - 'created': completion_timestamp, - 'model': model, - 'choices': [ - { - 'index': 0, - 'delta': {}, - 'finish_reason': 'stop', - } - ], - } - - content = json.dumps(end_completion_data, separators=(',', ':')) - yield f'data: {content}\n\n' - - logger.success(f'model: {model}, stream: {stream}') - - except GeneratorExit: - pass + return Response(content=json.dumps(streaming(), indent=4), media_type="application/json") - return self.app.response_class(streaming(), mimetype='text/event-stream') - - async def completions(self): - return 'not working yet', 500 - - async def model_info(self, model_name): - model_info = (g4f.ModelUtils.convert[model_name]) - - return jsonify({ - 'id' : model_name, - 'object' : 'model', - 'created' : 0, - 'owned_by' : model_info.base_provider - }) - - async def models(self): - model_list = [{ - 'id' : model, - 'object' : 'model', - 'created' : 0, - 'owned_by' : 'g4f'} for model in g4f.Model.__all__()] - - return jsonify({ - 'object': 'list', - 'data': model_list}) -
\ No newline at end of file +@app.post("/v1/completions") +async def completions(): + return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json") + +def run(ip): + split_ip = ip.split(":") + uvicorn.run(app, host=split_ip[0], port=int(split_ip[1]), use_colors=False) diff --git a/g4f/api/run.py b/g4f/api/run.py index 12bf9eed..5992ab60 100644 --- a/g4f/api/run.py +++ b/g4f/api/run.py @@ -3,4 +3,4 @@ import g4f.api if __name__ == "__main__": print(f'Starting server... [g4f v-{g4f.version}]') - g4f.api.Api(g4f).run('127.0.0.1:1337', 8)
\ No newline at end of file + g4f.api.run('127.0.0.1:1337') |