gchhablani commited on
Commit
4b29c6a
1 Parent(s): 0cb8576

Fix state model issue

Browse files
Files changed (2) hide show
  1. apps/mlm.py +7 -6
  2. apps/vqa.py +3 -3
apps/mlm.py CHANGED
@@ -27,12 +27,13 @@ def app(state):
27
 
28
  # @st.cache(persist=False) # TODO: Make this work with mlm_state. Currently not supported.
29
  def predict(transformed_image, caption_inputs):
30
- outputs = mlm_state.model(pixel_values=transformed_image, **caption_inputs)
31
- indices = np.where(caption_inputs["input_ids"] == bert_tokenizer.mask_token_id)[
32
- 1
33
- ][0]
34
  preds = outputs.logits[0][indices]
35
  scores = np.array(preds)
 
36
  return scores
37
 
38
  # @st.cache(persist=False)
@@ -56,10 +57,10 @@ def app(state):
56
  image = plt.imread(image_path)
57
  mlm_state.mlm_image = image
58
 
59
- if mlm_state.model is None:
60
  # Display Top-5 Predictions
61
  with st.spinner("Loading model..."):
62
- mlm_state.model = load_model(mlm_checkpoints[0])
63
 
64
  if st.button(
65
  "Get a random example",
 
27
 
28
  # @st.cache(persist=False) # TODO: Make this work with mlm_state. Currently not supported.
29
  def predict(transformed_image, caption_inputs):
30
+ outputs = mlm_state.mlm_model(pixel_values=transformed_image, **caption_inputs)
31
+ print(outputs.logits.shape)
32
+ indices = np.where(caption_inputs["input_ids"] == bert_tokenizer.mask_token_id)[1][0]
33
+ print(indices)
34
  preds = outputs.logits[0][indices]
35
  scores = np.array(preds)
36
+ print(scores)
37
  return scores
38
 
39
  # @st.cache(persist=False)
 
57
  image = plt.imread(image_path)
58
  mlm_state.mlm_image = image
59
 
60
+ if mlm_state.mlm_model is None:
61
  # Display Top-5 Predictions
62
  with st.spinner("Loading model..."):
63
+ mlm_state.mlm_model = load_model(mlm_checkpoints[0])
64
 
65
  if st.button(
66
  "Get a random example",
apps/vqa.py CHANGED
@@ -31,7 +31,7 @@ def app(state):
31
  # @st.cache(persist=False)
32
  def predict(transformed_image, question_inputs):
33
  return np.array(
34
- vqa_state.model(pixel_values=transformed_image, **question_inputs)[0][0]
35
  )
36
 
37
  # @st.cache(persist=False)
@@ -65,9 +65,9 @@ def app(state):
65
  image = plt.imread(image_path)
66
  vqa_state.vqa_image = image
67
 
68
- if vqa_state.model is None:
69
  with st.spinner("Loading model..."):
70
- vqa_state.model = load_model(vqa_checkpoints[0])
71
 
72
  # Display Top-5 Predictions
73
 
 
31
  # @st.cache(persist=False)
32
  def predict(transformed_image, question_inputs):
33
  return np.array(
34
+ vqa_state.vqa_model(pixel_values=transformed_image, **question_inputs)[0][0]
35
  )
36
 
37
  # @st.cache(persist=False)
 
65
  image = plt.imread(image_path)
66
  vqa_state.vqa_image = image
67
 
68
+ if vqa_state.vqa_model is None:
69
  with st.spinner("Loading model..."):
70
+ vqa_state.vqa_model = load_model(vqa_checkpoints[0])
71
 
72
  # Display Top-5 Predictions
73