sparse / ms-swift /uni /download.py
Enxin's picture
Upload folder using huggingface_hub
96fe658 verified
#!/usr/bin/env python3
"""
Selective SA-1B downloader: fetch only LayoutSAM-annotated images and masks.
"""
import os
import tarfile
import json
import argparse
import requests
from multiprocessing import Pool
import zlib
import sys
def download_and_extract(task):
try:
"""
Download a shard and extract only the requested images and masks.
task: (file_name, url, raw_dir, images_dir, masks_dir, skip_existing, image_paths, mask_paths)
"""
file_name, url, raw_dir, images_dir, masks_dir, skip_existing, image_paths, mask_paths = task
raw_path = os.path.join(raw_dir, file_name)
# 1) Download if missing
if not os.path.exists(raw_path):
print(f"Downloading {file_name} from {url}...")
resp = requests.get(url, stream=True)
resp.raise_for_status()
with open(raw_path, 'wb') as f:
for chunk in resp.iter_content(8192):
f.write(chunk)
else:
print(f"{file_name} already exists in {raw_dir}. Skipping download.")
# 2) Extract only target files
if file_name.endswith('.tar'):
shard_name = os.path.splitext(file_name)[0]
# Skip if already extracted
if skip_existing and os.path.isdir(os.path.join(images_dir, shard_name)) and os.path.isdir(os.path.join(masks_dir, shard_name)):
print(f"{file_name} already extracted. Skipping.")
return
print(f"Extracting {file_name}...")
extracted_count = 0
# Try-extract loop: if corrupt, delete & redownload once
for attempt in (1,2):
try:
with tarfile.open(raw_path) as tar:
for member in tar.getmembers():
name = member.name.lstrip("./")
key = f"{shard_name}/{name}"
if name.endswith(".jpg") and key in image_paths:
out = os.path.join(images_dir, shard_name)
os.makedirs(out, exist_ok=True)
tar.extract(member, path=out)
extracted_count += 1
elif name.endswith(".json") and key in mask_paths:
out = os.path.join(masks_dir, shard_name)
os.makedirs(out, exist_ok=True)
tar.extract(member, path=out)
extracted_count += 1
# extraction succeeded
break
except (tarfile.ReadError, zlib.error, OSError, IOError) as e:
print(f"⚠️ Shard {file_name} seems corrupt (attempt {attempt}): {e}")
if attempt == 1:
print(" → Removing and redownloading shard…")
os.remove(raw_path)
# re-download
resp = requests.get(url, stream=True)
resp.raise_for_status()
with open(raw_path, "wb") as f:
for chunk in resp.iter_content(8192):
f.write(chunk)
continue
else:
print(" → Second attempt failed; skipping this shard.")
return
print(f"{file_name} extracted {extracted_count} files.")
else:
print(f"{file_name} is not a .tar archive. Skipping extraction.")
except Exception as e:
# catch any unexpected error so the pool never crashes
print(f"‼️ Unexpected error processing {file_name}: {e}")
return
def main():
parser = argparse.ArgumentParser(
description="Download SA-1B shards and extract only LayoutSAM images and masks"
)
parser.add_argument('--processes', type=int, default=4)
parser.add_argument('--input_file', type=str, default='shard_links.txt')
parser.add_argument('--raw_dir', type=str, default='raw')
parser.add_argument('--images_dir', type=str, default='images')
parser.add_argument('--masks_dir', type=str, default='annotations')
parser.add_argument('--images_json', type=str, required=True,
help="Path to images_to_download.json with annotated image paths.")
parser.add_argument('--skip_existing', action='store_true')
args = parser.parse_args()
# Load image list and derive mask list
with open(args.images_json, 'r') as f:
image_list = json.load(f)
image_paths = set(image_list)
mask_paths = {p.replace('.jpg', '.json') for p in image_list}
print(f"Looking for {len(image_paths)} images and {len(mask_paths)} masks")
# Read shard links
lines = open(args.input_file).read().strip().splitlines()
if lines and lines[0].startswith('file_name'):
lines = lines[1:]
# Create dirs
os.makedirs(args.raw_dir, exist_ok=True)
os.makedirs(args.images_dir, exist_ok=True)
os.makedirs(args.masks_dir, exist_ok=True)
# Build tasks
tasks = []
for line in lines:
name, url = line.split('\t')
tasks.append((name, url,
args.raw_dir,
args.images_dir,
args.masks_dir,
args.skip_existing,
image_paths,
mask_paths))
# Parallel processing
with Pool(args.processes) as pool:
pool.map(download_and_extract, tasks)
print("✅ All done.")
if __name__ == '__main__':
main()