mbahrami commited on
Commit
4c5d55d
1 Parent(s): 680b98d

add st.cache

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -7,27 +7,37 @@ import sys
7
 
8
  HISTORY_WEIGHT = 100 # set history weight (if found any keyword from history, it will priorities based on its weight)
9
 
10
- @st.cache(allow_output_mutation=True)
11
  def get_model(model):
12
  return pipeline("fill-mask", model=model, top_k=10)#set the maximum of tokens to be retrieved after each inference to model
13
 
14
- @st.cache(allow_output_mutation=True)
 
 
 
15
  def loading_models(model='roberta-base'):
16
  return get_model(model), SentenceTransformer('all-MiniLM-L6-v2')
17
 
 
 
 
18
  def infer(text):
19
  global nlp
20
  return nlp(text+' '+nlp.tokenizer.mask_token)
21
 
 
 
 
 
22
  def sim(predicted_seq, sem_list):
23
  return semantic_model.encode(predicted_seq, convert_to_tensor=True), \
24
  semantic_model.encode(sem_list, convert_to_tensor=True)
25
 
26
- def hash_func(inp):
27
- #bypass hash function
28
- return True
29
 
30
- @st.cache(allow_output_mutation=True, hash_funcs={'tokenizers.Tokenizer': hash_func, 'tokenizers.AddedToken': hash_func})
 
 
31
  def main(text,semantic_text,history_keyword_text):
32
  global semantic_model, data_load_state
33
  data_load_state.text('Inference from model...')
 
7
 
8
  HISTORY_WEIGHT = 100 # set history weight (if found any keyword from history, it will priorities based on its weight)
9
 
10
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True)
11
  def get_model(model):
12
  return pipeline("fill-mask", model=model, top_k=10)#set the maximum of tokens to be retrieved after each inference to model
13
 
14
+ def hash_func(inp):
15
+ return True
16
+
17
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True)
18
  def loading_models(model='roberta-base'):
19
  return get_model(model), SentenceTransformer('all-MiniLM-L6-v2')
20
 
21
+ @st.cache(allow_output_mutation=True,
22
+ suppress_st_warning=True,
23
+ hash_funcs={'tokenizers.Tokenizer': hash_func, 'tokenizers.AddedToken': hash_func})
24
  def infer(text):
25
  global nlp
26
  return nlp(text+' '+nlp.tokenizer.mask_token)
27
 
28
+
29
+ @st.cache(allow_output_mutation=True,
30
+ suppress_st_warning=True,
31
+ hash_funcs={'tokenizers.Tokenizer': hash_func, 'tokenizers.AddedToken': hash_func})
32
  def sim(predicted_seq, sem_list):
33
  return semantic_model.encode(predicted_seq, convert_to_tensor=True), \
34
  semantic_model.encode(sem_list, convert_to_tensor=True)
35
 
36
+
 
 
37
 
38
+ @st.cache(allow_output_mutation=True,
39
+ suppress_st_warning=True,
40
+ hash_funcs={'tokenizers.Tokenizer': hash_func, 'tokenizers.AddedToken': hash_func})
41
  def main(text,semantic_text,history_keyword_text):
42
  global semantic_model, data_load_state
43
  data_load_state.text('Inference from model...')