simple-vpr-demo / scripts /sample_data.py
Oliver Grainge
Initial VPR demo implementation
351130e
"""
Sample database and query images from GSV-Cities dataset for VPR demo.
Creates data/database/ and data/query/ folders with images.
Ground truth is encoded in the filename: placeID_imageID.jpg
"""
import pandas as pd
from pathlib import Path
from PIL import Image
import shutil
import random
import json
from tqdm import tqdm
# Configuration
BASE_PATH = '/Users/olivergrainge/datasets/gsv-cities'
OUTPUT_PATH = 'data'
NUM_PLACES = 5 # Number of unique places to sample
DB_IMAGES_PER_PLACE = 4 # Images per place for database
QUERY_IMAGES_PER_PLACE = 1 # Images per place for queries
CITIES = ['London', 'Boston', 'Melbourne'] # Cities to sample from
MIN_IMAGES_PER_PLACE = 5 # Minimum images a place must have
def load_dataframes(base_path, cities):
"""Load and combine dataframes from multiple cities."""
dfs = []
for i, city in enumerate(cities):
df_path = Path(base_path) / 'Dataframes' / f'{city}.csv'
if not df_path.exists():
print(f"Warning: {df_path} not found, skipping {city}")
continue
df = pd.read_csv(df_path)
# Add prefix to place_id to distinguish between cities
df['place_id'] = df['place_id'] + (i * 10**5)
df['city_name'] = city
dfs.append(df)
if not dfs:
raise FileNotFoundError("No valid city dataframes found!")
return pd.concat(dfs, ignore_index=True)
def get_img_path(base_path, row):
"""Construct the full image path from a dataframe row."""
city = row['city_id']
pl_id = row['place_id'] % 10**5
pl_id = str(pl_id).zfill(7)
panoid = row['panoid']
year = str(row['year']).zfill(4)
month = str(row['month']).zfill(2)
northdeg = str(row['northdeg']).zfill(3)
lat, lon = str(row['lat']), str(row['lon'])
img_name = f"{city}_{pl_id}_{year}_{month}_{northdeg}_{lat}_{lon}_{panoid}.jpg"
return Path(base_path) / 'Images' / city / img_name
def sample_and_copy_images():
"""Main function to sample and organize images."""
# Create output directories
db_path = Path(OUTPUT_PATH) / 'database'
query_path = Path(OUTPUT_PATH) / 'query'
db_path.mkdir(parents=True, exist_ok=True)
query_path.mkdir(parents=True, exist_ok=True)
print("Loading dataframes...")
df = load_dataframes(BASE_PATH, CITIES)
# Filter places with minimum number of images
place_counts = df.groupby('place_id').size()
valid_places = place_counts[place_counts >= (DB_IMAGES_PER_PLACE + QUERY_IMAGES_PER_PLACE)].index
df = df[df['place_id'].isin(valid_places)]
print(f"Found {len(valid_places)} valid places")
# Sample N random places
sampled_places = random.sample(list(valid_places), min(NUM_PLACES, len(valid_places)))
print(f"Sampling {len(sampled_places)} places...")
# Ground truth structure
ground_truth = {
'database': [],
'query': [],
'place_mapping': {}
}
db_count = 0
query_count = 0
for place_id in tqdm(sampled_places, desc="Processing places"):
place_images = df[df['place_id'] == place_id]
# Sample images for this place
sampled = place_images.sample(n=min(DB_IMAGES_PER_PLACE + QUERY_IMAGES_PER_PLACE, len(place_images)))
# Split into database and query
db_images = sampled.iloc[:DB_IMAGES_PER_PLACE]
query_images = sampled.iloc[DB_IMAGES_PER_PLACE:DB_IMAGES_PER_PLACE + QUERY_IMAGES_PER_PLACE]
# Copy database images
for idx, (_, row) in enumerate(db_images.iterrows()):
src_path = get_img_path(BASE_PATH, row)
if not src_path.exists():
print(f"Warning: {src_path} not found, skipping")
continue
# New filename: placeID_dbXXXX.jpg
dst_filename = f"place{str(place_id).zfill(8)}_db{str(db_count).zfill(4)}.jpg"
dst_path = db_path / dst_filename
shutil.copy2(src_path, dst_path)
ground_truth['database'].append({
'filename': dst_filename,
'place_id': int(place_id),
'city': row['city_name'],
'lat': float(row['lat']),
'lon': float(row['lon'])
})
db_count += 1
# Copy query images
for idx, (_, row) in enumerate(query_images.iterrows()):
src_path = get_img_path(BASE_PATH, row)
if not src_path.exists():
print(f"Warning: {src_path} not found, skipping")
continue
# New filename: placeID_qXXXX.jpg
dst_filename = f"place{str(place_id).zfill(8)}_q{str(query_count).zfill(4)}.jpg"
dst_path = query_path / dst_filename
shutil.copy2(src_path, dst_path)
ground_truth['query'].append({
'filename': dst_filename,
'place_id': int(place_id),
'city': row['city_name'],
'lat': float(row['lat']),
'lon': float(row['lon'])
})
query_count += 1
# Save ground truth to JSON
gt_path = Path(OUTPUT_PATH) / 'ground_truth.json'
with open(gt_path, 'w') as f:
json.dump(ground_truth, f, indent=2)
print(f"\n✓ Successfully created dataset!")
print(f" Database images: {db_count} (in {db_path})")
print(f" Query images: {query_count} (in {query_path})")
print(f" Ground truth: {gt_path}")
print(f"\nGround truth structure:")
print(f" - Filenames contain place_id: place########_db####.jpg or place########_q####.jpg")
print(f" - JSON file contains detailed metadata including GPS coordinates")
if __name__ == "__main__":
sample_and_copy_images()