Update app.py
Browse files
app.py
CHANGED
@@ -6,31 +6,39 @@ import nltk
|
|
6 |
model = NextUsRegressor()
|
7 |
model.load_state_dict(torch.load("./nextus_regressor1012.pt"))
|
8 |
model.eval()
|
9 |
-
|
10 |
def shap(txt, tok_level):
|
11 |
batch = [txt]
|
12 |
if tok_level == "word":
|
13 |
-
|
|
|
14 |
elif tok_level == "sentence":
|
15 |
-
print("sentence")
|
|
|
16 |
else:
|
17 |
-
|
|
|
18 |
#tokens = nltk
|
|
|
|
|
19 |
with torch.no_grad():
|
20 |
y_pred = model(txt)
|
21 |
-
|
|
|
|
|
22 |
|
23 |
demo = gr.Interface(shap,
|
24 |
[
|
25 |
gr.Textbox(label="기사", lines=30, placeholder="기사를 입력하세요."),
|
26 |
-
gr.Radio(["sentence", "word"], value="sentence", info="문장 단위의 해설은 sentence를 단어 단위의 해설은 word를 선택하세요.")
|
27 |
],
|
28 |
-
gr.Textbox(label="Slant Score"))
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
34 |
#theme=gr.themes.Base())
|
35 |
|
36 |
if __name__ == "__main__":
|
|
|
6 |
model = NextUsRegressor()
|
7 |
model.load_state_dict(torch.load("./nextus_regressor1012.pt"))
|
8 |
model.eval()
|
9 |
+
mask = "[MASKED]"
|
10 |
def shap(txt, tok_level):
|
11 |
batch = [txt]
|
12 |
if tok_level == "word":
|
13 |
+
tokens = nltk.word_tokenize(txt)
|
14 |
+
#print("word")
|
15 |
elif tok_level == "sentence":
|
16 |
+
#print("sentence")
|
17 |
+
tokens = nltk.sent_tokenize(txt)
|
18 |
else:
|
19 |
+
pass
|
20 |
+
#print("this token granularity not supported")
|
21 |
#tokens = nltk
|
22 |
+
for i, _ in enumerate(tokens):
|
23 |
+
batch.append(" ".join([s for j, s in enumerate(tokens) if j!=i]))
|
24 |
with torch.no_grad():
|
25 |
y_pred = model(txt)
|
26 |
+
y_offs = model(batch)
|
27 |
+
shaps = (y_offs - y_pred).tolist()[0] # convert to list and make tuple to be returned
|
28 |
+
return [token, shap for token, shap in list(zip(tokens, shaps))]
|
29 |
|
30 |
demo = gr.Interface(shap,
|
31 |
[
|
32 |
gr.Textbox(label="기사", lines=30, placeholder="기사를 입력하세요."),
|
33 |
+
gr.Radio(choices=["sentence", "word"], label="해설 표시 단위", value="sentence", info="문장 단위의 해설은 sentence를 단어 단위의 해설은 word를 선택하세요.")
|
34 |
],
|
35 |
+
#gr.Textbox(label="Slant Score"))
|
36 |
+
gr.HighlightedText(
|
37 |
+
label="Diff",
|
38 |
+
combine_adjacent=True,
|
39 |
+
show_legend=True,
|
40 |
+
color_map={"+": "red", "-": "green"})
|
41 |
+
)
|
42 |
#theme=gr.themes.Base())
|
43 |
|
44 |
if __name__ == "__main__":
|