diff options
author | Arran Hobson Sayers <ahobsonsayers@gmail.com> | 2023-10-12 03:35:11 +0200 |
---|---|---|
committer | Arran Hobson Sayers <ahobsonsayers@gmail.com> | 2023-10-12 03:35:11 +0200 |
commit | 77697be33381a01350d0818ff069469faea2f4ac (patch) | |
tree | 418bfd7e3a9d01b94a6dcc3077c96ca87e674e73 /g4f/api/__init__.py | |
parent | ~ (diff) | |
download | gpt4free-77697be33381a01350d0818ff069469faea2f4ac.tar gpt4free-77697be33381a01350d0818ff069469faea2f4ac.tar.gz gpt4free-77697be33381a01350d0818ff069469faea2f4ac.tar.bz2 gpt4free-77697be33381a01350d0818ff069469faea2f4ac.tar.lz gpt4free-77697be33381a01350d0818ff069469faea2f4ac.tar.xz gpt4free-77697be33381a01350d0818ff069469faea2f4ac.tar.zst gpt4free-77697be33381a01350d0818ff069469faea2f4ac.zip |
Diffstat (limited to 'g4f/api/__init__.py')
-rw-r--r-- | g4f/api/__init__.py | 162 |
1 files changed, 162 insertions, 0 deletions
diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py new file mode 100644 index 00000000..c52085dc --- /dev/null +++ b/g4f/api/__init__.py @@ -0,0 +1,162 @@ +import json +import random +import string +import time + +import requests +from flask import Flask, request +from flask_cors import CORS +from transformers import AutoTokenizer + +from g4f import ChatCompletion + +app = Flask(__name__) +CORS(app) + + +@app.route("/") +def index(): + return "interference api, url: http://127.0.0.1:1337" + + +@app.route("/chat/completions", methods=["POST"]) +def chat_completions(): + model = request.get_json().get("model", "gpt-3.5-turbo") + stream = request.get_json().get("stream", False) + messages = request.get_json().get("messages") + + response = ChatCompletion.create(model=model, stream=stream, messages=messages) + + completion_id = "".join(random.choices(string.ascii_letters + string.digits, k=28)) + completion_timestamp = int(time.time()) + + if not stream: + return { + "id": f"chatcmpl-{completion_id}", + "object": "chat.completion", + "created": completion_timestamp, + "model": model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": response, + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": None, + "completion_tokens": None, + "total_tokens": None, + }, + } + + def streaming(): + for chunk in response: + completion_data = { + "id": f"chatcmpl-{completion_id}", + "object": "chat.completion.chunk", + "created": completion_timestamp, + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": chunk, + }, + "finish_reason": None, + } + ], + } + + content = json.dumps(completion_data, separators=(",", ":")) + yield f"data: {content}\n\n" + time.sleep(0.1) + + end_completion_data = { + "id": f"chatcmpl-{completion_id}", + "object": "chat.completion.chunk", + "created": completion_timestamp, + "model": model, + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": "stop", + } + ], + } + content = json.dumps(end_completion_data, separators=(",", ":")) + yield f"data: {content}\n\n" + + return app.response_class(streaming(), mimetype="text/event-stream") + + +# Get the embedding from huggingface +def get_embedding(input_text, token): + huggingface_token = token + embedding_model = "sentence-transformers/all-mpnet-base-v2" + max_token_length = 500 + + # Load the tokenizer for the 'all-mpnet-base-v2' model + tokenizer = AutoTokenizer.from_pretrained(embedding_model) + # Tokenize the text and split the tokens into chunks of 500 tokens each + tokens = tokenizer.tokenize(input_text) + token_chunks = [ + tokens[i : i + max_token_length] + for i in range(0, len(tokens), max_token_length) + ] + + # Initialize an empty list + embeddings = [] + + # Create embeddings for each chunk + for chunk in token_chunks: + # Convert the chunk tokens back to text + chunk_text = tokenizer.convert_tokens_to_string(chunk) + + # Use the Hugging Face API to get embeddings for the chunk + api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{embedding_model}" + headers = {"Authorization": f"Bearer {huggingface_token}"} + chunk_text = chunk_text.replace("\n", " ") + + # Make a POST request to get the chunk's embedding + response = requests.post( + api_url, + headers=headers, + json={"inputs": chunk_text, "options": {"wait_for_model": True}}, + ) + + # Parse the response and extract the embedding + chunk_embedding = response.json() + # Append the embedding to the list + embeddings.append(chunk_embedding) + + # averaging all the embeddings + # this isn't very effective + # someone a better idea? + num_embeddings = len(embeddings) + average_embedding = [sum(x) / num_embeddings for x in zip(*embeddings)] + embedding = average_embedding + return embedding + + +@app.route("/embeddings", methods=["POST"]) +def embeddings(): + input_text_list = request.get_json().get("input") + input_text = " ".join(map(str, input_text_list)) + token = request.headers.get("Authorization").replace("Bearer ", "") + embedding = get_embedding(input_text, token) + + return { + "data": [{"embedding": embedding, "index": 0, "object": "embedding"}], + "model": "text-embedding-ada-002", + "object": "list", + "usage": {"prompt_tokens": None, "total_tokens": None}, + } + + +def run_api(): + app.run(host="0.0.0.0", port=1337) |