Sujit Pal commited on
Commit
a78bf29
1 Parent(s): 96ac3ab

fix: added feature finder and small usability changes

Browse files
app.py CHANGED
@@ -1,11 +1,13 @@
1
  import dashboard_text2image
2
  import dashboard_image2image
 
3
 
4
  import streamlit as st
5
 
6
  PAGES = {
7
  "Text to Image": dashboard_text2image,
8
- "Image to Image": dashboard_image2image
 
9
  }
10
  st.sidebar.title("Navigation")
11
 
 
1
  import dashboard_text2image
2
  import dashboard_image2image
3
+ import dashboard_featurefinder
4
 
5
  import streamlit as st
6
 
7
  PAGES = {
8
  "Text to Image": dashboard_text2image,
9
+ "Image to Image": dashboard_image2image,
10
+ "Feature in Image": dashboard_featurefinder,
11
  }
12
  st.sidebar.title("Navigation")
13
 
dashboard_featurefinder.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import flax
3
+ import matplotlib.pyplot as plt
4
+ import nmslib
5
+ import numpy as np
6
+ import os
7
+ import streamlit as st
8
+
9
+ from tempfile import NamedTemporaryFile
10
+ from torchvision.transforms import Compose, Resize, ToPILImage
11
+ from transformers import CLIPProcessor, FlaxCLIPModel
12
+ from PIL import Image
13
+
14
+
15
+ BASELINE_MODEL = "openai/clip-vit-base-patch32"
16
+ # MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
17
+ MODEL_PATH = "flax-community/clip-rsicd-v2"
18
+
19
+ # IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-baseline.tsv"
20
+ # IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
21
+ IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
22
+
23
+ # IMAGES_DIR = "/home/shared/data/rsicd_images"
24
+ IMAGES_DIR = "./images"
25
+
26
+ 2
27
+ # @st.cache(allow_output_mutation=True)
28
+ # def load_index():
29
+ # filenames, image_vecs = [], []
30
+ # fvec = open(IMAGE_VECTOR_FILE, "r")
31
+ # for line in fvec:
32
+ # cols = line.strip().split('\t')
33
+ # filename = cols[0]
34
+ # image_vec = np.array([float(x) for x in cols[1].split(',')])
35
+ # filenames.append(filename)
36
+ # image_vecs.append(image_vec)
37
+ # V = np.array(image_vecs)
38
+ # index = nmslib.init(method='hnsw', space='cosinesimil')
39
+ # index.addDataPointBatch(V)
40
+ # index.createIndex({'post': 2}, print_progress=True)
41
+ # return filenames, index
42
+
43
+
44
+ @st.cache(allow_output_mutation=True)
45
+ def load_model():
46
+ # model = FlaxCLIPModel.from_pretrained(MODEL_PATH)
47
+ # processor = CLIPProcessor.from_pretrained(BASELINE_MODEL)
48
+ model = FlaxCLIPModel.from_pretrained("flax-community/clip-rsicd-v2")
49
+ processor = CLIPProcessor.from_pretrained("flax-community/clip-rsicd-v2")
50
+ return model, processor
51
+
52
+
53
+ def split_image(X):
54
+ num_rows = X.shape[0] // 224
55
+ num_cols = X.shape[1] // 224
56
+ Xc = X[0 : num_rows * 224, 0 : num_cols * 224, :]
57
+ patches = []
58
+ for j in range(num_rows):
59
+ for i in range(num_cols):
60
+ patches.append(Xc[j * 224 : (j + 1) * 224,
61
+ i * 224 : (i + 1) * 224,
62
+ :])
63
+ return num_rows, num_cols, patches
64
+
65
+
66
+ def get_patch_probabilities(patches, searched_feature,
67
+ image_preprocesor,
68
+ model, processor):
69
+ images = [image_preprocesor(patch) for patch in patches]
70
+ text = "An aerial image of {:s}".format(searched_feature)
71
+ inputs = processor(images=images,
72
+ text=text,
73
+ return_tensors="jax",
74
+ padding=True)
75
+ outputs = model(**inputs)
76
+ probs = jax.nn.softmax(outputs.logits_per_text, axis=-1)
77
+ probs_np = np.asarray(probs)[0]
78
+ return probs_np
79
+
80
+
81
+ def get_image_ranks(probs):
82
+ temp = np.argsort(-probs)
83
+ ranks = np.empty_like(temp)
84
+ ranks[temp] = np.arange(len(probs))
85
+ return ranks
86
+
87
+
88
+ def app():
89
+ model, processor = load_model()
90
+
91
+ st.title("Find Features in Images")
92
+ st.markdown("""
93
+ The CLIP model from OpenAI is trained in a self-supervised manner using
94
+ contrastive learning to project images and caption text onto a common
95
+ embedding space. We have fine-tuned the model using the RSICD dataset
96
+ (10k images and ~50k captions from the remote sensing domain).
97
+
98
+ This demo shows the ability of the model to find specific features
99
+ (specified as text queries) in the image. As an example, say you wish to
100
+ find the parts of the following image that contain a `beach`, `houses`,
101
+ or `ships`. We partition the image into tiles of (224, 224) and report
102
+ how likely each of them are to contain each text features.
103
+ """)
104
+ st.image("demo-images/st_tropez_1.png")
105
+ st.image("demo-images/st_tropez_2.png")
106
+ st.markdown("""
107
+ For this image and the queries listed above, our model reports that the
108
+ two left tiles are most likely to contain a `beach`, the two top right
109
+ tiles are most likely to contain `houses`, and the two bottom right tiles
110
+ are likely to contain `boats`.
111
+
112
+ You can try it yourself with your own photographs.
113
+ [Unsplash](https://unsplash.com/s/photos/aerial-view) has some good
114
+ aerial photographs. You will need to download from Unsplash to your
115
+ computer and upload it to the demo app.
116
+ """)
117
+ with st.form(key="form_3"):
118
+ buf = st.file_uploader("Upload Image for Analysis")
119
+ searched_feature = st.text_input(label="Feature to find")
120
+ submit_button = st.form_submit_button("Find")
121
+
122
+ if submit_button:
123
+ ftmp = NamedTemporaryFile()
124
+ ftmp.write(buf.getvalue())
125
+ image = plt.imread(ftmp.name)
126
+ if len(image.shape) != 3 and image.shape[2] != 3:
127
+ st.error("Image should be an RGB image")
128
+ if image.shape[0] < 224 or image.shape[1] < 224:
129
+ st.error("Image should be at least (224 x 224")
130
+ st.image(image, caption="Input Image")
131
+ st.markdown("---")
132
+ num_rows, num_cols, patches = split_image(image)
133
+ image_preprocessor = Compose([
134
+ ToPILImage(),
135
+ Resize(224)
136
+ ])
137
+ num_rows, num_cols, patches = split_image(image)
138
+ patch_probs = get_patch_probabilities(
139
+ patches,
140
+ searched_feature,
141
+ image_preprocessor,
142
+ model,
143
+ processor)
144
+ patch_ranks = get_image_ranks(patch_probs)
145
+ for i in range(num_rows):
146
+ row_patches = patches[i * num_cols : (i + 1) * num_cols]
147
+ row_probs = patch_probs[i * num_cols : (i + 1) * num_cols]
148
+ row_ranks = patch_ranks[i * num_cols : (i + 1) * num_cols]
149
+ captions = ["p({:s})={:.3f}, rank={:d}".format(searched_feature, p, r + 1)
150
+ for p, r in zip(row_probs, row_ranks)]
151
+ st.image(row_patches, caption=captions)
dashboard_image2image.py CHANGED
@@ -44,9 +44,27 @@ def load_model():
44
  return model, processor
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def app():
48
  filenames, index = load_index()
49
  model, processor = load_model()
 
 
 
50
 
51
  st.title("Image to Image Retrieval")
52
  st.markdown("""
@@ -63,13 +81,16 @@ def app():
63
  Our fine-tuned CLIP model was previously used to generate image vectors for
64
  our demo, and NMSLib was used for fast vector access.
65
 
66
- You will need an image file name to start, we recommend copy pasting the
67
- file name from one of the results of the text to image search.
68
- """)
 
 
 
 
69
 
70
- image_file = st.text_input("Image Query (filename):")
71
- if st.button("Query"):
72
- image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_file)))
73
  inputs = processor(images=image, return_tensors="jax", padding=True)
74
  query_vec = model.get_image_features(**inputs)
75
  query_vec = np.asarray(query_vec)
@@ -77,7 +98,7 @@ def app():
77
  result_filenames = [filenames[id] for id in ids]
78
  images, captions = [], []
79
  for result_filename, score in zip(result_filenames, distances):
80
- if result_filename == image_file:
81
  continue
82
  images.append(
83
  plt.imread(os.path.join(IMAGES_DIR, result_filename)))
 
44
  return model, processor
45
 
46
 
47
+ @st.cache(allow_output_mutation=True)
48
+ def load_example_images():
49
+ example_images = {}
50
+ image_names = os.listdir(IMAGES_DIR)
51
+ for image_name in image_names:
52
+ if image_name.find("_") < 0:
53
+ continue
54
+ image_class = image_name.split("_")[0]
55
+ if image_class in example_images.keys():
56
+ example_images[image_class].append(image_name)
57
+ else:
58
+ example_images[image_class] = [image_name]
59
+ return example_images
60
+
61
+
62
  def app():
63
  filenames, index = load_index()
64
  model, processor = load_model()
65
+ example_images = load_example_images()
66
+ example_image_list = sorted([v[np.random.randint(0, len(v))]
67
+ for k, v in example_images.items()][0:10])
68
 
69
  st.title("Image to Image Retrieval")
70
  st.markdown("""
 
81
  Our fine-tuned CLIP model was previously used to generate image vectors for
82
  our demo, and NMSLib was used for fast vector access.
83
 
84
+ Here are some randomly generated image files from our corpus. You can
85
+ copy paste one of these below or use one from the results of a text to
86
+ image search -- {:s}
87
+ """.format(", ".join("`{:s}`".format(example) for example in example_image_list)))
88
+
89
+ image_name = st.text_input("Provide an Image File Name")
90
+ submit_button = st.button("Find Similar")
91
 
92
+ if submit_button:
93
+ image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_name)))
 
94
  inputs = processor(images=image, return_tensors="jax", padding=True)
95
  query_vec = model.get_image_features(**inputs)
96
  query_vec = np.asarray(query_vec)
 
98
  result_filenames = [filenames[id] for id in ids]
99
  images, captions = [], []
100
  for result_filename, score in zip(result_filenames, distances):
101
+ if result_filename == image_name:
102
  continue
103
  images.append(
104
  plt.imread(os.path.join(IMAGES_DIR, result_filename)))
dashboard_text2image.py CHANGED
@@ -62,8 +62,8 @@ def app():
62
  Our fine-tuned CLIP model was previously used to generate image vectors for
63
  our demo, and NMSLib was used for fast vector access.
64
 
65
- Some suggested queries to start you off with -- "ships", "school house",
66
- "military installations", "mountains", "beaches", "airports", "lakes", etc.
67
  """)
68
 
69
  query = st.text_input("Text Query:")
 
62
  Our fine-tuned CLIP model was previously used to generate image vectors for
63
  our demo, and NMSLib was used for fast vector access.
64
 
65
+ Some suggested queries to start you off with -- `ships`, `school house`,
66
+ `military installations`, `mountains`, `beaches`, `airports`, `lakes`, etc.
67
  """)
68
 
69
  query = st.text_input("Text Query:")
demo-images/st_tropez_1.png ADDED
demo-images/st_tropez_2.png ADDED