sujitpal commited on
Commit
357b0b8
1 Parent(s): e44b0e6

new: initial revision (copied from main repo)

Browse files
app.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dashboard_text2image
2
+ import dashboard_image2image
3
+
4
+ import streamlit as st
5
+
6
+ PAGES = {
7
+ "Text to Image": dashboard_text2image,
8
+ "Image to Image": dashboard_image2image
9
+ }
10
+ st.sidebar.title("Navigation")
11
+
12
+ selection = st.sidebar.radio("Go to", list(PAGES.keys()))
13
+ page = PAGES[selection]
14
+ page.app()
dashboard_image2image.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import nmslib
3
+ import numpy as np
4
+ import os
5
+ import streamlit as st
6
+
7
+ from PIL import Image
8
+ from transformers import CLIPProcessor, FlaxCLIPModel
9
+
10
+
11
+ BASELINE_MODEL = "openai/clip-vit-base-patch32"
12
+ # MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
13
+ MODEL_PATH = "flax-community/clip-rsicd"
14
+
15
+ # IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-baseline.tsv"
16
+ # IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
17
+ IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
18
+
19
+ # IMAGES_DIR = "/home/shared/data/rsicd_images"
20
+ IMAGES_DIR = "./images"
21
+
22
+
23
+ @st.cache(allow_output_mutation=True)
24
+ def load_index():
25
+ filenames, image_vecs = [], []
26
+ fvec = open(IMAGE_VECTOR_FILE, "r")
27
+ for line in fvec:
28
+ cols = line.strip().split('\t')
29
+ filename = cols[0]
30
+ image_vec = np.array([float(x) for x in cols[1].split(',')])
31
+ filenames.append(filename)
32
+ image_vecs.append(image_vec)
33
+ V = np.array(image_vecs)
34
+ index = nmslib.init(method='hnsw', space='cosinesimil')
35
+ index.addDataPointBatch(V)
36
+ index.createIndex({'post': 2}, print_progress=True)
37
+ return filenames, index
38
+
39
+
40
+ @st.cache(allow_output_mutation=True)
41
+ def load_model():
42
+ model = FlaxCLIPModel.from_pretrained(MODEL_PATH)
43
+ processor = CLIPProcessor.from_pretrained(BASELINE_MODEL)
44
+ return model, processor
45
+
46
+
47
+ def app():
48
+ filenames, index = load_index()
49
+ model, processor = load_model()
50
+
51
+ st.title("Image to Image Retrieval")
52
+ st.markdown("""
53
+ The CLIP model from OpenAI is trained in a self-supervised manner using
54
+ contrastive learning to project images and caption text onto a common
55
+ embedding space. We have fine-tuned the model using the RSICD dataset
56
+ (10k images and ~50k captions from the remote sensing domain).
57
+
58
+ This demo shows the image to image retrieval capabilities of this model, i.e.,
59
+ given an image file name as a query (we suggest copy pasting the file name
60
+ from the result of a text to image query), we use our fine-tuned CLIP model
61
+ to project the query image to the image/caption embedding space and search
62
+ for nearby images (by cosine similarity) in this space.
63
+
64
+ Our fine-tuned CLIP model was previously used to generate image vectors for
65
+ our demo, and NMSLib was used for fast vector access.
66
+ """)
67
+
68
+ image_file = st.text_input("Image Query (filename):")
69
+ if st.button("Query"):
70
+ image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_file)))
71
+ inputs = processor(images=image, return_tensors="jax", padding=True)
72
+ query_vec = model.get_image_features(**inputs)
73
+ query_vec = np.asarray(query_vec)
74
+ ids, distances = index.knnQuery(query_vec, k=11)
75
+ result_filenames = [filenames[id] for id in ids]
76
+ images, captions = [], []
77
+ for result_filename, score in zip(result_filenames, distances):
78
+ if result_filename == image_file:
79
+ continue
80
+ images.append(
81
+ plt.imread(os.path.join(IMAGES_DIR, result_filename)))
82
+ captions.append("{:s} (score: {:.3f})".format(result_filename, 1.0 - score))
83
+ images = images[0:10]
84
+ captions = captions[0:10]
85
+ st.image(images[0:3], caption=captions[0:3])
86
+ st.image(images[3:6], caption=captions[3:6])
87
+ st.image(images[6:9], caption=captions[6:9])
88
+ st.image(images[9:], caption=captions[9:])
dashboard_text2image.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import nmslib
3
+ import numpy as np
4
+ import os
5
+ import streamlit as st
6
+
7
+ from transformers import CLIPProcessor, FlaxCLIPModel
8
+
9
+
10
+ BASELINE_MODEL = "openai/clip-vit-base-patch32"
11
+ # MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
12
+ MODEL_PATH = "flax-community/clip-rsicd"
13
+
14
+ # IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-baseline.tsv"
15
+ # IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
16
+ IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
17
+
18
+ # IMAGES_DIR = "/home/shared/data/rsicd_images"
19
+ IMAGES_DIR = "./images"
20
+
21
+
22
+ @st.cache(allow_output_mutation=True)
23
+ def load_index():
24
+ filenames, image_vecs = [], []
25
+ fvec = open(IMAGE_VECTOR_FILE, "r")
26
+ for line in fvec:
27
+ cols = line.strip().split('\t')
28
+ filename = cols[0]
29
+ image_vec = np.array([float(x) for x in cols[1].split(',')])
30
+ filenames.append(filename)
31
+ image_vecs.append(image_vec)
32
+ V = np.array(image_vecs)
33
+ index = nmslib.init(method='hnsw', space='cosinesimil')
34
+ index.addDataPointBatch(V)
35
+ index.createIndex({'post': 2}, print_progress=True)
36
+ return filenames, index
37
+
38
+
39
+ @st.cache(allow_output_mutation=True)
40
+ def load_model():
41
+ model = FlaxCLIPModel.from_pretrained(MODEL_PATH)
42
+ processor = CLIPProcessor.from_pretrained(BASELINE_MODEL)
43
+ return model, processor
44
+
45
+
46
+ def app():
47
+ filenames, index = load_index()
48
+ model, processor = load_model()
49
+
50
+ st.title("Text to Image Retrieval")
51
+ st.markdown("""
52
+ The CLIP model from OpenAI is trained in a self-supervised manner using
53
+ contrastive learning to project images and caption text onto a common
54
+ embedding space. We have fine-tuned the model using the RSICD dataset
55
+ (10k images and ~50k captions from the remote sensing domain).
56
+
57
+ This demo shows the image to text retrieval capabilities of this model, i.e.,
58
+ given a text query, we use our fine-tuned CLIP model to project the text query
59
+ to the image/caption embedding space and search for nearby images (by
60
+ cosine similarity) in this space.
61
+
62
+ Our fine-tuned CLIP model was previously used to generate image vectors for
63
+ our demo, and NMSLib was used for fast vector access.
64
+ """)
65
+
66
+ query = st.text_input("Text Query:")
67
+ if st.button("Query"):
68
+ inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)
69
+ query_vec = model.get_text_features(**inputs)
70
+ query_vec = np.asarray(query_vec)
71
+ ids, distances = index.knnQuery(query_vec, k=10)
72
+ result_filenames = [filenames[id] for id in ids]
73
+ images, captions = [], []
74
+ for result_filename, score in zip(result_filenames, distances):
75
+ images.append(
76
+ plt.imread(os.path.join(IMAGES_DIR, result_filename)))
77
+ captions.append("{:s} (score: {:.3f})".format(result_filename, 1.0 - score))
78
+ st.image(images[0:3], caption=captions[0:3])
79
+ st.image(images[3:6], caption=captions[3:6])
80
+ st.image(images[6:9], caption=captions[6:9])
81
+ st.image(images[9:], caption=captions[9:])
demo-image-encoder.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import json
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import requests
8
+ import os
9
+
10
+ from PIL import Image
11
+ from transformers import CLIPProcessor, FlaxCLIPModel
12
+
13
+
14
+ def encode_image(image_file, model, processor):
15
+ image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_file)))
16
+ inputs = processor(images=image, return_tensors="jax")
17
+ image_vec = model.get_image_features(**inputs)
18
+ return np.array(image_vec).reshape(-1)
19
+
20
+
21
+ DATA_DIR = "/home/shared/data"
22
+ IMAGES_DIR = os.path.join(DATA_DIR, "rsicd_images")
23
+ CAPTIONS_FILE = os.path.join(DATA_DIR, "dataset_rsicd.json")
24
+ VECTORS_DIR = os.path.join(DATA_DIR, "vectors")
25
+ BASELINE_MODEL = "openai/clip-vit-base-patch32"
26
+
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument("model_dir", help="Path to model to use for encoding")
29
+ args = parser.parse_args()
30
+
31
+ print("Loading image list...", end="")
32
+ image2captions = {}
33
+ with open(CAPTIONS_FILE, "r") as fcap:
34
+ data = json.loads(fcap.read())
35
+ for image in data["images"]:
36
+ if image["split"] == "test":
37
+ filename = image["filename"]
38
+ sentences = []
39
+ for sentence in image["sentences"]:
40
+ sentences.append(sentence["raw"])
41
+ image2captions[filename] = sentences
42
+
43
+ print("{:d} images".format(len(image2captions)))
44
+
45
+
46
+ print("Loading model...")
47
+ if args.model_dir == "baseline":
48
+ model = FlaxCLIPModel.from_pretrained(BASELINE_MODEL)
49
+ else:
50
+ model = FlaxCLIPModel.from_pretrained(args.model_dir)
51
+ processor = CLIPProcessor.from_pretrained(BASELINE_MODEL)
52
+
53
+
54
+ model_basename = "-".join(args.model_dir.split("/")[-2:])
55
+ vector_file = os.path.join(VECTORS_DIR, "test-{:s}.tsv".format(model_basename))
56
+ print("Vectors written to {:s}".format(vector_file))
57
+ num_written = 0
58
+ fvec = open(vector_file, "w")
59
+ for image_file in image2captions.keys():
60
+ if num_written % 100 == 0:
61
+ print("{:d} images processed".format(num_written))
62
+ image_vec = encode_image(image_file, model, processor)
63
+ image_vec_s = ",".join(["{:.7e}".format(x) for x in image_vec])
64
+ fvec.write("{:s}\t{:s}\n".format(image_file, image_vec_s))
65
+ num_written += 1
66
+
67
+ print("{:d} images processed, COMPLETE".format(num_written))
68
+ fvec.close()
69
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ streamlit==0.84.1
2
+ nmslib==2.1.1