File size: 3,086 Bytes
59392ba
3302270
1d09c47
3302270
ec2682f
 
59392ba
600139e
c129047
e8d5985
c129047
ec2682f
 
600139e
ec2682f
c129047
1d09c47
ecfe3e6
e8d5985
 
ecfe3e6
 
1d09c47
e8d5985
 
 
 
 
 
 
 
600139e
e8d5985
 
ec2682f
1d09c47
e8d5985
ec2682f
 
1d09c47
ec2682f
 
 
c129047
3302270
c129047
59392ba
e8d5985
 
3302270
e8d5985
3302270
e8d5985
3302270
e8d5985
3302270
e8d5985
3302270
e8d5985
 
3302270
e8d5985
 
 
3302270
 
 
e8d5985
 
3302270
e8d5985
 
 
 
3302270
 
767f651
e8d5985
 
 
3302270
1d09c47
e8d5985
3302270
 
e8d5985
3302270
1d09c47
 
ec2682f
3302270
7604512
e8d5985
 
93be2eb
6324e4e
e8d5985
3302270
 
6324e4e
3302270
ec2682f
e8d5985
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

device = "cuda:0" if torch.cuda.is_available() else "cpu"

model_checkpoint = "facebook/xlm-v-base"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
model = model.to(device)
mask_token = tokenizer.mask_token


def add_mask(target_word, text):
    text_masked = text.replace(target_word, mask_token)
    return text_masked


def eval_prob(target_word, text):
    
    # Replace target_word with mask
    text_masked = add_mask(target_word, text)
    
    # Get token ID of target_word
    target_idx = tokenizer.encode(target_word)[-2]

    # Convert masked text to token IDs
    inputs = tokenizer(text_masked, return_tensors="pt").to(device)
    
    # Calculate logits score (for each token, for each position)
    token_logits = model(**inputs).logits

    # Find the position of the mask and extract logits for that position
    mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
    mask_token_logits = token_logits[0, mask_token_index, :]

    # Convert logits to softmax probability
    logits = mask_token_logits[0].tolist()
    probs = torch.nn.functional.softmax(torch.tensor([logits]), dim=1)[0]

    return probs, target_idx


def process_prob(target_word, text):
    
    probs, target_idx = eval_prob(target_word, text)
    
    # Sort tokens based on probability scores
    words = [tokenizer.decode(idx) for idx in torch.sort(probs, descending=True).indices]
    scores = torch.sort(probs, descending=True).values
    
    # Consolidate results in dataframe
    d = {'word': words, 'score': scores}
    df = pd.DataFrame(data=d)
    
    # Get score rank and probability of target word
    result_rank = words.index(target_word)
    result_prob = scores[result_rank]
    
    # Create color code
    target_col = [0] * len(scores)
    target_col[result_rank] = 1
    df["target"] = target_col
    
    return result_rank, result_prob, df

def plot_results(target_word, text):
    
    _, _, df = process_prob(target_word, text)
    
    # Plot
    fig = px.bar(
        df[:150],
        x='word',
        y='score',
        color='target',
        color_continuous_scale=px.colors.sequential.Bluered,
    )
    
    # fig.update(layout_coloraxis_showscale=False)
    fig.show()
    
    return fig


gr.Interface(
    fn=plot_results,
    inputs=[
        gr.Textbox(label="词语", placeholder="Key in a 词语 or click an example"),
        gr.Textbox(label="造句", placeholder="造句 with the 词语 or click an example"),
    ],
    examples=[
        ["与众不同", "他的产品很特别,与众不同,跟别人的不一样。"],
        ["尴尬", "小明去朋友的生日庆祝会,忘了带礼物,感到很尴尬。"],
        ["标准", "小明朗读课文时发音标准,被老师评为优秀。"],
    ],
    outputs=["plot"],
    title="Chinese Sentence Grading",
).launch()