Annalyn Ng commited on
Commit
e8d5985
1 Parent(s): 3302270

change model to xlm-v-base

Browse files
Files changed (1) hide show
  1. app.py +41 -30
app.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  from transformers import AutoTokenizer, AutoModelForMaskedLM
7
 
8
 
9
- model_checkpoint = "xlm-roberta-base"
10
 
11
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
12
  model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
@@ -14,21 +14,25 @@ mask_token = tokenizer.mask_token
14
 
15
 
16
  def add_mask(target_word, text):
17
- text_mask = text.replace(target_word, mask_token)
18
- return text_mask
19
 
20
 
21
  def eval_prob(target_word, text):
22
- text_mask = add_mask(target_word, text)
23
-
24
- # Get index of target_word
25
- target_idx = tokenizer.encode(target_word)[2]
26
-
27
- # Get logits
28
- inputs = tokenizer(text_mask, return_tensors="pt")
 
 
 
 
29
  token_logits = model(**inputs).logits
30
 
31
- # Find the location of the MASK and extract its logits
32
  mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
33
  mask_token_logits = token_logits[0, mask_token_index, :]
34
 
@@ -36,55 +40,62 @@ def eval_prob(target_word, text):
36
  logits = mask_token_logits[0].tolist()
37
  probs = torch.nn.functional.softmax(torch.tensor([logits]), dim=1)[0]
38
 
39
- # Get probability of target word filling the MASK
40
- # result = float(probs[target_idx])
41
-
42
  return probs, target_idx
43
 
44
 
45
- def plot_results(target_word, text):
 
46
  probs, target_idx = eval_prob(target_word, text)
47
-
48
  # Sort tokens based on probability scores
49
- words = [
50
- tokenizer.decode(idx) for idx in torch.sort(probs, descending=True).indices
51
- ]
52
  scores = torch.sort(probs, descending=True).values
53
-
54
  # Consolidate results in dataframe
55
- d = {"word": words, "score": scores}
56
  df = pd.DataFrame(data=d)
57
-
58
- # Get score rank of target word
59
  result_rank = words.index(target_word)
 
 
 
60
  target_col = [0] * len(scores)
61
  target_col[result_rank] = 1
62
  df["target"] = target_col
 
 
63
 
 
 
 
 
64
  # Plot
65
  fig = px.bar(
66
  df[:100],
67
- x="word",
68
- y="score",
69
- color="target",
70
  color_continuous_scale=px.colors.sequential.Bluered,
71
  )
 
72
  # fig.update(layout_coloraxis_showscale=False)
73
  fig.show()
 
74
  return fig
75
 
76
 
77
  gr.Interface(
78
  fn=plot_results,
79
  inputs=[
80
- gr.Textbox(label="词语", placeholder="标准"),
81
- gr.Textbox(label="造句", placeholder="小明朗读课文时发音标准,被老师评为优秀。"),
82
  ],
83
  examples=[
84
- ["聪明", "小明很聪明,每年考班上第一名。"],
85
  ["尴尬", "小明去朋友的生日庆祝会,忘了带礼物,感到很尴尬。"],
86
  ["标准", "小明朗读课文时发音标准,被老师评为优秀。"],
87
  ],
88
  outputs=["plot"],
89
  title="Chinese Sentence Grading",
90
- ).launch()
 
6
  from transformers import AutoTokenizer, AutoModelForMaskedLM
7
 
8
 
9
+ model_checkpoint = "facebook/xlm-v-base"
10
 
11
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
12
  model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
 
14
 
15
 
16
  def add_mask(target_word, text):
17
+ text_masked = text.replace(target_word, mask_token)
18
+ return text_masked
19
 
20
 
21
  def eval_prob(target_word, text):
22
+
23
+ # Replace target_word with mask
24
+ text_masked = add_mask(target_word, text)
25
+
26
+ # Get token ID of target_word
27
+ target_idx = tokenizer.encode(target_word)[-2]
28
+
29
+ # Convert masked text to token IDs
30
+ inputs = tokenizer(text_masked, return_tensors="pt")
31
+
32
+ # Calculate logits score (for each token, for each position)
33
  token_logits = model(**inputs).logits
34
 
35
+ # Find the position of the mask and extract logits for that position
36
  mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
37
  mask_token_logits = token_logits[0, mask_token_index, :]
38
 
 
40
  logits = mask_token_logits[0].tolist()
41
  probs = torch.nn.functional.softmax(torch.tensor([logits]), dim=1)[0]
42
 
 
 
 
43
  return probs, target_idx
44
 
45
 
46
+ def process_prob(target_word, text):
47
+
48
  probs, target_idx = eval_prob(target_word, text)
49
+
50
  # Sort tokens based on probability scores
51
+ words = [tokenizer.decode(idx) for idx in torch.sort(probs, descending=True).indices]
 
 
52
  scores = torch.sort(probs, descending=True).values
53
+
54
  # Consolidate results in dataframe
55
+ d = {'word': words, 'score': scores}
56
  df = pd.DataFrame(data=d)
57
+
58
+ # Get score rank and probability of target word
59
  result_rank = words.index(target_word)
60
+ result_prob = scores[result_rank]
61
+
62
+ # Create color code
63
  target_col = [0] * len(scores)
64
  target_col[result_rank] = 1
65
  df["target"] = target_col
66
+
67
+ return result_rank, result_prob, df
68
 
69
+ def plot_results(target_word, text):
70
+
71
+ _, _, df = process_prob(target_word, text)
72
+
73
  # Plot
74
  fig = px.bar(
75
  df[:100],
76
+ x='word',
77
+ y='score',
78
+ color='target',
79
  color_continuous_scale=px.colors.sequential.Bluered,
80
  )
81
+
82
  # fig.update(layout_coloraxis_showscale=False)
83
  fig.show()
84
+
85
  return fig
86
 
87
 
88
  gr.Interface(
89
  fn=plot_results,
90
  inputs=[
91
+ gr.Textbox(label="词语", placeholder="Key in a 词语 or click an example"),
92
+ gr.Textbox(label="造句", placeholder="造句 with the 词语 or click an example"),
93
  ],
94
  examples=[
95
+ ["与众不同", "他的产品很特别,与众不同,跟别人的不一样。"],
96
  ["尴尬", "小明去朋友的生日庆祝会,忘了带礼物,感到很尴尬。"],
97
  ["标准", "小明朗读课文时发音标准,被老师评为优秀。"],
98
  ],
99
  outputs=["plot"],
100
  title="Chinese Sentence Grading",
101
+ ).launch()