File size: 4,003 Bytes
da676c8
40c9d2b
fe35e1b
da676c8
12094be
d968e32
fe35e1b
 
da676c8
4c5d55d
da676c8
0dc2218
da676c8
4c5d55d
 
 
 
7cf4d15
 
680b98d
4c5d55d
 
 
680b98d
6ea6ff3
680b98d
 
4c5d55d
 
 
 
680b98d
 
 
 
4c5d55d
 
 
680b98d
 
7cf4d15
680b98d
fe35e1b
7cf4d15
eacbe96
 
680b98d
4fd1747
7cf4d15
12094be
ac5b8a7
4fd1747
eacbe96
fe35e1b
fa02d7f
 
f089045
7cf4d15
fa02d7f
 
38a8bac
680b98d
fe35e1b
 
 
 
7f77d3a
0dc2218
7f77d3a
 
 
4677bcd
680b98d
4677bcd
fe35e1b
 
 
 
7cf4d15
fe35e1b
7cf4d15
fe35e1b
680b98d
 
 
 
fe35e1b
 
ba90497
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import streamlit as st
import pandas as pd
from streamlit import cli as stcli
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
import sys 

HISTORY_WEIGHT = 100 # set history weight (if found any keyword from history, it will priorities based on its weight)

@st.cache(allow_output_mutation=True, suppress_st_warning=True)
def get_model(model):
	return pipeline("fill-mask", model=model, top_k=10)#set the maximum of tokens to be retrieved after each inference to model

def hash_func(inp):
    return True

@st.cache(allow_output_mutation=True, suppress_st_warning=True)
def loading_models(model='roberta-base'):
     return get_model(model), SentenceTransformer('all-MiniLM-L6-v2')

@st.cache(allow_output_mutation=True, 
          suppress_st_warning=True,
          hash_funcs={'tokenizers.Tokenizer': hash_func, 'tokenizers.AddedToken': hash_func})
def infer(text):
#    global nlp 
    return nlp(text+' '+nlp.tokenizer.mask_token)


@st.cache(allow_output_mutation=True, 
          suppress_st_warning=True,
          hash_funcs={'tokenizers.Tokenizer': hash_func, 'tokenizers.AddedToken': hash_func})
def sim(predicted_seq, sem_list):
    return semantic_model.encode(predicted_seq, convert_to_tensor=True), \
            semantic_model.encode(sem_list, convert_to_tensor=True)
    
@st.cache(allow_output_mutation=True, 
          suppress_st_warning=True,
          hash_funcs={'tokenizers.Tokenizer': hash_func, 'tokenizers.AddedToken': hash_func})
def main(text,semantic_text,history_keyword_text):
    global semantic_model, data_load_state
    data_load_state.text('Inference from model...')
    result = infer(text)
    sem_list=[semantic_text.strip()]
    data_load_state.text('Checking similarity...')
    if len(semantic_text):
        predicted_seq=[rec['sequence'] for rec in result]
        predicted_embeddings, semantic_history_embeddings = sim(predicted_seq, sem_list)
        cosine_scores = util.cos_sim(predicted_embeddings, semantic_history_embeddings)
    data_load_state.text('similarity check completed...')
    
    for index, r in enumerate(result):
        if len(semantic_text):
                if len(r['token_str'])>2: #skip spcial chars such as "?"
                    result[index]['score']+=float(sum(cosine_scores[index]))*HISTORY_WEIGHT
        if r['token_str'].lower().strip() in history_keyword_text.lower().strip() and len(r['token_str'].lower().strip())>1:
            #found from history, then increase the score of tokens
            result[index]['score']*=HISTORY_WEIGHT
    data_load_state.text('Score updated...')
            
    #sort the results        
    df=pd.DataFrame(result).sort_values(by='score', ascending=False)
    return df
    
    
if __name__ == '__main__':
    if st._is_running_with_streamlit:
        st.markdown("""
# Auto-Complete
This is an example of an auto-complete approach where the next token suggested based on users's history Keyword match & Semantic similarity of users's history (log).
The next token is predicted per probability and a weight if it is appeared in keyword user's history or there is a similarity to semantic user's history
""")
        history_keyword_text = st.text_input("Enter users's history <Keywords Match> (optional, i.e., 'Gates')", value="")
        
        semantic_text = st.text_input("Enter users's history <Semantic> (optional, i.e., 'Microsoft' or 'President')", value="Microsoft")
        
        text = st.text_input("Enter a text for auto completion...", value='Where is Bill')
        model = st.selectbox("Choose a model", ["roberta-base", "bert-base-uncased"])
        
        data_load_state = st.text('1.Loading model ...')

        nlp, semantic_model = loading_models(model)
        
        df=main(text,semantic_text,history_keyword_text)
        #show the results as a table
        st.table(df)
        data_load_state.text('')
    else:
        sys.argv = ['streamlit', 'run', sys.argv[0]]
        sys.exit(stcli.main())