hen commited on
Commit
fc0f846
1 Parent(s): b311943
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import streamlit as st
3
+ from PIL import Image
4
+ import torch
5
+ import clip
6
+ from torchray.attribution.grad_cam import grad_cam
7
+ from miniclip.imageWrangle import heatmap, min_max_norm, torch_to_rgba
8
+
9
+ st.set_page_config(layout="wide")
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+
13
+ @st.cache(show_spinner=True, allow_output_mutation=True)
14
+ def get_model():
15
+ return clip.load("RN50", device=device, jit=False)
16
+
17
+
18
+ # OPTIONS:
19
+
20
+ st.sidebar.header('Options')
21
+ alpha = st.sidebar.radio("select alpha", [0.5, 0.7, 0.8], index=1)
22
+ layer = st.sidebar.selectbox("select saliency layer", ['layer4.2.relu'], index=0)
23
+
24
+ st.header("Saliency Map demo for CLIP")
25
+ st.write(
26
+ "a quick experiment by [Hendrik Strobelt](http://hendrik.strobelt.com) ([MIT-IBM Watson AI Lab](https://mitibmwatsonailab.mit.edu/)) ")
27
+ with st.beta_expander('1. Upload Image', expanded=True):
28
+ imageFile = st.file_uploader("Select a file:", type=[".jpg", ".png", ".jpeg"])
29
+
30
+ # st.write("### (2) Enter some desriptive texts.")
31
+ with st.beta_expander('2. Write Descriptions', expanded=True):
32
+ textarea = st.text_area("Descriptions seperated by semicolon", "a car; a dog; a cat")
33
+ prefix = st.text_input("(optional) Prefix all descriptions with: ", "an image of")
34
+
35
+ if imageFile:
36
+ st.markdown("<hr style='border:black solid;'>", unsafe_allow_html=True)
37
+ image_raw = Image.open(imageFile)
38
+ model, preprocess = get_model()
39
+
40
+ # preprocess image:
41
+ image = preprocess(image_raw).unsqueeze(0).to(device)
42
+
43
+ # preprocess text
44
+ prefix = prefix.strip()
45
+ if len(prefix) > 0:
46
+ categories = [f"{prefix} {x.strip()}" for x in textarea.split(';')]
47
+ else:
48
+ categories = [x.strip() for x in textarea.split(';')]
49
+ text = clip.tokenize(categories).to(device)
50
+ # st.write(text)
51
+ # with st.echo():
52
+ with torch.no_grad():
53
+ image_features = model.encode_image(image)
54
+ text_features = model.encode_text(text)
55
+ image_features_norm = image_features.norm(dim=-1, keepdim=True)
56
+ image_features_new = image_features / image_features_norm
57
+ text_features_norm = text_features.norm(dim=-1, keepdim=True)
58
+ text_features_new = text_features / text_features_norm
59
+ logit_scale = model.logit_scale.exp()
60
+ logits_per_image = logit_scale * image_features_new @ text_features_new.t()
61
+ probs = logits_per_image.softmax(dim=-1).cpu().numpy().tolist()
62
+
63
+ saliency = grad_cam(model.visual, image.type(model.dtype), image_features, saliency_layer=layer)
64
+ hm = heatmap(image[0], saliency[0][0,].detach().type(torch.float32).cpu(), alpha=alpha)
65
+
66
+ collect_images = []
67
+ for i in range(len(categories)):
68
+ # mutliply the normalized text embedding with image norm to get approx image embedding
69
+ text_prediction = (text_features_new[[i]] * image_features_norm)
70
+ saliency = grad_cam(model.visual, image.type(model.dtype), text_prediction, saliency_layer=layer)
71
+ hm = heatmap(image[0], saliency[0][0,].detach().type(torch.float32).cpu(), alpha=alpha)
72
+ collect_images.append(hm)
73
+ logits = logits_per_image.cpu().numpy().tolist()[0]
74
+ st.write("### Grad Cam for text embeddings")
75
+ st.image(collect_images,
76
+ width=256,
77
+ caption=[f"{x} - {str(round(y, 3))}/{str(round(l, 2))}" for (x, y, l) in
78
+ zip(categories, probs[0], logits)])
79
+
80
+ st.write("### Original Image and Grad Cam for image embedding")
81
+ st.image([Image.fromarray((torch_to_rgba(image[0]).numpy() * 255.).astype(np.uint8)), hm],
82
+ caption=["original", "image gradcam"]) # , caption="Grad Cam for original embedding")
83
+
84
+ # st.image(imageFile)
85
+
86
+
87
+ # @st.cache
88
+ def get_readme():
89
+ with open('README.md') as f:
90
+ return "\n".join([x.strip() for x in f.readlines()])
91
+
92
+
93
+ st.markdown("<hr style='border:black solid;'>", unsafe_allow_html=True)
94
+ with st.beta_expander('Description', expanded=True):
95
+ st.markdown(get_readme(), unsafe_allow_html=True)
96
+
97
+ hide_streamlit_style = """
98
+ <style>
99
+ #MainMenu {visibility: hidden;}
100
+ footer {visibility: hidden;}
101
+ </style>
102
+
103
+ """
104
+ st.markdown(hide_streamlit_style, unsafe_allow_html=True)
assets/clipper_example_coffeeMeeting.jpg ADDED
assets/clipper_example_room.jpg ADDED
assets/clipper_image_book_attack.jpg ADDED
assets/clipper_image_primes.jpg ADDED
assets/miniclip_teaser.jpg ADDED
assets/pharao.jpg ADDED
miniclip/imageWrangle.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ from streamlit.logger import update_formatter
4
+ import torch
5
+ from matplotlib import cm
6
+
7
+
8
+
9
+ def min_max_norm(array):
10
+ lim = [array.min(), array.max()]
11
+ array = array - lim[0]
12
+ array.mul_(1 / (1.e-10+ (lim[1] - lim[0])))
13
+ # array = torch.clamp(array, min=0, max=1)
14
+ return array
15
+
16
+ def torch_to_rgba(img):
17
+ img = min_max_norm(img)
18
+ rgba_im = img.permute(1, 2, 0).cpu()
19
+ if rgba_im.shape[2] == 3:
20
+ rgba_im = torch.cat((rgba_im, torch.ones(*rgba_im.shape[:2], 1)), dim=2)
21
+ assert rgba_im.shape[2] == 4
22
+ return rgba_im
23
+
24
+
25
+ def numpy_to_image(img, size):
26
+ """
27
+ takes a [0..1] normalized rgba input and returns resized image as [0...255] rgba image
28
+ """
29
+ resized = Image.fromarray((img*255.).astype(np.uint8)).resize((size, size))
30
+ return resized
31
+
32
+ def upscale_pytorch(img:np.array, size):
33
+ torch_img = torch.from_numpy(img).unsqueeze(0).permute(0,3,1,2)
34
+ print(torch_img)
35
+ upsampler = torch.nn.Upsample(size=size)
36
+ return upsampler(torch_img)[0].permute(1,2,0).cpu().numpy()
37
+
38
+
39
+ def heatmap(image:torch.Tensor, heatmap: torch.Tensor, size=None, alpha=.6):
40
+ if not size:
41
+ size = image.shape[1]
42
+ # print(heatmap)
43
+ # print(min_max_norm(heatmap))
44
+
45
+ img = torch_to_rgba(image).numpy() # [0...1] rgba numpy "image"
46
+ hm = cm.hot(min_max_norm(heatmap).numpy()) # [0...1] rgba numpy "image"
47
+
48
+ # print(hm.shape, hm)
49
+ #
50
+
51
+ img = np.array(numpy_to_image(img,size))
52
+ hm = np.array(numpy_to_image(hm, size))
53
+ # hm = upscale_pytorch(hm, size)
54
+ # print (hm)
55
+
56
+ return Image.fromarray((alpha * hm + (1-alpha)*img).astype(np.uint8))
57
+ # return Image.fromarray(hm)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy~=1.20.1
2
+ streamlit~=0.78.0
3
+ torch~=1.7.1
4
+ pillow~=8.1.2
5
+ torchray~=1.0.0.2
6
+ matplotlib~=3.3.4
7
+ git+https://github.com/openai/CLIP.git