summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTekky <98614666+xtekky@users.noreply.github.com>2023-09-22 21:30:06 +0200
committerGitHub <noreply@github.com>2023-09-22 21:30:06 +0200
commit2cb59b4e10dc387ec4c3fafe2c889cadcc59c376 (patch)
treeb0984639158826c8dc1b50d0feeee035eafe8c8e
parent~ (diff)
parentimport AutoTokenizer in app.py (diff)
downloadgpt4free-2cb59b4e10dc387ec4c3fafe2c889cadcc59c376.tar
gpt4free-2cb59b4e10dc387ec4c3fafe2c889cadcc59c376.tar.gz
gpt4free-2cb59b4e10dc387ec4c3fafe2c889cadcc59c376.tar.bz2
gpt4free-2cb59b4e10dc387ec4c3fafe2c889cadcc59c376.tar.lz
gpt4free-2cb59b4e10dc387ec4c3fafe2c889cadcc59c376.tar.xz
gpt4free-2cb59b4e10dc387ec4c3fafe2c889cadcc59c376.tar.zst
gpt4free-2cb59b4e10dc387ec4c3fafe2c889cadcc59c376.zip
-rw-r--r--README.md5
-rw-r--r--interference/app.py70
2 files changed, 71 insertions, 4 deletions
diff --git a/README.md b/README.md
index 13b781e5..763417ba 100644
--- a/README.md
+++ b/README.md
@@ -279,6 +279,9 @@ asyncio.run(run_async())
### interference openai-proxy api (use with openai python package)
+If you want to use the embedding function, you need to get a huggingface token. You can get one at https://huggingface.co/settings/tokens make sure your role is set to write. If you have your token, just use it instead of the OpenAI api-key.
+
+
get requirements:
```sh
@@ -294,7 +297,7 @@ python3 -m interference.app
```py
import openai
-openai.api_key = ""
+openai.api_key = "Empty if you don't use embeddings, otherwise your hugginface token"
openai.api_base = "http://localhost:1337"
diff --git a/interference/app.py b/interference/app.py
index 1b1af22f..15e5bc80 100644
--- a/interference/app.py
+++ b/interference/app.py
@@ -3,10 +3,10 @@ import random
import string
import time
from typing import Any
-
+import requests
from flask import Flask, request
from flask_cors import CORS
-
+from transformers import AutoTokenizer
from g4f import ChatCompletion
app = Flask(__name__)
@@ -88,9 +88,73 @@ def chat_completions():
return app.response_class(streaming(), mimetype="text/event-stream")
+#Get the embedding from huggingface
+def get_embedding(input_text, token):
+ huggingface_token = token
+ embedding_model = "sentence-transformers/all-mpnet-base-v2"
+ max_token_length = 500
+
+ # Load the tokenizer for the "all-mpnet-base-v2" model
+ tokenizer = AutoTokenizer.from_pretrained(embedding_model)
+ # Tokenize the text and split the tokens into chunks of 500 tokens each
+ tokens = tokenizer.tokenize(input_text)
+ token_chunks = [tokens[i:i + max_token_length] for i in range(0, len(tokens), max_token_length)]
+
+ # Initialize an empty list
+ embeddings = []
+
+ # Create embeddings for each chunk
+ for chunk in token_chunks:
+ # Convert the chunk tokens back to text
+ chunk_text = tokenizer.convert_tokens_to_string(chunk)
+
+ # Use the Hugging Face API to get embeddings for the chunk
+ api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{embedding_model}"
+ headers = {"Authorization": f"Bearer {huggingface_token}"}
+ chunk_text = chunk_text.replace("\n", " ")
+
+ # Make a POST request to get the chunk's embedding
+ response = requests.post(api_url, headers=headers, json={"inputs": chunk_text, "options": {"wait_for_model": True}})
+
+ # Parse the response and extract the embedding
+ chunk_embedding = response.json()
+ # Append the embedding to the list
+ embeddings.append(chunk_embedding)
+
+ #averaging all the embeddings
+ #this isn't very effective
+ #someone a better idea?
+ num_embeddings = len(embeddings)
+ average_embedding = [sum(x) / num_embeddings for x in zip(*embeddings)]
+ embedding = average_embedding
+ return embedding
+
+
+@app.route("/embeddings", methods=["POST"])
+def embeddings():
+ input_text_list = request.get_json().get("input")
+ input_text = ' '.join(map(str, input_text_list))
+ token = request.headers.get('Authorization').replace("Bearer ", "")
+ embedding = get_embedding(input_text, token)
+ return {
+ "data": [
+ {
+ "embedding": embedding,
+ "index": 0,
+ "object": "embedding"
+ }
+ ],
+ "model": "text-embedding-ada-002",
+ "object": "list",
+ "usage": {
+ "prompt_tokens": None,
+ "total_tokens": None
+ }
+ }
+
def main():
app.run(host="0.0.0.0", port=1337, debug=True)
if __name__ == "__main__":
- main() \ No newline at end of file
+ main()