Spaces:
Runtime error
Runtime error
add st.cache
Browse files
app.py
CHANGED
@@ -7,27 +7,37 @@ import sys
|
|
7 |
|
8 |
HISTORY_WEIGHT = 100 # set history weight (if found any keyword from history, it will priorities based on its weight)
|
9 |
|
10 |
-
@st.cache(allow_output_mutation=True)
|
11 |
def get_model(model):
|
12 |
return pipeline("fill-mask", model=model, top_k=10)#set the maximum of tokens to be retrieved after each inference to model
|
13 |
|
14 |
-
|
|
|
|
|
|
|
15 |
def loading_models(model='roberta-base'):
|
16 |
return get_model(model), SentenceTransformer('all-MiniLM-L6-v2')
|
17 |
|
|
|
|
|
|
|
18 |
def infer(text):
|
19 |
global nlp
|
20 |
return nlp(text+' '+nlp.tokenizer.mask_token)
|
21 |
|
|
|
|
|
|
|
|
|
22 |
def sim(predicted_seq, sem_list):
|
23 |
return semantic_model.encode(predicted_seq, convert_to_tensor=True), \
|
24 |
semantic_model.encode(sem_list, convert_to_tensor=True)
|
25 |
|
26 |
-
|
27 |
-
#bypass hash function
|
28 |
-
return True
|
29 |
|
30 |
-
@st.cache(allow_output_mutation=True,
|
|
|
|
|
31 |
def main(text,semantic_text,history_keyword_text):
|
32 |
global semantic_model, data_load_state
|
33 |
data_load_state.text('Inference from model...')
|
|
|
7 |
|
8 |
HISTORY_WEIGHT = 100 # set history weight (if found any keyword from history, it will priorities based on its weight)
|
9 |
|
10 |
+
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
|
11 |
def get_model(model):
|
12 |
return pipeline("fill-mask", model=model, top_k=10)#set the maximum of tokens to be retrieved after each inference to model
|
13 |
|
14 |
+
def hash_func(inp):
|
15 |
+
return True
|
16 |
+
|
17 |
+
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
|
18 |
def loading_models(model='roberta-base'):
|
19 |
return get_model(model), SentenceTransformer('all-MiniLM-L6-v2')
|
20 |
|
21 |
+
@st.cache(allow_output_mutation=True,
|
22 |
+
suppress_st_warning=True,
|
23 |
+
hash_funcs={'tokenizers.Tokenizer': hash_func, 'tokenizers.AddedToken': hash_func})
|
24 |
def infer(text):
|
25 |
global nlp
|
26 |
return nlp(text+' '+nlp.tokenizer.mask_token)
|
27 |
|
28 |
+
|
29 |
+
@st.cache(allow_output_mutation=True,
|
30 |
+
suppress_st_warning=True,
|
31 |
+
hash_funcs={'tokenizers.Tokenizer': hash_func, 'tokenizers.AddedToken': hash_func})
|
32 |
def sim(predicted_seq, sem_list):
|
33 |
return semantic_model.encode(predicted_seq, convert_to_tensor=True), \
|
34 |
semantic_model.encode(sem_list, convert_to_tensor=True)
|
35 |
|
36 |
+
|
|
|
|
|
37 |
|
38 |
+
@st.cache(allow_output_mutation=True,
|
39 |
+
suppress_st_warning=True,
|
40 |
+
hash_funcs={'tokenizers.Tokenizer': hash_func, 'tokenizers.AddedToken': hash_func})
|
41 |
def main(text,semantic_text,history_keyword_text):
|
42 |
global semantic_model, data_load_state
|
43 |
data_load_state.text('Inference from model...')
|