Siam2315's picture
Update basicsr/utils/misc.py
54519d5 verified
import os
import re
import random
import time
import torch
import numpy as np
from os import path as osp
from .dist_util import master_only
from .logger import get_root_logger
# ---------------------------
# GPU / MPS Compatibility
# ---------------------------
# Check if PyTorch β‰₯ 1.12 for MPS (Apple Silicon)
try:
version_match = re.findall(
r"^([0-9]+)\.([0-9]+)\.([0-9]+)",
torch.__version__
)[0]
IS_HIGH_VERSION = [int(x) for x in version_match] >= [1, 12, 0]
except:
IS_HIGH_VERSION = False
def gpu_is_available():
"""Return True if CUDA or MPS is available."""
if IS_HIGH_VERSION and torch.backends.mps.is_available():
return True
return torch.cuda.is_available() and torch.backends.cudnn.is_available()
def get_device(gpu_id=None):
"""Return the best available device (MPS β†’ CUDA β†’ CPU)."""
gpu_str = f":{gpu_id}" if isinstance(gpu_id, int) else ""
# Apple MPS
if IS_HIGH_VERSION and torch.backends.mps.is_available():
return torch.device("mps")
# NVIDIA CUDA
if torch.cuda.is_available() and torch.backends.cudnn.is_available():
return torch.device("cuda" + gpu_str)
# CPU fallback
return torch.device("cpu")
# ---------------------------
# Utilities
# ---------------------------
def set_random_seed(seed):
"""Set random seeds."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_time_str():
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
def mkdir_and_rename(path):
if osp.exists(path):
new_name = path + '_archived_' + get_time_str()
print(f'Path already exists. Renamed to {new_name}', flush=True)
os.rename(path, new_name)
os.makedirs(path, exist_ok=True)
@master_only
def make_exp_dirs(opt):
path_opt = opt['path'].copy()
if opt['is_train']:
mkdir_and_rename(path_opt.pop('experiments_root'))
else:
mkdir_and_rename(path_opt.pop('results_root'))
for key, path in path_opt.items():
if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key):
os.makedirs(path, exist_ok=True)
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
root = dir_path
def _scan(path):
for entry in os.scandir(path):
if entry.is_file() and not entry.name.startswith('.'):
file_path = entry.path if full_path else osp.relpath(entry.path, root)
if suffix is None or file_path.endswith(suffix):
yield file_path
elif entry.is_dir() and recursive:
yield from _scan(entry.path)
return _scan(dir_path)
def check_resume(opt, resume_iter):
logger = get_root_logger()
if opt['path']['resume_state']:
networks = [k for k in opt.keys() if k.startswith('network_')]
flag_pretrain = any(opt['path'].get(f'pretrain_{n}') for n in networks)
if flag_pretrain:
logger.warning('pretrain_network path will be ignored during resuming.')
for network in networks:
basename = network.replace('network_', '')
if opt['path'].get('ignore_resume_networks') is None or (
basename not in opt['path']['ignore_resume_networks']
):
opt['path'][f'pretrain_{network}'] = osp.join(
opt['path']['models'], f'net_{basename}_{resume_iter}.pth'
)
logger.info(f"Set pretrain for {network}")
def sizeof_fmt(size, suffix='B'):
for unit in ['', 'K', 'M', 'G', 'T', 'P']:
if size < 1024:
return f"{size:3.1f} {unit}{suffix}"
size /= 1024
return f"{size:3.1f} Y{suffix}"