File size: 2,478 Bytes
0b476dd
f8bb73e
0b476dd
f8bb73e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b476dd
 
 
f8bb73e
0b476dd
 
 
 
 
f8bb73e
 
0b476dd
f8bb73e
 
 
 
 
0b476dd
 
 
 
5cda48b
0b476dd
 
f8bb73e
 
 
 
 
 
 
 
 
 
 
 
 
 
0b476dd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from transformers import BertTokenizerFast, BertForSequenceClassification,GPT2LMHeadModel,BartForConditionalGeneration
import torch
import math

class CHSentenceSmoothScorer():
    def __init__(self) -> None:
        super().__init__()
      
        self.tokenizer = BertTokenizerFast.from_pretrained(
            "fnlp/bart-base-chinese")
        self.model = BartForConditionalGeneration.from_pretrained(
            "fnlp/bart-base-chinese")

    def __call__(self, sentences):
        input_ids = self.tokenizer.batch_encode_plus(
            sentences, return_tensors='pt',
            padding=True,
            max_length=50,
            truncation='longest_first'
        )['input_ids']
        logits = self.model(input_ids).logits

        softmax = torch.softmax(logits, dim=-1)

        out = []
        for i, sentence in enumerate(sentences):
            sent_token_ids = input_ids[i].tolist()
            sent_token_ids = list(
                filter(lambda x: x not in [self.tokenizer.pad_token_id], sent_token_ids))
            ppl = 0.0
            for j, token_id in enumerate(sent_token_ids):
                ppl += math.log(softmax[i][j][token_id].item())
            ppl = -1*(ppl/len(sent_token_ids))
            prob_socre = math.exp(ppl*-1)
            out.append(prob_socre)

        return out


model = BertForSequenceClassification.from_pretrained('./ch-sent-check-model')
tokenizer = BertTokenizerFast.from_pretrained('./ch-sent-check-model')
smooth_scorer = CHSentenceSmoothScorer()

def judge(sentence):
    input_ids = tokenizer(sentence,return_tensors='pt')['input_ids']
    out = model(input_ids)
    logits = out.logits
    prob = torch.softmax(logits,dim=-1)
    pred = torch.argmax(prob,dim=-1).item()
    pred_text = 'Incorrect' if pred == 0 else 'Correct'

    correct_prob = prob[0][1].item()
    pred_text = pred_text + f", score: {round(correct_prob*100,2)}"
    smooth_score = round(smooth_scorer([sentence])[0]*100,2)
    return pred_text,smooth_score

iface = gr.Interface(
    fn=judge, 
    inputs=gr.Textbox(
        label="請輸入一段中文句子來檢測正確性",
        lines=1,
    ), 
    outputs=[
        gr.Textbox(
            label="正確性檢查",
            lines=1
        ),
        gr.Textbox(
            label="流暢性檢查",
            lines=1
        )
    ],
    examples = [
        '請注意用字的鄭確性',
        '請注意用字的正確性'
    ]
)
iface.launch()