| import re | |
| import os | |
| import yaml | |
| import cv2 | |
| import argparse | |
| import warnings | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from easydict import EasyDict as ed | |
| class Simplify(nn.Module): | |
| def __init__(self, model): | |
| super(Simplify, self).__init__() | |
| self.model = model | |
| def cuda(self): | |
| self.model = self.model.cuda() | |
| return self | |
| def forward(self, x): | |
| out = self.model({'image': x}) | |
| return out['pred'] | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', '-c', type=str, default='configs/InSPyReNet_SwinB.yaml') | |
| parser.add_argument('--resume', '-r', action='store_true', default=False) | |
| parser.add_argument('--verbose', '-v', action='store_true', default=False) | |
| parser.add_argument('--debug', '-d', action='store_true', default=False) | |
| args = parser.parse_args() | |
| cuda_visible_devices = None | |
| local_rank = -1 | |
| if "CUDA_VISIBLE_DEVICES" in os.environ.keys(): | |
| cuda_visible_devices = [int(i) for i in os.environ["CUDA_VISIBLE_DEVICES"].split(',')] | |
| if "LOCAL_RANK" in os.environ.keys(): | |
| local_rank = int(os.environ["LOCAL_RANK"]) | |
| if local_rank == -1: | |
| device_num = 1 | |
| elif cuda_visible_devices is None: | |
| device_num = torch.cuda.device_count() | |
| else: | |
| device_num = len(cuda_visible_devices) | |
| args.device_num = device_num | |
| args.local_rank = local_rank | |
| warnings.simplefilter("ignore") | |
| return args | |
| def sort(x): | |
| convert = lambda text: int(text) if text.isdigit() else text.lower() | |
| alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] | |
| return sorted(x, key=alphanum_key) | |
| def load_config(config_dir, easy=True): | |
| cfg = yaml.load(open(config_dir), yaml.FullLoader) | |
| if easy is True: | |
| cfg = ed(cfg) | |
| return cfg | |
| def to_cuda(sample): | |
| for key in sample.keys(): | |
| if type(sample[key]) == torch.Tensor: | |
| sample[key] = sample[key].cuda() | |
| return sample | |
| def to_numpy(pred, shape): | |
| pred = F.interpolate(pred, shape, mode='bilinear', align_corners=True) | |
| pred = pred.data.cpu() | |
| pred = pred.numpy().squeeze() | |
| return pred | |
| def debug_tile(deblist, size=(100, 100), activation=None): | |
| debugs = [] | |
| for debs in deblist: | |
| debug = [] | |
| for deb in debs: | |
| if activation is not None: | |
| deb = activation(deb) | |
| log = deb.cpu().detach().numpy().squeeze() | |
| log = ((log - log.min()) / (log.max() - log.min()) * 255).astype(np.uint8) | |
| log = cv2.cvtColor(log, cv2.COLOR_GRAY2RGB) | |
| log = cv2.resize(log, size) | |
| debug.append(log) | |
| debugs.append(np.vstack(debug)) | |
| return np.hstack(debugs) | |
| if __name__ == "__main__": | |
| x = torch.rand(4, 3, 576, 576) |