|
import gradio as gr |
|
import torch |
|
from transformers import BertJapaneseTokenizer, BertForSequenceClassification |
|
|
|
|
|
MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking' |
|
descriptions = '''BERTをchABSA-datasetでファインチューニングしたもの。 |
|
chABSA-datasetは上場企業の有価証券報告書をベースに作成されたネガポジ用データセット''' |
|
|
|
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME) |
|
bert_sc_ = BertForSequenceClassification.from_pretrained("models/") |
|
bert_sc = bert_sc_.to("cpu") |
|
|
|
def func(text): |
|
encoding = tokenizer( |
|
text, |
|
padding = "longest", |
|
return_tensors="pt" |
|
) |
|
encoding = { k : v.cpu() for k, v in encoding.items()} |
|
|
|
with torch.no_grad(): |
|
output = bert_sc(**encoding) |
|
scores = output.logits.argmax(-1) |
|
|
|
label = "ネガティブ" if scores.item()==0 else "ポジティブ" |
|
|
|
return label,text |
|
|
|
app = gr.Interface(fn=func, inputs="text", outputs=["label","text"], title="ビジネス文書/ネガポジ分析", description=descriptions) |
|
app.launch() |