crapthings's picture
Upload folder using huggingface_hub
f7f604d
import re
import os
import yaml
import cv2
import argparse
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from easydict import EasyDict as ed
class Simplify(nn.Module):
def __init__(self, model):
super(Simplify, self).__init__()
self.model = model
def cuda(self):
self.model = self.model.cuda()
return self
def forward(self, x):
out = self.model({'image': x})
return out['pred']
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c', type=str, default='configs/InSPyReNet_SwinB.yaml')
parser.add_argument('--resume', '-r', action='store_true', default=False)
parser.add_argument('--verbose', '-v', action='store_true', default=False)
parser.add_argument('--debug', '-d', action='store_true', default=False)
args = parser.parse_args()
cuda_visible_devices = None
local_rank = -1
if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
cuda_visible_devices = [int(i) for i in os.environ["CUDA_VISIBLE_DEVICES"].split(',')]
if "LOCAL_RANK" in os.environ.keys():
local_rank = int(os.environ["LOCAL_RANK"])
if local_rank == -1:
device_num = 1
elif cuda_visible_devices is None:
device_num = torch.cuda.device_count()
else:
device_num = len(cuda_visible_devices)
args.device_num = device_num
args.local_rank = local_rank
warnings.simplefilter("ignore")
return args
def sort(x):
convert = lambda text: int(text) if text.isdigit() else text.lower()
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
return sorted(x, key=alphanum_key)
def load_config(config_dir, easy=True):
cfg = yaml.load(open(config_dir), yaml.FullLoader)
if easy is True:
cfg = ed(cfg)
return cfg
def to_cuda(sample):
for key in sample.keys():
if type(sample[key]) == torch.Tensor:
sample[key] = sample[key].cuda()
return sample
def to_numpy(pred, shape):
pred = F.interpolate(pred, shape, mode='bilinear', align_corners=True)
pred = pred.data.cpu()
pred = pred.numpy().squeeze()
return pred
def debug_tile(deblist, size=(100, 100), activation=None):
debugs = []
for debs in deblist:
debug = []
for deb in debs:
if activation is not None:
deb = activation(deb)
log = deb.cpu().detach().numpy().squeeze()
log = ((log - log.min()) / (log.max() - log.min()) * 255).astype(np.uint8)
log = cv2.cvtColor(log, cv2.COLOR_GRAY2RGB)
log = cv2.resize(log, size)
debug.append(log)
debugs.append(np.vstack(debug))
return np.hstack(debugs)
if __name__ == "__main__":
x = torch.rand(4, 3, 576, 576)