summaryrefslogtreecommitdiffstats
path: root/g4f/api
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/api/__init__.py346
-rw-r--r--g4f/api/_logging.py32
-rw-r--r--g4f/api/_tokenizer.py9
-rw-r--r--g4f/api/run.py5
4 files changed, 239 insertions, 153 deletions
diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py
index b19a721b..ecc70a13 100644
--- a/g4f/api/__init__.py
+++ b/g4f/api/__init__.py
@@ -1,162 +1,206 @@
+import g4f
+import time
import json
import random
import string
-import time
-
-# import requests
-from flask import Flask, request
-from flask_cors import CORS
-# from transformers import AutoTokenizer
-
-from g4f import ChatCompletion
-
-app = Flask(__name__)
-CORS(app)
-
-
-@app.route("/")
-def index():
- return "interference api, url: http://127.0.0.1:1337"
-
-
-@app.route("/chat/completions", methods=["POST"])
-def chat_completions():
- model = request.get_json().get("model", "gpt-3.5-turbo")
- stream = request.get_json().get("stream", False)
- messages = request.get_json().get("messages")
-
- response = ChatCompletion.create(model=model, stream=stream, messages=messages)
-
- completion_id = "".join(random.choices(string.ascii_letters + string.digits, k=28))
- completion_timestamp = int(time.time())
-
- if not stream:
- return {
- "id": f"chatcmpl-{completion_id}",
- "object": "chat.completion",
- "created": completion_timestamp,
- "model": model,
- "choices": [
- {
- "index": 0,
- "message": {
- "role": "assistant",
- "content": response,
- },
- "finish_reason": "stop",
- }
- ],
- "usage": {
- "prompt_tokens": None,
- "completion_tokens": None,
- "total_tokens": None,
- },
- }
-
- def streaming():
- for chunk in response:
- completion_data = {
- "id": f"chatcmpl-{completion_id}",
- "object": "chat.completion.chunk",
- "created": completion_timestamp,
- "model": model,
- "choices": [
+import logging
+
+from typing import Union
+from loguru import logger
+from waitress import serve
+from ._logging import hook_logging
+from ._tokenizer import tokenize
+from flask_cors import CORS
+from werkzeug.serving import WSGIRequestHandler
+from werkzeug.exceptions import default_exceptions
+from werkzeug.middleware.proxy_fix import ProxyFix
+
+from flask import (
+ Flask,
+ jsonify,
+ make_response,
+ request,
+)
+
+class Api:
+ __default_ip = '127.0.0.1'
+ __default_port = 1337
+
+ def __init__(self, engine: g4f, debug: bool = True, sentry: bool = False) -> None:
+ self.engine = engine
+ self.debug = debug
+ self.sentry = sentry
+ self.log_level = logging.DEBUG if debug else logging.WARN
+
+ hook_logging(level=self.log_level, format='[%(asctime)s] %(levelname)s in %(module)s: %(message)s')
+ self.logger = logging.getLogger('waitress')
+
+ self.app = Flask(__name__)
+ self.app.wsgi_app = ProxyFix(self.app.wsgi_app, x_port=1)
+ self.app.after_request(self.__after_request)
+
+ def run(self, bind_str, threads=8):
+ host, port = self.__parse_bind(bind_str)
+
+ CORS(self.app, resources={r'/v1/*': {'supports_credentials': True, 'expose_headers': [
+ 'Content-Type',
+ 'Authorization',
+ 'X-Requested-With',
+ 'Accept',
+ 'Origin',
+ 'Access-Control-Request-Method',
+ 'Access-Control-Request-Headers',
+ 'Content-Disposition'], 'max_age': 600}})
+
+ self.app.route('/v1/models', methods=['GET'])(self.models)
+ self.app.route('v1/models/<model_id>', methods=['GET'])(self.model_info)
+
+ self.app.route('/v1/chat/completions', methods=['POST'])(self.chat_completions)
+ self.app.route('/v1/completions', methods=['POST'])(self.completions)
+
+ for ex in default_exceptions:
+ self.app.register_error_handler(ex, self.__handle_error)
+
+ if not self.debug:
+ self.logger.warning('Serving on http://{}:{}'.format(host, port))
+
+ WSGIRequestHandler.protocol_version = 'HTTP/1.1'
+ serve(self.app, host=host, port=port, ident=None, threads=threads)
+
+ def __handle_error(self, e: Exception):
+ self.logger.error(e)
+
+ return make_response(jsonify({
+ 'code': e.code,
+ 'message': str(e.original_exception if self.debug and hasattr(e, 'original_exception') else e.name)}), 500)
+
+ @staticmethod
+ def __after_request(resp):
+ resp.headers['X-Server'] = 'g4f/%s' % g4f.version
+
+ return resp
+
+ def __parse_bind(self, bind_str):
+ sections = bind_str.split(':', 2)
+ if len(sections) < 2:
+ try:
+ port = int(sections[0])
+ return self.__default_ip, port
+ except ValueError:
+ return sections[0], self.__default_port
+
+ return sections[0], int(sections[1])
+
+ async def home(self):
+ return 'Hello world | https://127.0.0.1:1337/v1'
+
+ async def chat_completions(self):
+ model = request.json.get('model', 'gpt-3.5-turbo')
+ stream = request.json.get('stream', False)
+ messages = request.json.get('messages')
+
+ logger.info(f'model: {model}, stream: {stream}, request: {messages[-1]["content"]}')
+
+ response = self.engine.ChatCompletion.create(model=model,
+ stream=stream, messages=messages)
+
+ completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
+ completion_timestamp = int(time.time())
+
+ if not stream:
+ prompt_tokens, _ = tokenize(''.join([message['content'] for message in messages]))
+ completion_tokens, _ = tokenize(response)
+
+ return {
+ 'id': f'chatcmpl-{completion_id}',
+ 'object': 'chat.completion',
+ 'created': completion_timestamp,
+ 'model': model,
+ 'choices': [
{
- "index": 0,
- "delta": {
- "content": chunk,
+ 'index': 0,
+ 'message': {
+ 'role': 'assistant',
+ 'content': response,
},
- "finish_reason": None,
+ 'finish_reason': 'stop',
}
],
+ 'usage': {
+ 'prompt_tokens': prompt_tokens,
+ 'completion_tokens': completion_tokens,
+ 'total_tokens': prompt_tokens + completion_tokens,
+ },
}
- content = json.dumps(completion_data, separators=(",", ":"))
- yield f"data: {content}\n\n"
- time.sleep(0.1)
+ def streaming():
+ try:
+ for chunk in response:
+ completion_data = {
+ 'id': f'chatcmpl-{completion_id}',
+ 'object': 'chat.completion.chunk',
+ 'created': completion_timestamp,
+ 'model': model,
+ 'choices': [
+ {
+ 'index': 0,
+ 'delta': {
+ 'content': chunk,
+ },
+ 'finish_reason': None,
+ }
+ ],
+ }
- end_completion_data = {
- "id": f"chatcmpl-{completion_id}",
- "object": "chat.completion.chunk",
- "created": completion_timestamp,
- "model": model,
- "choices": [
- {
- "index": 0,
- "delta": {},
- "finish_reason": "stop",
+ content = json.dumps(completion_data, separators=(',', ':'))
+ yield f'data: {content}\n\n'
+ time.sleep(0.03)
+
+ end_completion_data = {
+ 'id': f'chatcmpl-{completion_id}',
+ 'object': 'chat.completion.chunk',
+ 'created': completion_timestamp,
+ 'model': model,
+ 'choices': [
+ {
+ 'index': 0,
+ 'delta': {},
+ 'finish_reason': 'stop',
+ }
+ ],
}
- ],
- }
- content = json.dumps(end_completion_data, separators=(",", ":"))
- yield f"data: {content}\n\n"
-
- return app.response_class(streaming(), mimetype="text/event-stream")
-
-
-# Get the embedding from huggingface
-# def get_embedding(input_text, token):
-# huggingface_token = token
-# embedding_model = "sentence-transformers/all-mpnet-base-v2"
-# max_token_length = 500
-
-# # Load the tokenizer for the 'all-mpnet-base-v2' model
-# tokenizer = AutoTokenizer.from_pretrained(embedding_model)
-# # Tokenize the text and split the tokens into chunks of 500 tokens each
-# tokens = tokenizer.tokenize(input_text)
-# token_chunks = [
-# tokens[i : i + max_token_length]
-# for i in range(0, len(tokens), max_token_length)
-# ]
-
-# # Initialize an empty list
-# embeddings = []
-
-# # Create embeddings for each chunk
-# for chunk in token_chunks:
-# # Convert the chunk tokens back to text
-# chunk_text = tokenizer.convert_tokens_to_string(chunk)
-
-# # Use the Hugging Face API to get embeddings for the chunk
-# api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{embedding_model}"
-# headers = {"Authorization": f"Bearer {huggingface_token}"}
-# chunk_text = chunk_text.replace("\n", " ")
-
-# # Make a POST request to get the chunk's embedding
-# response = requests.post(
-# api_url,
-# headers=headers,
-# json={"inputs": chunk_text, "options": {"wait_for_model": True}},
-# )
-
-# # Parse the response and extract the embedding
-# chunk_embedding = response.json()
-# # Append the embedding to the list
-# embeddings.append(chunk_embedding)
-
-# # averaging all the embeddings
-# # this isn't very effective
-# # someone a better idea?
-# num_embeddings = len(embeddings)
-# average_embedding = [sum(x) / num_embeddings for x in zip(*embeddings)]
-# embedding = average_embedding
-# return embedding
-
-
-# @app.route("/embeddings", methods=["POST"])
-# def embeddings():
-# input_text_list = request.get_json().get("input")
-# input_text = " ".join(map(str, input_text_list))
-# token = request.headers.get("Authorization").replace("Bearer ", "")
-# embedding = get_embedding(input_text, token)
-
-# return {
-# "data": [{"embedding": embedding, "index": 0, "object": "embedding"}],
-# "model": "text-embedding-ada-002",
-# "object": "list",
-# "usage": {"prompt_tokens": None, "total_tokens": None},
-# }
-
-
-def run_api():
- app.run(host="0.0.0.0", port=1337)
+
+ content = json.dumps(end_completion_data, separators=(',', ':'))
+ yield f'data: {content}\n\n'
+
+ logger.success(f'model: {model}, stream: {stream}')
+
+ except GeneratorExit:
+ pass
+
+ return self.app.response_class(streaming(), mimetype='text/event-stream')
+
+ async def completions(self):
+ return 'not working yet', 500
+
+ async def model_info(self, model_name):
+ model_info = (g4f.ModelUtils.convert[model_name])
+
+ return jsonify({
+ 'id' : model_name,
+ 'object' : 'model',
+ 'created' : 0,
+ 'owned_by' : model_info.base_provider
+ })
+
+ async def models(self):
+ model_list = [{
+ 'id' : model,
+ 'object' : 'model',
+ 'created' : 0,
+ 'owned_by' : 'g4f'} for model in g4f.Model.__all__()]
+
+ return jsonify({
+ 'object': 'list',
+ 'data': model_list})
+ \ No newline at end of file
diff --git a/g4f/api/_logging.py b/g4f/api/_logging.py
new file mode 100644
index 00000000..e91dff76
--- /dev/null
+++ b/g4f/api/_logging.py
@@ -0,0 +1,32 @@
+import sys,logging
+
+from loguru import logger
+
+def __exception_handle(e_type, e_value, e_traceback):
+ if issubclass(e_type, KeyboardInterrupt):
+ print('\nBye...')
+ sys.exit(0)
+
+ sys.__excepthook__(e_type, e_value, e_traceback)
+
+class __InterceptHandler(logging.Handler):
+ def emit(self, record):
+ try:
+ level = logger.level(record.levelname).name
+ except ValueError:
+ level = record.levelno
+
+ frame, depth = logging.currentframe(), 2
+ while frame.f_code.co_filename == logging.__file__:
+ frame = frame.f_back
+ depth += 1
+
+ logger.opt(depth=depth, exception=record.exc_info).log(
+ level, record.getMessage()
+ )
+
+def hook_except_handle():
+ sys.excepthook = __exception_handle
+
+def hook_logging(**kwargs):
+ logging.basicConfig(handlers=[__InterceptHandler()], **kwargs)
diff --git a/g4f/api/_tokenizer.py b/g4f/api/_tokenizer.py
new file mode 100644
index 00000000..fd8f9d5a
--- /dev/null
+++ b/g4f/api/_tokenizer.py
@@ -0,0 +1,9 @@
+import tiktoken
+from typing import Union
+
+def tokenize(text: str, model: str = 'gpt-3.5-turbo') -> Union[int, str]:
+ encoding = tiktoken.encoding_for_model(model)
+ encoded = encoding.encode(text)
+ num_tokens = len(encoded)
+
+ return num_tokens, encoded \ No newline at end of file
diff --git a/g4f/api/run.py b/g4f/api/run.py
index 6e9b63f3..d214aae7 100644
--- a/g4f/api/run.py
+++ b/g4f/api/run.py
@@ -1,4 +1,5 @@
-from g4f.api import run_api
+import g4f
+import g4f.api
if __name__ == "__main__":
- run_api()
+ g4f.api.Api(g4f).run('localhost:1337', 8) \ No newline at end of file