cn91 commited on
Commit
3ec9549
1 Parent(s): aaab29f

Update to 710M Character Model, add RTD

Browse files
Files changed (1) hide show
  1. app.py +108 -32
app.py CHANGED
@@ -1,8 +1,9 @@
1
- from transformers import pipeline, AutoTokenizer
2
  import pandas as pd
3
  import numpy as np
4
  import torch
5
  import streamlit as st
 
6
 
7
  USE_GPU = True
8
 
@@ -11,64 +12,127 @@ if USE_GPU and torch.cuda.is_available():
11
  else:
12
  device = torch.device('cpu')
13
 
14
- MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-186M-Chinese-SentencePiece"
15
- #MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-97M-CWS-Chinese"
16
 
17
-
18
- WORD_PROBABILITY_THRESHOLD = 0.02
19
- TOP_K_WORDS = 200
20
-
21
- CHINESE_WORDLIST = ['一定','一样','不得了','主观','从此','便于','俗话','倒霉','候选','充沛','分别','反倒','只好','同情','吹捧','咳嗽','围绕','如意','实行','将近','就职','应该','归还','当面','忘记','急忙','恢复','悲哀','感冒','成长','截至','打架','把握','报告','抱怨','担保','拒绝','拜访','拥护','拳头','拼搏','损坏','接待','握手','揭发','攀登','显示','普遍','未免','欣赏','正式','比如','流浪','涂抹','深刻','演绎','留念','瞻仰','确保','稍微','立刻','精心','结算','罕见','访问','请示','责怪','起初','转达','辅导','过瘾','运动','连忙','适合','遭受','重叠','镇静']
22
 
23
  @st.cache_resource
24
  def get_model_chinese():
25
  return pipeline("fill-mask", MODEL_NAME_CHINESE, device = device)
26
 
27
  @st.cache_resource
28
- def get_allowed_tokens():
29
- df = pd.read_csv('allowed_token_ids.csv')
30
- return set(list(df['token']))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def assess_chinese(word, sentence):
33
  print("Assessing Chinese")
34
- allowed_token_ids = get_allowed_tokens()
35
-
 
 
36
  if sentence.lower().find(word.lower()) == -1:
37
  print('Sentence does not contain the word!')
38
  return
39
 
40
- text = sentence.replace(word.lower(), "<mask>")
41
 
42
- top_k_prediction = mask_filler_chinese(text, top_k=TOP_K_WORDS)
43
- target_word_prediction = mask_filler_chinese(text, targets = word)
 
 
 
 
 
 
 
 
 
44
 
45
  norm_factor = 0
46
  for output in top_k_prediction:
47
- if output['token'] not in allowed_token_ids:
48
  norm_factor += output['score']
49
 
50
  top_k_prediction_new = []
51
  for output in top_k_prediction:
52
- if output['token'] in allowed_token_ids:
53
  output['score'] = output['score']/(1-min(0.5,norm_factor))
54
  top_k_prediction_new.append(output)
55
-
56
- target_word_prediction[0]['score'] = target_word_prediction[0]['score'] / (1-min(0.5,norm_factor))
57
- score = target_word_prediction[0]['score']
 
 
 
 
 
 
 
 
 
 
58
 
59
  # append the original word if its not found in the results
60
  top_k_prediction_filtered = [output for output in top_k_prediction_new if \
61
  output['token_str'] == word]
62
  if len(top_k_prediction_filtered) == 0:
63
- top_k_prediction_new.extend(target_word_prediction)
64
 
65
  return top_k_prediction_new, score
66
 
67
  def assess_sentence(word, sentence):
68
  return assess_chinese(word, sentence)
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def get_chinese_word():
71
- possible_words = CHINESE_WORDLIST
72
  word = np.random.choice(possible_words)
73
  return word
74
 
@@ -77,6 +141,8 @@ def get_word():
77
 
78
  mask_filler_chinese = get_model_chinese()
79
  #wordlist_chinese = get_wordlist_chinese()
 
 
80
 
81
  def highlight_given_word(row):
82
  color = '#ACE5EE' if row.Words == target_word else 'white'
@@ -101,14 +167,17 @@ def get_top_5_results(top_k_prediction):
101
  return top_5_df
102
 
103
  #### Streamlit Page
104
- st.title("造句 Auto-marking Demo")
105
 
106
  if 'target_word' not in st.session_state:
107
  st.session_state['target_word'] = get_word()
108
  target_word = st.session_state['target_word']
 
 
 
 
109
 
110
- st.write("Target word: ", target_word)
111
- if st.button("Get new word"):
112
  st.session_state['target_word'] = get_word()
113
  st.experimental_rerun()
114
 
@@ -122,16 +191,23 @@ if st.button("Grade"):
122
  with open('./result01.json', 'w') as outfile:
123
  outfile.write(str(top_k_prediction))
124
 
125
- st.write(f"Probability: {score:.2%}")
126
- st.write(f"Target probability: {WORD_PROBABILITY_THRESHOLD:.2%}")
 
 
 
 
 
127
  predictions_df = get_top_5_results(top_k_prediction)
128
  df_style = predictions_df.style.apply(highlight_given_word, axis=1)
129
 
130
  if (score >= WORD_PROBABILITY_THRESHOLD):
131
  # st.balloons()
132
- st.success("Yay good job! 🕺 Practice again with other words", icon="✅")
133
- st.table(df_style)
 
 
134
  else:
135
- st.warning("Hmmm.. maybe try again?")
136
- st.table(df_style)
137
 
 
1
+ from transformers import pipeline, AutoTokenizer, ElectraForPreTraining
2
  import pandas as pd
3
  import numpy as np
4
  import torch
5
  import streamlit as st
6
+ from annotated_text import annotated_text
7
 
8
  USE_GPU = True
9
 
 
12
  else:
13
  device = torch.device('cpu')
14
 
15
+ MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese"
16
+ RTD_MODEL_NAME_CHINESE = "hfl/chinese-electra-180g-large-discriminator"
17
 
18
+ WORD_PROBABILITY_THRESHOLD = 0.05
19
+ TOP_K_WORDS = 10
 
 
 
20
 
21
  @st.cache_resource
22
  def get_model_chinese():
23
  return pipeline("fill-mask", MODEL_NAME_CHINESE, device = device)
24
 
25
  @st.cache_resource
26
+ def get_rtd_tokenizer_chinese():
27
+ return AutoTokenizer.from_pretrained(RTD_MODEL_NAME_CHINESE)
28
+
29
+ @st.cache_resource
30
+ def get_rtd_model_chinese():
31
+ return ElectraForPreTraining.from_pretrained(RTD_MODEL_NAME_CHINESE)
32
+
33
+ @st.cache_resource
34
+ def get_wordlist_chinese():
35
+ df = pd.read_csv('wordlist_chinese_v2.csv')
36
+ wordlist = df[df.assess == True]
37
+ return wordlist['Chinese'].tolist()
38
+
39
+ @st.cache_resource
40
+ def get_allowed_words():
41
+ df = pd.read_csv('allowed_words.csv')
42
+ return set(list(df['word']))
43
 
44
  def assess_chinese(word, sentence):
45
  print("Assessing Chinese")
46
+ number_of_chars = len(word)
47
+ assert number_of_chars == 2
48
+
49
+ allowed_words = get_allowed_words()
50
  if sentence.lower().find(word.lower()) == -1:
51
  print('Sentence does not contain the word!')
52
  return
53
 
54
+ text = sentence.replace(word.lower(), "[MASK]"*number_of_chars)
55
 
56
+ top_k_prediction = []
57
+ candidates = mask_filler_chinese(text, top_k=TOP_K_WORDS)[0]
58
+ for candidate in candidates:
59
+ temp_text = text.replace("[MASK]", candidate['token_str'], 1)
60
+ second_predictions = mask_filler_chinese(temp_text, top_k=5)
61
+ for prediction in second_predictions:
62
+ prediction['token_str'] = candidate['token_str'] + prediction['token_str']
63
+ prediction['score'] = candidate['score'] * prediction['score']
64
+
65
+ top_k_prediction.extend(second_predictions)
66
+ top_k_prediction = sorted(top_k_prediction, key = lambda x: x['score'], reverse = True)[:(TOP_K_WORDS*5)]
67
 
68
  norm_factor = 0
69
  for output in top_k_prediction:
70
+ if output['token_str'] not in allowed_words:
71
  norm_factor += output['score']
72
 
73
  top_k_prediction_new = []
74
  for output in top_k_prediction:
75
+ if output['token_str'] in allowed_words:
76
  output['score'] = output['score']/(1-min(0.5,norm_factor))
77
  top_k_prediction_new.append(output)
78
+ print (f"NORM_FACTOR: {norm_factor}")
79
+
80
+ # Get target word prediction
81
+ temp_text = text
82
+ output1 = mask_filler_chinese(text, targets=word[0])[0][0]
83
+ temp_text = text.replace("[MASK]", word[0], 1)
84
+ output2 = mask_filler_chinese(temp_text, targets = word[1])[0]
85
+ output2['token_str'] = output1['token_str'] + output2['token_str']
86
+ output2['score'] = output1['score'] * output2['score']
87
+ target_word_prediction = output2
88
+
89
+ target_word_prediction['score'] = target_word_prediction['score'] / (1-min(0.5,norm_factor))
90
+ score = target_word_prediction['score']
91
 
92
  # append the original word if its not found in the results
93
  top_k_prediction_filtered = [output for output in top_k_prediction_new if \
94
  output['token_str'] == word]
95
  if len(top_k_prediction_filtered) == 0:
96
+ top_k_prediction_new.extend([target_word_prediction])
97
 
98
  return top_k_prediction_new, score
99
 
100
  def assess_sentence(word, sentence):
101
  return assess_chinese(word, sentence)
102
 
103
+ def get_annotated_sentence(sentence, errors):
104
+ if len(errors) == 0:
105
+ return sentence
106
+
107
+ output = ["Input sentence: "]
108
+
109
+ wrong_char_indices = [e[0].item() for e in errors]
110
+ curr_ind = 0
111
+ for i in range(len(wrong_char_indices)):
112
+ output.append(sentence[curr_ind:wrong_char_indices[i]])
113
+ output.append((sentence[wrong_char_indices[i]], "", "#F8C8DC"))
114
+ # output.append((sentence[wrong_char_indices[i]], " ", "#ff4b4b"))
115
+ curr_ind = wrong_char_indices[i] + 1
116
+ output.append(sentence[curr_ind:])
117
+ print(output)
118
+
119
+ return output
120
+
121
+ def get_word_errors(word, sentence):
122
+ tokens = rtd_tokenizer_chinese(sentence, return_tensors = 'pt', return_offsets_mapping = True)
123
+ scores = rtd_model_chinese(**rtd_tokenizer_chinese(sentence, return_tensors = 'pt'))[0][0]
124
+
125
+ errors = []
126
+ for i in range(len(scores)):
127
+ if scores[i] > 0:
128
+ errors.append(tokens['offset_mapping'][0][i])
129
+
130
+ print(errors)
131
+ return errors
132
+
133
+
134
  def get_chinese_word():
135
+ possible_words = get_wordlist_chinese()
136
  word = np.random.choice(possible_words)
137
  return word
138
 
 
141
 
142
  mask_filler_chinese = get_model_chinese()
143
  #wordlist_chinese = get_wordlist_chinese()
144
+ rtd_tokenizer_chinese = get_rtd_tokenizer_chinese()
145
+ rtd_model_chinese = get_rtd_model_chinese()
146
 
147
  def highlight_given_word(row):
148
  color = '#ACE5EE' if row.Words == target_word else 'white'
 
167
  return top_5_df
168
 
169
  #### Streamlit Page
170
+ st.title("造句 Self-marking Demo")
171
 
172
  if 'target_word' not in st.session_state:
173
  st.session_state['target_word'] = get_word()
174
  target_word = st.session_state['target_word']
175
+ target_word_ind = get_wordlist_chinese().index(target_word)
176
+
177
+ #st.write("Target word: ", target_word)
178
+ target_word = st.selectbox("Choose a word:", get_wordlist_chinese(), index = target_word_ind)
179
 
180
+ if st.button("Get random word"):
 
181
  st.session_state['target_word'] = get_word()
182
  st.experimental_rerun()
183
 
 
191
  with open('./result01.json', 'w') as outfile:
192
  outfile.write(str(top_k_prediction))
193
 
194
+ errors = get_word_errors(target_word, sentence)
195
+ annotated_sentence = get_annotated_sentence(sentence, errors)
196
+
197
+ annotated_text(annotated_sentence)
198
+
199
+ st.write(f"Probability score: {score:.1%}. (Target: {WORD_PROBABILITY_THRESHOLD:.1%})")
200
+ # st.write(f"Target probability: {WORD_PROBABILITY_THRESHOLD:.1%}")
201
  predictions_df = get_top_5_results(top_k_prediction)
202
  df_style = predictions_df.style.apply(highlight_given_word, axis=1)
203
 
204
  if (score >= WORD_PROBABILITY_THRESHOLD):
205
  # st.balloons()
206
+ if (len(errors) == 0):
207
+ st.success("Yay good job! 🕺 Practice again with other words", icon="✅")
208
+ else:
209
+ st.warning("Potential word errors detected. Try again?")
210
  else:
211
+ st.warning("Probability score too low. Maybe try again?")
212
+ st.table(df_style)
213