liujch1998 commited on
Commit
5d92357
β€’
1 Parent(s): 5afe135
Files changed (2) hide show
  1. app.py +60 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import transformers
4
+
5
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
6
+
7
+ class Interactive:
8
+ def __init__(self):
9
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained('liujch1998/cd-pi')
10
+ self.model = transformers.AutoModelForSeq2SeqLM.from_pretrained('liujch1998/cd-pi').to(device)
11
+ self.linear = torch.nn.Linear(self.model.shared.embedding_dim, 1).to(device)
12
+ self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D)
13
+ self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1)
14
+ self.model.eval()
15
+ self.t = 2.2247
16
+
17
+ def run(self, statement):
18
+ input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest').input_ids.to(device)
19
+ with torch.no_grad():
20
+ output = self.model(input_ids)
21
+ last_hidden_state = output.last_hidden_state.to(device) # (B=1, L, D)
22
+ hidden = last_hidden_state[0, -1, :] # (D)
23
+ logit = self.linear(hidden).squeeze(-1) # ()
24
+ logit_calibrated = logit / self.t
25
+ score = logit.sigmoid()
26
+ score_calibrated = logit_calibrated.sigmoid()
27
+ return {
28
+ 'logit': logit.item(),
29
+ 'logit_calibrated': logit_calibrated.item(),
30
+ 'score': score.item(),
31
+ 'score_calibrated': score_calibrated.item(),
32
+ }
33
+
34
+ interactive = Interactive()
35
+
36
+ def predict(statement, model):
37
+ result = interactive.run(statement)
38
+ return {
39
+ 'True': result['score_calibrated'],
40
+ 'False': 1 - result['score_calibrated'],
41
+ }
42
+
43
+ examples = [
44
+ 'If A sits next to B and B sits next to C, then A must sit next to C.',
45
+ 'If A sits next to B and B sits next to C, then A might not sit next to C.',
46
+ ]
47
+
48
+ input_statement = gr.Dropdown(choices=examples, label='Statement:')
49
+ input_model = gr.Textbox(label='Commonsense statement verification model:', value='liujch1998/cd-pi', interactive=False)
50
+ output = gr.outputs.Label(num_top_classes=2)
51
+
52
+ description = '''This is a demo for a commonsense statement verification model. Under development.'''
53
+
54
+ gr.Interface(
55
+ fn=predict,
56
+ inputs=[input_statement, input_model],
57
+ outputs=output,
58
+ title="cd-pi Demo",
59
+ description=description,
60
+ ).launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ tokenizers
4
+ sentencepiece