Sujit Pal commited on
Commit
6d88167
1 Parent(s): c0c0d12

fix: changes based on evaluation

Browse files
app.py CHANGED
@@ -5,12 +5,19 @@ import dashboard_featurefinder
5
  import streamlit as st
6
 
7
  PAGES = {
8
- "Text to Image": dashboard_text2image,
9
- "Image to Image": dashboard_image2image,
10
- "Feature in Image": dashboard_featurefinder,
11
  }
12
- st.sidebar.title("Navigation")
13
 
 
 
 
 
 
 
 
14
  selection = st.sidebar.radio("Go to", list(PAGES.keys()))
15
  page = PAGES[selection]
16
  page.app()
 
5
  import streamlit as st
6
 
7
  PAGES = {
8
+ "Retrieve Images given Text": dashboard_text2image,
9
+ "Retrieve Images given Image": dashboard_image2image,
10
+ "Find Feature in Image": dashboard_featurefinder,
11
  }
12
+ st.sidebar.title("CLIP-RSICD")
13
 
14
+ st.sidebar.markdown("""
15
+ The CLIP model from OpenAI is trained in a self-supervised manner using
16
+ contrastive learning to project images and caption text onto a common
17
+ embedding space. We have fine-tuned the model (see [Model card](https://huggingface.co/flax-community/clip-rsicd-v2))
18
+ using the [RSICD dataset](https://github.com/201528014227051/RSICD_optimal).
19
+ Click here for [more information about our project](https://github.com/arampacha/CLIP-rsicd).
20
+ """)
21
  selection = st.sidebar.radio("Go to", list(PAGES.keys()))
22
  page = PAGES[selection]
23
  page.app()
dashboard_featurefinder.py CHANGED
@@ -4,6 +4,7 @@ import matplotlib.pyplot as plt
4
  import nmslib
5
  import numpy as np
6
  import os
 
7
  import streamlit as st
8
 
9
  from tempfile import NamedTemporaryFile
@@ -61,17 +62,38 @@ def get_image_ranks(probs):
61
  return ranks
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def app():
65
  model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
66
 
67
  st.title("Find Features in Images")
68
  st.markdown("""
69
- The CLIP model from OpenAI is trained in a self-supervised manner using
70
- contrastive learning to project images and caption text onto a common
71
- embedding space. We have fine-tuned the model (see [Model card](https://huggingface.co/flax-community/clip-rsicd-v2))
72
- using the RSICD dataset (10k images and ~50k captions from the remote
73
- sensing domain). Click here for [more information about our project](https://github.com/arampacha/CLIP-rsicd).
74
-
75
  This demo shows the ability of the model to find specific features
76
  (specified as text queries) in the image. As an example, say you wish to
77
  find the parts of the following image that contain a `beach`, `houses`,
@@ -92,46 +114,62 @@ def app():
92
  for features that you can ask the model to identify.
93
  """)
94
  # buf = st.file_uploader("Upload Image for Analysis", type=["png", "jpg"])
95
- image_file = st.selectbox("Image File", index=0,
96
- options=[
97
- "St-Tropez-Port.jpg",
98
- "Acopulco-Bay.jpg",
99
- "Highway-through-Forest.jpg",
100
- "Forest-with-River.jpg",
101
- "Eagle-Bay-Coastline.jpg",
102
- "Multistoreyed-Buildings.jpg",
103
- "Street-View-Malayasia.jpg",
104
- ])
105
- searched_feature = st.text_input("Feature to find")
 
 
 
 
 
106
 
107
  if st.button("Find"):
108
- # ftmp = NamedTemporaryFile()
109
- # ftmp.write(buf.getvalue())
110
- # image = plt.imread(ftmp.name)
111
- image = plt.imread(os.path.join("demo-images", image_file))
112
- if len(image.shape) != 3 and image.shape[2] != 3:
113
- st.error("Image should be an RGB image")
114
- if image.shape[0] < 224 or image.shape[1] < 224:
115
- st.error("Image should be at least (224 x 224")
116
- st.image(image, caption="Input Image")
117
- st.markdown("---")
118
- num_rows, num_cols, patches = split_image(image)
119
- image_preprocessor = Compose([
120
- ToPILImage(),
121
- Resize(224)
122
- ])
123
- num_rows, num_cols, patches = split_image(image)
124
- patch_probs = get_patch_probabilities(
125
- patches,
126
- searched_feature,
127
- image_preprocessor,
128
- model,
129
- processor)
130
- patch_ranks = get_image_ranks(patch_probs)
131
- for i in range(num_rows):
132
- row_patches = patches[i * num_cols : (i + 1) * num_cols]
133
- row_probs = patch_probs[i * num_cols : (i + 1) * num_cols]
134
- row_ranks = patch_ranks[i * num_cols : (i + 1) * num_cols]
135
- captions = ["p({:s})={:.3f}, rank={:d}".format(searched_feature, p, r + 1)
136
- for p, r in zip(row_probs, row_ranks)]
137
- st.image(row_patches, caption=captions)
 
 
 
 
 
 
 
 
 
 
 
 
4
  import nmslib
5
  import numpy as np
6
  import os
7
+ import requests
8
  import streamlit as st
9
 
10
  from tempfile import NamedTemporaryFile
 
62
  return ranks
63
 
64
 
65
+ def download_and_prepare_image(image_url):
66
+ """
67
+ Take input image and resize it to 672x896
68
+ """
69
+ try:
70
+ image_raw = requests.get(image_url, stream=True,).raw
71
+ image = Image.open(image_raw).convert("RGB")
72
+ width, height = image.size
73
+ # print("WID,HGT:", width, height)
74
+ if width < 224 or height < 224:
75
+ return None
76
+ # take the short edge and reduce to 672
77
+ if width < height:
78
+ resize_factor = 672 / width
79
+ image = image.resize((672, int(height * resize_factor)))
80
+ image = image.crop((0, 0, 672, 896))
81
+ else:
82
+ resize_factor = 672 / height
83
+ image = image.resize((int(width * resize_factor), 896))
84
+ image = image.crop((0, 0, 896, 672))
85
+ return np.asarray(image)
86
+ except Exception as e:
87
+ # print(e)
88
+ return None
89
+
90
+
91
+
92
  def app():
93
  model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
94
 
95
  st.title("Find Features in Images")
96
  st.markdown("""
 
 
 
 
 
 
97
  This demo shows the ability of the model to find specific features
98
  (specified as text queries) in the image. As an example, say you wish to
99
  find the parts of the following image that contain a `beach`, `houses`,
 
114
  for features that you can ask the model to identify.
115
  """)
116
  # buf = st.file_uploader("Upload Image for Analysis", type=["png", "jpg"])
117
+ image_file = st.selectbox(
118
+ "Sample Image File",
119
+ options=[
120
+ "-- select one --",
121
+ "St-Tropez-Port.jpg",
122
+ "Acopulco-Bay.jpg",
123
+ "Highway-through-Forest.jpg",
124
+ "Forest-with-River.jpg",
125
+ "Eagle-Bay-Coastline.jpg",
126
+ "Multistoreyed-Buildings.jpg",
127
+ "Street-View-Malayasia.jpg",
128
+ ])
129
+ image_url = st.text_input(
130
+ "OR provide an image URL",
131
+ value="https://static.eos.com/wp-content/uploads/2019/04/Main.jpg")
132
+ searched_feature = st.text_input("Feature to find", value="beach")
133
 
134
  if st.button("Find"):
135
+ # print("image_file:", image_file)
136
+ # print("image_url:", image_url)
137
+ if image_file.startswith("--"):
138
+ image = download_and_prepare_image(image_url)
139
+ else:
140
+ image = plt.imread(os.path.join("demo-images", image_file))
141
+
142
+ if image is None:
143
+ st.error("Image could not be downloaded, please try another one")
144
+ else:
145
+ st.image(image, caption="Input Image")
146
+ st.markdown("---")
147
+ # print("image.shape:", image.shape)
148
+ num_rows, num_cols, patches = split_image(image)
149
+ # print("num_rows, num_cols, num(patches:", num_rows, num_cols, len(patches), patches[0].shape)
150
+ image_preprocessor = Compose([
151
+ ToPILImage(),
152
+ Resize(224)
153
+ ])
154
+ num_rows, num_cols, patches = split_image(image)
155
+ patch_probs = get_patch_probabilities(
156
+ patches,
157
+ searched_feature,
158
+ image_preprocessor,
159
+ model,
160
+ processor)
161
+ patch_ranks = get_image_ranks(patch_probs)
162
+ pid = 0
163
+ for i in range(num_rows):
164
+ cols = st.beta_columns(num_cols)
165
+ for col in cols:
166
+ caption = "#{:d} p({:s})={:.3f}".format(
167
+ patch_ranks[pid] + 1, searched_feature, patch_probs[pid])
168
+ col.image(patches[pid], caption=caption)
169
+ pid += 1
170
+ # row_patches = patches[i * num_cols : (i + 1) * num_cols]
171
+ # row_probs = patch_probs[i * num_cols : (i + 1) * num_cols]
172
+ # row_ranks = patch_ranks[i * num_cols : (i + 1) * num_cols]
173
+ # captions = ["p({:s})={:.3f}, rank={:d}".format(searched_feature, p, r + 1)
174
+ # for p, r in zip(row_probs, row_ranks)]
175
+ # st.image(row_patches, caption=captions)
dashboard_image2image.py CHANGED
@@ -2,6 +2,7 @@ import matplotlib.pyplot as plt
2
  import nmslib
3
  import numpy as np
4
  import os
 
5
  import streamlit as st
6
 
7
  from PIL import Image
@@ -33,25 +34,48 @@ def load_example_images():
33
  example_images[image_class].append(image_name)
34
  else:
35
  example_images[image_class] = [image_name]
36
- return example_images
 
 
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def app():
40
  filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
41
  model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
42
 
43
- example_images = load_example_images()
44
- example_image_list = sorted([v[np.random.randint(0, len(v))]
45
- for k, v in example_images.items()][0:10])
46
 
47
- st.title("Image to Image Retrieval")
48
  st.markdown("""
49
- The CLIP model from OpenAI is trained in a self-supervised manner using
50
- contrastive learning to project images and caption text onto a common
51
- embedding space. We have fine-tuned the model (see [Model card](https://huggingface.co/flax-community/clip-rsicd-v2))
52
- using the RSICD dataset (10k images and ~50k captions from the remote
53
- sensing domain). Click here for [more information about our project](https://github.com/arampacha/CLIP-rsicd).
54
-
55
  This demo shows the image to image retrieval capabilities of this model, i.e.,
56
  given an image file name as a query, we use our fine-tuned CLIP model
57
  to project the query image to the image/caption embedding space and search
@@ -60,31 +84,92 @@ def app():
60
  Our fine-tuned CLIP model was previously used to generate image vectors for
61
  our demo, and NMSLib was used for fast vector access.
62
 
63
- Here are some randomly generated image files from our corpus. You can
64
- copy paste one of these below or use one from the results of a text to
65
- image search -- {:s}
66
- """.format(", ".join("`{:s}`".format(example) for example in example_image_list)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- image_name = st.text_input("Provide an Image File Name")
 
 
 
69
  submit_button = st.button("Find Similar")
 
 
 
 
 
 
 
 
 
 
70
 
71
- if submit_button:
72
- image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_name)))
73
- inputs = processor(images=image, return_tensors="jax", padding=True)
74
- query_vec = model.get_image_features(**inputs)
75
- query_vec = np.asarray(query_vec)
76
- ids, distances = index.knnQuery(query_vec, k=11)
77
- result_filenames = [filenames[id] for id in ids]
78
- images, captions = [], []
79
- for result_filename, score in zip(result_filenames, distances):
80
- if result_filename == image_name:
81
- continue
82
- images.append(
83
- plt.imread(os.path.join(IMAGES_DIR, result_filename)))
84
- captions.append("{:s} (score: {:.3f})".format(result_filename, 1.0 - score))
85
- images = images[0:10]
86
- captions = captions[0:10]
87
- st.image(images[0:3], caption=captions[0:3])
88
- st.image(images[3:6], caption=captions[3:6])
89
- st.image(images[6:9], caption=captions[6:9])
90
- st.image(images[9:], caption=captions[9:])
 
 
 
2
  import nmslib
3
  import numpy as np
4
  import os
5
+ import requests
6
  import streamlit as st
7
 
8
  from PIL import Image
 
34
  example_images[image_class].append(image_name)
35
  else:
36
  example_images[image_class] = [image_name]
37
+ example_image_list = sorted([v[np.random.randint(0, len(v))]
38
+ for k, v in example_images.items()][0:10])
39
+ return example_image_list
40
+
41
 
42
+ def get_image_thumbnail(image_filename):
43
+ image = Image.open(os.path.join(IMAGES_DIR, image_filename))
44
+ image = image.resize((100, 100))
45
+ return image
46
+
47
+
48
+ def download_and_prepare_image(image_url):
49
+ try:
50
+ image_raw = requests.get(image_url, stream=True,).raw
51
+ image = Image.open(image_raw).convert("RGB")
52
+ width, height = image.size
53
+ # print("width, height:", width, height)
54
+ resize_mult = width / 224 if width < height else height / 224
55
+ # print("resize_mult:", resize_mult)
56
+ # print("resize:", width // resize_mult, height // resize_mult)
57
+ image = image.resize((int(width // resize_mult),
58
+ int(height // resize_mult)))
59
+ width, height = image.size
60
+ left = int((width - 224) // 2)
61
+ top = int((height - 224) // 2)
62
+ right = int((width + 224) // 2)
63
+ bottom = int((height + 224) // 2)
64
+ # print("LTRB:", left, top, right, bottom)
65
+ image = image.crop((left, top, right, bottom))
66
+ return image
67
+ except Exception as e:
68
+ # print(e)
69
+ return None
70
 
71
  def app():
72
  filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
73
  model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
74
 
75
+ example_image_list = load_example_images()
 
 
76
 
77
+ st.title("Retrieve Images given Images")
78
  st.markdown("""
 
 
 
 
 
 
79
  This demo shows the image to image retrieval capabilities of this model, i.e.,
80
  given an image file name as a query, we use our fine-tuned CLIP model
81
  to project the query image to the image/caption embedding space and search
 
84
  Our fine-tuned CLIP model was previously used to generate image vectors for
85
  our demo, and NMSLib was used for fast vector access.
86
 
87
+ Here are some randomly generated image files from our corpus, that you can
88
+ find similar images for by selecting the button below it. Alternatively you
89
+ can upload your own image from the Internet.
90
+ """)
91
+
92
+ suggest_idx = -1
93
+ col0, col1, col2, col3, col4 = st.beta_columns(5)
94
+ col0.image(get_image_thumbnail(example_image_list[0]))
95
+ col1.image(get_image_thumbnail(example_image_list[1]))
96
+ col2.image(get_image_thumbnail(example_image_list[2]))
97
+ col3.image(get_image_thumbnail(example_image_list[3]))
98
+ col4.image(get_image_thumbnail(example_image_list[4]))
99
+ col0t, col1t, col2t, col3t, col4t = st.beta_columns(5)
100
+ with col0t:
101
+ if st.button("Image-1"):
102
+ suggest_idx = 0
103
+ with col1t:
104
+ if st.button("Image-2"):
105
+ suggest_idx = 1
106
+ with col2t:
107
+ if st.button("Image-3"):
108
+ suggest_idx = 2
109
+ with col3t:
110
+ if st.button("Image-4"):
111
+ suggest_idx = 3
112
+ with col4t:
113
+ if st.button("Image-5"):
114
+ suggest_idx = 4
115
+ col5, col6, col7, col8, col9 = st.beta_columns(5)
116
+ col5.image(get_image_thumbnail(example_image_list[5]))
117
+ col6.image(get_image_thumbnail(example_image_list[6]))
118
+ col7.image(get_image_thumbnail(example_image_list[7]))
119
+ col8.image(get_image_thumbnail(example_image_list[8]))
120
+ col9.image(get_image_thumbnail(example_image_list[9]))
121
+ col5t, col6t, col7t, col8t, col9t = st.beta_columns(5)
122
+ with col5t:
123
+ if st.button("Image-6"):
124
+ suggest_idx = 5
125
+ with col6t:
126
+ if st.button("Image-7"):
127
+ suggest_idx = 6
128
+ with col7t:
129
+ if st.button("Image-8"):
130
+ suggest_idx = 7
131
+ with col8t:
132
+ if st.button("Image-9"):
133
+ suggest_idx = 8
134
+ with col9t:
135
+ if st.button("Image-10"):
136
+ suggest_idx = 9
137
 
138
+ image_url = st.text_input(
139
+ "OR provide an image URL",
140
+ value="https://media.wired.com/photos/5a8c80647b7bd44d86b88077/master/w_2240,c_limit/Satellite-FINAL.jpg")
141
+
142
  submit_button = st.button("Find Similar")
143
+
144
+ if submit_button or suggest_idx > -1:
145
+ image_name = None
146
+ if suggest_idx > -1:
147
+ image_name = example_image_list[suggest_idx]
148
+ image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_name)))
149
+ else:
150
+ image = download_and_prepare_image(image_url)
151
+ st.image(image, caption="Input Image")
152
+ st.markdown("---")
153
 
154
+ if image is None:
155
+ st.error("Image could not be downloaded, please try another one!")
156
+ else:
157
+ inputs = processor(images=image, return_tensors="jax", padding=True)
158
+ query_vec = model.get_image_features(**inputs)
159
+ query_vec = np.asarray(query_vec)
160
+ ids, distances = index.knnQuery(query_vec, k=11)
161
+ result_filenames = [filenames[id] for id in ids]
162
+ images, captions = [], []
163
+ for result_filename, score in zip(result_filenames, distances):
164
+ if image_name is not None and result_filename == image_name:
165
+ continue
166
+ images.append(
167
+ plt.imread(os.path.join(IMAGES_DIR, result_filename)))
168
+ captions.append("{:s} (score: {:.3f})".format(result_filename, 1.0 - score))
169
+ images = images[0:10]
170
+ captions = captions[0:10]
171
+ st.image(images[0:3], caption=captions[0:3])
172
+ st.image(images[3:6], caption=captions[3:6])
173
+ st.image(images[6:9], caption=captions[6:9])
174
+ st.image(images[9:], caption=captions[9:])
175
+ suggest_idx = -1
dashboard_text2image.py CHANGED
@@ -24,14 +24,8 @@ def app():
24
  filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
25
  model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
26
 
27
- st.title("Text to Image Retrieval")
28
  st.markdown("""
29
- The CLIP model from OpenAI is trained in a self-supervised manner using
30
- contrastive learning to project images and caption text onto a common
31
- embedding space. We have fine-tuned the model (see [Model card](https://huggingface.co/flax-community/clip-rsicd-v2))
32
- using the RSICD dataset (10k images and ~50k captions from the remote
33
- sensing domain). Click here for [more information about our project](https://github.com/arampacha/CLIP-rsicd).
34
-
35
  This demo shows the image to text retrieval capabilities of this model, i.e.,
36
  given a text query, we use our fine-tuned CLIP model to project the text query
37
  to the image/caption embedding space and search for nearby images (by
@@ -40,12 +34,45 @@ def app():
40
  Our fine-tuned CLIP model was previously used to generate image vectors for
41
  our demo, and NMSLib was used for fast vector access.
42
 
43
- Some suggested queries to start you off with -- `ships`, `school house`,
44
- `military installations`, `mountains`, `beaches`, `airports`, `lakes`, etc.
45
  """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- query = st.text_input("Text Query:")
48
- if st.button("Query"):
49
  inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)
50
  query_vec = model.get_text_features(**inputs)
51
  query_vec = np.asarray(query_vec)
@@ -60,3 +87,4 @@ def app():
60
  st.image(images[3:6], caption=captions[3:6])
61
  st.image(images[6:9], caption=captions[6:9])
62
  st.image(images[9:], caption=captions[9:])
 
 
24
  filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
25
  model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
26
 
27
+ st.title("Retrieve Images given Text")
28
  st.markdown("""
 
 
 
 
 
 
29
  This demo shows the image to text retrieval capabilities of this model, i.e.,
30
  given a text query, we use our fine-tuned CLIP model to project the text query
31
  to the image/caption embedding space and search for nearby images (by
 
34
  Our fine-tuned CLIP model was previously used to generate image vectors for
35
  our demo, and NMSLib was used for fast vector access.
36
 
 
 
37
  """)
38
+ suggested_query = [
39
+ "ships",
40
+ "school house",
41
+ "military installation",
42
+ "mountains",
43
+ "beaches",
44
+ "airports",
45
+ "lakes"
46
+ ]
47
+ st.text("Some suggested queries to start you off with...")
48
+ col0, col1, col2, col3, col4, col5, col6 = st.beta_columns(7)
49
+ # [1, 1.1, 1.3, 1.1, 1, 1, 1])
50
+ suggest_idx = -1
51
+ with col0:
52
+ if st.button(suggested_query[0]):
53
+ suggest_idx = 0
54
+ with col1:
55
+ if st.button(suggested_query[1]):
56
+ suggest_idx = 1
57
+ with col2:
58
+ if st.button(suggested_query[2]):
59
+ suggest_idx = 2
60
+ with col3:
61
+ if st.button(suggested_query[3]):
62
+ suggest_idx = 3
63
+ with col4:
64
+ if st.button(suggested_query[4]):
65
+ suggest_idx = 4
66
+ with col5:
67
+ if st.button(suggested_query[5]):
68
+ suggest_idx = 5
69
+ with col6:
70
+ if st.button(suggested_query[6]):
71
+ suggest_idx = 6
72
+ query = st.text_input("OR enter a text Query:")
73
+ query = suggested_query[suggest_idx] if suggest_idx > -1 else query
74
 
75
+ if st.button("Query") or suggest_idx > -1:
 
76
  inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)
77
  query_vec = model.get_text_features(**inputs)
78
  query_vec = np.asarray(query_vec)
 
87
  st.image(images[3:6], caption=captions[3:6])
88
  st.image(images[6:9], caption=captions[6:9])
89
  st.image(images[9:], caption=captions[9:])
90
+ suggest_idx = -1
demo-images/Acopulco-Bay.jpg CHANGED
demo-images/Eagle-Bay-Coastline.jpg CHANGED
demo-images/Forest-with-River.jpg CHANGED
demo-images/Highway-through-Forest.jpg CHANGED
demo-images/Multistoreyed-Buildings.jpg CHANGED
demo-images/St-Tropez-Port.jpg CHANGED
demo-images/Street-View-Malayasia.jpg CHANGED
requirements.txt CHANGED
@@ -6,3 +6,4 @@ jaxlib
6
  flax
7
  torch==1.9.0
8
  torchvision==0.10.0
 
 
6
  flax
7
  torch==1.9.0
8
  torchvision==0.10.0
9
+ requests