annalyzin's picture
Update examples
87c59dc
raw
history blame contribute delete
No virus
3.09 kB
import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_checkpoint = "facebook/xlm-v-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
model = model.to(device)
mask_token = tokenizer.mask_token
def add_mask(target_word, text):
text_masked = text.replace(target_word, mask_token)
return text_masked
def eval_prob(target_word, text):
# Replace target_word with mask
text_masked = add_mask(target_word, text)
# Get token ID of target_word
target_idx = tokenizer.encode(target_word)[-2]
# Convert masked text to token IDs
inputs = tokenizer(text_masked, return_tensors="pt").to(device)
# Calculate logits score (for each token, for each position)
token_logits = model(**inputs).logits
# Find the position of the mask and extract logits for that position
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
# Convert logits to softmax probability
logits = mask_token_logits[0].tolist()
probs = torch.nn.functional.softmax(torch.tensor([logits]), dim=1)[0]
return probs, target_idx
def process_prob(target_word, text):
probs, target_idx = eval_prob(target_word, text)
# Sort tokens based on probability scores
words = [tokenizer.decode(idx) for idx in torch.sort(probs, descending=True).indices]
scores = torch.sort(probs, descending=True).values
# Consolidate results in dataframe
d = {'word': words, 'score': scores}
df = pd.DataFrame(data=d)
# Get score rank and probability of target word
result_rank = words.index(target_word)
result_prob = scores[result_rank]
# Create color code
target_col = [0] * len(scores)
target_col[result_rank] = 1
df["target"] = target_col
return result_rank, result_prob, df
def plot_results(target_word, text):
_, _, df = process_prob(target_word, text)
# Plot
fig = px.bar(
df[:150],
x='word',
y='score',
color='target',
color_continuous_scale=px.colors.sequential.Bluered,
)
# fig.update(layout_coloraxis_showscale=False)
fig.show()
return fig
gr.Interface(
fn=plot_results,
inputs=[
gr.Textbox(label="词语", placeholder="Key in a 词语 or click an example"),
gr.Textbox(label="造句", placeholder="造句 with the 词语 or click an example"),
],
examples=[
["与众不同", "他的产品很特别,与众不同,跟别人的不一样。"],
["尴尬", "小明去朋友的生日庆祝会,忘了带礼物,感到很尴尬。"],
["标准", "小明朗读课文时发音标准,被老师评为优秀。"],
],
outputs=["plot"],
title="Chinese Sentence Grading",
).launch()