cn91 commited on
Commit
f6cb372
·
1 Parent(s): abf1538

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py CHANGED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ WORD_PROBABILITY_THRESHOLD = 0.02
2
+ #WORD_PROBABILITY_THRESHOLD_ENGLISH = 0.02
3
+ #WORD_PROBABILITY_THRESHOLD_CHINESE = 0.02
4
+ TOP_K_WORDS = 10
5
+
6
+ ENGLISH_LANG = "English"
7
+ CHINESE_LANG = "Chinese"
8
+
9
+ CHINESE_WORDLIST = ['一定','一样','不得了','主观','从此','便于','俗话','倒霉','候选','充沛','分别','反倒','只好','同情','吹捧','咳嗽','围绕','如意','实行','将近','就职','应该','归还','当面','忘记','急忙','恢复','悲哀','感冒','成长','截至','打架','把握','报告','抱怨','担保','拒绝','拜访','拥护','拳头','拼搏','损坏','接待','握手','揭发','攀登','显示','普遍','未免','欣赏','正式','比如','流浪','涂抹','深刻','演绎','留念','瞻仰','确保','稍微','立刻','精心','结算','罕见','访问','请示','责怪','起初','转达','辅导','过瘾','运动','连忙','适合','遭受','重叠','镇静']
10
+
11
+ @st.cache_resource
12
+ def get_model_chinese():
13
+ return pipeline("fill-mask", MODEL_NAME_CHINESE, device = device)
14
+
15
+ @st.cache_resource
16
+ def get_model_english():
17
+ return pipeline("fill-mask", MODEL_NAME_ENGLISH, device = device)
18
+
19
+ @st.cache_data
20
+ def get_wordlist_chinese():
21
+ return pd.read_csv('wordlist_chinese.csv')
22
+
23
+ @st.cache_data
24
+ def get_wordlist_english():
25
+ return pd.read_csv('wordlist_english.csv')
26
+
27
+ def assess_chinese(word, sentence):
28
+ print("Assessing English")
29
+ if sentence.lower().find(word.lower()) == -1:
30
+ print('Sentence does not contain the word!')
31
+ return
32
+
33
+ text = sentence.replace(word.lower(), "<mask>")
34
+
35
+ top_k_prediction = mask_filler_chinese(text, top_k=TOP_K_WORDS)
36
+ target_word_prediction = mask_filler_chinese(text, targets = word)
37
+
38
+ score = target_word_prediction[0]['score']
39
+
40
+ # append the original word if its not found in the results
41
+ top_k_prediction_filtered = [output for output in top_k_prediction if \
42
+ output['token_str'] == word]
43
+ if len(top_k_prediction_filtered) == 0:
44
+ top_k_prediction.extend(target_word_prediction)
45
+
46
+ return top_k_prediction, score
47
+
48
+ def assess_english(word, sentence):
49
+ if sentence.lower().find(word.lower()) == -1:
50
+ raise Exception("Sentence does not contain the target word")
51
+
52
+ text = sentence.replace(word.lower(), "<mask>")
53
+
54
+ top_k_prediction = mask_filler_english(text, top_k=TOP_K_WORDS)
55
+ target_word_prediction = mask_filler_english(text, targets = chr(9601)+word)
56
+
57
+ score = target_word_prediction[0]['score']
58
+
59
+ # append the original word if its not found in the results
60
+ top_k_prediction_filtered = [output for output in top_k_prediction if \
61
+ output['token_str'] == word]
62
+ if len(top_k_prediction_filtered) == 0:
63
+ top_k_prediction.extend(target_word_prediction)
64
+
65
+ return top_k_prediction, score
66
+
67
+ def assess_sentence(language, word, sentence):
68
+ if (language == ENGLISH_LANG):
69
+ return assess_english(word, sentence)
70
+ elif (language == CHINESE_LANG):
71
+ return assess_chinese(word, sentence)
72
+
73
+ def get_chinese_word():
74
+ include = (wordlist_chinese.assess == True) & (wordlist_chinese.Chinese.apply(len) == 2)
75
+ possible_words = wordlist_chinese[include]
76
+ word = possible_words.sample(1).iloc[0].Chinese
77
+ test_words = CHINESE_WORDLIST
78
+ word = np.random.choice(test_words)
79
+ return word
80
+
81
+ def get_english_word():
82
+ include = (wordlist_english.assess == True)
83
+ possible_words = wordlist_english[include]
84
+ word = possible_words.sample(1).iloc[0].word
85
+ test_words = ["independent","satisfied","excited"]
86
+ word = np.random.choice(test_words)
87
+ return word
88
+
89
+ def get_word(language):
90
+ if (language == ENGLISH_LANG):
91
+ return get_english_word()
92
+ elif (language == CHINESE_LANG):
93
+ return get_chinese_word()
94
+
95
+ mask_filler_chinese = get_model_chinese()
96
+ mask_filler_english = get_model_english()
97
+ wordlist_chinese = get_wordlist_chinese()
98
+ wordlist_english = get_wordlist_english()
99
+
100
+ def highlight_given_word(row):
101
+ color = '#ACE5EE' if row.Words == target_word else 'white'
102
+ return [f'background-color:{color}'] * len(row)
103
+
104
+ def get_top_5_results(top_k_prediction):
105
+ predictions_df = pd.DataFrame(top_k_prediction)
106
+ predictions_df = predictions_df.drop(columns=["token", "sequence"])
107
+ predictions_df = predictions_df.rename(columns={"score": "Probability", "token_str": "Words"})
108
+
109
+ if (predictions_df[:5].Words == target_word).sum() == 0:
110
+ print("target word not in top 5")
111
+ top_5_df = predictions_df[:5]
112
+ target_word_df = predictions_df[(predictions_df.Words == target_word)]
113
+ print(target_word_df)
114
+ top_5_df = pd.concat([top_5_df, target_word_df])
115
+
116
+ else:
117
+ top_5_df = predictions_df[:5]
118
+ top_5_df['Probability'] = top_5_df['Probability'].apply(lambda x: f"{x:.2%}")
119
+
120
+ return top_5_df
121
+
122
+ #### Streamlit Page
123
+ st.title("造句 Auto-marking Demo")
124
+ language = st.radio("Select your language", (ENGLISH_LANG, CHINESE_LANG))
125
+ #st.info("You are practising on " + language)
126
+
127
+ if 'target_word' not in st.session_state:
128
+ st.session_state['target_word'] = get_word(language)
129
+ target_word = st.session_state['target_word']
130
+
131
+ st.write("Target word: ", target_word)
132
+ if st.button("Get new word"):
133
+ st.session_state['target_word'] = get_word(language)
134
+ st.experimental_rerun()
135
+
136
+ st.subheader("Form your sentence and input below!")
137
+ sentence = st.text_input('Enter your sentence here', placeholder="Enter your sentence here!")
138
+
139
+ if st.button("Grade"):
140
+ top_k_prediction, score = assess_sentence(language, target_word, sentence)
141
+ with open('./result01.json', 'w') as outfile:
142
+ outfile.write(str(top_k_prediction))
143
+
144
+ st.write(f"Probability: {score:.2%}")
145
+ st.write(f"Target probability: {WORD_PROBABILITY_THRESHOLD:.2%}")
146
+ predictions_df = get_top_5_results(top_k_prediction)
147
+ df_style = predictions_df.style.apply(highlight_given_word, axis=1)
148
+
149
+ if (score >= WORD_PROBABILITY_THRESHOLD):
150
+ st.balloons()
151
+ st.success("Yay good job! That's a great sentence 🕺 Practice again with other word", icon="✅")
152
+ st.table(df_style)
153
+ else:
154
+ st.warning("Hmmm.. maybe try again?")
155
+ st.table(df_style)
156
+