rabay35 commited on
Commit
2505678
1 Parent(s): eb2a7bb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import torch
3
+ import gradio as gr
4
+ from scipy.spatial.distance import cosine
5
+
6
+ # Disable CUDA
7
+ torch.backends.cudnn.enabled = False
8
+ torch.cuda.is_available = lambda : False
9
+
10
+ # Load model and tokenizer
11
+ modelname = "algolia/algolia-large-en-generic-v2410"
12
+ model = SentenceTransformer(modelname)
13
+ def get_embedding(text):
14
+ embedding = model.encode([text])
15
+ return embedding[0]
16
+
17
+ def compute_similarity(query, documents):
18
+ query_emb = get_embedding(query)
19
+ doc_embeddings = [get_embedding(doc) for doc in documents]
20
+
21
+ # Calculate cosine similarity
22
+ similarities = [1 - cosine(query_emb, doc_emb) for doc_emb in doc_embeddings]
23
+ ranked_docs = sorted(zip(documents, similarities), key=lambda x: x[1], reverse=True)
24
+
25
+ # Format output
26
+ return [{"document": doc, "similarity_score": round(sim, 4)} for doc, sim in ranked_docs]
27
+
28
+ # Gradio interface function
29
+ def gradio_compute_similarity(query, documents):
30
+ # Prefix the query string
31
+ query = "query: " + query
32
+ # Split documents by lines for the Gradio input
33
+ documents_list = documents.split("\n")
34
+ results = compute_similarity(query, documents_list)
35
+ return results
36
+
37
+ # Gradio Interface
38
+ iface = gr.Interface(
39
+ fn=gradio_compute_similarity,
40
+ inputs=[
41
+ gr.Textbox(label="Query", placeholder="Enter your query here"),
42
+ gr.Textbox(lines=5, label="Documents", placeholder="Enter a list of documents, one per line")
43
+ ],
44
+ outputs=gr.JSON(label="Ranked Results"),
45
+ allow_flagging="never",
46
+ title="Document Similarity",
47
+ description="Provide a query and a list of documents. See the ranked similarity scores."
48
+ )
49
+
50
+ iface.launch()