mishig HF staff commited on
Commit
a342b03
1 Parent(s): 7aee0dd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import hnswlib
5
+ import gradio as gr
6
+ import numpy as np
7
+
8
+ seperator = "-HFSEP-"
9
+ base_name="intfloat/e5-large-v2"
10
+ device="cuda"
11
+ max_length=512
12
+ tokenizer = AutoTokenizer.from_pretrained(base_name)
13
+ model = AutoModel.from_pretrained(base_name).to(device)
14
+
15
+ def get_embeddings(input_texts):
16
+ batch_dict = tokenizer(
17
+ input_texts,
18
+ max_length=max_length,
19
+ padding=True,
20
+ truncation=True,
21
+ return_tensors='pt'
22
+ ).to(device)
23
+
24
+ with torch.no_grad():
25
+ outputs = model(**batch_dict)
26
+ embeddings = _average_pool(
27
+ outputs.last_hidden_state, batch_dict['attention_mask']
28
+ )
29
+ embeddings = F.normalize(embeddings, p=2, dim=1)
30
+ embeddings_np = embeddings.cpu().numpy()
31
+
32
+ if device == "cuda":
33
+ del embeddings
34
+ torch.cuda.empty_cache()
35
+
36
+ return embeddings_np
37
+
38
+ def _average_pool(
39
+ last_hidden_states,
40
+ attention_mask
41
+ ):
42
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
43
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
44
+
45
+ def create_hnsw_index(embeddings_np, space='ip', ef_construction=100, M=16):
46
+ index = hnswlib.Index(space=space, dim=len(embeddings_np[0]))
47
+ index.init_index(max_elements=len(embeddings_np), ef_construction=ef_construction, M=M)
48
+ ids = np.arange(embeddings_np.shape[0])
49
+ index.add_items(embeddings_np, ids)
50
+ return index
51
+
52
+ def gradio_function(query, paragraph_chunks, top_k):
53
+ paragraph_chunks = paragraph_chunks.split(seperator) # Split the comma-separated values into a list
54
+ paragraph_chunks = [item.strip() for item in paragraph_chunks] # Trim whitespace from each item
55
+
56
+ print("creating embeddings")
57
+ embeddings_np = get_embeddings([query]+paragraph_chunks)
58
+ query_embedding, chunks_embeddings = embeddings_np[0], embeddings_np[1:]
59
+
60
+ print("creating index")
61
+ search_index = create_hnsw_index(chunks_embeddings)
62
+ print("searching index")
63
+ labels, _ = search_index.knn_query(query_embedding, k=min(int(top_k), len(chunks_embeddings)))
64
+ return f"The closes labels are: {labels}"
65
+
66
+ interface = gr.Interface(
67
+ fn=gradio_function,
68
+ inputs=[
69
+ gr.Textbox(placeholder="Enter a user query..."),
70
+ gr.Textbox(placeholder="Enter comma-separated strings..."),
71
+ gr.Number()
72
+ ],
73
+ outputs="text"
74
+ )
75
+
76
+ interface.launch()