gchhablani commited on
Commit
2c8f495
1 Parent(s): 405f2d4

Add mask filling app

Browse files
Files changed (6) hide show
  1. app.py +8 -3
  2. apps/mlm.py +49 -49
  3. apps/utils.py +1 -0
  4. apps/vqa.py +44 -42
  5. multiapp.py +10 -3
  6. resize_images.py +10 -3
app.py CHANGED
@@ -1,13 +1,17 @@
1
  from apps import mlm, vqa
2
  import os
3
  import streamlit as st
 
4
  from multiapp import MultiApp
5
 
 
6
  def read_markdown(path, parent="./sections/"):
7
  with open(os.path.join(parent, path)) as f:
8
  return f.read()
9
 
 
10
  def main():
 
11
  st.set_page_config(
12
  page_title="Multilingual VQA",
13
  layout="wide",
@@ -30,7 +34,7 @@ def main():
30
  st.write(read_markdown("abstract.md"))
31
  st.write(read_markdown("caveats.md"))
32
  st.write("## Methodology")
33
- col1, col2 = st.beta_columns([1,1])
34
  col1.image(
35
  "./misc/article/Multilingual-VQA.png",
36
  caption="Masked LM model for Image-text Pretraining.",
@@ -43,10 +47,11 @@ def main():
43
  st.write(read_markdown("checkpoints.md"))
44
  st.write(read_markdown("acknowledgements.md"))
45
 
46
- app = MultiApp()
47
  app.add_app("Visual Question Answering", vqa.app)
48
  app.add_app("Mask Filling", mlm.app)
49
  app.run()
 
50
 
51
  if __name__ == "__main__":
52
- main()
 
1
  from apps import mlm, vqa
2
  import os
3
  import streamlit as st
4
+ from session import _get_state
5
  from multiapp import MultiApp
6
 
7
+
8
  def read_markdown(path, parent="./sections/"):
9
  with open(os.path.join(parent, path)) as f:
10
  return f.read()
11
 
12
+
13
  def main():
14
+ state = _get_state()
15
  st.set_page_config(
16
  page_title="Multilingual VQA",
17
  layout="wide",
 
34
  st.write(read_markdown("abstract.md"))
35
  st.write(read_markdown("caveats.md"))
36
  st.write("## Methodology")
37
+ col1, col2 = st.beta_columns([1, 1])
38
  col1.image(
39
  "./misc/article/Multilingual-VQA.png",
40
  caption="Masked LM model for Image-text Pretraining.",
 
47
  st.write(read_markdown("checkpoints.md"))
48
  st.write(read_markdown("acknowledgements.md"))
49
 
50
+ app = MultiApp(state)
51
  app.add_app("Visual Question Answering", vqa.app)
52
  app.add_app("Mask Filling", mlm.app)
53
  app.run()
54
+ state.sync()
55
 
56
  if __name__ == "__main__":
57
+ main()
apps/mlm.py CHANGED
@@ -1,11 +1,9 @@
1
-
2
  from .utils import (
3
  get_text_attributes,
4
  get_top_5_predictions,
5
  get_transformed_image,
6
  plotly_express_horizontal_bar_plot,
7
- translate_labels,
8
- bert_tokenizer
9
  )
10
 
11
  import streamlit as st
@@ -13,97 +11,99 @@ import numpy as np
13
  import pandas as pd
14
  import os
15
  import matplotlib.pyplot as plt
16
-
17
- from session import _get_state
18
 
19
 
20
  from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
21
  FlaxCLIPVisionBertForMaskedLM,
22
  )
23
 
 
24
  def softmax(logits):
25
  return np.exp(logits) / np.sum(np.exp(logits), axis=0)
26
 
27
- def app():
28
- state = _get_state()
29
 
30
- @st.cache(persist=False)
31
  def predict(transformed_image, caption_inputs):
32
- outputs = state.model(pixel_values=transformed_image, **caption_inputs)
33
- indices = np.where(caption_inputs['input_ids']==bert_tokenizer.mask_token_id)
34
- preds = outputs.logits[indices][0]
35
- sorted_indices = np.argsort(preds)[::-1] # Get reverse sorted scores
36
- top_5_indices = sorted_indices[:5]
37
- top_5_tokens = bert_tokenizer.convert_ids_to_tokens(top_5_indices)
38
- top_5_scores = np.array(preds[top_5_indices])
39
- return top_5_tokens, top_5_scores
40
-
41
-
42
- @st.cache(persist=False)
43
  def load_model(ckpt):
44
  return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt)
45
 
46
- mlm_checkpoints = ['flax-community/clip-vision-bert-cc12m-70k']
47
  dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t")
48
 
49
  first_index = 20
50
- # Init Session State
51
- if state.image_file is None:
52
- state.image_file = dummy_data.loc[first_index, "image_file"]
53
  caption = dummy_data.loc[first_index, "caption"].strip("- ")
54
- ids = bert_tokenizer(caption)
55
- ids[np.random.randint(0, len(ids))] = bert_tokenizer.mask_token_id
56
- state.caption = bert_tokenizer.decode(ids)
57
- state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]
58
 
59
- image_path = os.path.join("cc12m_data/images_vqa", state.image_file)
60
  image = plt.imread(image_path)
61
- state.image = image
62
 
63
- if state.model is None:
64
- # Display Top-5 Predictions
65
- with st.spinner("Loading model..."):
66
- state.model = load_model(mlm_checkpoints[0])
67
 
68
  if st.button(
69
  "Get a random example",
70
  help="Get a random example from the 100 `seeded` image-text pairs.",
71
  ):
72
  sample = dummy_data.sample(1).reset_index()
73
- state.image_file = sample.loc[0, "image_file"]
74
  caption = sample.loc[0, "caption"].strip("- ")
75
- ids = bert_tokenizer(caption)
76
- ids[np.random.randint(0, len(ids))] = bert_tokenizer.mask_token_id
77
- state.caption = bert_tokenizer.decode(ids)
78
- state.caption_lang_id = sample.loc[0, "lang_id"]
79
 
80
- image_path = os.path.join("cc12m_data/images_vqa", state.image_file)
81
  image = plt.imread(image_path)
82
- state.image = image
83
 
84
- transformed_image = get_transformed_image(state.image)
85
 
86
  new_col1, new_col2 = st.beta_columns([5, 5])
87
 
88
  # Display Image
89
- new_col1.image(state.image, use_column_width="always")
90
-
91
 
92
  # Display caption
93
  new_col2.write("Write your text with exactly one [MASK] token.")
94
  caption = new_col2.text_input(
95
  label="Text",
96
- value=state.caption,
97
  help="Type your masked caption regarding the image above in one of the four languages.",
98
  )
99
 
 
 
 
100
  caption_inputs = get_text_attributes(caption)
101
 
102
  # Display Top-5 Predictions
103
-
104
  with st.spinner("Predicting..."):
105
- logits = predict(transformed_image, dict(caption_inputs))
106
- logits = softmax(logits)
107
- labels, values = get_top_5_predictions(logits)
 
108
  fig = plotly_express_horizontal_bar_plot(values, labels)
109
- st.plotly_chart(fig, use_container_width=True)
 
 
 
 
1
  from .utils import (
2
  get_text_attributes,
3
  get_top_5_predictions,
4
  get_transformed_image,
5
  plotly_express_horizontal_bar_plot,
6
+ bert_tokenizer,
 
7
  )
8
 
9
  import streamlit as st
 
11
  import pandas as pd
12
  import os
13
  import matplotlib.pyplot as plt
14
+ from mtranslate import translate
 
15
 
16
 
17
  from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
18
  FlaxCLIPVisionBertForMaskedLM,
19
  )
20
 
21
+
22
  def softmax(logits):
23
  return np.exp(logits) / np.sum(np.exp(logits), axis=0)
24
 
25
+ def app(state):
26
+ mlm_state = state
27
 
28
+ # @st.cache(persist=False) # TODO: Make this work with mlm_state. Currently not supported.
29
  def predict(transformed_image, caption_inputs):
30
+ outputs = model(pixel_values=transformed_image, **caption_inputs)
31
+ indices = np.where(caption_inputs["input_ids"] == bert_tokenizer.mask_token_id)[
32
+ 1
33
+ ][0]
34
+ preds = outputs.logits[0][indices]
35
+ scores = np.array(preds)
36
+ return scores
37
+
38
+ # @st.cache(persist=False)
 
 
39
  def load_model(ckpt):
40
  return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt)
41
 
42
+ mlm_checkpoints = ["flax-community/clip-vision-bert-cc12m-70k"]
43
  dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t")
44
 
45
  first_index = 20
46
+ # Init Session mlm_state
47
+ if mlm_state.mlm_image_file is None:
48
+ mlm_state.mlm_image_file = dummy_data.loc[first_index, "image_file"]
49
  caption = dummy_data.loc[first_index, "caption"].strip("- ")
50
+ ids = bert_tokenizer.encode(caption)
51
+ ids[np.random.randint(1, len(ids) - 1)] = bert_tokenizer.mask_token_id
52
+ mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
53
+ mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]
54
 
55
+ image_path = os.path.join("cc12m_data/images_vqa", mlm_state.mlm_image_file)
56
  image = plt.imread(image_path)
57
+ mlm_state.mlm_image = image
58
 
59
+ #if model is None:
60
+ # Display Top-5 Predictions
61
+ with st.spinner("Loading model..."):
62
+ model = load_model(mlm_checkpoints[0])
63
 
64
  if st.button(
65
  "Get a random example",
66
  help="Get a random example from the 100 `seeded` image-text pairs.",
67
  ):
68
  sample = dummy_data.sample(1).reset_index()
69
+ mlm_state.mlm_image_file = sample.loc[0, "image_file"]
70
  caption = sample.loc[0, "caption"].strip("- ")
71
+ ids = bert_tokenizer.encode(caption)
72
+ ids[np.random.randint(1, len(ids) - 1)] = bert_tokenizer.mask_token_id
73
+ mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
74
+ mlm_state.caption_lang_id = sample.loc[0, "lang_id"]
75
 
76
+ image_path = os.path.join("cc12m_data/images_vqa", mlm_state.mlm_image_file)
77
  image = plt.imread(image_path)
78
+ mlm_state.mlm_image = image
79
 
80
+ transformed_image = get_transformed_image(mlm_state.mlm_image)
81
 
82
  new_col1, new_col2 = st.beta_columns([5, 5])
83
 
84
  # Display Image
85
+ new_col1.image(mlm_state.mlm_image, use_column_width="always")
 
86
 
87
  # Display caption
88
  new_col2.write("Write your text with exactly one [MASK] token.")
89
  caption = new_col2.text_input(
90
  label="Text",
91
+ value=mlm_state.caption,
92
  help="Type your masked caption regarding the image above in one of the four languages.",
93
  )
94
 
95
+ new_col2.markdown(
96
+ f"""**English Translation**: {caption if mlm_state.caption_lang_id == "en" else translate(caption, 'en')}"""
97
+ )
98
  caption_inputs = get_text_attributes(caption)
99
 
100
  # Display Top-5 Predictions
 
101
  with st.spinner("Predicting..."):
102
+ scores = predict(transformed_image, dict(caption_inputs))
103
+ scores = softmax(scores)
104
+ labels, values = get_top_5_predictions(scores)
105
+ # newer_col1, newer_col2 = st.beta_columns([6,4])
106
  fig = plotly_express_horizontal_bar_plot(values, labels)
107
+ st.dataframe(pd.DataFrame({"Tokens":labels, "English Translation": list(map(lambda x: translate(x),labels))}).T)
108
+ st.plotly_chart(fig, use_container_width=True)
109
+
apps/utils.py CHANGED
@@ -40,6 +40,7 @@ def get_transformed_image(image):
40
 
41
  bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-uncased")
42
 
 
43
  def get_text_attributes(text):
44
  return bert_tokenizer([text], return_token_type_ids=True, return_tensors="np")
45
 
 
40
 
41
  bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-uncased")
42
 
43
+
44
  def get_text_attributes(text):
45
  return bert_tokenizer([text], return_token_type_ids=True, return_tensors="np")
46
 
apps/vqa.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  from .utils import (
3
  get_text_attributes,
4
  get_top_5_predictions,
@@ -15,29 +14,33 @@ import matplotlib.pyplot as plt
15
  import json
16
 
17
  from mtranslate import translate
18
- from session import _get_state
19
 
20
 
21
  from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
22
  FlaxCLIPVisionBertForSequenceClassification,
23
  )
24
 
 
25
  def softmax(logits):
26
  return np.exp(logits) / np.sum(np.exp(logits), axis=0)
27
 
28
- def app():
29
- state = _get_state()
30
 
31
- @st.cache(persist=True)
32
- def predict(transformed_image, question_inputs):
33
- return np.array(state.model(pixel_values=transformed_image, **question_inputs)[0][0])
34
 
 
 
 
 
 
35
 
36
- @st.cache(persist=True)
37
  def load_model(ckpt):
38
  return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
39
 
40
- vqa_checkpoints = ["flax-community/clip-vision-bert-vqa-ft-6k"] # TODO: Maybe add more checkpoints?
 
 
41
  dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t")
42
  code_to_name = {
43
  "en": "English",
@@ -46,77 +49,76 @@ def app():
46
  "es": "Spanish",
47
  }
48
 
49
-
50
  with open("answer_reverse_mapping.json") as f:
51
  answer_reverse_mapping = json.load(f)
52
 
53
  first_index = 20
54
- # Init Session State
55
- if state.image_file is None:
56
- state.image_file = dummy_data.loc[first_index, "image_file"]
57
- state.question = dummy_data.loc[first_index, "question"].strip("- ")
58
- state.answer_label = dummy_data.loc[first_index, "answer_label"]
59
- state.question_lang_id = dummy_data.loc[first_index, "lang_id"]
60
- state.answer_lang_id = dummy_data.loc[first_index, "lang_id"]
61
-
62
- image_path = os.path.join("resized_images", state.image_file)
63
  image = plt.imread(image_path)
64
- state.image = image
65
 
66
- if state.model is None:
67
- # Display Top-5 Predictions
68
- with st.spinner("Loading model..."):
69
- state.model = load_model(vqa_checkpoints[0])
 
70
 
71
  if st.button(
72
  "Get a random example",
73
  help="Get a random example from the 100 `seeded` image-text pairs.",
74
  ):
75
  sample = dummy_data.sample(1).reset_index()
76
- state.image_file = sample.loc[0, "image_file"]
77
- state.question = sample.loc[0, "question"].strip("- ")
78
- state.answer_label = sample.loc[0, "answer_label"]
79
- state.question_lang_id = sample.loc[0, "lang_id"]
80
- state.answer_lang_id = sample.loc[0, "lang_id"]
81
 
82
- image_path = os.path.join("resized_images", state.image_file)
83
  image = plt.imread(image_path)
84
- state.image = image
85
 
86
- transformed_image = get_transformed_image(state.image)
87
 
88
  new_col1, new_col2 = st.beta_columns([5, 5])
89
 
90
  # Display Image
91
- new_col1.image(state.image, use_column_width="always")
92
-
93
 
94
  # Display Question
95
  question = new_col2.text_input(
96
  label="Question",
97
- value=state.question,
98
  help="Type your question regarding the image above in one of the four languages.",
99
  )
100
  new_col2.markdown(
101
- f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}"""
102
  )
103
 
104
  question_inputs = get_text_attributes(question)
105
 
106
  # Select Language
107
  options = ["en", "de", "es", "fr"]
108
- state.answer_lang_id = new_col2.selectbox(
109
  "Answer Language",
110
- index=options.index(state.answer_lang_id),
111
  options=options,
112
  format_func=lambda x: code_to_name[x],
113
  help="The language to be used to show the top-5 labels.",
114
  )
115
 
116
- actual_answer = answer_reverse_mapping[str(state.answer_label)]
117
  new_col2.markdown(
118
  "**Actual Answer**: "
119
- + translate_labels([actual_answer], state.answer_lang_id)[0]
120
  + " ("
121
  + actual_answer
122
  + ")"
@@ -126,6 +128,6 @@ def app():
126
  logits = predict(transformed_image, dict(question_inputs))
127
  logits = softmax(logits)
128
  labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
129
- translated_labels = translate_labels(labels, state.answer_lang_id)
130
  fig = plotly_express_horizontal_bar_plot(values, translated_labels)
131
- st.plotly_chart(fig, use_container_width=True)
 
 
1
  from .utils import (
2
  get_text_attributes,
3
  get_top_5_predictions,
 
14
  import json
15
 
16
  from mtranslate import translate
 
17
 
18
 
19
  from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
20
  FlaxCLIPVisionBertForSequenceClassification,
21
  )
22
 
23
+
24
  def softmax(logits):
25
  return np.exp(logits) / np.sum(np.exp(logits), axis=0)
26
 
 
 
27
 
28
+ def app(state):
29
+ vqa_state = state
 
30
 
31
+ # @st.cache(persist=False)
32
+ def predict(transformed_image, question_inputs):
33
+ return np.array(
34
+ model(pixel_values=transformed_image, **question_inputs)[0][0]
35
+ )
36
 
37
+ # @st.cache(persist=False)
38
  def load_model(ckpt):
39
  return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
40
 
41
+ vqa_checkpoints = [
42
+ "flax-community/clip-vision-bert-vqa-ft-6k"
43
+ ] # TODO: Maybe add more checkpoints?
44
  dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t")
45
  code_to_name = {
46
  "en": "English",
 
49
  "es": "Spanish",
50
  }
51
 
 
52
  with open("answer_reverse_mapping.json") as f:
53
  answer_reverse_mapping = json.load(f)
54
 
55
  first_index = 20
56
+ # Init Session vqa_state
57
+ if vqa_state.vqa_image_file is None:
58
+ vqa_state.vqa_image_file = dummy_data.loc[first_index, "image_file"]
59
+ vqa_state.question = dummy_data.loc[first_index, "question"].strip("- ")
60
+ vqa_state.answer_label = dummy_data.loc[first_index, "answer_label"]
61
+ vqa_state.question_lang_id = dummy_data.loc[first_index, "lang_id"]
62
+ vqa_state.answer_lang_id = dummy_data.loc[first_index, "lang_id"]
63
+
64
+ image_path = os.path.join("resized_images", vqa_state.vqa_image_file)
65
  image = plt.imread(image_path)
66
+ vqa_state.vqa_image = image
67
 
68
+ # if model is None:
69
+
70
+ # Display Top-5 Predictions
71
+ with st.spinner("Loading model..."):
72
+ model = load_model(vqa_checkpoints[0])
73
 
74
  if st.button(
75
  "Get a random example",
76
  help="Get a random example from the 100 `seeded` image-text pairs.",
77
  ):
78
  sample = dummy_data.sample(1).reset_index()
79
+ vqa_state.vqa_image_file = sample.loc[0, "image_file"]
80
+ vqa_state.question = sample.loc[0, "question"].strip("- ")
81
+ vqa_state.answer_label = sample.loc[0, "answer_label"]
82
+ vqa_state.question_lang_id = sample.loc[0, "lang_id"]
83
+ vqa_state.answer_lang_id = sample.loc[0, "lang_id"]
84
 
85
+ image_path = os.path.join("resized_images", vqa_state.vqa_image_file)
86
  image = plt.imread(image_path)
87
+ vqa_state.vqa_image = image
88
 
89
+ transformed_image = get_transformed_image(vqa_state.vqa_image)
90
 
91
  new_col1, new_col2 = st.beta_columns([5, 5])
92
 
93
  # Display Image
94
+ new_col1.image(vqa_state.vqa_image, use_column_width="always")
 
95
 
96
  # Display Question
97
  question = new_col2.text_input(
98
  label="Question",
99
+ value=vqa_state.question,
100
  help="Type your question regarding the image above in one of the four languages.",
101
  )
102
  new_col2.markdown(
103
+ f"""**English Translation**: {question if vqa_state.question_lang_id == "en" else translate(question, 'en')}"""
104
  )
105
 
106
  question_inputs = get_text_attributes(question)
107
 
108
  # Select Language
109
  options = ["en", "de", "es", "fr"]
110
+ vqa_state.answer_lang_id = new_col2.selectbox(
111
  "Answer Language",
112
+ index=options.index(vqa_state.answer_lang_id),
113
  options=options,
114
  format_func=lambda x: code_to_name[x],
115
  help="The language to be used to show the top-5 labels.",
116
  )
117
 
118
+ actual_answer = answer_reverse_mapping[str(vqa_state.answer_label)]
119
  new_col2.markdown(
120
  "**Actual Answer**: "
121
+ + translate_labels([actual_answer], vqa_state.answer_lang_id)[0]
122
  + " ("
123
  + actual_answer
124
  + ")"
 
128
  logits = predict(transformed_image, dict(question_inputs))
129
  logits = softmax(logits)
130
  labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
131
+ translated_labels = translate_labels(labels, vqa_state.answer_lang_id)
132
  fig = plotly_express_horizontal_bar_plot(values, translated_labels)
133
+ st.plotly_chart(fig, use_container_width=True)
multiapp.py CHANGED
@@ -1,10 +1,17 @@
1
  import streamlit as st
 
 
2
  class MultiApp:
3
- def __init__(self):
4
  self.apps = []
 
 
5
  def add_app(self, title, func):
6
  self.apps.append({"title": title, "function": func})
 
7
  def run(self):
8
  st.sidebar.header("Tasks")
9
- app = st.sidebar.radio("", self.apps, format_func=lambda app: app["title"])
10
- app["function"]()
 
 
 
1
  import streamlit as st
2
+ from session import _get_state
3
+
4
  class MultiApp:
5
+ def __init__(self, state):
6
  self.apps = []
7
+ self.state = state
8
+
9
  def add_app(self, title, func):
10
  self.apps.append({"title": title, "function": func})
11
+
12
  def run(self):
13
  st.sidebar.header("Tasks")
14
+ app = st.sidebar.radio(
15
+ "", self.apps, format_func=lambda app: app["title"]
16
+ )
17
+ app["function"](self.state)
resize_images.py CHANGED
@@ -7,7 +7,11 @@ def resize_images(path, new_path, num_pixels=300):
7
  if not os.path.exists(new_path):
8
  os.makedirs(new_path)
9
  for filename in os.listdir(path):
10
- if not filename.startswith('.') and (filename.endswith('.jpg') or filename.endswith('.jpeg') or filename.endswith('.png')):
 
 
 
 
11
  img = cv2.imread(os.path.join(path, filename))
12
  height, width, channels = img.shape
13
  if height > width:
@@ -16,8 +20,11 @@ def resize_images(path, new_path, num_pixels=300):
16
  else:
17
  new_width = num_pixels
18
  new_height = int(height * new_width / width)
19
- img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
 
 
20
  cv2.imwrite(os.path.join(new_path, filename), img)
21
 
 
22
  # resize_images('./images/val2014', './resized_images/val2014')
23
- resize_images('./misc/article', './misc/article/resized', 500)
 
7
  if not os.path.exists(new_path):
8
  os.makedirs(new_path)
9
  for filename in os.listdir(path):
10
+ if not filename.startswith(".") and (
11
+ filename.endswith(".jpg")
12
+ or filename.endswith(".jpeg")
13
+ or filename.endswith(".png")
14
+ ):
15
  img = cv2.imread(os.path.join(path, filename))
16
  height, width, channels = img.shape
17
  if height > width:
 
20
  else:
21
  new_width = num_pixels
22
  new_height = int(height * new_width / width)
23
+ img = cv2.resize(
24
+ img, (new_width, new_height), interpolation=cv2.INTER_CUBIC
25
+ )
26
  cv2.imwrite(os.path.join(new_path, filename), img)
27
 
28
+
29
  # resize_images('./images/val2014', './resized_images/val2014')
30
+ resize_images("./misc/article", "./misc/article/resized", 500)