Oliver Grainge commited on
Commit
351130e
·
1 Parent(s): e00d7ee

Initial VPR demo implementation

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,44 @@
1
- ---
2
- title: Simple Vpr Demo
3
- emoji: 🐠
4
- colorFrom: green
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: 'This space, is a simple demo for a vpr system. '
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Visual Place Recognition Demo
2
+
3
+ This is a Visual Place Recognition (VPR) demo using EigenPlaces model. Upload a query image to find similar places in our database of 400+ images from various cities.
4
+
5
+ ## How it works
6
+
7
+ 1. Upload a query image
8
+ 2. The model extracts visual features from your image
9
+ 3. It compares these features with pre-computed features from 400+ database images
10
+ 4. Returns the most similar matches with similarity scores and location information
11
+
12
+ ## Dataset
13
+
14
+ - **Database**: 400+ images from various cities
15
+ - **Cities**: Melbourne, Boston, and others
16
+ - **Metadata**: Each image includes place ID, city, and GPS coordinates
17
+
18
+ ## Model
19
+
20
+ - **Architecture**: EigenPlaces with ResNet50 backbone
21
+ - **Descriptor Dimension**: 2048
22
+ - **Similarity Metric**: Cosine similarity
23
+
24
+ ## Usage
25
+
26
+ 1. Upload a query image using the interface
27
+ 2. Adjust the number of matches you want to see (1-10)
28
+ 3. Click "Find Matches" to get results
29
+ 4. View the matched images and their metadata
30
+
31
+ ## Technical Details
32
+
33
+ The demo uses:
34
+ - EigenPlaces model for visual feature extraction
35
+ - Pre-computed descriptors for fast similarity search
36
+ - Cosine similarity for matching
37
+ - Gradio for the web interface
38
+
39
+ ## Files
40
+
41
+ - `app.py`: Main Gradio application
42
+ - `model.py`: Model loading and descriptor computation
43
+ - `dataset.py`: Dataset handling and ground truth lookup
44
+ - `data/`: Contains database images, query images, and ground truth JSON
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import json
6
+ from pathlib import Path
7
+ from torch.utils.data import DataLoader
8
+ from tqdm import tqdm
9
+
10
+ from model import load_model
11
+ from dataset import VPRDataset
12
+
13
+ # Global variables
14
+ model = None
15
+ dataset = None
16
+ db_descriptors = None
17
+ db_filenames = None
18
+
19
+ def load_everything():
20
+ """Load model, dataset, and pre-compute database descriptors."""
21
+ global model, dataset, db_descriptors, db_filenames
22
+
23
+ print("Loading model...")
24
+ model = load_model("eigenplaces")
25
+
26
+ print("Loading dataset...")
27
+ dataset = VPRDataset('data')
28
+
29
+ print("Pre-computing database descriptors...")
30
+ # Create database-only dataset
31
+ db_dataset = VPRDataset('data', include_queries=False, include_database=True)
32
+
33
+ # Create DataLoader for efficient batch processing
34
+ batch_size = 1
35
+ dataloader = DataLoader(db_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
36
+
37
+ # Compute descriptors for database images using DataLoader
38
+ db_descriptors = []
39
+ db_filenames = []
40
+
41
+ model.eval()
42
+ with torch.no_grad():
43
+ for batch_images, batch_filenames, batch_is_query in tqdm(dataloader, desc="Computing database descriptors"):
44
+ # Move batch to same device as model
45
+ device = next(model.parameters()).device
46
+ batch_images = batch_images.to(device)
47
+
48
+ # Compute descriptors for this batch
49
+ batch_descriptors = model(batch_images)
50
+
51
+ # Store results
52
+ db_descriptors.append(batch_descriptors.cpu())
53
+ db_filenames.extend(batch_filenames)
54
+
55
+ # Concatenate all descriptors
56
+ db_descriptors = torch.cat(db_descriptors, dim=0)
57
+ print(f"Pre-computed descriptors for {len(db_filenames)} database images")
58
+
59
+ def find_matches(query_image, top_k=5):
60
+ """Find top-k matches for a query image."""
61
+ if model is None or db_descriptors is None:
62
+ return "Model not loaded yet. Please wait..."
63
+
64
+ # Extract query descriptor
65
+ with torch.no_grad():
66
+ query_batch = query_image.unsqueeze(0)
67
+ query_descriptor = model(query_batch).cpu()
68
+
69
+ # Compute similarities (cosine similarity)
70
+ query_norm = query_descriptor / torch.norm(query_descriptor)
71
+ db_norm = db_descriptors / torch.norm(db_descriptors, dim=1, keepdim=True)
72
+ similarities = torch.mm(query_norm, db_norm.T).squeeze()
73
+
74
+ # Get top-k matches
75
+ top_similarities, top_indices = torch.topk(similarities, top_k)
76
+
77
+ # Prepare results
78
+ results = []
79
+ for i, (sim, idx) in enumerate(zip(top_similarities, top_indices)):
80
+ filename = db_filenames[idx]
81
+ img_path = Path('data') / 'database' / filename
82
+
83
+ # Get metadata
84
+ item_info = dataset.get_item_by_filename(filename)
85
+
86
+ results.append({
87
+ 'image': str(img_path),
88
+ 'similarity': float(sim),
89
+ 'place_id': item_info['place_id'],
90
+ 'city': item_info['city'],
91
+ 'coordinates': f"{item_info['lat']:.4f}, {item_info['lon']:.4f}"
92
+ })
93
+
94
+ return results
95
+
96
+ def demo_interface(query_image, top_k):
97
+ """Gradio interface function."""
98
+ if query_image is None:
99
+ return "Please upload a query image"
100
+
101
+ # Convert PIL to tensor format expected by model
102
+ import torchvision.transforms as T
103
+ transform = T.Compose([
104
+ T.Resize((480, 640)),
105
+ T.ToTensor(),
106
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
107
+ ])
108
+
109
+ query_tensor = transform(query_image)
110
+ matches = find_matches(query_tensor, int(top_k))
111
+
112
+ if isinstance(matches, str):
113
+ return matches
114
+
115
+ # Format results for display
116
+ result_text = "Top Matches:\n\n"
117
+ result_images = []
118
+
119
+ for i, match in enumerate(matches):
120
+ result_text += f"{i+1}. Similarity: {match['similarity']:.4f}\n"
121
+ result_text += f" Place ID: {match['place_id']}\n"
122
+ result_text += f" City: {match['city']}\n"
123
+ result_text += f" Coordinates: {match['coordinates']}\n\n"
124
+
125
+ result_images.append(match['image'])
126
+
127
+ return result_text, result_images
128
+
129
+ # Load everything on startup
130
+ load_everything()
131
+
132
+ # Create Gradio interface
133
+ with gr.Blocks(title="Visual Place Recognition Demo") as demo:
134
+ gr.Markdown("# Visual Place Recognition Demo")
135
+ gr.Markdown("Upload a query image to find similar places in our database!")
136
+
137
+ with gr.Row():
138
+ with gr.Column():
139
+ query_input = gr.Image(type="pil", label="Query Image")
140
+ top_k_slider = gr.Slider(1, 10, value=5, step=1, label="Number of matches")
141
+ find_button = gr.Button("Find Matches")
142
+
143
+ with gr.Column():
144
+ result_text = gr.Textbox(label="Results", lines=10)
145
+ result_gallery = gr.Gallery(label="Matched Images", show_label=True)
146
+
147
+ find_button.click(
148
+ fn=demo_interface,
149
+ inputs=[query_input, top_k_slider],
150
+ outputs=[result_text, result_gallery]
151
+ )
152
+
153
+ if __name__ == "__main__":
154
+ demo.launch()
data/database/place00004796_db0000.jpg ADDED

Git LFS Details

  • SHA256: 60267c21981a6481e44b55ba55dc91d93ac3e1776007b519bec0472616df7e7c
  • Pointer size: 130 Bytes
  • Size of remote file: 47.7 kB
data/database/place00004796_db0001.jpg ADDED

Git LFS Details

  • SHA256: d9e9b6363f7ccef97dae12742159623986e6d21a94bcafe90177bdb47ee9e4bb
  • Pointer size: 130 Bytes
  • Size of remote file: 41.1 kB
data/database/place00004796_db0002.jpg ADDED

Git LFS Details

  • SHA256: e5200cba65f034684afe861632215e23f39f7feca539b3c67a8f0f726040b85c
  • Pointer size: 130 Bytes
  • Size of remote file: 55.9 kB
data/database/place00004796_db0003.jpg ADDED

Git LFS Details

  • SHA256: 4c6db70fc0940e9ff4ba7ec639e2a112bc250af3703f74a422c9ff3dd2e3fcdd
  • Pointer size: 130 Bytes
  • Size of remote file: 38.5 kB
data/database/place00008797_db0008.jpg ADDED

Git LFS Details

  • SHA256: b5e4a2c5e77085c7b6d74383a10a9bb97b91b4b9f86e72d7be5743ee362b8947
  • Pointer size: 130 Bytes
  • Size of remote file: 54.9 kB
data/database/place00008797_db0009.jpg ADDED

Git LFS Details

  • SHA256: ef6e88ba238847bcf247197b75ada3bded98e0f6002015fd130bc40b69107e20
  • Pointer size: 130 Bytes
  • Size of remote file: 62.4 kB
data/database/place00008797_db0010.jpg ADDED

Git LFS Details

  • SHA256: 5a0911faabcb845c96917bba820a80920607891e4789f4acf7dd91c61dbb4216
  • Pointer size: 130 Bytes
  • Size of remote file: 65.1 kB
data/database/place00008797_db0011.jpg ADDED

Git LFS Details

  • SHA256: 35729e9b4a73696ab97eea4d39a4b8942b4fb22d36dedac6893b3ca1dabd8f2b
  • Pointer size: 130 Bytes
  • Size of remote file: 68.5 kB
data/database/place00201236_db0012.jpg ADDED

Git LFS Details

  • SHA256: 3a1d599647a77a6f5774a6235939250a5318b014c2ab7e4e421a237c8c7f2521
  • Pointer size: 130 Bytes
  • Size of remote file: 53.4 kB
data/database/place00201236_db0013.jpg ADDED

Git LFS Details

  • SHA256: 52354c0aa539fbbced447e89a18bd462a2d61c853778ab8fafec9f41424f55cb
  • Pointer size: 130 Bytes
  • Size of remote file: 65.1 kB
data/database/place00201236_db0014.jpg ADDED

Git LFS Details

  • SHA256: f86f54b15f2f3124141c5a0e9534c5d34a13aca7e4d0103edf0726fb32c23882
  • Pointer size: 130 Bytes
  • Size of remote file: 58.2 kB
data/database/place00201236_db0015.jpg ADDED

Git LFS Details

  • SHA256: c1acd8d5c2bcb73e476fbae51e435f8823c47baa6205d632afa4a4ad8b28506d
  • Pointer size: 130 Bytes
  • Size of remote file: 32.6 kB
data/database/place00203981_db0004.jpg ADDED

Git LFS Details

  • SHA256: da4ce17c8bc5c05888b7b3cfac4478cdbfa2763b25956dd3172ec6cb1b95466e
  • Pointer size: 131 Bytes
  • Size of remote file: 109 kB
data/database/place00203981_db0005.jpg ADDED

Git LFS Details

  • SHA256: fe79baa14257277a6ba7ad0bfeff0a5597ab891618fcf3be9fb91367a6fd56b6
  • Pointer size: 130 Bytes
  • Size of remote file: 71.9 kB
data/database/place00203981_db0006.jpg ADDED

Git LFS Details

  • SHA256: ea42c78cca5429501bad7d7920af86215877ae3cebc9dbf51625e1a283d52c09
  • Pointer size: 130 Bytes
  • Size of remote file: 87.1 kB
data/database/place00203981_db0007.jpg ADDED

Git LFS Details

  • SHA256: cd746c0bb10d7d8083f2ac4024e4e0c4977332153724e17761606453d7a7407c
  • Pointer size: 130 Bytes
  • Size of remote file: 89.5 kB
data/database/place00205527_db0016.jpg ADDED

Git LFS Details

  • SHA256: c1b4ee76fbe66cea5f63ae390957b4e43d829fba39a152bbd63545d3afccfd5c
  • Pointer size: 130 Bytes
  • Size of remote file: 34.5 kB
data/database/place00205527_db0017.jpg ADDED

Git LFS Details

  • SHA256: 106b606985e7b131fe7d5bc5acc35c2b7f34e956ec2abb6bc3502cc78c238d49
  • Pointer size: 130 Bytes
  • Size of remote file: 51.5 kB
data/database/place00205527_db0018.jpg ADDED

Git LFS Details

  • SHA256: 87f07997e2c28e2fa79d43cbfa829b0559ec1fc49d3683928a1920636e586604
  • Pointer size: 130 Bytes
  • Size of remote file: 33 kB
data/database/place00205527_db0019.jpg ADDED

Git LFS Details

  • SHA256: 844ac5e18c96ae4d0bdf8943a09f831877cdc9f63ff0de3ec3bec59b1a293fd3
  • Pointer size: 130 Bytes
  • Size of remote file: 49.8 kB
data/ground_truth.json ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "database": [
3
+ {
4
+ "filename": "place00004796_db0000.jpg",
5
+ "place_id": 4796,
6
+ "city": "London",
7
+ "lat": 51.50916209875723,
8
+ "lon": -0.1489464938656002
9
+ },
10
+ {
11
+ "filename": "place00004796_db0001.jpg",
12
+ "place_id": 4796,
13
+ "city": "London",
14
+ "lat": 51.5091715210003,
15
+ "lon": -0.148933838176506
16
+ },
17
+ {
18
+ "filename": "place00004796_db0002.jpg",
19
+ "place_id": 4796,
20
+ "city": "London",
21
+ "lat": 51.50917252788888,
22
+ "lon": -0.1489267203772275
23
+ },
24
+ {
25
+ "filename": "place00004796_db0003.jpg",
26
+ "place_id": 4796,
27
+ "city": "London",
28
+ "lat": 51.50916624802473,
29
+ "lon": -0.1489330344067381
30
+ },
31
+ {
32
+ "filename": "place00203981_db0004.jpg",
33
+ "place_id": 203981,
34
+ "city": "Melbourne",
35
+ "lat": -37.81509362211957,
36
+ "lon": 144.992838452311
37
+ },
38
+ {
39
+ "filename": "place00203981_db0005.jpg",
40
+ "place_id": 203981,
41
+ "city": "Melbourne",
42
+ "lat": -37.81506767547648,
43
+ "lon": 144.9928412577368
44
+ },
45
+ {
46
+ "filename": "place00203981_db0006.jpg",
47
+ "place_id": 203981,
48
+ "city": "Melbourne",
49
+ "lat": -37.81509508467622,
50
+ "lon": 144.9928784011848
51
+ },
52
+ {
53
+ "filename": "place00203981_db0007.jpg",
54
+ "place_id": 203981,
55
+ "city": "Melbourne",
56
+ "lat": -37.81507084333941,
57
+ "lon": 144.9928344766533
58
+ },
59
+ {
60
+ "filename": "place00008797_db0008.jpg",
61
+ "place_id": 8797,
62
+ "city": "London",
63
+ "lat": 51.52845407460239,
64
+ "lon": -0.1750029953860952
65
+ },
66
+ {
67
+ "filename": "place00008797_db0009.jpg",
68
+ "place_id": 8797,
69
+ "city": "London",
70
+ "lat": 51.52846288309365,
71
+ "lon": -0.1750064077374475
72
+ },
73
+ {
74
+ "filename": "place00008797_db0010.jpg",
75
+ "place_id": 8797,
76
+ "city": "London",
77
+ "lat": 51.52845549538673,
78
+ "lon": -0.1750071427715894
79
+ },
80
+ {
81
+ "filename": "place00008797_db0011.jpg",
82
+ "place_id": 8797,
83
+ "city": "London",
84
+ "lat": 51.52842807679626,
85
+ "lon": -0.1750499767058685
86
+ },
87
+ {
88
+ "filename": "place00201236_db0012.jpg",
89
+ "place_id": 201236,
90
+ "city": "Melbourne",
91
+ "lat": -37.84278510997986,
92
+ "lon": 144.9907996152359
93
+ },
94
+ {
95
+ "filename": "place00201236_db0013.jpg",
96
+ "place_id": 201236,
97
+ "city": "Melbourne",
98
+ "lat": -37.84281467344245,
99
+ "lon": 144.9908206918488
100
+ },
101
+ {
102
+ "filename": "place00201236_db0014.jpg",
103
+ "place_id": 201236,
104
+ "city": "Melbourne",
105
+ "lat": -37.8428116471166,
106
+ "lon": 144.9907948010103
107
+ },
108
+ {
109
+ "filename": "place00201236_db0015.jpg",
110
+ "place_id": 201236,
111
+ "city": "Melbourne",
112
+ "lat": -37.84279737925578,
113
+ "lon": 144.9907950980963
114
+ },
115
+ {
116
+ "filename": "place00205527_db0016.jpg",
117
+ "place_id": 205527,
118
+ "city": "Melbourne",
119
+ "lat": -37.79846599314756,
120
+ "lon": 144.9649501082595
121
+ },
122
+ {
123
+ "filename": "place00205527_db0017.jpg",
124
+ "place_id": 205527,
125
+ "city": "Melbourne",
126
+ "lat": -37.79846121519402,
127
+ "lon": 144.9649625622069
128
+ },
129
+ {
130
+ "filename": "place00205527_db0018.jpg",
131
+ "place_id": 205527,
132
+ "city": "Melbourne",
133
+ "lat": -37.79846527742957,
134
+ "lon": 144.9650067670199
135
+ },
136
+ {
137
+ "filename": "place00205527_db0019.jpg",
138
+ "place_id": 205527,
139
+ "city": "Melbourne",
140
+ "lat": -37.79846435981845,
141
+ "lon": 144.964976424101
142
+ }
143
+ ],
144
+ "query": [
145
+ {
146
+ "filename": "place00004796_q0000.jpg",
147
+ "place_id": 4796,
148
+ "city": "London",
149
+ "lat": 51.5091691701721,
150
+ "lon": -0.1489371501232689
151
+ },
152
+ {
153
+ "filename": "place00203981_q0001.jpg",
154
+ "place_id": 203981,
155
+ "city": "Melbourne",
156
+ "lat": -37.81508165748332,
157
+ "lon": 144.9928762644381
158
+ },
159
+ {
160
+ "filename": "place00008797_q0002.jpg",
161
+ "place_id": 8797,
162
+ "city": "London",
163
+ "lat": 51.52844848114152,
164
+ "lon": -0.1750700055099293
165
+ },
166
+ {
167
+ "filename": "place00201236_q0003.jpg",
168
+ "place_id": 201236,
169
+ "city": "Melbourne",
170
+ "lat": -37.84279084700875,
171
+ "lon": 144.9907951180407
172
+ },
173
+ {
174
+ "filename": "place00205527_q0004.jpg",
175
+ "place_id": 205527,
176
+ "city": "Melbourne",
177
+ "lat": -37.79845501432714,
178
+ "lon": 144.9649401788737
179
+ }
180
+ ],
181
+ "place_mapping": {}
182
+ }
data/query/place00004796_q0000.jpg ADDED

Git LFS Details

  • SHA256: 75a1e7c9faadbaf931799aee74f7ff94205ac99d3619c9b0af61d4a724ee0849
  • Pointer size: 130 Bytes
  • Size of remote file: 54.3 kB
data/query/place00008797_q0002.jpg ADDED

Git LFS Details

  • SHA256: c471e6e1ab91a3cb93aec32e8ff4bdd706351e029b3504d0817f1a5d7e32e3d7
  • Pointer size: 130 Bytes
  • Size of remote file: 60 kB
data/query/place00201236_q0003.jpg ADDED

Git LFS Details

  • SHA256: f7154ca4874b14b94180b256b1541cbc7c26ce3acd1372ea07dffde243a05848
  • Pointer size: 130 Bytes
  • Size of remote file: 71.9 kB
data/query/place00203981_q0001.jpg ADDED

Git LFS Details

  • SHA256: 7ba472eeb385423094f88869b27b2ee01dcc15816c06fb70b7fd051bdb87297d
  • Pointer size: 130 Bytes
  • Size of remote file: 67.6 kB
data/query/place00205527_q0004.jpg ADDED

Git LFS Details

  • SHA256: abd980ba66d287ad64e4c7cf636115621d19ec0725981c584b7a611892cade9c
  • Pointer size: 130 Bytes
  • Size of remote file: 39.8 kB
dataset.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple PyTorch Dataset for VPR (Visual Place Recognition)
3
+ Combines database and query images with ground truth lookup.
4
+ """
5
+
6
+ import json
7
+ from pathlib import Path
8
+ from PIL import Image
9
+ import torch
10
+ from torch.utils.data import Dataset
11
+ import torchvision.transforms as T
12
+ from typing import List, Dict, Tuple
13
+
14
+
15
+ class VPRDataset(Dataset):
16
+ """
17
+ Simple VPR Dataset that loads both database and query images.
18
+
19
+ Usage:
20
+ dataset = VPRDataset('data')
21
+
22
+ # Get an image
23
+ img, filename, is_query = dataset[0]
24
+
25
+ # Get ground truth matches for a query
26
+ matches = dataset.gt('place00000123_q0000.jpg')
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ data_dir='data',
32
+ transform=None,
33
+ include_queries=True,
34
+ include_database=True
35
+ ):
36
+ """
37
+ Args:
38
+ data_dir: Path to data folder containing database/, query/, and ground_truth.json
39
+ transform: Optional torchvision transforms to apply to images
40
+ include_queries: Whether to include query images in the dataset
41
+ include_database: Whether to include database images in the dataset
42
+ """
43
+ self.data_dir = Path(data_dir)
44
+ self.include_queries = include_queries
45
+ self.include_database = include_database
46
+
47
+ # Default transform if none provided
48
+ if transform is None:
49
+ self.transform = T.Compose([
50
+ T.Resize((480, 640)),
51
+ T.ToTensor(),
52
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
53
+ ])
54
+ else:
55
+ self.transform = transform
56
+
57
+ # Load ground truth
58
+ gt_path = self.data_dir / 'ground_truth.json'
59
+ with open(gt_path, 'r') as f:
60
+ self.ground_truth = json.load(f)
61
+
62
+ # Build the dataset items list
63
+ self.items = []
64
+
65
+ if include_database:
66
+ for item in self.ground_truth['database']:
67
+ self.items.append({
68
+ 'filename': item['filename'],
69
+ 'path': self.data_dir / 'database' / item['filename'],
70
+ 'place_id': item['place_id'],
71
+ 'is_query': False,
72
+ 'city': item['city'],
73
+ 'lat': item['lat'],
74
+ 'lon': item['lon']
75
+ })
76
+
77
+ if include_queries:
78
+ for item in self.ground_truth['query']:
79
+ self.items.append({
80
+ 'filename': item['filename'],
81
+ 'path': self.data_dir / 'query' / item['filename'],
82
+ 'place_id': item['place_id'],
83
+ 'is_query': True,
84
+ 'city': item['city'],
85
+ 'lat': item['lat'],
86
+ 'lon': item['lon']
87
+ })
88
+
89
+ # Build lookup tables for fast ground truth queries
90
+ self._build_lookup_tables()
91
+
92
+ def _build_lookup_tables(self):
93
+ """Build internal lookup tables for efficient ground truth queries."""
94
+ # Map filename -> full item info
95
+ self.filename_to_item = {item['filename']: item for item in self.items}
96
+
97
+ # Map place_id -> list of database filenames
98
+ self.place_to_db_files = {}
99
+ for item in self.ground_truth['database']:
100
+ place_id = item['place_id']
101
+ if place_id not in self.place_to_db_files:
102
+ self.place_to_db_files[place_id] = []
103
+ self.place_to_db_files[place_id].append(item['filename'])
104
+
105
+ # Map query filename -> its place_id for fast lookup
106
+ self.query_to_place = {
107
+ item['filename']: item['place_id']
108
+ for item in self.ground_truth['query']
109
+ }
110
+
111
+ def __len__(self):
112
+ """Return total number of images in dataset."""
113
+ return len(self.items)
114
+
115
+ def __getitem__(self, idx) -> Tuple[torch.Tensor, str, bool]:
116
+ """
117
+ Get an image from the dataset.
118
+
119
+ Returns:
120
+ tuple: (image_tensor, filename, is_query)
121
+ - image_tensor: Transformed image as torch.Tensor
122
+ - filename: String filename (e.g., 'place00000123_db0001.jpg')
123
+ - is_query: Boolean indicating if this is a query image
124
+ """
125
+ item = self.items[idx]
126
+
127
+ # Load image
128
+ img = Image.open(item['path']).convert('RGB')
129
+
130
+ # Apply transforms
131
+ if self.transform:
132
+ img = self.transform(img)
133
+
134
+ return img, item['filename'], item['is_query']
135
+
136
+ def gt(self, query_filename: str) -> List[str]:
137
+ """
138
+ Get ground truth database matches for a query image.
139
+
140
+ Args:
141
+ query_filename: Filename of the query image (e.g., 'place00000123_q0000.jpg')
142
+
143
+ Returns:
144
+ List of database image filenames that match this query (same place_id)
145
+
146
+ Example:
147
+ >>> dataset = VPRDataset('data')
148
+ >>> matches = dataset.gt('place00000123_q0000.jpg')
149
+ >>> print(matches)
150
+ ['place00000123_db0000.jpg', 'place00000123_db0001.jpg', 'place00000123_db0002.jpg']
151
+ """
152
+ if query_filename not in self.query_to_place:
153
+ raise ValueError(f"Query filename '{query_filename}' not found in dataset")
154
+
155
+ place_id = self.query_to_place[query_filename]
156
+ return self.place_to_db_files.get(place_id, [])
157
+
158
+ def get_query_filenames(self) -> List[str]:
159
+ """Get list of all query image filenames."""
160
+ return list(self.query_to_place.keys())
161
+
162
+ def get_database_filenames(self) -> List[str]:
163
+ """Get list of all database image filenames."""
164
+ all_db_files = []
165
+ for files in self.place_to_db_files.values():
166
+ all_db_files.extend(files)
167
+ return all_db_files
168
+
169
+ def get_item_by_filename(self, filename: str) -> Dict:
170
+ """
171
+ Get full item information by filename.
172
+
173
+ Args:
174
+ filename: Image filename
175
+
176
+ Returns:
177
+ Dictionary with keys: filename, path, place_id, is_query, city, lat, lon
178
+ """
179
+ if filename not in self.filename_to_item:
180
+ raise ValueError(f"Filename '{filename}' not found in dataset")
181
+ return self.filename_to_item[filename]
182
+
183
+ @staticmethod
184
+ def get_place_id_from_filename(filename: str) -> int:
185
+ """
186
+ Extract place_id from filename.
187
+
188
+ Args:
189
+ filename: Image filename (e.g., 'place00000123_db0001.jpg')
190
+
191
+ Returns:
192
+ Integer place_id (e.g., 123)
193
+ """
194
+ return int(filename.split('_')[0].replace('place', ''))
195
+
196
+
197
+ # ============================================================================
198
+ # EXAMPLE USAGE
199
+ # ============================================================================
200
+
201
+ if __name__ == "__main__":
202
+ from torch.utils.data import DataLoader
203
+
204
+ print("=" * 60)
205
+ print("EXAMPLE 1: Basic Dataset Usage")
206
+ print("=" * 60)
207
+
208
+ # Create dataset with both queries and database
209
+ dataset = VPRDataset('data')
210
+ print(f"Total images in dataset: {len(dataset)}")
211
+ print(f"Query images: {len(dataset.get_query_filenames())}")
212
+ print(f"Database images: {len(dataset.get_database_filenames())}")
213
+ print()
214
+
215
+ # Get a single image
216
+ img, filename, is_query = dataset[0]
217
+ print(f"First image:")
218
+ print(f" Filename: {filename}")
219
+ print(f" Is query: {is_query}")
220
+ print(f" Image shape: {img.shape}")
221
+ print()
222
+
223
+ print("=" * 60)
224
+ print("EXAMPLE 2: Ground Truth Lookup")
225
+ print("=" * 60)
226
+
227
+ # Get a query filename
228
+ query_files = dataset.get_query_filenames()
229
+ query_file = query_files[0]
230
+
231
+ print(f"Query: {query_file}")
232
+
233
+ # Get ground truth matches
234
+ matches = dataset.gt(query_file)
235
+ print(f"Ground truth matches ({len(matches)} images):")
236
+ for match in matches:
237
+ print(f" - {match}")
238
+ print()
239
+
240
+ print("=" * 60)
241
+ print("EXAMPLE 3: Create Separate Query and Database Datasets")
242
+ print("=" * 60)
243
+
244
+ # Create database-only dataset
245
+ db_dataset = VPRDataset('data', include_queries=False, include_database=True)
246
+ print(f"Database-only dataset size: {len(db_dataset)}")
247
+
248
+ # Create query-only dataset
249
+ query_dataset = VPRDataset('data', include_queries=True, include_database=False)
250
+ print(f"Query-only dataset size: {len(query_dataset)}")
251
+ print()
252
+
253
+ print("=" * 60)
254
+ print("EXAMPLE 4: Using with DataLoader")
255
+ print("=" * 60)
256
+
257
+ # Create dataloader
258
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
259
+
260
+ # Get a batch
261
+ batch_imgs, batch_filenames, batch_is_query = next(iter(dataloader))
262
+ print(f"Batch shape: {batch_imgs.shape}")
263
+ print(f"Batch filenames: {batch_filenames}")
264
+ print(f"Batch is_query flags: {batch_is_query}")
265
+ print()
266
+
267
+ print("=" * 60)
268
+ print("EXAMPLE 5: Get Item Info by Filename")
269
+ print("=" * 60)
270
+
271
+ item_info = dataset.get_item_by_filename(query_file)
272
+ print(f"Full info for {query_file}:")
273
+ for key, value in item_info.items():
274
+ if key != 'path': # Skip path for cleaner output
275
+ print(f" {key}: {value}")
276
+ print()
277
+
278
+ print("=" * 60)
279
+ print("EXAMPLE 6: Typical VPR Workflow")
280
+ print("=" * 60)
281
+
282
+ print("Typical usage pattern:")
283
+ print("""
284
+ # 1. Create separate datasets
285
+ db_dataset = VPRDataset('data', include_queries=False)
286
+ query_dataset = VPRDataset('data', include_database=False)
287
+
288
+ # 2. Extract features for all database images
289
+ db_features = []
290
+ db_filenames = []
291
+ for img, filename, _ in db_dataset:
292
+ feat = model(img.unsqueeze(0)) # Your VPR model
293
+ db_features.append(feat)
294
+ db_filenames.append(filename)
295
+
296
+ # 3. For each query, find matches
297
+ for img, query_filename, _ in query_dataset:
298
+ # Extract query features
299
+ query_feat = model(img.unsqueeze(0))
300
+
301
+ # Compute similarities with database
302
+ similarities = compute_similarity(query_feat, db_features)
303
+
304
+ # Get top-K predictions
305
+ top_k_indices = similarities.argsort()[::-1][:10]
306
+ predicted_files = [db_filenames[i] for i in top_k_indices]
307
+
308
+ # Get ground truth
309
+ gt_files = query_dataset.gt(query_filename)
310
+
311
+ # Evaluate: check if any gt_files are in predicted_files
312
+ recall_at_10 = any(gt in predicted_files for gt in gt_files)
313
+ """)
model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from dataset import VPRDataset
4
+ from torch.utils.data import DataLoader
5
+ from tqdm import tqdm
6
+
7
+
8
+ def load_model(model_name: str, device: str = "auto") -> nn.Module:
9
+ """Load a pre-trained VPR model.
10
+
11
+ Args:
12
+ model_name: Name of the model to load (currently supports "eigenplaces")
13
+ device: Device to load model on ("auto", "cpu", "cuda")
14
+
15
+ Returns:
16
+ Loaded model in evaluation mode
17
+ """
18
+ if device == "auto":
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ if model_name.lower() == "eigenplaces":
22
+ model = torch.hub.load("gmberton/eigenplaces", "get_trained_model", backbone="ResNet50", fc_output_dim=2048)
23
+ setattr(model, "descriptor_dim", 2048)
24
+ else:
25
+ raise ValueError(f"Model {model_name} not found")
26
+
27
+ model = model.to(device)
28
+ model.eval()
29
+ return model
30
+
31
+
32
+ def compute_descriptors(model: nn.Module, dataset: VPRDataset, batch_size: int = 32, device: str = "auto") -> torch.Tensor:
33
+ """Compute descriptors for all images in the dataset.
34
+
35
+ Args:
36
+ model: Pre-trained VPR model
37
+ dataset: VPRDataset containing images
38
+ batch_size: Batch size for processing
39
+ device: Device to run inference on
40
+
41
+ Returns:
42
+ Tensor of shape (len(dataset), descriptor_dim) containing all descriptors
43
+ """
44
+ if device == "auto":
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
+
47
+ model.eval()
48
+ descriptors = torch.zeros(len(dataset), model.descriptor_dim, device="cpu")
49
+
50
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)
51
+
52
+ with torch.no_grad():
53
+ start_idx = 0
54
+ for batch_images, batch_filenames, batch_is_query in tqdm(dataloader, desc="Computing descriptors"):
55
+ batch_images = batch_images.to(device)
56
+ batch_descriptors = model(batch_images)
57
+ batch_descriptors = batch_descriptors.cpu()
58
+ end_idx = start_idx + batch_descriptors.size(0)
59
+ descriptors[start_idx:end_idx] = batch_descriptors
60
+ start_idx = end_idx
61
+
62
+ return descriptors.cpu() # Return on CPU for easier handling
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=1.9.0
2
+ torchvision>=0.10.0
3
+ gradio>=3.0.0
4
+ Pillow>=8.0.0
5
+ numpy>=1.21.0
6
+ tqdm>=4.60.0
scripts/sample_data.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sample database and query images from GSV-Cities dataset for VPR demo.
3
+ Creates data/database/ and data/query/ folders with images.
4
+ Ground truth is encoded in the filename: placeID_imageID.jpg
5
+ """
6
+
7
+ import pandas as pd
8
+ from pathlib import Path
9
+ from PIL import Image
10
+ import shutil
11
+ import random
12
+ import json
13
+ from tqdm import tqdm
14
+
15
+ # Configuration
16
+ BASE_PATH = '/Users/olivergrainge/datasets/gsv-cities'
17
+ OUTPUT_PATH = 'data'
18
+ NUM_PLACES = 5 # Number of unique places to sample
19
+ DB_IMAGES_PER_PLACE = 4 # Images per place for database
20
+ QUERY_IMAGES_PER_PLACE = 1 # Images per place for queries
21
+ CITIES = ['London', 'Boston', 'Melbourne'] # Cities to sample from
22
+ MIN_IMAGES_PER_PLACE = 5 # Minimum images a place must have
23
+
24
+ def load_dataframes(base_path, cities):
25
+ """Load and combine dataframes from multiple cities."""
26
+ dfs = []
27
+ for i, city in enumerate(cities):
28
+ df_path = Path(base_path) / 'Dataframes' / f'{city}.csv'
29
+ if not df_path.exists():
30
+ print(f"Warning: {df_path} not found, skipping {city}")
31
+ continue
32
+
33
+ df = pd.read_csv(df_path)
34
+ # Add prefix to place_id to distinguish between cities
35
+ df['place_id'] = df['place_id'] + (i * 10**5)
36
+ df['city_name'] = city
37
+ dfs.append(df)
38
+
39
+ if not dfs:
40
+ raise FileNotFoundError("No valid city dataframes found!")
41
+
42
+ return pd.concat(dfs, ignore_index=True)
43
+
44
+ def get_img_path(base_path, row):
45
+ """Construct the full image path from a dataframe row."""
46
+ city = row['city_id']
47
+ pl_id = row['place_id'] % 10**5
48
+ pl_id = str(pl_id).zfill(7)
49
+ panoid = row['panoid']
50
+ year = str(row['year']).zfill(4)
51
+ month = str(row['month']).zfill(2)
52
+ northdeg = str(row['northdeg']).zfill(3)
53
+ lat, lon = str(row['lat']), str(row['lon'])
54
+
55
+ img_name = f"{city}_{pl_id}_{year}_{month}_{northdeg}_{lat}_{lon}_{panoid}.jpg"
56
+ return Path(base_path) / 'Images' / city / img_name
57
+
58
+ def sample_and_copy_images():
59
+ """Main function to sample and organize images."""
60
+
61
+ # Create output directories
62
+ db_path = Path(OUTPUT_PATH) / 'database'
63
+ query_path = Path(OUTPUT_PATH) / 'query'
64
+ db_path.mkdir(parents=True, exist_ok=True)
65
+ query_path.mkdir(parents=True, exist_ok=True)
66
+
67
+ print("Loading dataframes...")
68
+ df = load_dataframes(BASE_PATH, CITIES)
69
+
70
+ # Filter places with minimum number of images
71
+ place_counts = df.groupby('place_id').size()
72
+ valid_places = place_counts[place_counts >= (DB_IMAGES_PER_PLACE + QUERY_IMAGES_PER_PLACE)].index
73
+ df = df[df['place_id'].isin(valid_places)]
74
+
75
+ print(f"Found {len(valid_places)} valid places")
76
+
77
+ # Sample N random places
78
+ sampled_places = random.sample(list(valid_places), min(NUM_PLACES, len(valid_places)))
79
+
80
+ print(f"Sampling {len(sampled_places)} places...")
81
+
82
+ # Ground truth structure
83
+ ground_truth = {
84
+ 'database': [],
85
+ 'query': [],
86
+ 'place_mapping': {}
87
+ }
88
+
89
+ db_count = 0
90
+ query_count = 0
91
+
92
+ for place_id in tqdm(sampled_places, desc="Processing places"):
93
+ place_images = df[df['place_id'] == place_id]
94
+
95
+ # Sample images for this place
96
+ sampled = place_images.sample(n=min(DB_IMAGES_PER_PLACE + QUERY_IMAGES_PER_PLACE, len(place_images)))
97
+
98
+ # Split into database and query
99
+ db_images = sampled.iloc[:DB_IMAGES_PER_PLACE]
100
+ query_images = sampled.iloc[DB_IMAGES_PER_PLACE:DB_IMAGES_PER_PLACE + QUERY_IMAGES_PER_PLACE]
101
+
102
+ # Copy database images
103
+ for idx, (_, row) in enumerate(db_images.iterrows()):
104
+ src_path = get_img_path(BASE_PATH, row)
105
+ if not src_path.exists():
106
+ print(f"Warning: {src_path} not found, skipping")
107
+ continue
108
+
109
+ # New filename: placeID_dbXXXX.jpg
110
+ dst_filename = f"place{str(place_id).zfill(8)}_db{str(db_count).zfill(4)}.jpg"
111
+ dst_path = db_path / dst_filename
112
+
113
+ shutil.copy2(src_path, dst_path)
114
+ ground_truth['database'].append({
115
+ 'filename': dst_filename,
116
+ 'place_id': int(place_id),
117
+ 'city': row['city_name'],
118
+ 'lat': float(row['lat']),
119
+ 'lon': float(row['lon'])
120
+ })
121
+ db_count += 1
122
+
123
+ # Copy query images
124
+ for idx, (_, row) in enumerate(query_images.iterrows()):
125
+ src_path = get_img_path(BASE_PATH, row)
126
+ if not src_path.exists():
127
+ print(f"Warning: {src_path} not found, skipping")
128
+ continue
129
+
130
+ # New filename: placeID_qXXXX.jpg
131
+ dst_filename = f"place{str(place_id).zfill(8)}_q{str(query_count).zfill(4)}.jpg"
132
+ dst_path = query_path / dst_filename
133
+
134
+ shutil.copy2(src_path, dst_path)
135
+ ground_truth['query'].append({
136
+ 'filename': dst_filename,
137
+ 'place_id': int(place_id),
138
+ 'city': row['city_name'],
139
+ 'lat': float(row['lat']),
140
+ 'lon': float(row['lon'])
141
+ })
142
+ query_count += 1
143
+
144
+ # Save ground truth to JSON
145
+ gt_path = Path(OUTPUT_PATH) / 'ground_truth.json'
146
+ with open(gt_path, 'w') as f:
147
+ json.dump(ground_truth, f, indent=2)
148
+
149
+ print(f"\n✓ Successfully created dataset!")
150
+ print(f" Database images: {db_count} (in {db_path})")
151
+ print(f" Query images: {query_count} (in {query_path})")
152
+ print(f" Ground truth: {gt_path}")
153
+ print(f"\nGround truth structure:")
154
+ print(f" - Filenames contain place_id: place########_db####.jpg or place########_q####.jpg")
155
+ print(f" - JSON file contains detailed metadata including GPS coordinates")
156
+
157
+ if __name__ == "__main__":
158
+ sample_and_copy_images()