diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/api/__init__.py | 346 | ||||
-rw-r--r-- | g4f/api/_logging.py | 32 | ||||
-rw-r--r-- | g4f/api/_tokenizer.py | 9 | ||||
-rw-r--r-- | g4f/api/run.py | 5 |
4 files changed, 239 insertions, 153 deletions
diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index b19a721b..ecc70a13 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -1,162 +1,206 @@ +import g4f +import time import json import random import string -import time - -# import requests -from flask import Flask, request -from flask_cors import CORS -# from transformers import AutoTokenizer - -from g4f import ChatCompletion - -app = Flask(__name__) -CORS(app) - - -@app.route("/") -def index(): - return "interference api, url: http://127.0.0.1:1337" - - -@app.route("/chat/completions", methods=["POST"]) -def chat_completions(): - model = request.get_json().get("model", "gpt-3.5-turbo") - stream = request.get_json().get("stream", False) - messages = request.get_json().get("messages") - - response = ChatCompletion.create(model=model, stream=stream, messages=messages) - - completion_id = "".join(random.choices(string.ascii_letters + string.digits, k=28)) - completion_timestamp = int(time.time()) - - if not stream: - return { - "id": f"chatcmpl-{completion_id}", - "object": "chat.completion", - "created": completion_timestamp, - "model": model, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": response, - }, - "finish_reason": "stop", - } - ], - "usage": { - "prompt_tokens": None, - "completion_tokens": None, - "total_tokens": None, - }, - } - - def streaming(): - for chunk in response: - completion_data = { - "id": f"chatcmpl-{completion_id}", - "object": "chat.completion.chunk", - "created": completion_timestamp, - "model": model, - "choices": [ +import logging + +from typing import Union +from loguru import logger +from waitress import serve +from ._logging import hook_logging +from ._tokenizer import tokenize +from flask_cors import CORS +from werkzeug.serving import WSGIRequestHandler +from werkzeug.exceptions import default_exceptions +from werkzeug.middleware.proxy_fix import ProxyFix + +from flask import ( + Flask, + jsonify, + make_response, + request, +) + +class Api: + __default_ip = '127.0.0.1' + __default_port = 1337 + + def __init__(self, engine: g4f, debug: bool = True, sentry: bool = False) -> None: + self.engine = engine + self.debug = debug + self.sentry = sentry + self.log_level = logging.DEBUG if debug else logging.WARN + + hook_logging(level=self.log_level, format='[%(asctime)s] %(levelname)s in %(module)s: %(message)s') + self.logger = logging.getLogger('waitress') + + self.app = Flask(__name__) + self.app.wsgi_app = ProxyFix(self.app.wsgi_app, x_port=1) + self.app.after_request(self.__after_request) + + def run(self, bind_str, threads=8): + host, port = self.__parse_bind(bind_str) + + CORS(self.app, resources={r'/v1/*': {'supports_credentials': True, 'expose_headers': [ + 'Content-Type', + 'Authorization', + 'X-Requested-With', + 'Accept', + 'Origin', + 'Access-Control-Request-Method', + 'Access-Control-Request-Headers', + 'Content-Disposition'], 'max_age': 600}}) + + self.app.route('/v1/models', methods=['GET'])(self.models) + self.app.route('v1/models/<model_id>', methods=['GET'])(self.model_info) + + self.app.route('/v1/chat/completions', methods=['POST'])(self.chat_completions) + self.app.route('/v1/completions', methods=['POST'])(self.completions) + + for ex in default_exceptions: + self.app.register_error_handler(ex, self.__handle_error) + + if not self.debug: + self.logger.warning('Serving on http://{}:{}'.format(host, port)) + + WSGIRequestHandler.protocol_version = 'HTTP/1.1' + serve(self.app, host=host, port=port, ident=None, threads=threads) + + def __handle_error(self, e: Exception): + self.logger.error(e) + + return make_response(jsonify({ + 'code': e.code, + 'message': str(e.original_exception if self.debug and hasattr(e, 'original_exception') else e.name)}), 500) + + @staticmethod + def __after_request(resp): + resp.headers['X-Server'] = 'g4f/%s' % g4f.version + + return resp + + def __parse_bind(self, bind_str): + sections = bind_str.split(':', 2) + if len(sections) < 2: + try: + port = int(sections[0]) + return self.__default_ip, port + except ValueError: + return sections[0], self.__default_port + + return sections[0], int(sections[1]) + + async def home(self): + return 'Hello world | https://127.0.0.1:1337/v1' + + async def chat_completions(self): + model = request.json.get('model', 'gpt-3.5-turbo') + stream = request.json.get('stream', False) + messages = request.json.get('messages') + + logger.info(f'model: {model}, stream: {stream}, request: {messages[-1]["content"]}') + + response = self.engine.ChatCompletion.create(model=model, + stream=stream, messages=messages) + + completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28)) + completion_timestamp = int(time.time()) + + if not stream: + prompt_tokens, _ = tokenize(''.join([message['content'] for message in messages])) + completion_tokens, _ = tokenize(response) + + return { + 'id': f'chatcmpl-{completion_id}', + 'object': 'chat.completion', + 'created': completion_timestamp, + 'model': model, + 'choices': [ { - "index": 0, - "delta": { - "content": chunk, + 'index': 0, + 'message': { + 'role': 'assistant', + 'content': response, }, - "finish_reason": None, + 'finish_reason': 'stop', } ], + 'usage': { + 'prompt_tokens': prompt_tokens, + 'completion_tokens': completion_tokens, + 'total_tokens': prompt_tokens + completion_tokens, + }, } - content = json.dumps(completion_data, separators=(",", ":")) - yield f"data: {content}\n\n" - time.sleep(0.1) + def streaming(): + try: + for chunk in response: + completion_data = { + 'id': f'chatcmpl-{completion_id}', + 'object': 'chat.completion.chunk', + 'created': completion_timestamp, + 'model': model, + 'choices': [ + { + 'index': 0, + 'delta': { + 'content': chunk, + }, + 'finish_reason': None, + } + ], + } - end_completion_data = { - "id": f"chatcmpl-{completion_id}", - "object": "chat.completion.chunk", - "created": completion_timestamp, - "model": model, - "choices": [ - { - "index": 0, - "delta": {}, - "finish_reason": "stop", + content = json.dumps(completion_data, separators=(',', ':')) + yield f'data: {content}\n\n' + time.sleep(0.03) + + end_completion_data = { + 'id': f'chatcmpl-{completion_id}', + 'object': 'chat.completion.chunk', + 'created': completion_timestamp, + 'model': model, + 'choices': [ + { + 'index': 0, + 'delta': {}, + 'finish_reason': 'stop', + } + ], } - ], - } - content = json.dumps(end_completion_data, separators=(",", ":")) - yield f"data: {content}\n\n" - - return app.response_class(streaming(), mimetype="text/event-stream") - - -# Get the embedding from huggingface -# def get_embedding(input_text, token): -# huggingface_token = token -# embedding_model = "sentence-transformers/all-mpnet-base-v2" -# max_token_length = 500 - -# # Load the tokenizer for the 'all-mpnet-base-v2' model -# tokenizer = AutoTokenizer.from_pretrained(embedding_model) -# # Tokenize the text and split the tokens into chunks of 500 tokens each -# tokens = tokenizer.tokenize(input_text) -# token_chunks = [ -# tokens[i : i + max_token_length] -# for i in range(0, len(tokens), max_token_length) -# ] - -# # Initialize an empty list -# embeddings = [] - -# # Create embeddings for each chunk -# for chunk in token_chunks: -# # Convert the chunk tokens back to text -# chunk_text = tokenizer.convert_tokens_to_string(chunk) - -# # Use the Hugging Face API to get embeddings for the chunk -# api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{embedding_model}" -# headers = {"Authorization": f"Bearer {huggingface_token}"} -# chunk_text = chunk_text.replace("\n", " ") - -# # Make a POST request to get the chunk's embedding -# response = requests.post( -# api_url, -# headers=headers, -# json={"inputs": chunk_text, "options": {"wait_for_model": True}}, -# ) - -# # Parse the response and extract the embedding -# chunk_embedding = response.json() -# # Append the embedding to the list -# embeddings.append(chunk_embedding) - -# # averaging all the embeddings -# # this isn't very effective -# # someone a better idea? -# num_embeddings = len(embeddings) -# average_embedding = [sum(x) / num_embeddings for x in zip(*embeddings)] -# embedding = average_embedding -# return embedding - - -# @app.route("/embeddings", methods=["POST"]) -# def embeddings(): -# input_text_list = request.get_json().get("input") -# input_text = " ".join(map(str, input_text_list)) -# token = request.headers.get("Authorization").replace("Bearer ", "") -# embedding = get_embedding(input_text, token) - -# return { -# "data": [{"embedding": embedding, "index": 0, "object": "embedding"}], -# "model": "text-embedding-ada-002", -# "object": "list", -# "usage": {"prompt_tokens": None, "total_tokens": None}, -# } - - -def run_api(): - app.run(host="0.0.0.0", port=1337) + + content = json.dumps(end_completion_data, separators=(',', ':')) + yield f'data: {content}\n\n' + + logger.success(f'model: {model}, stream: {stream}') + + except GeneratorExit: + pass + + return self.app.response_class(streaming(), mimetype='text/event-stream') + + async def completions(self): + return 'not working yet', 500 + + async def model_info(self, model_name): + model_info = (g4f.ModelUtils.convert[model_name]) + + return jsonify({ + 'id' : model_name, + 'object' : 'model', + 'created' : 0, + 'owned_by' : model_info.base_provider + }) + + async def models(self): + model_list = [{ + 'id' : model, + 'object' : 'model', + 'created' : 0, + 'owned_by' : 'g4f'} for model in g4f.Model.__all__()] + + return jsonify({ + 'object': 'list', + 'data': model_list}) +
\ No newline at end of file diff --git a/g4f/api/_logging.py b/g4f/api/_logging.py new file mode 100644 index 00000000..e91dff76 --- /dev/null +++ b/g4f/api/_logging.py @@ -0,0 +1,32 @@ +import sys,logging + +from loguru import logger + +def __exception_handle(e_type, e_value, e_traceback): + if issubclass(e_type, KeyboardInterrupt): + print('\nBye...') + sys.exit(0) + + sys.__excepthook__(e_type, e_value, e_traceback) + +class __InterceptHandler(logging.Handler): + def emit(self, record): + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno + + frame, depth = logging.currentframe(), 2 + while frame.f_code.co_filename == logging.__file__: + frame = frame.f_back + depth += 1 + + logger.opt(depth=depth, exception=record.exc_info).log( + level, record.getMessage() + ) + +def hook_except_handle(): + sys.excepthook = __exception_handle + +def hook_logging(**kwargs): + logging.basicConfig(handlers=[__InterceptHandler()], **kwargs) diff --git a/g4f/api/_tokenizer.py b/g4f/api/_tokenizer.py new file mode 100644 index 00000000..fd8f9d5a --- /dev/null +++ b/g4f/api/_tokenizer.py @@ -0,0 +1,9 @@ +import tiktoken +from typing import Union + +def tokenize(text: str, model: str = 'gpt-3.5-turbo') -> Union[int, str]: + encoding = tiktoken.encoding_for_model(model) + encoded = encoding.encode(text) + num_tokens = len(encoded) + + return num_tokens, encoded
\ No newline at end of file diff --git a/g4f/api/run.py b/g4f/api/run.py index 6e9b63f3..d214aae7 100644 --- a/g4f/api/run.py +++ b/g4f/api/run.py @@ -1,4 +1,5 @@ -from g4f.api import run_api +import g4f +import g4f.api if __name__ == "__main__": - run_api() + g4f.api.Api(g4f).run('localhost:1337', 8)
\ No newline at end of file |