Spaces:
Sleeping
Sleeping
File size: 2,652 Bytes
59392ba 3302270 1d09c47 3302270 ec2682f 59392ba c129047 ec2682f c129047 ec2682f c129047 1d09c47 ecfe3e6 1d09c47 ecfe3e6 1d09c47 ecfe3e6 3302270 ec2682f 3302270 b9d5dc1 ec2682f ecfe3e6 ec2682f 1d09c47 ec2682f 1d09c47 ec2682f c129047 ec2682f 3302270 c129047 59392ba 3302270 1d09c47 3302270 1d09c47 3302270 1d09c47 3302270 1d09c47 3302270 1d09c47 ec2682f 3302270 7604512 3302270 93be2eb 6324e4e 3302270 6324e4e 3302270 ec2682f 3302270 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
model_checkpoint = "xlm-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
mask_token = tokenizer.mask_token
def add_mask(target_word, text):
text_mask = text.replace(target_word, mask_token)
return text_mask
def eval_prob(target_word, text):
text_mask = add_mask(target_word, text)
# Get index of target_word
target_idx = tokenizer.encode(target_word)[2]
# Get logits
inputs = tokenizer(text_mask, return_tensors="pt")
token_logits = model(**inputs).logits
# Find the location of the MASK and extract its logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
# Convert logits to softmax probability
logits = mask_token_logits[0].tolist()
probs = torch.nn.functional.softmax(torch.tensor([logits]), dim=1)[0]
# Get probability of target word filling the MASK
# result = float(probs[target_idx])
return probs, target_idx
def plot_results(target_word, text):
probs, target_idx = eval_prob(target_word, text)
# Sort tokens based on probability scores
words = [
tokenizer.decode(idx) for idx in torch.sort(probs, descending=True).indices
]
scores = torch.sort(probs, descending=True).values
# Consolidate results in dataframe
d = {"word": words, "score": scores}
df = pd.DataFrame(data=d)
# Get score rank of target word
result_rank = words.index(target_word)
target_col = [0] * len(scores)
target_col[result_rank] = 1
df["target"] = target_col
# Plot
fig = px.bar(
df[:100],
x="word",
y="score",
color="target",
color_continuous_scale=px.colors.sequential.Bluered,
)
# fig.update(layout_coloraxis_showscale=False)
fig.show()
return fig
gr.Interface(
fn=plot_results,
inputs=[
gr.Textbox(label="词语", placeholder="标准"),
gr.Textbox(label="造句", placeholder="小明朗读课文时发音标准,被老师评为优秀。"),
],
examples=[
["聪明", "小明很聪明,每年考班上第一名。"],
["尴尬", "小明去朋友的生日庆祝会,忘了带礼物,感到很尴尬。"],
["标准", "小明朗读课文时发音标准,被老师评为优秀。"],
],
outputs=["plot"],
title="Chinese Sentence Grading",
).launch()
|