import glob import os import torch import torch.nn as nn def get_padding(kernel_size, dilation=1): return int((kernel_size*dilation - dilation)/2) def get_padding_2d(kernel_size, dilation=(1, 1)): return (int((kernel_size[0]*dilation[0] - dilation[0])/2), int((kernel_size[1]*dilation[1] - dilation[1])/2)) def load_checkpoint(filepath, device): assert os.path.isfile(filepath) print("Loading '{}'".format(filepath)) checkpoint_dict = torch.load(filepath, map_location=device) print("Complete.") return checkpoint_dict def save_checkpoint(filepath, obj): print("Saving checkpoint to {}".format(filepath)) torch.save(obj, filepath) print("Complete.") def scan_checkpoint(cp_dir, prefix): pattern = os.path.join(cp_dir, prefix + '????????') cp_list = glob.glob(pattern) if len(cp_list) == 0: return None return sorted(cp_list)[-1] class LearnableSigmoid_1d(nn.Module): def __init__(self, in_features, beta=1): super().__init__() self.beta = beta self.slope = nn.Parameter(torch.ones(in_features)) self.slope.requiresGrad = True def forward(self, x): return self.beta * torch.sigmoid(self.slope * x) class LearnableSigmoid_2d(nn.Module): def __init__(self, in_features, beta=1): super().__init__() self.beta = beta self.slope = nn.Parameter(torch.ones(in_features, 1)) self.slope.requiresGrad = True def forward(self, x): return self.beta * torch.sigmoid(self.slope * x)