soarhigh commited on
Commit
e1e2421
·
1 Parent(s): 0d44303

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -12
app.py CHANGED
@@ -6,31 +6,39 @@ import nltk
6
  model = NextUsRegressor()
7
  model.load_state_dict(torch.load("./nextus_regressor1012.pt"))
8
  model.eval()
9
-
10
  def shap(txt, tok_level):
11
  batch = [txt]
12
  if tok_level == "word":
13
- print("word")
 
14
  elif tok_level == "sentence":
15
- print("sentence")
 
16
  else:
17
- print("this token granularity not supported")
 
18
  #tokens = nltk
 
 
19
  with torch.no_grad():
20
  y_pred = model(txt)
21
- return txt
 
 
22
 
23
  demo = gr.Interface(shap,
24
  [
25
  gr.Textbox(label="기사", lines=30, placeholder="기사를 입력하세요."),
26
- gr.Radio(["sentence", "word"], value="sentence", info="문장 단위의 해설은 sentence를 단어 단위의 해설은 word를 선택하세요.")
27
  ],
28
- gr.Textbox(label="Slant Score"))
29
- #gr.HighlightedText(
30
- # label="Diff",
31
- # combine_adjacent=True,
32
- # show_legend=True,
33
- # color_map={"+": "red", "-": "green"}),
 
34
  #theme=gr.themes.Base())
35
 
36
  if __name__ == "__main__":
 
6
  model = NextUsRegressor()
7
  model.load_state_dict(torch.load("./nextus_regressor1012.pt"))
8
  model.eval()
9
+ mask = "[MASKED]"
10
  def shap(txt, tok_level):
11
  batch = [txt]
12
  if tok_level == "word":
13
+ tokens = nltk.word_tokenize(txt)
14
+ #print("word")
15
  elif tok_level == "sentence":
16
+ #print("sentence")
17
+ tokens = nltk.sent_tokenize(txt)
18
  else:
19
+ pass
20
+ #print("this token granularity not supported")
21
  #tokens = nltk
22
+ for i, _ in enumerate(tokens):
23
+ batch.append(" ".join([s for j, s in enumerate(tokens) if j!=i]))
24
  with torch.no_grad():
25
  y_pred = model(txt)
26
+ y_offs = model(batch)
27
+ shaps = (y_offs - y_pred).tolist()[0] # convert to list and make tuple to be returned
28
+ return [token, shap for token, shap in list(zip(tokens, shaps))]
29
 
30
  demo = gr.Interface(shap,
31
  [
32
  gr.Textbox(label="기사", lines=30, placeholder="기사를 입력하세요."),
33
+ gr.Radio(choices=["sentence", "word"], label="해설 표시 단위", value="sentence", info="문장 단위의 해설은 sentence를 단어 단위의 해설은 word를 선택하세요.")
34
  ],
35
+ #gr.Textbox(label="Slant Score"))
36
+ gr.HighlightedText(
37
+ label="Diff",
38
+ combine_adjacent=True,
39
+ show_legend=True,
40
+ color_map={"+": "red", "-": "green"})
41
+ )
42
  #theme=gr.themes.Base())
43
 
44
  if __name__ == "__main__":