Spaces:
Runtime error
Runtime error
import os | |
import re | |
import random | |
import time | |
import torch | |
import numpy as np | |
from os import path as osp | |
from .dist_util import master_only | |
from .logger import get_root_logger | |
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ | |
torch.__version__)[0][:3])] >= [1, 12, 0] | |
def gpu_is_available(): | |
if IS_HIGH_VERSION: | |
if torch.backends.mps.is_available(): | |
return True | |
return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False | |
def get_device(gpu_id=None): | |
if gpu_id is None: | |
gpu_str = '' | |
elif isinstance(gpu_id, int): | |
gpu_str = f':{gpu_id}' | |
else: | |
raise TypeError('Input should be int value.') | |
if IS_HIGH_VERSION: | |
if torch.backends.mps.is_available(): | |
return torch.device('mps'+gpu_str) | |
return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') | |
def set_random_seed(seed): | |
"""Set random seeds.""" | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def get_time_str(): | |
return time.strftime('%Y%m%d_%H%M%S', time.localtime()) | |
def mkdir_and_rename(path): | |
"""mkdirs. If path exists, rename it with timestamp and create a new one. | |
Args: | |
path (str): Folder path. | |
""" | |
if osp.exists(path): | |
new_name = path + '_archived_' + get_time_str() | |
print(f'Path already exists. Rename it to {new_name}', flush=True) | |
os.rename(path, new_name) | |
os.makedirs(path, exist_ok=True) | |
def make_exp_dirs(opt): | |
"""Make dirs for experiments.""" | |
path_opt = opt['path'].copy() | |
if opt['is_train']: | |
mkdir_and_rename(path_opt.pop('experiments_root')) | |
else: | |
mkdir_and_rename(path_opt.pop('results_root')) | |
for key, path in path_opt.items(): | |
if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key): | |
os.makedirs(path, exist_ok=True) | |
def scandir(dir_path, suffix=None, recursive=False, full_path=False): | |
"""Scan a directory to find the interested files. | |
Args: | |
dir_path (str): Path of the directory. | |
suffix (str | tuple(str), optional): File suffix that we are | |
interested in. Default: None. | |
recursive (bool, optional): If set to True, recursively scan the | |
directory. Default: False. | |
full_path (bool, optional): If set to True, include the dir_path. | |
Default: False. | |
Returns: | |
A generator for all the interested files with relative pathes. | |
""" | |
if (suffix is not None) and not isinstance(suffix, (str, tuple)): | |
raise TypeError('"suffix" must be a string or tuple of strings') | |
root = dir_path | |
def _scandir(dir_path, suffix, recursive): | |
for entry in os.scandir(dir_path): | |
if not entry.name.startswith('.') and entry.is_file(): | |
if full_path: | |
return_path = entry.path | |
else: | |
return_path = osp.relpath(entry.path, root) | |
if suffix is None: | |
yield return_path | |
elif return_path.endswith(suffix): | |
yield return_path | |
else: | |
if recursive: | |
yield from _scandir(entry.path, suffix=suffix, recursive=recursive) | |
else: | |
continue | |
return _scandir(dir_path, suffix=suffix, recursive=recursive) | |
def check_resume(opt, resume_iter): | |
"""Check resume states and pretrain_network paths. | |
Args: | |
opt (dict): Options. | |
resume_iter (int): Resume iteration. | |
""" | |
logger = get_root_logger() | |
if opt['path']['resume_state']: | |
# get all the networks | |
networks = [key for key in opt.keys() if key.startswith('network_')] | |
flag_pretrain = False | |
for network in networks: | |
if opt['path'].get(f'pretrain_{network}') is not None: | |
flag_pretrain = True | |
if flag_pretrain: | |
logger.warning('pretrain_network path will be ignored during resuming.') | |
# set pretrained model paths | |
for network in networks: | |
name = f'pretrain_{network}' | |
basename = network.replace('network_', '') | |
if opt['path'].get('ignore_resume_networks') is None or (basename | |
not in opt['path']['ignore_resume_networks']): | |
opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') | |
logger.info(f"Set {name} to {opt['path'][name]}") | |
def sizeof_fmt(size, suffix='B'): | |
"""Get human readable file size. | |
Args: | |
size (int): File size. | |
suffix (str): Suffix. Default: 'B'. | |
Return: | |
str: Formated file siz. | |
""" | |
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: | |
if abs(size) < 1024.0: | |
return f'{size:3.1f} {unit}{suffix}' | |
size /= 1024.0 | |
return f'{size:3.1f} Y{suffix}' | |