YOLO / yolo /utils /get_dataset.py
henry000's picture
🚚 [Move] All model code to yolo/
1197f7d
raw
history blame
3.1 kB
import os
import zipfile
import requests
from hydra import main
from loguru import logger
from tqdm import tqdm
def download_file(url, destination):
"""
Downloads a file from the specified URL to the destination path with progress logging.
"""
logger.info(f"Downloading {os.path.basename(destination)}...")
with requests.get(url, stream=True) as response:
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))
progress = tqdm(total=total_size, unit="iB", unit_scale=True, desc=os.path.basename(destination), leave=True)
with open(destination, "wb") as file:
for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
file.write(data)
progress.update(len(data))
progress.close()
logger.info("Download completed.")
def unzip_file(source, destination):
"""
Extracts a ZIP file to the specified directory and removes the ZIP file after extraction.
"""
logger.info(f"Unzipping {os.path.basename(source)}...")
with zipfile.ZipFile(source, "r") as zip_ref:
zip_ref.extractall(destination)
os.remove(source)
logger.info(f"Removed {source}.")
def check_files(directory, expected_count=None):
"""
Returns True if the number of files in the directory matches expected_count, False otherwise.
"""
files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
return len(files) == expected_count if expected_count is not None else bool(files)
@main(config_path="../config/data", config_name="download", version_base=None)
def prepare_dataset(cfg):
"""
Prepares dataset by downloading and unzipping if necessary.
"""
data_dir = cfg.save_path
for data_type, settings in cfg.datasets.items():
base_url = settings["base_url"]
for dataset_type, dataset_args in settings.items():
if dataset_type == "base_url":
continue # Skip the base_url entry
file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
url = f"{base_url}{file_name}"
local_zip_path = os.path.join(data_dir, file_name)
extract_to = os.path.join(data_dir, data_type) if data_type != "annotations" else data_dir
final_place = os.path.join(extract_to, dataset_type)
os.makedirs(extract_to, exist_ok=True)
if check_files(final_place, dataset_args.get("file_num")):
logger.info(f"Dataset {dataset_type} already verified.")
continue
if not os.path.exists(local_zip_path):
download_file(url, local_zip_path)
unzip_file(local_zip_path, extract_to)
if not check_files(final_place, dataset_args.get("file_num")):
logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
if __name__ == "__main__":
import sys
sys.path.append("./")
from tools.log_helper import custom_logger
custom_logger()
prepare_dataset()