import gradio as gr import numpy as np import pandas as pd import plotly.express as px import torch from transformers import AutoTokenizer, AutoModelForMaskedLM model_checkpoint = "xlm-roberta-base" tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) model = AutoModelForMaskedLM.from_pretrained(model_checkpoint) mask_token = tokenizer.mask_token def add_mask(target_word, text): text_mask = text.replace(target_word, mask_token) return text_mask def eval_prob(target_word, text): text_mask = add_mask(target_word, text) # Get index of target_word target_idx = tokenizer.encode(target_word)[2] # Get logits inputs = tokenizer(text_mask, return_tensors="pt") token_logits = model(**inputs).logits # Find the location of the MASK and extract its logits mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] mask_token_logits = token_logits[0, mask_token_index, :] # Convert logits to softmax probability logits = mask_token_logits[0].tolist() probs = torch.nn.functional.softmax(torch.tensor([logits]), dim=1)[0] # Get probability of target word filling the MASK # result = float(probs[target_idx]) return probs, target_idx def plot_results(target_word, text): probs, target_idx = eval_prob(target_word, text) # Sort tokens based on probability scores words = [ tokenizer.decode(idx) for idx in torch.sort(probs, descending=True).indices ] scores = torch.sort(probs, descending=True).values # Consolidate results in dataframe d = {"word": words, "score": scores} df = pd.DataFrame(data=d) # Get score rank of target word result_rank = words.index(target_word) target_col = [0] * len(scores) target_col[result_rank] = 1 df["target"] = target_col # Plot fig = px.bar( df[:100], x="word", y="score", color="target", color_continuous_scale=px.colors.sequential.Bluered, ) # fig.update(layout_coloraxis_showscale=False) fig.show() return fig gr.Interface( fn=plot_results, inputs=[ gr.Textbox(label="词语", placeholder="标准"), gr.Textbox(label="造句", placeholder="小明朗读课文时发音标准,被老师评为优秀。"), ], examples=[ ["聪明", "小明很聪明,每年考班上第一名。"], ["尴尬", "小明去朋友的生日庆祝会,忘了带礼物,感到很尴尬。"], ["标准", "小明朗读课文时发音标准,被老师评为优秀。"], ], outputs=["plot"], title="Chinese Sentence Grading", ).launch()