p208p2002's picture
Update app.py
f8bb73e
import gradio as gr
from transformers import BertTokenizerFast, BertForSequenceClassification,GPT2LMHeadModel,BartForConditionalGeneration
import torch
import math
class CHSentenceSmoothScorer():
def __init__(self) -> None:
super().__init__()
self.tokenizer = BertTokenizerFast.from_pretrained(
"fnlp/bart-base-chinese")
self.model = BartForConditionalGeneration.from_pretrained(
"fnlp/bart-base-chinese")
def __call__(self, sentences):
input_ids = self.tokenizer.batch_encode_plus(
sentences, return_tensors='pt',
padding=True,
max_length=50,
truncation='longest_first'
)['input_ids']
logits = self.model(input_ids).logits
softmax = torch.softmax(logits, dim=-1)
out = []
for i, sentence in enumerate(sentences):
sent_token_ids = input_ids[i].tolist()
sent_token_ids = list(
filter(lambda x: x not in [self.tokenizer.pad_token_id], sent_token_ids))
ppl = 0.0
for j, token_id in enumerate(sent_token_ids):
ppl += math.log(softmax[i][j][token_id].item())
ppl = -1*(ppl/len(sent_token_ids))
prob_socre = math.exp(ppl*-1)
out.append(prob_socre)
return out
model = BertForSequenceClassification.from_pretrained('./ch-sent-check-model')
tokenizer = BertTokenizerFast.from_pretrained('./ch-sent-check-model')
smooth_scorer = CHSentenceSmoothScorer()
def judge(sentence):
input_ids = tokenizer(sentence,return_tensors='pt')['input_ids']
out = model(input_ids)
logits = out.logits
prob = torch.softmax(logits,dim=-1)
pred = torch.argmax(prob,dim=-1).item()
pred_text = 'Incorrect' if pred == 0 else 'Correct'
correct_prob = prob[0][1].item()
pred_text = pred_text + f", score: {round(correct_prob*100,2)}"
smooth_score = round(smooth_scorer([sentence])[0]*100,2)
return pred_text,smooth_score
iface = gr.Interface(
fn=judge,
inputs=gr.Textbox(
label="請輸入一段中文句子來檢測正確性",
lines=1,
),
outputs=[
gr.Textbox(
label="正確性檢查",
lines=1
),
gr.Textbox(
label="流暢性檢查",
lines=1
)
],
examples = [
'請注意用字的鄭確性',
'請注意用字的正確性'
]
)
iface.launch()