File size: 3,172 Bytes
50361dc fc44eab f65da74 60aa3c2 877eb8b 9957801 fc44eab b350fd8 fc44eab e1e2421 877eb8b 60aa3c2 e1e2421 60aa3c2 e1e2421 60aa3c2 e1e2421 60aa3c2 e1e2421 60aa3c2 e1e2421 8de1bb4 877eb8b 50361dc 877eb8b 50361dc 877eb8b 60aa3c2 50361dc 9957801 50361dc 9957801 50361dc b992439 9957801 a760798 50361dc 9957801 a054902 a760798 a054902 f74875b 60aa3c2 877eb8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import io
import gradio as gr
import torch
from nextus_regressor_class import *
import nltk
from pprint import pprint
import pandas as pd
model = NextUsRegressor()
model.load_state_dict(torch.load("./nextus_regressor1030.pt"))
model.eval()
mask = "[MASKED]"
threshold = 0.05
def shap(txt, tok_level):
batch = [txt]
if tok_level == "word":
tokens = nltk.word_tokenize(txt)
#print("word")
elif tok_level == "sentence":
#print("sentence")
tokens = nltk.sent_tokenize(txt)
else:
pass
#print("this token granularity not supported")
#tokens = nltk
for i, _ in enumerate(tokens):
batch.append(" ".join([s for j, s in enumerate(tokens) if j!=i]))
with torch.no_grad():
y_pred = model(txt)
y_offs = model(batch)
shaps = (y_offs - y_pred).tolist() # convert to list and make tuple to be returned
shapss = [s[0] for s in shaps]
labels = list()
for s in shapss:
if s <= -1.0*threshold:
labels.append("+")
elif s >= threshold:
labels.append("-")
else:
labels.append(None)
# labels = ["+" if s < -1.0*threshold "-" elif s > threshold else " " for s in shapss]
# print(len(tokens), len(labels))
# print(list(zip(tokens, labels)))
pprint(list(zip(tokens, shapss)))
# return str(list(zip(tokens, labels)))
largest_shap = torch.max(y_offs - y_pred).item()
largest_shap_span = tokens[torch.argmax(y_offs - y_pred).item()]
explanation = "๊ฐ์ฅ ํฐ ์ํฅ์ ๋ฏธ์น ํ
์คํธ๋\n'"+ largest_shap_span+ "'\n์ด๋ฉฐ, ํด๋น ํ
์คํธ๊ฐ ์์ ๊ฒฝ์ฐ Slant ์ค์ฝ์ด\n" + str(round(y_pred.item(), 4))+ "\n์์\n"+ str(round(largest_shap,4))+ "\n๋งํผ ๋ฒ์ด๋ฉ๋๋ค."
return list(zip(tokens, labels)), explanation
# return txt
def parse_file_input(f):
# print(f, type(f))
all_articles = list()
# with open(f, "r") as fh:
if ".csv" in f.name:
all_articles += pd.read_csv(f.name).iloc[:, 0].to_list()
elif ".xls" in f.name:
all_articles += pd.read_excel(f.name).iloc[:, 0].to_list()
else:
pass
print(len(all_articles))
print(all_articles)
scores = model(all_articles)
return scores
demo = gr.Interface(parse_file_input,
[
gr.File(file_count="single", file_types=[".csv", ".xls", ".xlsx"], type="file", label="๊ธฐ์ฌ ํ์ผ(csv/excel)์ ์
๋ก๋ํ์ธ์")
#gr.Textbox(label="๊ธฐ์ฌ", lines=30, placeholder="๊ธฐ์ฌ๋ฅผ ์
๋ ฅํ์ธ์."),
# gr.Radio(choices=["sentence", "word"], label="ํด์ค ํ์ ๋จ์", value="sentence", info="๋ฌธ์ฅ ๋จ์์ ํด์ค์ sentence๋ฅผ, ๋จ์ด ๋จ์์ ํด์ค์ word๋ฅผ ์ ํํ์ธ์.")
],
gr.Textbox(label="Slant Scores"),
#gr.HighlightedText(
# label="Diff",
# combine_adjacent=True,
# show_legend=True,
# color_map={"+": "red", "-": "green"}),
theme=gr.themes.Base())
demo.launch()
|