Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import plotly.express as px | |
import torch | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
model_checkpoint = "xlm-roberta-base" | |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint) | |
mask_token = tokenizer.mask_token | |
def add_mask(target_word, text): | |
text_mask = text.replace(target_word, mask_token) | |
return text_mask | |
def eval_prob(target_word, text): | |
text_mask = add_mask(target_word, text) | |
# Get index of target_word | |
target_idx = tokenizer.encode(target_word)[2] | |
# Get logits | |
inputs = tokenizer(text_mask, return_tensors="pt") | |
token_logits = model(**inputs).logits | |
# Find the location of the MASK and extract its logits | |
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] | |
mask_token_logits = token_logits[0, mask_token_index, :] | |
# Convert logits to softmax probability | |
logits = mask_token_logits[0].tolist() | |
probs = torch.nn.functional.softmax(torch.tensor([logits]), dim=1)[0] | |
# Get probability of target word filling the MASK | |
# result = float(probs[target_idx]) | |
return probs, target_idx | |
def plot_results(target_word, text): | |
probs, target_idx = eval_prob(target_word, text) | |
# Sort tokens based on probability scores | |
words = [ | |
tokenizer.decode(idx) for idx in torch.sort(probs, descending=True).indices | |
] | |
scores = torch.sort(probs, descending=True).values | |
# Consolidate results in dataframe | |
d = {"word": words, "score": scores} | |
df = pd.DataFrame(data=d) | |
# Get score rank of target word | |
result_rank = words.index(target_word) | |
target_col = [0] * len(scores) | |
target_col[result_rank] = 1 | |
df["target"] = target_col | |
# Plot | |
fig = px.bar( | |
df[:100], | |
x="word", | |
y="score", | |
color="target", | |
color_continuous_scale=px.colors.sequential.Bluered, | |
) | |
# fig.update(layout_coloraxis_showscale=False) | |
fig.show() | |
return fig | |
gr.Interface( | |
fn=plot_results, | |
inputs=[ | |
gr.Textbox(label="词语", placeholder="标准"), | |
gr.Textbox(label="造句", placeholder="小明朗读课文时发音标准,被老师评为优秀。"), | |
], | |
examples=[ | |
["聪明", "小明很聪明,每年考班上第一名。"], | |
["尴尬", "小明去朋友的生日庆祝会,忘了带礼物,感到很尴尬。"], | |
["标准", "小明朗读课文时发音标准,被老师评为优秀。"], | |
], | |
outputs=["plot"], | |
title="Chinese Sentence Grading", | |
).launch() | |