|
import re |
|
import os |
|
import yaml |
|
import cv2 |
|
import argparse |
|
import warnings |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from easydict import EasyDict as ed |
|
|
|
class Simplify(nn.Module): |
|
def __init__(self, model): |
|
super(Simplify, self).__init__() |
|
self.model = model |
|
|
|
def cuda(self): |
|
self.model = self.model.cuda() |
|
return self |
|
|
|
def forward(self, x): |
|
out = self.model({'image': x}) |
|
return out['pred'] |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--config', '-c', type=str, default='configs/InSPyReNet_SwinB.yaml') |
|
parser.add_argument('--resume', '-r', action='store_true', default=False) |
|
parser.add_argument('--verbose', '-v', action='store_true', default=False) |
|
parser.add_argument('--debug', '-d', action='store_true', default=False) |
|
args = parser.parse_args() |
|
|
|
cuda_visible_devices = None |
|
local_rank = -1 |
|
|
|
if "CUDA_VISIBLE_DEVICES" in os.environ.keys(): |
|
cuda_visible_devices = [int(i) for i in os.environ["CUDA_VISIBLE_DEVICES"].split(',')] |
|
if "LOCAL_RANK" in os.environ.keys(): |
|
local_rank = int(os.environ["LOCAL_RANK"]) |
|
|
|
if local_rank == -1: |
|
device_num = 1 |
|
elif cuda_visible_devices is None: |
|
device_num = torch.cuda.device_count() |
|
else: |
|
device_num = len(cuda_visible_devices) |
|
|
|
args.device_num = device_num |
|
args.local_rank = local_rank |
|
|
|
warnings.simplefilter("ignore") |
|
|
|
return args |
|
|
|
def sort(x): |
|
convert = lambda text: int(text) if text.isdigit() else text.lower() |
|
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] |
|
return sorted(x, key=alphanum_key) |
|
|
|
def load_config(config_dir, easy=True): |
|
cfg = yaml.load(open(config_dir), yaml.FullLoader) |
|
if easy is True: |
|
cfg = ed(cfg) |
|
return cfg |
|
|
|
def to_cuda(sample): |
|
for key in sample.keys(): |
|
if type(sample[key]) == torch.Tensor: |
|
sample[key] = sample[key].cuda() |
|
return sample |
|
|
|
def to_numpy(pred, shape): |
|
pred = F.interpolate(pred, shape, mode='bilinear', align_corners=True) |
|
pred = pred.data.cpu() |
|
pred = pred.numpy().squeeze() |
|
return pred |
|
|
|
def debug_tile(deblist, size=(100, 100), activation=None): |
|
debugs = [] |
|
for debs in deblist: |
|
debug = [] |
|
for deb in debs: |
|
if activation is not None: |
|
deb = activation(deb) |
|
log = deb.cpu().detach().numpy().squeeze() |
|
log = ((log - log.min()) / (log.max() - log.min()) * 255).astype(np.uint8) |
|
log = cv2.cvtColor(log, cv2.COLOR_GRAY2RGB) |
|
log = cv2.resize(log, size) |
|
debug.append(log) |
|
debugs.append(np.vstack(debug)) |
|
return np.hstack(debugs) |
|
|
|
|
|
if __name__ == "__main__": |
|
x = torch.rand(4, 3, 576, 576) |