Annalyn Ng
add barplot
3302270
raw
history blame
2.65 kB
import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
model_checkpoint = "xlm-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
mask_token = tokenizer.mask_token
def add_mask(target_word, text):
text_mask = text.replace(target_word, mask_token)
return text_mask
def eval_prob(target_word, text):
text_mask = add_mask(target_word, text)
# Get index of target_word
target_idx = tokenizer.encode(target_word)[2]
# Get logits
inputs = tokenizer(text_mask, return_tensors="pt")
token_logits = model(**inputs).logits
# Find the location of the MASK and extract its logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
# Convert logits to softmax probability
logits = mask_token_logits[0].tolist()
probs = torch.nn.functional.softmax(torch.tensor([logits]), dim=1)[0]
# Get probability of target word filling the MASK
# result = float(probs[target_idx])
return probs, target_idx
def plot_results(target_word, text):
probs, target_idx = eval_prob(target_word, text)
# Sort tokens based on probability scores
words = [
tokenizer.decode(idx) for idx in torch.sort(probs, descending=True).indices
]
scores = torch.sort(probs, descending=True).values
# Consolidate results in dataframe
d = {"word": words, "score": scores}
df = pd.DataFrame(data=d)
# Get score rank of target word
result_rank = words.index(target_word)
target_col = [0] * len(scores)
target_col[result_rank] = 1
df["target"] = target_col
# Plot
fig = px.bar(
df[:100],
x="word",
y="score",
color="target",
color_continuous_scale=px.colors.sequential.Bluered,
)
# fig.update(layout_coloraxis_showscale=False)
fig.show()
return fig
gr.Interface(
fn=plot_results,
inputs=[
gr.Textbox(label="词语", placeholder="标准"),
gr.Textbox(label="造句", placeholder="小明朗读课文时发音标准,被老师评为优秀。"),
],
examples=[
["聪明", "小明很聪明,每年考班上第一名。"],
["尴尬", "小明去朋友的生日庆祝会,忘了带礼物,感到很尴尬。"],
["标准", "小明朗读课文时发音标准,被老师评为优秀。"],
],
outputs=["plot"],
title="Chinese Sentence Grading",
).launch()