File size: 2,652 Bytes
59392ba
3302270
1d09c47
3302270
ec2682f
 
59392ba
c129047
ec2682f
c129047
ec2682f
 
 
c129047
1d09c47
ecfe3e6
1d09c47
 
ecfe3e6
 
1d09c47
ecfe3e6
3302270
ec2682f
3302270
b9d5dc1
ec2682f
ecfe3e6
ec2682f
1d09c47
ec2682f
 
 
1d09c47
ec2682f
 
 
c129047
ec2682f
3302270
 
 
c129047
59392ba
3302270
 
1d09c47
3302270
 
 
 
 
1d09c47
3302270
 
 
1d09c47
3302270
 
 
 
 
 
 
 
 
 
 
 
 
1d09c47
3302270
 
 
1d09c47
 
ec2682f
3302270
7604512
3302270
 
93be2eb
6324e4e
3302270
 
 
6324e4e
3302270
ec2682f
3302270
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM


model_checkpoint = "xlm-roberta-base"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
mask_token = tokenizer.mask_token


def add_mask(target_word, text):
    text_mask = text.replace(target_word, mask_token)
    return text_mask


def eval_prob(target_word, text):
    text_mask = add_mask(target_word, text)

    # Get index of target_word
    target_idx = tokenizer.encode(target_word)[2]

    # Get logits
    inputs = tokenizer(text_mask, return_tensors="pt")
    token_logits = model(**inputs).logits

    # Find the location of the MASK and extract its logits
    mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
    mask_token_logits = token_logits[0, mask_token_index, :]

    # Convert logits to softmax probability
    logits = mask_token_logits[0].tolist()
    probs = torch.nn.functional.softmax(torch.tensor([logits]), dim=1)[0]

    # Get probability of target word filling the MASK
    # result = float(probs[target_idx])

    return probs, target_idx


def plot_results(target_word, text):
    probs, target_idx = eval_prob(target_word, text)

    # Sort tokens based on probability scores
    words = [
        tokenizer.decode(idx) for idx in torch.sort(probs, descending=True).indices
    ]
    scores = torch.sort(probs, descending=True).values

    # Consolidate results in dataframe
    d = {"word": words, "score": scores}
    df = pd.DataFrame(data=d)

    # Get score rank of target word
    result_rank = words.index(target_word)
    target_col = [0] * len(scores)
    target_col[result_rank] = 1
    df["target"] = target_col

    # Plot
    fig = px.bar(
        df[:100],
        x="word",
        y="score",
        color="target",
        color_continuous_scale=px.colors.sequential.Bluered,
    )
    # fig.update(layout_coloraxis_showscale=False)
    fig.show()
    return fig


gr.Interface(
    fn=plot_results,
    inputs=[
        gr.Textbox(label="词语", placeholder="标准"),
        gr.Textbox(label="造句", placeholder="小明朗读课文时发音标准,被老师评为优秀。"),
    ],
    examples=[
        ["聪明", "小明很聪明,每年考班上第一名。"],
        ["尴尬", "小明去朋友的生日庆祝会,忘了带礼物,感到很尴尬。"],
        ["标准", "小明朗读课文时发音标准,被老师评为优秀。"],
    ],
    outputs=["plot"],
    title="Chinese Sentence Grading",
).launch()