|
|
|
""" |
|
Selective SA-1B downloader: fetch only LayoutSAM-annotated images and masks. |
|
""" |
|
import os |
|
import tarfile |
|
import json |
|
import argparse |
|
import requests |
|
from multiprocessing import Pool |
|
import zlib |
|
import sys |
|
|
|
def download_and_extract(task): |
|
try: |
|
""" |
|
Download a shard and extract only the requested images and masks. |
|
task: (file_name, url, raw_dir, images_dir, masks_dir, skip_existing, image_paths, mask_paths) |
|
""" |
|
file_name, url, raw_dir, images_dir, masks_dir, skip_existing, image_paths, mask_paths = task |
|
raw_path = os.path.join(raw_dir, file_name) |
|
|
|
|
|
if not os.path.exists(raw_path): |
|
print(f"Downloading {file_name} from {url}...") |
|
resp = requests.get(url, stream=True) |
|
resp.raise_for_status() |
|
with open(raw_path, 'wb') as f: |
|
for chunk in resp.iter_content(8192): |
|
f.write(chunk) |
|
else: |
|
print(f"{file_name} already exists in {raw_dir}. Skipping download.") |
|
|
|
|
|
if file_name.endswith('.tar'): |
|
shard_name = os.path.splitext(file_name)[0] |
|
|
|
if skip_existing and os.path.isdir(os.path.join(images_dir, shard_name)) and os.path.isdir(os.path.join(masks_dir, shard_name)): |
|
print(f"{file_name} already extracted. Skipping.") |
|
return |
|
|
|
print(f"Extracting {file_name}...") |
|
extracted_count = 0 |
|
|
|
for attempt in (1,2): |
|
try: |
|
with tarfile.open(raw_path) as tar: |
|
for member in tar.getmembers(): |
|
name = member.name.lstrip("./") |
|
key = f"{shard_name}/{name}" |
|
if name.endswith(".jpg") and key in image_paths: |
|
out = os.path.join(images_dir, shard_name) |
|
os.makedirs(out, exist_ok=True) |
|
tar.extract(member, path=out) |
|
extracted_count += 1 |
|
elif name.endswith(".json") and key in mask_paths: |
|
out = os.path.join(masks_dir, shard_name) |
|
os.makedirs(out, exist_ok=True) |
|
tar.extract(member, path=out) |
|
extracted_count += 1 |
|
|
|
break |
|
except (tarfile.ReadError, zlib.error, OSError, IOError) as e: |
|
print(f"⚠️ Shard {file_name} seems corrupt (attempt {attempt}): {e}") |
|
if attempt == 1: |
|
print(" → Removing and redownloading shard…") |
|
os.remove(raw_path) |
|
|
|
resp = requests.get(url, stream=True) |
|
resp.raise_for_status() |
|
with open(raw_path, "wb") as f: |
|
for chunk in resp.iter_content(8192): |
|
f.write(chunk) |
|
continue |
|
else: |
|
print(" → Second attempt failed; skipping this shard.") |
|
return |
|
|
|
print(f"{file_name} extracted {extracted_count} files.") |
|
else: |
|
print(f"{file_name} is not a .tar archive. Skipping extraction.") |
|
except Exception as e: |
|
|
|
print(f"‼️ Unexpected error processing {file_name}: {e}") |
|
return |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser( |
|
description="Download SA-1B shards and extract only LayoutSAM images and masks" |
|
) |
|
parser.add_argument('--processes', type=int, default=4) |
|
parser.add_argument('--input_file', type=str, default='shard_links.txt') |
|
parser.add_argument('--raw_dir', type=str, default='raw') |
|
parser.add_argument('--images_dir', type=str, default='images') |
|
parser.add_argument('--masks_dir', type=str, default='annotations') |
|
parser.add_argument('--images_json', type=str, required=True, |
|
help="Path to images_to_download.json with annotated image paths.") |
|
parser.add_argument('--skip_existing', action='store_true') |
|
args = parser.parse_args() |
|
|
|
|
|
with open(args.images_json, 'r') as f: |
|
image_list = json.load(f) |
|
image_paths = set(image_list) |
|
mask_paths = {p.replace('.jpg', '.json') for p in image_list} |
|
|
|
print(f"Looking for {len(image_paths)} images and {len(mask_paths)} masks") |
|
|
|
|
|
lines = open(args.input_file).read().strip().splitlines() |
|
if lines and lines[0].startswith('file_name'): |
|
lines = lines[1:] |
|
|
|
|
|
os.makedirs(args.raw_dir, exist_ok=True) |
|
os.makedirs(args.images_dir, exist_ok=True) |
|
os.makedirs(args.masks_dir, exist_ok=True) |
|
|
|
|
|
tasks = [] |
|
for line in lines: |
|
name, url = line.split('\t') |
|
tasks.append((name, url, |
|
args.raw_dir, |
|
args.images_dir, |
|
args.masks_dir, |
|
args.skip_existing, |
|
image_paths, |
|
mask_paths)) |
|
|
|
|
|
with Pool(args.processes) as pool: |
|
pool.map(download_and_extract, tasks) |
|
|
|
print("✅ All done.") |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|