Galuh Sahid commited on
Commit
3e7562a
1 Parent(s): a4af9d2

bug: fix cache

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -18,13 +18,17 @@ MODELS = {
18
 
19
  headers = {}
20
 
21
- @st.cache(show_spinner=False, persist=True)
22
- def load_gpt(model_type, text):
23
- print("Loading model...")
24
  model = GPT2LMHeadModel.from_pretrained(MODELS[model_type])
 
 
 
 
 
25
  tokenizer = GPT2Tokenizer.from_pretrained(MODELS[model_type])
26
 
27
- return model, tokenizer
28
 
29
  def get_image(text: str):
30
  url = "https://wikisearch.uncool.ai/get_image/"
@@ -116,7 +120,8 @@ if st.button("Run"):
116
  with st.spinner(text="Getting results..."):
117
 
118
  st.subheader("Result")
119
- model, tokenizer = load_gpt(model_name, text)
 
120
 
121
  input_ids = tokenizer.encode(text, return_tensors='pt')
122
  output = model.generate(input_ids=input_ids,
 
18
 
19
  headers = {}
20
 
21
+ @st.cache(show_spinner=False, persist=True, hash_funcs={transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel: lambda _: None})
22
+ def load_gpt(model_type):
 
23
  model = GPT2LMHeadModel.from_pretrained(MODELS[model_type])
24
+
25
+ return model
26
+
27
+ @st.cache(show_spinner=False, persist=True, hash_funcs={transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer: lambda _: None})
28
+ def load_gpt_tokenizer(model_type):
29
  tokenizer = GPT2Tokenizer.from_pretrained(MODELS[model_type])
30
 
31
+ return tokenizer
32
 
33
  def get_image(text: str):
34
  url = "https://wikisearch.uncool.ai/get_image/"
 
120
  with st.spinner(text="Getting results..."):
121
 
122
  st.subheader("Result")
123
+ model = load_gpt(model_name)
124
+ tokenizer = load_gpt_tokenizer(model_name)
125
 
126
  input_ids = tokenizer.encode(text, return_tensors='pt')
127
  output = model.generate(input_ids=input_ids,