Spaces:
Sleeping
Sleeping
Oliver Grainge
commited on
Commit
·
351130e
1
Parent(s):
e00d7ee
Initial VPR demo implementation
Browse files- .gitattributes +1 -0
- README.md +44 -14
- app.py +154 -0
- data/database/place00004796_db0000.jpg +3 -0
- data/database/place00004796_db0001.jpg +3 -0
- data/database/place00004796_db0002.jpg +3 -0
- data/database/place00004796_db0003.jpg +3 -0
- data/database/place00008797_db0008.jpg +3 -0
- data/database/place00008797_db0009.jpg +3 -0
- data/database/place00008797_db0010.jpg +3 -0
- data/database/place00008797_db0011.jpg +3 -0
- data/database/place00201236_db0012.jpg +3 -0
- data/database/place00201236_db0013.jpg +3 -0
- data/database/place00201236_db0014.jpg +3 -0
- data/database/place00201236_db0015.jpg +3 -0
- data/database/place00203981_db0004.jpg +3 -0
- data/database/place00203981_db0005.jpg +3 -0
- data/database/place00203981_db0006.jpg +3 -0
- data/database/place00203981_db0007.jpg +3 -0
- data/database/place00205527_db0016.jpg +3 -0
- data/database/place00205527_db0017.jpg +3 -0
- data/database/place00205527_db0018.jpg +3 -0
- data/database/place00205527_db0019.jpg +3 -0
- data/ground_truth.json +182 -0
- data/query/place00004796_q0000.jpg +3 -0
- data/query/place00008797_q0002.jpg +3 -0
- data/query/place00201236_q0003.jpg +3 -0
- data/query/place00203981_q0001.jpg +3 -0
- data/query/place00205527_q0004.jpg +3 -0
- dataset.py +313 -0
- model.py +62 -0
- requirements.txt +6 -0
- scripts/sample_data.py +158 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,14 +1,44 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Visual Place Recognition Demo
|
| 2 |
+
|
| 3 |
+
This is a Visual Place Recognition (VPR) demo using EigenPlaces model. Upload a query image to find similar places in our database of 400+ images from various cities.
|
| 4 |
+
|
| 5 |
+
## How it works
|
| 6 |
+
|
| 7 |
+
1. Upload a query image
|
| 8 |
+
2. The model extracts visual features from your image
|
| 9 |
+
3. It compares these features with pre-computed features from 400+ database images
|
| 10 |
+
4. Returns the most similar matches with similarity scores and location information
|
| 11 |
+
|
| 12 |
+
## Dataset
|
| 13 |
+
|
| 14 |
+
- **Database**: 400+ images from various cities
|
| 15 |
+
- **Cities**: Melbourne, Boston, and others
|
| 16 |
+
- **Metadata**: Each image includes place ID, city, and GPS coordinates
|
| 17 |
+
|
| 18 |
+
## Model
|
| 19 |
+
|
| 20 |
+
- **Architecture**: EigenPlaces with ResNet50 backbone
|
| 21 |
+
- **Descriptor Dimension**: 2048
|
| 22 |
+
- **Similarity Metric**: Cosine similarity
|
| 23 |
+
|
| 24 |
+
## Usage
|
| 25 |
+
|
| 26 |
+
1. Upload a query image using the interface
|
| 27 |
+
2. Adjust the number of matches you want to see (1-10)
|
| 28 |
+
3. Click "Find Matches" to get results
|
| 29 |
+
4. View the matched images and their metadata
|
| 30 |
+
|
| 31 |
+
## Technical Details
|
| 32 |
+
|
| 33 |
+
The demo uses:
|
| 34 |
+
- EigenPlaces model for visual feature extraction
|
| 35 |
+
- Pre-computed descriptors for fast similarity search
|
| 36 |
+
- Cosine similarity for matching
|
| 37 |
+
- Gradio for the web interface
|
| 38 |
+
|
| 39 |
+
## Files
|
| 40 |
+
|
| 41 |
+
- `app.py`: Main Gradio application
|
| 42 |
+
- `model.py`: Model loading and descriptor computation
|
| 43 |
+
- `dataset.py`: Dataset handling and ground truth lookup
|
| 44 |
+
- `data/`: Contains database images, query images, and ground truth JSON
|
app.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from model import load_model
|
| 11 |
+
from dataset import VPRDataset
|
| 12 |
+
|
| 13 |
+
# Global variables
|
| 14 |
+
model = None
|
| 15 |
+
dataset = None
|
| 16 |
+
db_descriptors = None
|
| 17 |
+
db_filenames = None
|
| 18 |
+
|
| 19 |
+
def load_everything():
|
| 20 |
+
"""Load model, dataset, and pre-compute database descriptors."""
|
| 21 |
+
global model, dataset, db_descriptors, db_filenames
|
| 22 |
+
|
| 23 |
+
print("Loading model...")
|
| 24 |
+
model = load_model("eigenplaces")
|
| 25 |
+
|
| 26 |
+
print("Loading dataset...")
|
| 27 |
+
dataset = VPRDataset('data')
|
| 28 |
+
|
| 29 |
+
print("Pre-computing database descriptors...")
|
| 30 |
+
# Create database-only dataset
|
| 31 |
+
db_dataset = VPRDataset('data', include_queries=False, include_database=True)
|
| 32 |
+
|
| 33 |
+
# Create DataLoader for efficient batch processing
|
| 34 |
+
batch_size = 1
|
| 35 |
+
dataloader = DataLoader(db_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
|
| 36 |
+
|
| 37 |
+
# Compute descriptors for database images using DataLoader
|
| 38 |
+
db_descriptors = []
|
| 39 |
+
db_filenames = []
|
| 40 |
+
|
| 41 |
+
model.eval()
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
for batch_images, batch_filenames, batch_is_query in tqdm(dataloader, desc="Computing database descriptors"):
|
| 44 |
+
# Move batch to same device as model
|
| 45 |
+
device = next(model.parameters()).device
|
| 46 |
+
batch_images = batch_images.to(device)
|
| 47 |
+
|
| 48 |
+
# Compute descriptors for this batch
|
| 49 |
+
batch_descriptors = model(batch_images)
|
| 50 |
+
|
| 51 |
+
# Store results
|
| 52 |
+
db_descriptors.append(batch_descriptors.cpu())
|
| 53 |
+
db_filenames.extend(batch_filenames)
|
| 54 |
+
|
| 55 |
+
# Concatenate all descriptors
|
| 56 |
+
db_descriptors = torch.cat(db_descriptors, dim=0)
|
| 57 |
+
print(f"Pre-computed descriptors for {len(db_filenames)} database images")
|
| 58 |
+
|
| 59 |
+
def find_matches(query_image, top_k=5):
|
| 60 |
+
"""Find top-k matches for a query image."""
|
| 61 |
+
if model is None or db_descriptors is None:
|
| 62 |
+
return "Model not loaded yet. Please wait..."
|
| 63 |
+
|
| 64 |
+
# Extract query descriptor
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
query_batch = query_image.unsqueeze(0)
|
| 67 |
+
query_descriptor = model(query_batch).cpu()
|
| 68 |
+
|
| 69 |
+
# Compute similarities (cosine similarity)
|
| 70 |
+
query_norm = query_descriptor / torch.norm(query_descriptor)
|
| 71 |
+
db_norm = db_descriptors / torch.norm(db_descriptors, dim=1, keepdim=True)
|
| 72 |
+
similarities = torch.mm(query_norm, db_norm.T).squeeze()
|
| 73 |
+
|
| 74 |
+
# Get top-k matches
|
| 75 |
+
top_similarities, top_indices = torch.topk(similarities, top_k)
|
| 76 |
+
|
| 77 |
+
# Prepare results
|
| 78 |
+
results = []
|
| 79 |
+
for i, (sim, idx) in enumerate(zip(top_similarities, top_indices)):
|
| 80 |
+
filename = db_filenames[idx]
|
| 81 |
+
img_path = Path('data') / 'database' / filename
|
| 82 |
+
|
| 83 |
+
# Get metadata
|
| 84 |
+
item_info = dataset.get_item_by_filename(filename)
|
| 85 |
+
|
| 86 |
+
results.append({
|
| 87 |
+
'image': str(img_path),
|
| 88 |
+
'similarity': float(sim),
|
| 89 |
+
'place_id': item_info['place_id'],
|
| 90 |
+
'city': item_info['city'],
|
| 91 |
+
'coordinates': f"{item_info['lat']:.4f}, {item_info['lon']:.4f}"
|
| 92 |
+
})
|
| 93 |
+
|
| 94 |
+
return results
|
| 95 |
+
|
| 96 |
+
def demo_interface(query_image, top_k):
|
| 97 |
+
"""Gradio interface function."""
|
| 98 |
+
if query_image is None:
|
| 99 |
+
return "Please upload a query image"
|
| 100 |
+
|
| 101 |
+
# Convert PIL to tensor format expected by model
|
| 102 |
+
import torchvision.transforms as T
|
| 103 |
+
transform = T.Compose([
|
| 104 |
+
T.Resize((480, 640)),
|
| 105 |
+
T.ToTensor(),
|
| 106 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 107 |
+
])
|
| 108 |
+
|
| 109 |
+
query_tensor = transform(query_image)
|
| 110 |
+
matches = find_matches(query_tensor, int(top_k))
|
| 111 |
+
|
| 112 |
+
if isinstance(matches, str):
|
| 113 |
+
return matches
|
| 114 |
+
|
| 115 |
+
# Format results for display
|
| 116 |
+
result_text = "Top Matches:\n\n"
|
| 117 |
+
result_images = []
|
| 118 |
+
|
| 119 |
+
for i, match in enumerate(matches):
|
| 120 |
+
result_text += f"{i+1}. Similarity: {match['similarity']:.4f}\n"
|
| 121 |
+
result_text += f" Place ID: {match['place_id']}\n"
|
| 122 |
+
result_text += f" City: {match['city']}\n"
|
| 123 |
+
result_text += f" Coordinates: {match['coordinates']}\n\n"
|
| 124 |
+
|
| 125 |
+
result_images.append(match['image'])
|
| 126 |
+
|
| 127 |
+
return result_text, result_images
|
| 128 |
+
|
| 129 |
+
# Load everything on startup
|
| 130 |
+
load_everything()
|
| 131 |
+
|
| 132 |
+
# Create Gradio interface
|
| 133 |
+
with gr.Blocks(title="Visual Place Recognition Demo") as demo:
|
| 134 |
+
gr.Markdown("# Visual Place Recognition Demo")
|
| 135 |
+
gr.Markdown("Upload a query image to find similar places in our database!")
|
| 136 |
+
|
| 137 |
+
with gr.Row():
|
| 138 |
+
with gr.Column():
|
| 139 |
+
query_input = gr.Image(type="pil", label="Query Image")
|
| 140 |
+
top_k_slider = gr.Slider(1, 10, value=5, step=1, label="Number of matches")
|
| 141 |
+
find_button = gr.Button("Find Matches")
|
| 142 |
+
|
| 143 |
+
with gr.Column():
|
| 144 |
+
result_text = gr.Textbox(label="Results", lines=10)
|
| 145 |
+
result_gallery = gr.Gallery(label="Matched Images", show_label=True)
|
| 146 |
+
|
| 147 |
+
find_button.click(
|
| 148 |
+
fn=demo_interface,
|
| 149 |
+
inputs=[query_input, top_k_slider],
|
| 150 |
+
outputs=[result_text, result_gallery]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if __name__ == "__main__":
|
| 154 |
+
demo.launch()
|
data/database/place00004796_db0000.jpg
ADDED
|
Git LFS Details
|
data/database/place00004796_db0001.jpg
ADDED
|
Git LFS Details
|
data/database/place00004796_db0002.jpg
ADDED
|
Git LFS Details
|
data/database/place00004796_db0003.jpg
ADDED
|
Git LFS Details
|
data/database/place00008797_db0008.jpg
ADDED
|
Git LFS Details
|
data/database/place00008797_db0009.jpg
ADDED
|
Git LFS Details
|
data/database/place00008797_db0010.jpg
ADDED
|
Git LFS Details
|
data/database/place00008797_db0011.jpg
ADDED
|
Git LFS Details
|
data/database/place00201236_db0012.jpg
ADDED
|
Git LFS Details
|
data/database/place00201236_db0013.jpg
ADDED
|
Git LFS Details
|
data/database/place00201236_db0014.jpg
ADDED
|
Git LFS Details
|
data/database/place00201236_db0015.jpg
ADDED
|
Git LFS Details
|
data/database/place00203981_db0004.jpg
ADDED
|
Git LFS Details
|
data/database/place00203981_db0005.jpg
ADDED
|
Git LFS Details
|
data/database/place00203981_db0006.jpg
ADDED
|
Git LFS Details
|
data/database/place00203981_db0007.jpg
ADDED
|
Git LFS Details
|
data/database/place00205527_db0016.jpg
ADDED
|
Git LFS Details
|
data/database/place00205527_db0017.jpg
ADDED
|
Git LFS Details
|
data/database/place00205527_db0018.jpg
ADDED
|
Git LFS Details
|
data/database/place00205527_db0019.jpg
ADDED
|
Git LFS Details
|
data/ground_truth.json
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"database": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "place00004796_db0000.jpg",
|
| 5 |
+
"place_id": 4796,
|
| 6 |
+
"city": "London",
|
| 7 |
+
"lat": 51.50916209875723,
|
| 8 |
+
"lon": -0.1489464938656002
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"filename": "place00004796_db0001.jpg",
|
| 12 |
+
"place_id": 4796,
|
| 13 |
+
"city": "London",
|
| 14 |
+
"lat": 51.5091715210003,
|
| 15 |
+
"lon": -0.148933838176506
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"filename": "place00004796_db0002.jpg",
|
| 19 |
+
"place_id": 4796,
|
| 20 |
+
"city": "London",
|
| 21 |
+
"lat": 51.50917252788888,
|
| 22 |
+
"lon": -0.1489267203772275
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"filename": "place00004796_db0003.jpg",
|
| 26 |
+
"place_id": 4796,
|
| 27 |
+
"city": "London",
|
| 28 |
+
"lat": 51.50916624802473,
|
| 29 |
+
"lon": -0.1489330344067381
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"filename": "place00203981_db0004.jpg",
|
| 33 |
+
"place_id": 203981,
|
| 34 |
+
"city": "Melbourne",
|
| 35 |
+
"lat": -37.81509362211957,
|
| 36 |
+
"lon": 144.992838452311
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"filename": "place00203981_db0005.jpg",
|
| 40 |
+
"place_id": 203981,
|
| 41 |
+
"city": "Melbourne",
|
| 42 |
+
"lat": -37.81506767547648,
|
| 43 |
+
"lon": 144.9928412577368
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"filename": "place00203981_db0006.jpg",
|
| 47 |
+
"place_id": 203981,
|
| 48 |
+
"city": "Melbourne",
|
| 49 |
+
"lat": -37.81509508467622,
|
| 50 |
+
"lon": 144.9928784011848
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"filename": "place00203981_db0007.jpg",
|
| 54 |
+
"place_id": 203981,
|
| 55 |
+
"city": "Melbourne",
|
| 56 |
+
"lat": -37.81507084333941,
|
| 57 |
+
"lon": 144.9928344766533
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"filename": "place00008797_db0008.jpg",
|
| 61 |
+
"place_id": 8797,
|
| 62 |
+
"city": "London",
|
| 63 |
+
"lat": 51.52845407460239,
|
| 64 |
+
"lon": -0.1750029953860952
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"filename": "place00008797_db0009.jpg",
|
| 68 |
+
"place_id": 8797,
|
| 69 |
+
"city": "London",
|
| 70 |
+
"lat": 51.52846288309365,
|
| 71 |
+
"lon": -0.1750064077374475
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"filename": "place00008797_db0010.jpg",
|
| 75 |
+
"place_id": 8797,
|
| 76 |
+
"city": "London",
|
| 77 |
+
"lat": 51.52845549538673,
|
| 78 |
+
"lon": -0.1750071427715894
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"filename": "place00008797_db0011.jpg",
|
| 82 |
+
"place_id": 8797,
|
| 83 |
+
"city": "London",
|
| 84 |
+
"lat": 51.52842807679626,
|
| 85 |
+
"lon": -0.1750499767058685
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"filename": "place00201236_db0012.jpg",
|
| 89 |
+
"place_id": 201236,
|
| 90 |
+
"city": "Melbourne",
|
| 91 |
+
"lat": -37.84278510997986,
|
| 92 |
+
"lon": 144.9907996152359
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"filename": "place00201236_db0013.jpg",
|
| 96 |
+
"place_id": 201236,
|
| 97 |
+
"city": "Melbourne",
|
| 98 |
+
"lat": -37.84281467344245,
|
| 99 |
+
"lon": 144.9908206918488
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"filename": "place00201236_db0014.jpg",
|
| 103 |
+
"place_id": 201236,
|
| 104 |
+
"city": "Melbourne",
|
| 105 |
+
"lat": -37.8428116471166,
|
| 106 |
+
"lon": 144.9907948010103
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"filename": "place00201236_db0015.jpg",
|
| 110 |
+
"place_id": 201236,
|
| 111 |
+
"city": "Melbourne",
|
| 112 |
+
"lat": -37.84279737925578,
|
| 113 |
+
"lon": 144.9907950980963
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"filename": "place00205527_db0016.jpg",
|
| 117 |
+
"place_id": 205527,
|
| 118 |
+
"city": "Melbourne",
|
| 119 |
+
"lat": -37.79846599314756,
|
| 120 |
+
"lon": 144.9649501082595
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"filename": "place00205527_db0017.jpg",
|
| 124 |
+
"place_id": 205527,
|
| 125 |
+
"city": "Melbourne",
|
| 126 |
+
"lat": -37.79846121519402,
|
| 127 |
+
"lon": 144.9649625622069
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"filename": "place00205527_db0018.jpg",
|
| 131 |
+
"place_id": 205527,
|
| 132 |
+
"city": "Melbourne",
|
| 133 |
+
"lat": -37.79846527742957,
|
| 134 |
+
"lon": 144.9650067670199
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"filename": "place00205527_db0019.jpg",
|
| 138 |
+
"place_id": 205527,
|
| 139 |
+
"city": "Melbourne",
|
| 140 |
+
"lat": -37.79846435981845,
|
| 141 |
+
"lon": 144.964976424101
|
| 142 |
+
}
|
| 143 |
+
],
|
| 144 |
+
"query": [
|
| 145 |
+
{
|
| 146 |
+
"filename": "place00004796_q0000.jpg",
|
| 147 |
+
"place_id": 4796,
|
| 148 |
+
"city": "London",
|
| 149 |
+
"lat": 51.5091691701721,
|
| 150 |
+
"lon": -0.1489371501232689
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"filename": "place00203981_q0001.jpg",
|
| 154 |
+
"place_id": 203981,
|
| 155 |
+
"city": "Melbourne",
|
| 156 |
+
"lat": -37.81508165748332,
|
| 157 |
+
"lon": 144.9928762644381
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"filename": "place00008797_q0002.jpg",
|
| 161 |
+
"place_id": 8797,
|
| 162 |
+
"city": "London",
|
| 163 |
+
"lat": 51.52844848114152,
|
| 164 |
+
"lon": -0.1750700055099293
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"filename": "place00201236_q0003.jpg",
|
| 168 |
+
"place_id": 201236,
|
| 169 |
+
"city": "Melbourne",
|
| 170 |
+
"lat": -37.84279084700875,
|
| 171 |
+
"lon": 144.9907951180407
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"filename": "place00205527_q0004.jpg",
|
| 175 |
+
"place_id": 205527,
|
| 176 |
+
"city": "Melbourne",
|
| 177 |
+
"lat": -37.79845501432714,
|
| 178 |
+
"lon": 144.9649401788737
|
| 179 |
+
}
|
| 180 |
+
],
|
| 181 |
+
"place_mapping": {}
|
| 182 |
+
}
|
data/query/place00004796_q0000.jpg
ADDED
|
Git LFS Details
|
data/query/place00008797_q0002.jpg
ADDED
|
Git LFS Details
|
data/query/place00201236_q0003.jpg
ADDED
|
Git LFS Details
|
data/query/place00203981_q0001.jpg
ADDED
|
Git LFS Details
|
data/query/place00205527_q0004.jpg
ADDED
|
Git LFS Details
|
dataset.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple PyTorch Dataset for VPR (Visual Place Recognition)
|
| 3 |
+
Combines database and query images with ground truth lookup.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
import torchvision.transforms as T
|
| 12 |
+
from typing import List, Dict, Tuple
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class VPRDataset(Dataset):
|
| 16 |
+
"""
|
| 17 |
+
Simple VPR Dataset that loads both database and query images.
|
| 18 |
+
|
| 19 |
+
Usage:
|
| 20 |
+
dataset = VPRDataset('data')
|
| 21 |
+
|
| 22 |
+
# Get an image
|
| 23 |
+
img, filename, is_query = dataset[0]
|
| 24 |
+
|
| 25 |
+
# Get ground truth matches for a query
|
| 26 |
+
matches = dataset.gt('place00000123_q0000.jpg')
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
data_dir='data',
|
| 32 |
+
transform=None,
|
| 33 |
+
include_queries=True,
|
| 34 |
+
include_database=True
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Args:
|
| 38 |
+
data_dir: Path to data folder containing database/, query/, and ground_truth.json
|
| 39 |
+
transform: Optional torchvision transforms to apply to images
|
| 40 |
+
include_queries: Whether to include query images in the dataset
|
| 41 |
+
include_database: Whether to include database images in the dataset
|
| 42 |
+
"""
|
| 43 |
+
self.data_dir = Path(data_dir)
|
| 44 |
+
self.include_queries = include_queries
|
| 45 |
+
self.include_database = include_database
|
| 46 |
+
|
| 47 |
+
# Default transform if none provided
|
| 48 |
+
if transform is None:
|
| 49 |
+
self.transform = T.Compose([
|
| 50 |
+
T.Resize((480, 640)),
|
| 51 |
+
T.ToTensor(),
|
| 52 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 53 |
+
])
|
| 54 |
+
else:
|
| 55 |
+
self.transform = transform
|
| 56 |
+
|
| 57 |
+
# Load ground truth
|
| 58 |
+
gt_path = self.data_dir / 'ground_truth.json'
|
| 59 |
+
with open(gt_path, 'r') as f:
|
| 60 |
+
self.ground_truth = json.load(f)
|
| 61 |
+
|
| 62 |
+
# Build the dataset items list
|
| 63 |
+
self.items = []
|
| 64 |
+
|
| 65 |
+
if include_database:
|
| 66 |
+
for item in self.ground_truth['database']:
|
| 67 |
+
self.items.append({
|
| 68 |
+
'filename': item['filename'],
|
| 69 |
+
'path': self.data_dir / 'database' / item['filename'],
|
| 70 |
+
'place_id': item['place_id'],
|
| 71 |
+
'is_query': False,
|
| 72 |
+
'city': item['city'],
|
| 73 |
+
'lat': item['lat'],
|
| 74 |
+
'lon': item['lon']
|
| 75 |
+
})
|
| 76 |
+
|
| 77 |
+
if include_queries:
|
| 78 |
+
for item in self.ground_truth['query']:
|
| 79 |
+
self.items.append({
|
| 80 |
+
'filename': item['filename'],
|
| 81 |
+
'path': self.data_dir / 'query' / item['filename'],
|
| 82 |
+
'place_id': item['place_id'],
|
| 83 |
+
'is_query': True,
|
| 84 |
+
'city': item['city'],
|
| 85 |
+
'lat': item['lat'],
|
| 86 |
+
'lon': item['lon']
|
| 87 |
+
})
|
| 88 |
+
|
| 89 |
+
# Build lookup tables for fast ground truth queries
|
| 90 |
+
self._build_lookup_tables()
|
| 91 |
+
|
| 92 |
+
def _build_lookup_tables(self):
|
| 93 |
+
"""Build internal lookup tables for efficient ground truth queries."""
|
| 94 |
+
# Map filename -> full item info
|
| 95 |
+
self.filename_to_item = {item['filename']: item for item in self.items}
|
| 96 |
+
|
| 97 |
+
# Map place_id -> list of database filenames
|
| 98 |
+
self.place_to_db_files = {}
|
| 99 |
+
for item in self.ground_truth['database']:
|
| 100 |
+
place_id = item['place_id']
|
| 101 |
+
if place_id not in self.place_to_db_files:
|
| 102 |
+
self.place_to_db_files[place_id] = []
|
| 103 |
+
self.place_to_db_files[place_id].append(item['filename'])
|
| 104 |
+
|
| 105 |
+
# Map query filename -> its place_id for fast lookup
|
| 106 |
+
self.query_to_place = {
|
| 107 |
+
item['filename']: item['place_id']
|
| 108 |
+
for item in self.ground_truth['query']
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
def __len__(self):
|
| 112 |
+
"""Return total number of images in dataset."""
|
| 113 |
+
return len(self.items)
|
| 114 |
+
|
| 115 |
+
def __getitem__(self, idx) -> Tuple[torch.Tensor, str, bool]:
|
| 116 |
+
"""
|
| 117 |
+
Get an image from the dataset.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
tuple: (image_tensor, filename, is_query)
|
| 121 |
+
- image_tensor: Transformed image as torch.Tensor
|
| 122 |
+
- filename: String filename (e.g., 'place00000123_db0001.jpg')
|
| 123 |
+
- is_query: Boolean indicating if this is a query image
|
| 124 |
+
"""
|
| 125 |
+
item = self.items[idx]
|
| 126 |
+
|
| 127 |
+
# Load image
|
| 128 |
+
img = Image.open(item['path']).convert('RGB')
|
| 129 |
+
|
| 130 |
+
# Apply transforms
|
| 131 |
+
if self.transform:
|
| 132 |
+
img = self.transform(img)
|
| 133 |
+
|
| 134 |
+
return img, item['filename'], item['is_query']
|
| 135 |
+
|
| 136 |
+
def gt(self, query_filename: str) -> List[str]:
|
| 137 |
+
"""
|
| 138 |
+
Get ground truth database matches for a query image.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
query_filename: Filename of the query image (e.g., 'place00000123_q0000.jpg')
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
List of database image filenames that match this query (same place_id)
|
| 145 |
+
|
| 146 |
+
Example:
|
| 147 |
+
>>> dataset = VPRDataset('data')
|
| 148 |
+
>>> matches = dataset.gt('place00000123_q0000.jpg')
|
| 149 |
+
>>> print(matches)
|
| 150 |
+
['place00000123_db0000.jpg', 'place00000123_db0001.jpg', 'place00000123_db0002.jpg']
|
| 151 |
+
"""
|
| 152 |
+
if query_filename not in self.query_to_place:
|
| 153 |
+
raise ValueError(f"Query filename '{query_filename}' not found in dataset")
|
| 154 |
+
|
| 155 |
+
place_id = self.query_to_place[query_filename]
|
| 156 |
+
return self.place_to_db_files.get(place_id, [])
|
| 157 |
+
|
| 158 |
+
def get_query_filenames(self) -> List[str]:
|
| 159 |
+
"""Get list of all query image filenames."""
|
| 160 |
+
return list(self.query_to_place.keys())
|
| 161 |
+
|
| 162 |
+
def get_database_filenames(self) -> List[str]:
|
| 163 |
+
"""Get list of all database image filenames."""
|
| 164 |
+
all_db_files = []
|
| 165 |
+
for files in self.place_to_db_files.values():
|
| 166 |
+
all_db_files.extend(files)
|
| 167 |
+
return all_db_files
|
| 168 |
+
|
| 169 |
+
def get_item_by_filename(self, filename: str) -> Dict:
|
| 170 |
+
"""
|
| 171 |
+
Get full item information by filename.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
filename: Image filename
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
Dictionary with keys: filename, path, place_id, is_query, city, lat, lon
|
| 178 |
+
"""
|
| 179 |
+
if filename not in self.filename_to_item:
|
| 180 |
+
raise ValueError(f"Filename '{filename}' not found in dataset")
|
| 181 |
+
return self.filename_to_item[filename]
|
| 182 |
+
|
| 183 |
+
@staticmethod
|
| 184 |
+
def get_place_id_from_filename(filename: str) -> int:
|
| 185 |
+
"""
|
| 186 |
+
Extract place_id from filename.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
filename: Image filename (e.g., 'place00000123_db0001.jpg')
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Integer place_id (e.g., 123)
|
| 193 |
+
"""
|
| 194 |
+
return int(filename.split('_')[0].replace('place', ''))
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# ============================================================================
|
| 198 |
+
# EXAMPLE USAGE
|
| 199 |
+
# ============================================================================
|
| 200 |
+
|
| 201 |
+
if __name__ == "__main__":
|
| 202 |
+
from torch.utils.data import DataLoader
|
| 203 |
+
|
| 204 |
+
print("=" * 60)
|
| 205 |
+
print("EXAMPLE 1: Basic Dataset Usage")
|
| 206 |
+
print("=" * 60)
|
| 207 |
+
|
| 208 |
+
# Create dataset with both queries and database
|
| 209 |
+
dataset = VPRDataset('data')
|
| 210 |
+
print(f"Total images in dataset: {len(dataset)}")
|
| 211 |
+
print(f"Query images: {len(dataset.get_query_filenames())}")
|
| 212 |
+
print(f"Database images: {len(dataset.get_database_filenames())}")
|
| 213 |
+
print()
|
| 214 |
+
|
| 215 |
+
# Get a single image
|
| 216 |
+
img, filename, is_query = dataset[0]
|
| 217 |
+
print(f"First image:")
|
| 218 |
+
print(f" Filename: {filename}")
|
| 219 |
+
print(f" Is query: {is_query}")
|
| 220 |
+
print(f" Image shape: {img.shape}")
|
| 221 |
+
print()
|
| 222 |
+
|
| 223 |
+
print("=" * 60)
|
| 224 |
+
print("EXAMPLE 2: Ground Truth Lookup")
|
| 225 |
+
print("=" * 60)
|
| 226 |
+
|
| 227 |
+
# Get a query filename
|
| 228 |
+
query_files = dataset.get_query_filenames()
|
| 229 |
+
query_file = query_files[0]
|
| 230 |
+
|
| 231 |
+
print(f"Query: {query_file}")
|
| 232 |
+
|
| 233 |
+
# Get ground truth matches
|
| 234 |
+
matches = dataset.gt(query_file)
|
| 235 |
+
print(f"Ground truth matches ({len(matches)} images):")
|
| 236 |
+
for match in matches:
|
| 237 |
+
print(f" - {match}")
|
| 238 |
+
print()
|
| 239 |
+
|
| 240 |
+
print("=" * 60)
|
| 241 |
+
print("EXAMPLE 3: Create Separate Query and Database Datasets")
|
| 242 |
+
print("=" * 60)
|
| 243 |
+
|
| 244 |
+
# Create database-only dataset
|
| 245 |
+
db_dataset = VPRDataset('data', include_queries=False, include_database=True)
|
| 246 |
+
print(f"Database-only dataset size: {len(db_dataset)}")
|
| 247 |
+
|
| 248 |
+
# Create query-only dataset
|
| 249 |
+
query_dataset = VPRDataset('data', include_queries=True, include_database=False)
|
| 250 |
+
print(f"Query-only dataset size: {len(query_dataset)}")
|
| 251 |
+
print()
|
| 252 |
+
|
| 253 |
+
print("=" * 60)
|
| 254 |
+
print("EXAMPLE 4: Using with DataLoader")
|
| 255 |
+
print("=" * 60)
|
| 256 |
+
|
| 257 |
+
# Create dataloader
|
| 258 |
+
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
|
| 259 |
+
|
| 260 |
+
# Get a batch
|
| 261 |
+
batch_imgs, batch_filenames, batch_is_query = next(iter(dataloader))
|
| 262 |
+
print(f"Batch shape: {batch_imgs.shape}")
|
| 263 |
+
print(f"Batch filenames: {batch_filenames}")
|
| 264 |
+
print(f"Batch is_query flags: {batch_is_query}")
|
| 265 |
+
print()
|
| 266 |
+
|
| 267 |
+
print("=" * 60)
|
| 268 |
+
print("EXAMPLE 5: Get Item Info by Filename")
|
| 269 |
+
print("=" * 60)
|
| 270 |
+
|
| 271 |
+
item_info = dataset.get_item_by_filename(query_file)
|
| 272 |
+
print(f"Full info for {query_file}:")
|
| 273 |
+
for key, value in item_info.items():
|
| 274 |
+
if key != 'path': # Skip path for cleaner output
|
| 275 |
+
print(f" {key}: {value}")
|
| 276 |
+
print()
|
| 277 |
+
|
| 278 |
+
print("=" * 60)
|
| 279 |
+
print("EXAMPLE 6: Typical VPR Workflow")
|
| 280 |
+
print("=" * 60)
|
| 281 |
+
|
| 282 |
+
print("Typical usage pattern:")
|
| 283 |
+
print("""
|
| 284 |
+
# 1. Create separate datasets
|
| 285 |
+
db_dataset = VPRDataset('data', include_queries=False)
|
| 286 |
+
query_dataset = VPRDataset('data', include_database=False)
|
| 287 |
+
|
| 288 |
+
# 2. Extract features for all database images
|
| 289 |
+
db_features = []
|
| 290 |
+
db_filenames = []
|
| 291 |
+
for img, filename, _ in db_dataset:
|
| 292 |
+
feat = model(img.unsqueeze(0)) # Your VPR model
|
| 293 |
+
db_features.append(feat)
|
| 294 |
+
db_filenames.append(filename)
|
| 295 |
+
|
| 296 |
+
# 3. For each query, find matches
|
| 297 |
+
for img, query_filename, _ in query_dataset:
|
| 298 |
+
# Extract query features
|
| 299 |
+
query_feat = model(img.unsqueeze(0))
|
| 300 |
+
|
| 301 |
+
# Compute similarities with database
|
| 302 |
+
similarities = compute_similarity(query_feat, db_features)
|
| 303 |
+
|
| 304 |
+
# Get top-K predictions
|
| 305 |
+
top_k_indices = similarities.argsort()[::-1][:10]
|
| 306 |
+
predicted_files = [db_filenames[i] for i in top_k_indices]
|
| 307 |
+
|
| 308 |
+
# Get ground truth
|
| 309 |
+
gt_files = query_dataset.gt(query_filename)
|
| 310 |
+
|
| 311 |
+
# Evaluate: check if any gt_files are in predicted_files
|
| 312 |
+
recall_at_10 = any(gt in predicted_files for gt in gt_files)
|
| 313 |
+
""")
|
model.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from dataset import VPRDataset
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def load_model(model_name: str, device: str = "auto") -> nn.Module:
|
| 9 |
+
"""Load a pre-trained VPR model.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
model_name: Name of the model to load (currently supports "eigenplaces")
|
| 13 |
+
device: Device to load model on ("auto", "cpu", "cuda")
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
Loaded model in evaluation mode
|
| 17 |
+
"""
|
| 18 |
+
if device == "auto":
|
| 19 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 20 |
+
|
| 21 |
+
if model_name.lower() == "eigenplaces":
|
| 22 |
+
model = torch.hub.load("gmberton/eigenplaces", "get_trained_model", backbone="ResNet50", fc_output_dim=2048)
|
| 23 |
+
setattr(model, "descriptor_dim", 2048)
|
| 24 |
+
else:
|
| 25 |
+
raise ValueError(f"Model {model_name} not found")
|
| 26 |
+
|
| 27 |
+
model = model.to(device)
|
| 28 |
+
model.eval()
|
| 29 |
+
return model
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def compute_descriptors(model: nn.Module, dataset: VPRDataset, batch_size: int = 32, device: str = "auto") -> torch.Tensor:
|
| 33 |
+
"""Compute descriptors for all images in the dataset.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
model: Pre-trained VPR model
|
| 37 |
+
dataset: VPRDataset containing images
|
| 38 |
+
batch_size: Batch size for processing
|
| 39 |
+
device: Device to run inference on
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Tensor of shape (len(dataset), descriptor_dim) containing all descriptors
|
| 43 |
+
"""
|
| 44 |
+
if device == "auto":
|
| 45 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 46 |
+
|
| 47 |
+
model.eval()
|
| 48 |
+
descriptors = torch.zeros(len(dataset), model.descriptor_dim, device="cpu")
|
| 49 |
+
|
| 50 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)
|
| 51 |
+
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
start_idx = 0
|
| 54 |
+
for batch_images, batch_filenames, batch_is_query in tqdm(dataloader, desc="Computing descriptors"):
|
| 55 |
+
batch_images = batch_images.to(device)
|
| 56 |
+
batch_descriptors = model(batch_images)
|
| 57 |
+
batch_descriptors = batch_descriptors.cpu()
|
| 58 |
+
end_idx = start_idx + batch_descriptors.size(0)
|
| 59 |
+
descriptors[start_idx:end_idx] = batch_descriptors
|
| 60 |
+
start_idx = end_idx
|
| 61 |
+
|
| 62 |
+
return descriptors.cpu() # Return on CPU for easier handling
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=1.9.0
|
| 2 |
+
torchvision>=0.10.0
|
| 3 |
+
gradio>=3.0.0
|
| 4 |
+
Pillow>=8.0.0
|
| 5 |
+
numpy>=1.21.0
|
| 6 |
+
tqdm>=4.60.0
|
scripts/sample_data.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sample database and query images from GSV-Cities dataset for VPR demo.
|
| 3 |
+
Creates data/database/ and data/query/ folders with images.
|
| 4 |
+
Ground truth is encoded in the filename: placeID_imageID.jpg
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import shutil
|
| 11 |
+
import random
|
| 12 |
+
import json
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
# Configuration
|
| 16 |
+
BASE_PATH = '/Users/olivergrainge/datasets/gsv-cities'
|
| 17 |
+
OUTPUT_PATH = 'data'
|
| 18 |
+
NUM_PLACES = 5 # Number of unique places to sample
|
| 19 |
+
DB_IMAGES_PER_PLACE = 4 # Images per place for database
|
| 20 |
+
QUERY_IMAGES_PER_PLACE = 1 # Images per place for queries
|
| 21 |
+
CITIES = ['London', 'Boston', 'Melbourne'] # Cities to sample from
|
| 22 |
+
MIN_IMAGES_PER_PLACE = 5 # Minimum images a place must have
|
| 23 |
+
|
| 24 |
+
def load_dataframes(base_path, cities):
|
| 25 |
+
"""Load and combine dataframes from multiple cities."""
|
| 26 |
+
dfs = []
|
| 27 |
+
for i, city in enumerate(cities):
|
| 28 |
+
df_path = Path(base_path) / 'Dataframes' / f'{city}.csv'
|
| 29 |
+
if not df_path.exists():
|
| 30 |
+
print(f"Warning: {df_path} not found, skipping {city}")
|
| 31 |
+
continue
|
| 32 |
+
|
| 33 |
+
df = pd.read_csv(df_path)
|
| 34 |
+
# Add prefix to place_id to distinguish between cities
|
| 35 |
+
df['place_id'] = df['place_id'] + (i * 10**5)
|
| 36 |
+
df['city_name'] = city
|
| 37 |
+
dfs.append(df)
|
| 38 |
+
|
| 39 |
+
if not dfs:
|
| 40 |
+
raise FileNotFoundError("No valid city dataframes found!")
|
| 41 |
+
|
| 42 |
+
return pd.concat(dfs, ignore_index=True)
|
| 43 |
+
|
| 44 |
+
def get_img_path(base_path, row):
|
| 45 |
+
"""Construct the full image path from a dataframe row."""
|
| 46 |
+
city = row['city_id']
|
| 47 |
+
pl_id = row['place_id'] % 10**5
|
| 48 |
+
pl_id = str(pl_id).zfill(7)
|
| 49 |
+
panoid = row['panoid']
|
| 50 |
+
year = str(row['year']).zfill(4)
|
| 51 |
+
month = str(row['month']).zfill(2)
|
| 52 |
+
northdeg = str(row['northdeg']).zfill(3)
|
| 53 |
+
lat, lon = str(row['lat']), str(row['lon'])
|
| 54 |
+
|
| 55 |
+
img_name = f"{city}_{pl_id}_{year}_{month}_{northdeg}_{lat}_{lon}_{panoid}.jpg"
|
| 56 |
+
return Path(base_path) / 'Images' / city / img_name
|
| 57 |
+
|
| 58 |
+
def sample_and_copy_images():
|
| 59 |
+
"""Main function to sample and organize images."""
|
| 60 |
+
|
| 61 |
+
# Create output directories
|
| 62 |
+
db_path = Path(OUTPUT_PATH) / 'database'
|
| 63 |
+
query_path = Path(OUTPUT_PATH) / 'query'
|
| 64 |
+
db_path.mkdir(parents=True, exist_ok=True)
|
| 65 |
+
query_path.mkdir(parents=True, exist_ok=True)
|
| 66 |
+
|
| 67 |
+
print("Loading dataframes...")
|
| 68 |
+
df = load_dataframes(BASE_PATH, CITIES)
|
| 69 |
+
|
| 70 |
+
# Filter places with minimum number of images
|
| 71 |
+
place_counts = df.groupby('place_id').size()
|
| 72 |
+
valid_places = place_counts[place_counts >= (DB_IMAGES_PER_PLACE + QUERY_IMAGES_PER_PLACE)].index
|
| 73 |
+
df = df[df['place_id'].isin(valid_places)]
|
| 74 |
+
|
| 75 |
+
print(f"Found {len(valid_places)} valid places")
|
| 76 |
+
|
| 77 |
+
# Sample N random places
|
| 78 |
+
sampled_places = random.sample(list(valid_places), min(NUM_PLACES, len(valid_places)))
|
| 79 |
+
|
| 80 |
+
print(f"Sampling {len(sampled_places)} places...")
|
| 81 |
+
|
| 82 |
+
# Ground truth structure
|
| 83 |
+
ground_truth = {
|
| 84 |
+
'database': [],
|
| 85 |
+
'query': [],
|
| 86 |
+
'place_mapping': {}
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
db_count = 0
|
| 90 |
+
query_count = 0
|
| 91 |
+
|
| 92 |
+
for place_id in tqdm(sampled_places, desc="Processing places"):
|
| 93 |
+
place_images = df[df['place_id'] == place_id]
|
| 94 |
+
|
| 95 |
+
# Sample images for this place
|
| 96 |
+
sampled = place_images.sample(n=min(DB_IMAGES_PER_PLACE + QUERY_IMAGES_PER_PLACE, len(place_images)))
|
| 97 |
+
|
| 98 |
+
# Split into database and query
|
| 99 |
+
db_images = sampled.iloc[:DB_IMAGES_PER_PLACE]
|
| 100 |
+
query_images = sampled.iloc[DB_IMAGES_PER_PLACE:DB_IMAGES_PER_PLACE + QUERY_IMAGES_PER_PLACE]
|
| 101 |
+
|
| 102 |
+
# Copy database images
|
| 103 |
+
for idx, (_, row) in enumerate(db_images.iterrows()):
|
| 104 |
+
src_path = get_img_path(BASE_PATH, row)
|
| 105 |
+
if not src_path.exists():
|
| 106 |
+
print(f"Warning: {src_path} not found, skipping")
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
# New filename: placeID_dbXXXX.jpg
|
| 110 |
+
dst_filename = f"place{str(place_id).zfill(8)}_db{str(db_count).zfill(4)}.jpg"
|
| 111 |
+
dst_path = db_path / dst_filename
|
| 112 |
+
|
| 113 |
+
shutil.copy2(src_path, dst_path)
|
| 114 |
+
ground_truth['database'].append({
|
| 115 |
+
'filename': dst_filename,
|
| 116 |
+
'place_id': int(place_id),
|
| 117 |
+
'city': row['city_name'],
|
| 118 |
+
'lat': float(row['lat']),
|
| 119 |
+
'lon': float(row['lon'])
|
| 120 |
+
})
|
| 121 |
+
db_count += 1
|
| 122 |
+
|
| 123 |
+
# Copy query images
|
| 124 |
+
for idx, (_, row) in enumerate(query_images.iterrows()):
|
| 125 |
+
src_path = get_img_path(BASE_PATH, row)
|
| 126 |
+
if not src_path.exists():
|
| 127 |
+
print(f"Warning: {src_path} not found, skipping")
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
# New filename: placeID_qXXXX.jpg
|
| 131 |
+
dst_filename = f"place{str(place_id).zfill(8)}_q{str(query_count).zfill(4)}.jpg"
|
| 132 |
+
dst_path = query_path / dst_filename
|
| 133 |
+
|
| 134 |
+
shutil.copy2(src_path, dst_path)
|
| 135 |
+
ground_truth['query'].append({
|
| 136 |
+
'filename': dst_filename,
|
| 137 |
+
'place_id': int(place_id),
|
| 138 |
+
'city': row['city_name'],
|
| 139 |
+
'lat': float(row['lat']),
|
| 140 |
+
'lon': float(row['lon'])
|
| 141 |
+
})
|
| 142 |
+
query_count += 1
|
| 143 |
+
|
| 144 |
+
# Save ground truth to JSON
|
| 145 |
+
gt_path = Path(OUTPUT_PATH) / 'ground_truth.json'
|
| 146 |
+
with open(gt_path, 'w') as f:
|
| 147 |
+
json.dump(ground_truth, f, indent=2)
|
| 148 |
+
|
| 149 |
+
print(f"\n✓ Successfully created dataset!")
|
| 150 |
+
print(f" Database images: {db_count} (in {db_path})")
|
| 151 |
+
print(f" Query images: {query_count} (in {query_path})")
|
| 152 |
+
print(f" Ground truth: {gt_path}")
|
| 153 |
+
print(f"\nGround truth structure:")
|
| 154 |
+
print(f" - Filenames contain place_id: place########_db####.jpg or place########_q####.jpg")
|
| 155 |
+
print(f" - JSON file contains detailed metadata including GPS coordinates")
|
| 156 |
+
|
| 157 |
+
if __name__ == "__main__":
|
| 158 |
+
sample_and_copy_images()
|