#!/usr/bin/env python3 """ Selective SA-1B downloader: fetch only LayoutSAM-annotated images and masks. """ import os import tarfile import json import argparse import requests from multiprocessing import Pool import zlib import sys def download_and_extract(task): try: """ Download a shard and extract only the requested images and masks. task: (file_name, url, raw_dir, images_dir, masks_dir, skip_existing, image_paths, mask_paths) """ file_name, url, raw_dir, images_dir, masks_dir, skip_existing, image_paths, mask_paths = task raw_path = os.path.join(raw_dir, file_name) # 1) Download if missing if not os.path.exists(raw_path): print(f"Downloading {file_name} from {url}...") resp = requests.get(url, stream=True) resp.raise_for_status() with open(raw_path, 'wb') as f: for chunk in resp.iter_content(8192): f.write(chunk) else: print(f"{file_name} already exists in {raw_dir}. Skipping download.") # 2) Extract only target files if file_name.endswith('.tar'): shard_name = os.path.splitext(file_name)[0] # Skip if already extracted if skip_existing and os.path.isdir(os.path.join(images_dir, shard_name)) and os.path.isdir(os.path.join(masks_dir, shard_name)): print(f"{file_name} already extracted. Skipping.") return print(f"Extracting {file_name}...") extracted_count = 0 # Try-extract loop: if corrupt, delete & redownload once for attempt in (1,2): try: with tarfile.open(raw_path) as tar: for member in tar.getmembers(): name = member.name.lstrip("./") key = f"{shard_name}/{name}" if name.endswith(".jpg") and key in image_paths: out = os.path.join(images_dir, shard_name) os.makedirs(out, exist_ok=True) tar.extract(member, path=out) extracted_count += 1 elif name.endswith(".json") and key in mask_paths: out = os.path.join(masks_dir, shard_name) os.makedirs(out, exist_ok=True) tar.extract(member, path=out) extracted_count += 1 # extraction succeeded break except (tarfile.ReadError, zlib.error, OSError, IOError) as e: print(f"⚠️ Shard {file_name} seems corrupt (attempt {attempt}): {e}") if attempt == 1: print(" → Removing and redownloading shard…") os.remove(raw_path) # re-download resp = requests.get(url, stream=True) resp.raise_for_status() with open(raw_path, "wb") as f: for chunk in resp.iter_content(8192): f.write(chunk) continue else: print(" → Second attempt failed; skipping this shard.") return print(f"{file_name} extracted {extracted_count} files.") else: print(f"{file_name} is not a .tar archive. Skipping extraction.") except Exception as e: # catch any unexpected error so the pool never crashes print(f"‼️ Unexpected error processing {file_name}: {e}") return def main(): parser = argparse.ArgumentParser( description="Download SA-1B shards and extract only LayoutSAM images and masks" ) parser.add_argument('--processes', type=int, default=4) parser.add_argument('--input_file', type=str, default='shard_links.txt') parser.add_argument('--raw_dir', type=str, default='raw') parser.add_argument('--images_dir', type=str, default='images') parser.add_argument('--masks_dir', type=str, default='annotations') parser.add_argument('--images_json', type=str, required=True, help="Path to images_to_download.json with annotated image paths.") parser.add_argument('--skip_existing', action='store_true') args = parser.parse_args() # Load image list and derive mask list with open(args.images_json, 'r') as f: image_list = json.load(f) image_paths = set(image_list) mask_paths = {p.replace('.jpg', '.json') for p in image_list} print(f"Looking for {len(image_paths)} images and {len(mask_paths)} masks") # Read shard links lines = open(args.input_file).read().strip().splitlines() if lines and lines[0].startswith('file_name'): lines = lines[1:] # Create dirs os.makedirs(args.raw_dir, exist_ok=True) os.makedirs(args.images_dir, exist_ok=True) os.makedirs(args.masks_dir, exist_ok=True) # Build tasks tasks = [] for line in lines: name, url = line.split('\t') tasks.append((name, url, args.raw_dir, args.images_dir, args.masks_dir, args.skip_existing, image_paths, mask_paths)) # Parallel processing with Pool(args.processes) as pool: pool.map(download_and_extract, tasks) print("✅ All done.") if __name__ == '__main__': main()