mdanish commited on
Commit
6b2ffe4
·
1 Parent(s): 843d2c1

add KNN init. code

Browse files
Files changed (1) hide show
  1. app.py +77 -1
app.py CHANGED
@@ -6,6 +6,18 @@ from PIL import Image
6
  from io import BytesIO
7
  import os
8
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Securely get the token from environment variables
10
  MAPILLARY_ACCESS_TOKEN = os.environ.get('MAPILLARY_ACCESS_TOKEN')
11
 
@@ -60,9 +72,73 @@ def get_nearest_image(lat, lon):
60
  st.error(f"Error fetching Mapillary data: {str(e)}")
61
  return None
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def main():
64
- st.title("Amsterdam Street View Explorer")
65
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # Initialize the map centered on Amsterdam
67
  amsterdam_coords = [52.3676, 4.9041]
68
  m = folium.Map(location=amsterdam_coords, zoom_start=13)
 
6
  from io import BytesIO
7
  import os
8
 
9
+ knnpath = '20241204-ams-no-env-open_clip_ViT-H-14-378-quickgelu.npz'
10
+ clip_model_name = 'ViT-H-14-378-quickgelu'
11
+ pretrained_name = 'dfn5b'
12
+
13
+ categories = ['walkability', 'bikeability', 'pleasantness', 'greenness', 'safety']
14
+
15
+ # Set page config
16
+ st.set_page_config(
17
+ page_title="Percept",
18
+ layout="wide"
19
+ )
20
+
21
  # Securely get the token from environment variables
22
  MAPILLARY_ACCESS_TOKEN = os.environ.get('MAPILLARY_ACCESS_TOKEN')
23
 
 
72
  st.error(f"Error fetching Mapillary data: {str(e)}")
73
  return None
74
 
75
+ @st.cache_resource
76
+ def load_model():
77
+ """Load the OpenCLIP model and return model and processor"""
78
+ model, _, preprocess = open_clip.create_model_and_transforms(
79
+ clip_model_name, pretrained=pretrained_name
80
+ )
81
+ tokenizer = open_clip.get_tokenizer(clip_model_name)
82
+ return model, preprocess, tokenizer
83
+
84
+ def process_image(image, preprocess):
85
+ """Process image and return tensor"""
86
+ if isinstance(image, str):
87
+ # If image is a URL
88
+ response = requests.get(image)
89
+ image = Image.open(BytesIO(response.content))
90
+ # Ensure image is in RGB mode
91
+ if image.mode != 'RGB':
92
+ image = image.convert('RGB')
93
+ processed_image = preprocess(image).unsqueeze(0)
94
+ return processed_image
95
+
96
+ def knn_get_score(knn, k, cat, vec):
97
+ allvecs = knn[f'{cat}_vecs']
98
+ if debug: st.write('allvecs.shape', allvecs.shape)
99
+ scores = knn[f'{cat}_scores']
100
+ if debug: st.write('scores.shape', scores.shape)
101
+ # Compute cosine similiarity of vec against allvecs
102
+ # (both are already normalized)
103
+ cos_sim_table = vec @ allvecs.T
104
+ if debug: st.write('cos_sim_table.shape', cos_sim_table.shape)
105
+ # Get sorted array indices by similiarity in descending order
106
+ sortinds = np.flip(np.argsort(cos_sim_table, axis=1), axis=1)
107
+ if debug: st.write('sortinds.shape', sortinds.shape)
108
+ # Get corresponding scores for the sorted vectors
109
+ kscores = scores[sortinds][:,:k]
110
+ if debug: st.write('kscores.shape', kscores.shape)
111
+ # Get actual sorted similiarity scores
112
+ # (line copied from clip_retrieval_knn.py even though sortinds.shape[0] == 1 here)
113
+ ksims = cos_sim_table[np.expand_dims(np.arange(sortinds.shape[0]), axis=1), sortinds]
114
+ ksims = ksims[:,:k]
115
+ if debug: st.write('ksims.shape', ksims.shape)
116
+ # Apply normalization after exponential formula
117
+ ksims = softmax(10**ksims)
118
+ # Weighted sum
119
+ kweightedscore = np.sum(kscores * ksims)
120
+ return kweightedscore
121
+
122
+
123
+ @st.cache_resource
124
+ def load_knn():
125
+ return np.load(knnpath)
126
+
127
  def main():
128
+ st.title("Percept: Map Explorer")
129
 
130
+ try:
131
+ with st.spinner('Loading CLIP model... This may take a moment.'):
132
+ model, preprocess, tokenizer = load_model()
133
+ device = "cuda" if torch.cuda.is_available() else "cpu"
134
+ model = model.to(device)
135
+ except Exception as e:
136
+ st.error(f"Error loading model: {str(e)}")
137
+ st.info("Please make sure you have enough memory and the correct dependencies installed.")
138
+
139
+ with st.spinner('Loading KNN model... This may take a moment.'):
140
+ knn = load_knn()
141
+
142
  # Initialize the map centered on Amsterdam
143
  amsterdam_coords = [52.3676, 4.9041]
144
  m = folium.Map(location=amsterdam_coords, zoom_start=13)