import gradio as gr import numpy as np import pandas as pd import plotly.express as px import torch from transformers import AutoTokenizer, AutoModelForMaskedLM device = "cuda:0" if torch.cuda.is_available() else "cpu" model_checkpoint = "facebook/xlm-v-base" tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) model = AutoModelForMaskedLM.from_pretrained(model_checkpoint) model = model.to(device) mask_token = tokenizer.mask_token def add_mask(target_word, text): text_masked = text.replace(target_word, mask_token) return text_masked def eval_prob(target_word, text): # Replace target_word with mask text_masked = add_mask(target_word, text) # Get token ID of target_word target_idx = tokenizer.encode(target_word)[-2] # Convert masked text to token IDs inputs = tokenizer(text_masked, return_tensors="pt").to(device) # Calculate logits score (for each token, for each position) token_logits = model(**inputs).logits # Find the position of the mask and extract logits for that position mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] mask_token_logits = token_logits[0, mask_token_index, :] # Convert logits to softmax probability logits = mask_token_logits[0].tolist() probs = torch.nn.functional.softmax(torch.tensor([logits]), dim=1)[0] return probs, target_idx def process_prob(target_word, text): probs, target_idx = eval_prob(target_word, text) # Sort tokens based on probability scores words = [tokenizer.decode(idx) for idx in torch.sort(probs, descending=True).indices] scores = torch.sort(probs, descending=True).values # Consolidate results in dataframe d = {'word': words, 'score': scores} df = pd.DataFrame(data=d) # Get score rank and probability of target word result_rank = words.index(target_word) result_prob = scores[result_rank] # Create color code target_col = [0] * len(scores) target_col[result_rank] = 1 df["target"] = target_col return result_rank, result_prob, df def plot_results(target_word, text): _, _, df = process_prob(target_word, text) # Plot fig = px.bar( df[:150], x='word', y='score', color='target', color_continuous_scale=px.colors.sequential.Bluered, ) # fig.update(layout_coloraxis_showscale=False) fig.show() return fig gr.Interface( fn=plot_results, inputs=[ gr.Textbox(label="词语", placeholder="Key in a 词语 or click an example"), gr.Textbox(label="造句", placeholder="造句 with the 词语 or click an example"), ], examples=[ ["与众不同", "他的产品很特别,与众不同,跟别人的不一样。"], ["尴尬", "小明去朋友的生日庆祝会,忘了带礼物,感到很尴尬。"], ["标准", "小明朗读课文时发音标准,被老师评为优秀。"], ], outputs=["plot"], title="Chinese Sentence Grading", ).launch()