Apoorv Saxena commited on
Commit
7a1c034
1 Parent(s): 5312aec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -5
app.py CHANGED
@@ -1,7 +1,70 @@
1
- import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
 
 
5
 
6
+ def getScores(ids, scores, pad_token_id):
7
+ """get sequence scores from model.generate output"""
8
+ scores = torch.stack(scores, dim=1)
9
+ log_probs = torch.log_softmax(scores, dim=2)
10
+ # remove start token
11
+ ids = ids[:,1:]
12
+ # gather needed probs
13
+ x = ids.unsqueeze(-1).expand(log_probs.shape)
14
+ needed_logits = torch.gather(log_probs, 2, x)
15
+ final_logits = needed_logits[:, :, 0]
16
+ padded_mask = (ids == pad_token_id)
17
+ final_logits[padded_mask] = 0
18
+ final_scores = final_logits.sum(dim=-1)
19
+ return final_scores.cpu().detach().numpy()
20
+
21
+ def topkSample(input, model, tokenizer,
22
+ num_samples=5,
23
+ num_beams=1,
24
+ max_output_length=30):
25
+ tokenized = tokenizer(input, return_tensors="pt")
26
+ out = model.generate(**tokenized,
27
+ do_sample=True,
28
+ num_return_sequences = num_samples,
29
+ num_beams = num_beams,
30
+ eos_token_id = tokenizer.eos_token_id,
31
+ pad_token_id = tokenizer.pad_token_id,
32
+ output_scores = True,
33
+ return_dict_in_generate=True,
34
+ max_length=max_output_length,)
35
+ out_tokens = out.sequences
36
+ out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
37
+ out_scores = getScores(out_tokens, out.scores, tokenizer.pad_token_id)
38
+
39
+ pair_list = [(x[0], x[1]) for x in zip(out_str, out_scores)]
40
+ sorted_pair_list = sorted(pair_list, key=lambda x:x[1], reverse=True)
41
+ return sorted_pair_list
42
+
43
+ def greedyPredict(input, model, tokenizer):
44
+ input_ids = tokenizer([input], return_tensors="pt").input_ids
45
+ out_tokens = model.generate(input_ids)
46
+ out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
47
+ return out_str[0]
48
+
49
+ def predict_tail(entity, relation):
50
+ global model, tokenizer
51
+ input = entity + "| " + relation
52
+ out = topkSample(input, model, tokenizer, num_samples=5)
53
+ out_dict = {}
54
+ for k, v in out:
55
+ out_dict[k] = np.exp(v).item()
56
+ return out_dict
57
+
58
+
59
+ tokenizer = AutoTokenizer.from_pretrained("apoorvumang/kgt5-wikikg90mv2")
60
+ model = AutoModelForSeq2SeqLM.from_pretrained("apoorvumang/kgt5-base-wikikg90mv2")
61
+
62
+
63
+
64
+ ent_input = gradio.inputs.Textbox(lines=1, default="World War II")
65
+ rel_input = gradio.inputs.Textbox(lines=1, default="followed by")
66
+ output = gradio.outputs.Label()
67
+
68
+
69
+ iface = gr.Interface(fn=predict_tail, inputs=[ent_input, rel_input], outputs=output)
70
+ iface.launch()