embed / app.py
mikesoylu's picture
Update app.py
d1afd99
raw history blame
No virus
1.52 kB
import gradio as gr
import json
import numpy as np
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-MiniLM-L6-v2')
model.max_seq_length = 256
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
def text_to_embedding(text):
# Tokenize the input text
tokens = model.tokenize(text)
# Check if the token count exceeds the model's maximum sequence length
if len(tokens) > model.max_seq_length:
# Split the input text into chunks
chunks = []
for i in range(0, len(tokens), model.max_seq_length):
chunk = tokens[i:i + model.max_seq_length]
chunks.append(model.tokenizer.convert_tokens_to_string(chunk))
# Encode each chunk and store the embeddings
embeddings = []
for chunk in chunks:
embedding = model.encode(chunk)
embeddings.append(embedding)
# Calculate the average embedding
avg_embedding = np.mean(embeddings, axis=0)
else:
# If the token count is within the limit, just encode the input text
avg_embedding = model.encode(text)
return json.dumps(avg_embedding, cls=NumpyEncoder)
inputs = gr.inputs.Textbox(default="Type text here.")
outputs = gr.outputs.Textbox()
app = gr.Interface(fn=text_to_embedding, inputs=inputs, outputs=outputs, title="Text to Embedding")
app.launch()