File size: 1,834 Bytes
1046772
a2ca503
 
1ae2bef
 
 
fc4cd1f
 
 
 
 
 
 
 
e53b04f
 
1f029d6
 
 
 
 
321f9f3
51718fb
d7f67d7
ad572cd
d7f67d7
1f029d6
 
 
 
 
fc4cd1f
0325baf
 
fc4cd1f
a2ca503
 
3010861
e029f9c
 
a2ca503
3c26522
e029f9c
 
 
 
 
 
a2ca503
e029f9c
 
 
 
 
85a0494
a2ca503
9d17ea5
e670f36
85a0494
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import utils
from huggingface_hub.keras_mixin import from_pretrained_keras
from PIL import Image
import streamlit as st
import tensorflow as tf

st.cache(show_spinner=True)
def load_model():
	# Load the DINO model
	dino = from_pretrained_keras("probing-vits/vit-dino-base16")
	return dino
	
dino=load_model()

# Inputs
st.title("Input your image")
image_url = st.text_input(
	label="URL of image",
	value="https://dl.fbaipublicfiles.com/dino/img.png",
	placeholder="https://your-favourite-image.png"
)
uploaded_file = st.file_uploader("or an image file", type =["jpg","jpeg"])

# Outputs
st.title("Original Image from URL")

# Preprocess the same image but with normlization.
image, preprocessed_image = utils.load_image_from_url(
	image_url,
	model_type="dino"
)
if uploaded_file:
	image = Image.open(uploaded_file)
	preprocessed_image = utils.preprocess_image(image, "dino")
	
st.image(image, caption="Original Image")

with st.spinner("Generating the attention scores..."):
	# Get the attention scores
	_, attention_score_dict = dino.predict(preprocessed_image)

with st.spinner("Generating the heat maps... HOLD ON!"):
	# De-normalize the image for visual clarity. 
	in1k_mean = tf.constant([0.485 * 255, 0.456 * 255, 0.406 * 255])  
	in1k_std = tf.constant([0.229 * 255, 0.224 * 255, 0.225 * 255])  
	preprocessed_img_orig = (preprocessed_image * in1k_std) + in1k_mean
	preprocessed_img_orig = preprocessed_img_orig / 255.
	preprocessed_img_orig = tf.clip_by_value(preprocessed_img_orig, 0.0, 1.0).numpy()

	attentions = utils.attention_heatmap(
		attention_score_dict=attention_score_dict,
		image=preprocessed_img_orig
	)

	utils.plot(attentions=attentions, image=preprocessed_img_orig)

# Show the attention maps
st.title("Attention 🔥 Maps")
image = Image.open("heat_map.png")
st.image(image, caption="Attention Heat Maps")