annalyzin's picture
Update examples
87c59dc
import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_checkpoint = "facebook/xlm-v-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
model = model.to(device)
mask_token = tokenizer.mask_token
def add_mask(target_word, text):
text_masked = text.replace(target_word, mask_token)
return text_masked
def eval_prob(target_word, text):
# Replace target_word with mask
text_masked = add_mask(target_word, text)
# Get token ID of target_word
target_idx = tokenizer.encode(target_word)[-2]
# Convert masked text to token IDs
inputs = tokenizer(text_masked, return_tensors="pt").to(device)
# Calculate logits score (for each token, for each position)
token_logits = model(**inputs).logits
# Find the position of the mask and extract logits for that position
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
# Convert logits to softmax probability
logits = mask_token_logits[0].tolist()
probs = torch.nn.functional.softmax(torch.tensor([logits]), dim=1)[0]
return probs, target_idx
def process_prob(target_word, text):
probs, target_idx = eval_prob(target_word, text)
# Sort tokens based on probability scores
words = [tokenizer.decode(idx) for idx in torch.sort(probs, descending=True).indices]
scores = torch.sort(probs, descending=True).values
# Consolidate results in dataframe
d = {'word': words, 'score': scores}
df = pd.DataFrame(data=d)
# Get score rank and probability of target word
result_rank = words.index(target_word)
result_prob = scores[result_rank]
# Create color code
target_col = [0] * len(scores)
target_col[result_rank] = 1
df["target"] = target_col
return result_rank, result_prob, df
def plot_results(target_word, text):
_, _, df = process_prob(target_word, text)
# Plot
fig = px.bar(
df[:150],
x='word',
y='score',
color='target',
color_continuous_scale=px.colors.sequential.Bluered,
)
# fig.update(layout_coloraxis_showscale=False)
fig.show()
return fig
gr.Interface(
fn=plot_results,
inputs=[
gr.Textbox(label="词语", placeholder="Key in a 词语 or click an example"),
gr.Textbox(label="造句", placeholder="造句 with the 词语 or click an example"),
],
examples=[
["与众不同", "他的产品很特别,与众不同,跟别人的不一样。"],
["尴尬", "小明去朋友的生日庆祝会,忘了带礼物,感到很尴尬。"],
["标准", "小明朗读课文时发音标准,被老师评为优秀。"],
],
outputs=["plot"],
title="Chinese Sentence Grading",
).launch()