File size: 6,195 Bytes
f6cb372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
WORD_PROBABILITY_THRESHOLD = 0.02
#WORD_PROBABILITY_THRESHOLD_ENGLISH = 0.02
#WORD_PROBABILITY_THRESHOLD_CHINESE = 0.02
TOP_K_WORDS = 10

ENGLISH_LANG = "English"
CHINESE_LANG = "Chinese"

CHINESE_WORDLIST = ['一定','一样','不得了','主观','从此','便于','俗话','倒霉','候选','充沛','分别','反倒','只好','同情','吹捧','咳嗽','围绕','如意','实行','将近','就职','应该','归还','当面','忘记','急忙','恢复','悲哀','感冒','成长','截至','打架','把握','报告','抱怨','担保','拒绝','拜访','拥护','拳头','拼搏','损坏','接待','握手','揭发','攀登','显示','普遍','未免','欣赏','正式','比如','流浪','涂抹','深刻','演绎','留念','瞻仰','确保','稍微','立刻','精心','结算','罕见','访问','请示','责怪','起初','转达','辅导','过瘾','运动','连忙','适合','遭受','重叠','镇静']

@st.cache_resource
def get_model_chinese():
    return pipeline("fill-mask", MODEL_NAME_CHINESE, device = device)

@st.cache_resource
def get_model_english():
    return pipeline("fill-mask", MODEL_NAME_ENGLISH, device = device)    

@st.cache_data
def get_wordlist_chinese():
    return pd.read_csv('wordlist_chinese.csv')

@st.cache_data
def get_wordlist_english():
    return pd.read_csv('wordlist_english.csv')

def assess_chinese(word, sentence):
    print("Assessing English")
    if sentence.lower().find(word.lower()) == -1:
        print('Sentence does not contain the word!')
        return

    text = sentence.replace(word.lower(), "<mask>")

    top_k_prediction = mask_filler_chinese(text, top_k=TOP_K_WORDS)
    target_word_prediction = mask_filler_chinese(text, targets = word)

    score = target_word_prediction[0]['score']

    # append the original word if its not found in the results
    top_k_prediction_filtered = [output for output in top_k_prediction if \
                                 output['token_str'] == word]
    if len(top_k_prediction_filtered) == 0:
        top_k_prediction.extend(target_word_prediction)

    return top_k_prediction, score

def assess_english(word, sentence):
    if sentence.lower().find(word.lower()) == -1:
        raise Exception("Sentence does not contain the target word")

    text = sentence.replace(word.lower(), "<mask>")

    top_k_prediction = mask_filler_english(text, top_k=TOP_K_WORDS)
    target_word_prediction = mask_filler_english(text, targets = chr(9601)+word)

    score = target_word_prediction[0]['score']

    # append the original word if its not found in the results
    top_k_prediction_filtered = [output for output in top_k_prediction if \
                                 output['token_str'] == word]
    if len(top_k_prediction_filtered) == 0:
        top_k_prediction.extend(target_word_prediction)

    return top_k_prediction, score

def assess_sentence(language, word, sentence):
    if (language == ENGLISH_LANG):
        return assess_english(word, sentence)
    elif (language == CHINESE_LANG):
        return assess_chinese(word, sentence)
    
def get_chinese_word():
    include = (wordlist_chinese.assess == True) & (wordlist_chinese.Chinese.apply(len) == 2)
    possible_words = wordlist_chinese[include]
    word = possible_words.sample(1).iloc[0].Chinese
    test_words = CHINESE_WORDLIST
    word = np.random.choice(test_words)
    return word

def get_english_word():
    include = (wordlist_english.assess == True)
    possible_words = wordlist_english[include]
    word = possible_words.sample(1).iloc[0].word
    test_words = ["independent","satisfied","excited"]
    word = np.random.choice(test_words)
    return word

def get_word(language):
    if (language == ENGLISH_LANG):
        return get_english_word()
    elif (language == CHINESE_LANG):
        return get_chinese_word()

mask_filler_chinese = get_model_chinese()
mask_filler_english = get_model_english()
wordlist_chinese = get_wordlist_chinese()
wordlist_english = get_wordlist_english()

def highlight_given_word(row):
    color = '#ACE5EE' if row.Words == target_word else 'white'
    return [f'background-color:{color}'] * len(row)

def get_top_5_results(top_k_prediction):
    predictions_df = pd.DataFrame(top_k_prediction)
    predictions_df = predictions_df.drop(columns=["token", "sequence"])
    predictions_df = predictions_df.rename(columns={"score": "Probability", "token_str": "Words"})

    if (predictions_df[:5].Words == target_word).sum() == 0:
        print("target word not in top 5")
        top_5_df = predictions_df[:5]
        target_word_df = predictions_df[(predictions_df.Words == target_word)]
        print(target_word_df)
        top_5_df = pd.concat([top_5_df, target_word_df])

    else:
        top_5_df = predictions_df[:5]
    top_5_df['Probability'] = top_5_df['Probability'].apply(lambda x: f"{x:.2%}")

    return top_5_df

#### Streamlit Page
st.title("造句 Auto-marking Demo")
language = st.radio("Select your language", (ENGLISH_LANG, CHINESE_LANG))
#st.info("You are practising on " + language)

if 'target_word' not in st.session_state:
    st.session_state['target_word'] = get_word(language)
target_word = st.session_state['target_word']

st.write("Target word: ", target_word)
if st.button("Get new word"):
    st.session_state['target_word'] = get_word(language)
    st.experimental_rerun()

st.subheader("Form your sentence and input below!")
sentence = st.text_input('Enter your sentence here', placeholder="Enter your sentence here!")

if st.button("Grade"):
    top_k_prediction, score = assess_sentence(language, target_word, sentence)
    with open('./result01.json', 'w') as outfile:
        outfile.write(str(top_k_prediction))

    st.write(f"Probability: {score:.2%}")
    st.write(f"Target probability: {WORD_PROBABILITY_THRESHOLD:.2%}")
    predictions_df = get_top_5_results(top_k_prediction)
    df_style = predictions_df.style.apply(highlight_given_word, axis=1)

    if (score >= WORD_PROBABILITY_THRESHOLD):
        st.balloons()
        st.success("Yay good job! That's a great sentence 🕺 Practice again with other word", icon="✅")
        st.table(df_style)
    else:
        st.warning("Hmmm.. maybe try again?")
        st.table(df_style)