Jack Wong commited on
Commit
207ea71
1 Parent(s): 7761d91

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -3
app.py CHANGED
@@ -1,9 +1,53 @@
 
 
 
1
  import gradio as gr
 
2
 
 
 
3
 
4
- def greet(name):
5
- return "Hello " + name + "!!"
6
 
 
 
7
 
8
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  iface.launch()
 
1
+ import time
2
+ import torch
3
+ import tempfile
4
  import gradio as gr
5
+ import torch.nn.functional as F
6
 
7
+ from transformers import AutoTokenizer, AutoModel
8
+ from transformers import PreTrainedTokenizer, PreTrainedModel
9
 
10
+ temp_dir = tempfile.TemporaryDirectory()
11
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12
 
13
+ def get_tokenizer() -> PreTrainedTokenizer:
14
+ return AutoTokenizer.from_pretrained('thenlper/gte-large', trust_remote_code=True, cache_dir=temp_dir.name)
15
 
16
+ def get_model() -> PreTrainedModel:
17
+ return AutoModel.from_pretrained('thenlper/gte-large', trust_remote_code=True, cache_dir=temp_dir.name).to(device)
18
+
19
+ def average_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
20
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
21
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
22
+
23
+ def normalize_embeddings(embeddings: torch.Tensor) -> list[float]:
24
+ norm_embeddings = F.normalize(embeddings, p=2, dim=1)
25
+ return norm_embeddings.tolist()
26
+
27
+ def get_embeddings(text: str) -> list[float]:
28
+ tokenizer = get_tokenizer()
29
+ model = get_model()
30
+
31
+ with torch.inference_mode():
32
+ start = time.time()
33
+ batch_dict = tokenizer(
34
+ text,
35
+ max_length=512,
36
+ truncation=True,
37
+ padding=True,
38
+ return_tensors='pt'
39
+ ).to(device)
40
+
41
+ outputs = model(**batch_dict, return_dict=True)
42
+ embeddings = average_pooling(
43
+ last_hidden_states=outputs.last_hidden_state,
44
+ attention_mask=batch_dict['attention_mask']
45
+ )
46
+
47
+ norm_embeddings = normalize_embeddings(embeddings)
48
+ end = time.time()
49
+ print("Execution time: ", end - start)
50
+ return norm_embeddings
51
+
52
+ iface = gr.Interface(fn=get_embeddings, inputs="text", outputs="text")
53
  iface.launch()