import io import gradio as gr import torch from nextus_regressor_class import * import nltk from pprint import pprint import pandas as pd model = NextUsRegressor() model.load_state_dict(torch.load("./nextus_regressor1012.pt")) model.eval() mask = "[MASKED]" threshold = 0.05 def shap(txt, tok_level): batch = [txt] if tok_level == "word": tokens = nltk.word_tokenize(txt) #print("word") elif tok_level == "sentence": #print("sentence") tokens = nltk.sent_tokenize(txt) else: pass #print("this token granularity not supported") #tokens = nltk for i, _ in enumerate(tokens): batch.append(" ".join([s for j, s in enumerate(tokens) if j!=i])) with torch.no_grad(): y_pred = model(txt) y_offs = model(batch) shaps = (y_offs - y_pred).tolist() # convert to list and make tuple to be returned shapss = [s[0] for s in shaps] labels = list() for s in shapss: if s <= -1.0*threshold: labels.append("+") elif s >= threshold: labels.append("-") else: labels.append(None) # labels = ["+" if s < -1.0*threshold "-" elif s > threshold else " " for s in shapss] # print(len(tokens), len(labels)) # print(list(zip(tokens, labels))) pprint(list(zip(tokens, shapss))) # return str(list(zip(tokens, labels))) largest_shap = torch.max(y_offs - y_pred).item() largest_shap_span = tokens[torch.argmax(y_offs - y_pred).item()] explanation = "가장 큰 영향을 미친 텍스트는\n'"+ largest_shap_span+ "'\n이며, 해당 텍스트가 없을 경우 Slant 스코어\n" + str(round(y_pred.item(), 4))+ "\n에서\n"+ str(round(largest_shap,4))+ "\n만큼 벗어납니다." return list(zip(tokens, labels)), explanation # return txt def parse_file_input(f): # print(f, type(f)) all_articles = list() # with open(f, "r") as fh: if ".csv" in f.name: all_articles += pd.read_csv(f.name).iloc[:, 0].to_list() elif ".xls" in f.name: all_articles += pd.read_excel(f.name).iloc[:, 0].to_list() else: pass # print(len(all_articles)) # print(all_articles) scores = model(all_articles) return scores demo = gr.Interface(parse_file_input, [ gr.File(file_count="single", file_types=[".csv", ".xls", ".xlsx"], type="file", label="기사 파일(csv/excel)을 업로드하세요") #gr.Textbox(label="기사", lines=30, placeholder="기사를 입력하세요."), # gr.Radio(choices=["sentence", "word"], label="해설 표시 단위", value="sentence", info="문장 단위의 해설은 sentence를, 단어 단위의 해설은 word를 선택하세요.") ], gr.Textbox(label="Slant Scores"), #gr.HighlightedText( # label="Diff", # combine_adjacent=True, # show_legend=True, # color_map={"+": "red", "-": "green"}), theme=gr.themes.Base()) demo.launch()