fashxp commited on
Commit
c4ae669
1 Parent(s): 13283da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -5
app.py CHANGED
@@ -1,10 +1,55 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
 
 
 
3
 
4
- def greet(name):
5
- return "Hello " + name
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
9
 
10
- demo.launch()
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ from datasets import load_dataset
3
+ import torch
4
+
5
+ model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
7
+ model = AutoModel.from_pretrained(model_ckpt)
8
+
9
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
10
+ model.to(device)
11
+
12
+ def cls_pooling(model_output):
13
+ return model_output.last_hidden_state[:, 0]
14
+
15
+ def get_embeddings(text_list):
16
+ encoded_input = tokenizer(
17
+ text_list, padding=True, truncation=True, return_tensors="pt"
18
+ )
19
+ encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
20
+ model_output = model(**encoded_input)
21
+ return cls_pooling(model_output)
22
+
23
+
24
+ embeddings_doc_dataset = load_dataset("fashxp/pimcore-docs-embeddings")
25
+ embeddings_doc_dataset = embeddings_doc_dataset['train']
26
+ embeddings_doc_dataset.add_faiss_index(column="embeddings")
27
+
28
+ import pandas as pd
29
 
30
+ def find_in_docs(question):
31
+ question_embedding = get_embeddings([question]).cpu().detach().numpy()
32
+ question_embedding.shape
33
 
34
+ scores, samples = embeddings_doc_dataset.get_nearest_examples(
35
+ "embeddings", question_embedding, k=10
36
+ )
37
 
38
+ samples_df = pd.DataFrame.from_dict(samples)
39
+ samples_df["scores"] = scores
40
+ samples_df.sort_values("scores", ascending=False, inplace=True)
41
+
42
+ result = ''
43
+
44
+ for _, row in samples_df.iterrows():
45
+ result = result + f"HEADING: {row.heading}\n" + f"SCORE: {row.scores}\n" + f"URL: {row.url}\n" + ("=" * 50) + "\n\n"
46
+
47
+ return result
48
+
49
+
50
+
51
+ import gradio as gr
52
 
53
+ demo = gr.Interface(fn=find_in_docs, inputs="text", outputs="text")
54
 
55
+ demo.launch(share=True)