| import os |
| import re |
| import random |
| import time |
| import torch |
| import torch.nn as nn |
| import logging |
| import numpy as np |
| from os import path as osp |
|
|
| def constant_init(module, val, bias=0): |
| if hasattr(module, 'weight') and module.weight is not None: |
| nn.init.constant_(module.weight, val) |
| if hasattr(module, 'bias') and module.bias is not None: |
| nn.init.constant_(module.bias, bias) |
|
|
| initialized_logger = {} |
| def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): |
| """Get the root logger. |
| The logger will be initialized if it has not been initialized. By default a |
| StreamHandler will be added. If `log_file` is specified, a FileHandler will |
| also be added. |
| Args: |
| logger_name (str): root logger name. Default: 'basicsr'. |
| log_file (str | None): The log filename. If specified, a FileHandler |
| will be added to the root logger. |
| log_level (int): The root logger level. Note that only the process of |
| rank 0 is affected, while other processes will set the level to |
| "Error" and be silent most of the time. |
| Returns: |
| logging.Logger: The root logger. |
| """ |
| logger = logging.getLogger(logger_name) |
| |
| if logger_name in initialized_logger: |
| return logger |
|
|
| format_str = '%(asctime)s %(levelname)s: %(message)s' |
| stream_handler = logging.StreamHandler() |
| stream_handler.setFormatter(logging.Formatter(format_str)) |
| logger.addHandler(stream_handler) |
| logger.propagate = False |
|
|
| if log_file is not None: |
| logger.setLevel(log_level) |
| |
| |
| file_handler = logging.FileHandler(log_file, 'a') |
| file_handler.setFormatter(logging.Formatter(format_str)) |
| file_handler.setLevel(log_level) |
| logger.addHandler(file_handler) |
| initialized_logger[logger_name] = True |
| return logger |
|
|
|
|
| IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ |
| torch.__version__)[0][:3])] >= [1, 12, 0] |
|
|
| def gpu_is_available(): |
| if IS_HIGH_VERSION: |
| if torch.backends.mps.is_available(): |
| return True |
| return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False |
|
|
| def get_device(gpu_id=None): |
| if gpu_id is None: |
| gpu_str = '' |
| elif isinstance(gpu_id, int): |
| gpu_str = f':{gpu_id}' |
| else: |
| raise TypeError('Input should be int value.') |
|
|
| if IS_HIGH_VERSION: |
| if torch.backends.mps.is_available(): |
| return torch.device('mps'+gpu_str) |
| return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') |
|
|
|
|
| def set_random_seed(seed): |
| """Set random seeds.""" |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def get_time_str(): |
| return time.strftime('%Y%m%d_%H%M%S', time.localtime()) |
|
|
|
|
| def scandir(dir_path, suffix=None, recursive=False, full_path=False): |
| """Scan a directory to find the interested files. |
| |
| Args: |
| dir_path (str): Path of the directory. |
| suffix (str | tuple(str), optional): File suffix that we are |
| interested in. Default: None. |
| recursive (bool, optional): If set to True, recursively scan the |
| directory. Default: False. |
| full_path (bool, optional): If set to True, include the dir_path. |
| Default: False. |
| |
| Returns: |
| A generator for all the interested files with relative pathes. |
| """ |
|
|
| if (suffix is not None) and not isinstance(suffix, (str, tuple)): |
| raise TypeError('"suffix" must be a string or tuple of strings') |
|
|
| root = dir_path |
|
|
| def _scandir(dir_path, suffix, recursive): |
| for entry in os.scandir(dir_path): |
| if not entry.name.startswith('.') and entry.is_file(): |
| if full_path: |
| return_path = entry.path |
| else: |
| return_path = osp.relpath(entry.path, root) |
|
|
| if suffix is None: |
| yield return_path |
| elif return_path.endswith(suffix): |
| yield return_path |
| else: |
| if recursive: |
| yield from _scandir(entry.path, suffix=suffix, recursive=recursive) |
| else: |
| continue |
|
|
| return _scandir(dir_path, suffix=suffix, recursive=recursive) |