Abhay Mishra
initial commit
16c64bc
raw
history blame contribute delete
No virus
842 Bytes
import pickle
import torch
import numpy as np
import gradio as gr
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
with open("title_to_content_embed.pickle", "rb") as handle:
loaded_map = pickle.load(handle)
course_titles = list(loaded_map.keys())
course_content_embeddings = np.array(list(loaded_map.values()), dtype=np.float32)
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
def give_best_match(query):
embed = model.encode(query)
result = cos(torch.from_numpy(course_content_embeddings),torch.from_numpy(embed))
indices = reversed(np.argsort(result))
predictions = {course_titles[i] : float(result[i]) for i in indices}
return predictions
demo = gr.Interface(fn = give_best_match, inputs="text",outputs=gr.Label(num_top_classes=5))
demo.launch()