Spaces:
Runtime error
Runtime error
| import os | |
| import copy | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from CaffeLoader import loadCaffemodel, ModelParallel | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| # Basic options | |
| parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg') | |
| parser.add_argument("-style_blend_weights", default=None) | |
| parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg') | |
| parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512) | |
| parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set -gpu = c", default=0) | |
| # Optimization options | |
| parser.add_argument("-content_weight", type=float, default=5e0) | |
| parser.add_argument("-style_weight", type=float, default=1e2) | |
| parser.add_argument("-normalize_weights", action='store_true') | |
| parser.add_argument("-tv_weight", type=float, default=1e-3) | |
| parser.add_argument("-num_iterations", type=int, default=1000) | |
| parser.add_argument("-init", choices=['random', 'image'], default='random') | |
| parser.add_argument("-init_image", default=None) | |
| parser.add_argument("-optimizer", choices=['lbfgs', 'adam'], default='adam') | |
| parser.add_argument("-learning_rate", type=float, default=1e0) | |
| parser.add_argument("-lbfgs_num_correction", type=int, default=100) | |
| # Output options | |
| parser.add_argument("-print_iter", type=int, default=50) | |
| parser.add_argument("-save_iter", type=int, default=100) | |
| parser.add_argument("-output_image", default='out.png') | |
| # Other options | |
| parser.add_argument("-style_scale", type=float, default=1.0) | |
| parser.add_argument("-original_colors", type=int, choices=[0, 1], default=0) | |
| parser.add_argument("-pooling", choices=['avg', 'max'], default='max') | |
| parser.add_argument("-model_file", type=str, default='models/vgg19-d01eb7cb.pth') | |
| parser.add_argument("-disable_check", action='store_true') | |
| parser.add_argument("-backend", choices=['nn', 'cudnn', 'mkl', 'mkldnn', 'openmp', 'mkl,cudnn', 'cudnn,mkl'], default='nn') | |
| parser.add_argument("-cudnn_autotune", action='store_true') | |
| parser.add_argument("-seed", type=int, default=-1) | |
| parser.add_argument("-content_layers", help="layers for content", default='relu4_2') | |
| parser.add_argument("-style_layers", help="layers for style", default='relu1_1,relu2_1,relu3_1,relu4_1,relu5_1') | |
| parser.add_argument("-multidevice_strategy", default='4,7,29') | |
| params = parser.parse_args() | |
| Image.MAX_IMAGE_PIXELS = 1000000000 # Support gigapixel images | |
| class TransferParams(): | |
| style_image = 'examples/inputs/seated-nude.jpg' | |
| style_blend_weights = None | |
| content_image = 'examples/inputs/tubingen.jpg' | |
| image_size = 300 | |
| gpu = "c" #0 | |
| content_weight = 5e0 | |
| style_weight = 1e2 | |
| normalize_weights = False | |
| tv_weight = 1e-3 | |
| num_iterations = 1000 | |
| init = 'random' | |
| init_image = None | |
| optimizer = 'adam' | |
| learning_rate = 1e0 | |
| lbfgs_num_correction = 100 | |
| print_iter = 50 | |
| save_iter = 1000 | |
| output_image = 'out.png' | |
| log_level = 10 | |
| style_scale = 1.0 | |
| original_colors = 0 | |
| pooling = 'max' | |
| model_file = 'models/nin_imagenet.pth'#vgg16-00b39a1b.pth' | |
| disable_check = False | |
| backend = 'mkl' | |
| cudnn_autotune = False | |
| seed = -1 | |
| content_layers = 'relu0,relu3,relu7,relu12'#relu4_2'# | |
| style_layers = 'relu0,relu3,relu7,relu12'#relu1_1,relu2_1,relu3_1,relu4_1,relu5_1'#' | |
| multidevice_strategy = '4,7,29' | |
| def main(): | |
| transfer(params) | |
| def transfer(params): | |
| dtype, multidevice, backward_device = setup_gpu() | |
| cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu, params.disable_check) | |
| content_image = preprocess(params.content_image, params.image_size).type(dtype) | |
| style_image_input = params.style_image.split(',') | |
| style_image_list, ext = [], [".jpg", ".jpeg", ".png", ".tiff"] | |
| for image in style_image_input: | |
| if os.path.isdir(image): | |
| images = (image + "/" + file for file in os.listdir(image) | |
| if os.path.splitext(file)[1].lower() in ext) | |
| style_image_list.extend(images) | |
| else: | |
| style_image_list.append(image) | |
| style_images_caffe = [] | |
| for image in style_image_list: | |
| style_size = int(params.image_size * params.style_scale) | |
| img_caffe = preprocess(image, style_size).type(dtype) | |
| style_images_caffe.append(img_caffe) | |
| if params.init_image != None: | |
| image_size = (content_image.size(2), content_image.size(3)) | |
| init_image = preprocess(params.init_image, image_size).type(dtype) | |
| # Handle style blending weights for multiple style inputs | |
| style_blend_weights = [] | |
| if params.style_blend_weights == None: | |
| # Style blending not specified, so use equal weighting | |
| for i in style_image_list: | |
| style_blend_weights.append(1.0) | |
| for i, blend_weights in enumerate(style_blend_weights): | |
| style_blend_weights[i] = int(style_blend_weights[i]) | |
| else: | |
| style_blend_weights = params.style_blend_weights.split(',') | |
| assert len(style_blend_weights) == len(style_image_list), \ | |
| "-style_blend_weights and -style_images must have the same number of elements!" | |
| # Normalize the style blending weights so they sum to 1 | |
| style_blend_sum = 0 | |
| for i, blend_weights in enumerate(style_blend_weights): | |
| style_blend_weights[i] = float(style_blend_weights[i]) | |
| style_blend_sum = float(style_blend_sum) + style_blend_weights[i] | |
| for i, blend_weights in enumerate(style_blend_weights): | |
| style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum) | |
| content_layers = params.content_layers.split(',') | |
| style_layers = params.style_layers.split(',') | |
| # Set up the network, inserting style and content loss modules | |
| cnn = copy.deepcopy(cnn) | |
| content_losses, style_losses, tv_losses = [], [], [] | |
| next_content_idx, next_style_idx = 1, 1 | |
| net = nn.Sequential() | |
| c, r = 0, 0 | |
| if params.tv_weight > 0: | |
| tv_mod = TVLoss(params.tv_weight).type(dtype) | |
| net.add_module(str(len(net)), tv_mod) | |
| tv_losses.append(tv_mod) | |
| for i, layer in enumerate(list(cnn), 1): | |
| if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers): | |
| if isinstance(layer, nn.Conv2d): | |
| net.add_module(str(len(net)), layer) | |
| if layerList['C'][c] in content_layers: | |
| #print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c])) | |
| loss_module = ContentLoss(params.content_weight) | |
| net.add_module(str(len(net)), loss_module) | |
| content_losses.append(loss_module) | |
| if layerList['C'][c] in style_layers: | |
| #print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c])) | |
| loss_module = StyleLoss(params.style_weight) | |
| net.add_module(str(len(net)), loss_module) | |
| style_losses.append(loss_module) | |
| c+=1 | |
| if isinstance(layer, nn.ReLU): | |
| net.add_module(str(len(net)), layer) | |
| if layerList['R'][r] in content_layers: | |
| #print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r])) | |
| loss_module = ContentLoss(params.content_weight) | |
| net.add_module(str(len(net)), loss_module) | |
| content_losses.append(loss_module) | |
| next_content_idx += 1 | |
| if layerList['R'][r] in style_layers: | |
| #print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r])) | |
| loss_module = StyleLoss(params.style_weight) | |
| net.add_module(str(len(net)), loss_module) | |
| style_losses.append(loss_module) | |
| next_style_idx += 1 | |
| r+=1 | |
| if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d): | |
| net.add_module(str(len(net)), layer) | |
| if multidevice: | |
| net = setup_multi_device(net) | |
| # Capture content targets | |
| for i in content_losses: | |
| i.mode = 'capture' | |
| #print("Capturing content targets") | |
| print_torch(net, multidevice) | |
| net(content_image) | |
| # Capture style targets | |
| for i in content_losses: | |
| i.mode = 'None' | |
| for i, image in enumerate(style_images_caffe): | |
| #print("Capturing style target " + str(i+1)) | |
| for j in style_losses: | |
| j.mode = 'capture' | |
| j.blend_weight = style_blend_weights[i] | |
| net(style_images_caffe[i]) | |
| # Set all loss modules to loss mode | |
| for i in content_losses: | |
| i.mode = 'loss' | |
| for i in style_losses: | |
| i.mode = 'loss' | |
| # Maybe normalize content and style weights | |
| if params.normalize_weights: | |
| normalize_weights(content_losses, style_losses) | |
| # Freeze the network in order to prevent | |
| # unnecessary gradient calculations | |
| for param in net.parameters(): | |
| param.requires_grad = False | |
| # Initialize the image | |
| if params.seed >= 0: | |
| torch.manual_seed(params.seed) | |
| torch.cuda.manual_seed_all(params.seed) | |
| torch.backends.cudnn.deterministic=True | |
| if params.init == 'random': | |
| B, C, H, W = content_image.size() | |
| img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype) | |
| elif params.init == 'image': | |
| if params.init_image != None: | |
| img = init_image.clone() | |
| else: | |
| img = content_image.clone() | |
| img = nn.Parameter(img) | |
| def maybe_print(t, loss): | |
| if params.print_iter > 0 and t % params.print_iter == 0: | |
| print("Iteration " + str(t) + " / "+ str(params.num_iterations)) | |
| for i, loss_module in enumerate(content_losses): | |
| print(" Content " + str(i+1) + " loss: " + str(loss_module.loss.item())) | |
| for i, loss_module in enumerate(style_losses): | |
| print(" Style " + str(i+1) + " loss: " + str(loss_module.loss.item())) | |
| print(" Total loss: " + str(loss.item())) | |
| #final_image = '' | |
| def maybe_save(t): | |
| should_save = params.save_iter > 950 and t % params.save_iter == 0 | |
| should_save = should_save or t == params.num_iterations | |
| if should_save: | |
| output_filename, file_extension = os.path.splitext(params.output_image) | |
| if t == params.num_iterations: | |
| filename = output_filename + str(file_extension) | |
| else: | |
| filename = str(output_filename) + "_" + str(t) + str(file_extension) | |
| disp = deprocess(img.clone()) | |
| # Maybe perform postprocessing for color-independent style transfer | |
| if params.original_colors == 1: | |
| disp = original_colors(deprocess(content_image.clone()), disp) | |
| disp.save(str(filename)) | |
| return disp | |
| # Function to evaluate loss and gradient. We run the net forward and | |
| # backward to get the gradient, and sum up losses from the loss modules. | |
| # optim.lbfgs internally handles iteration and calls this function many | |
| # times, so we manually count the number of iterations to handle printing | |
| # and saving intermediate results. | |
| num_calls = [0] | |
| def feval(): | |
| num_calls[0] += 1 | |
| optimizer.zero_grad() | |
| net(img) | |
| loss = 0 | |
| for mod in content_losses: | |
| loss += mod.loss.to(backward_device) | |
| for mod in style_losses: | |
| loss += mod.loss.to(backward_device) | |
| if params.tv_weight > 0: | |
| for mod in tv_losses: | |
| loss += mod.loss.to(backward_device) | |
| loss.backward() | |
| final_image = maybe_save(num_calls[0]) | |
| maybe_print(num_calls[0], loss) | |
| return loss | |
| ##print('the final image is', final_image) | |
| optimizer, loopVal = setup_optimizer(img) | |
| while num_calls[0] <= loopVal: | |
| optimizer.step(feval) | |
| # Configure the optimizer | |
| def setup_optimizer(img): | |
| if params.optimizer == 'lbfgs': | |
| print("Running optimization with L-BFGS") | |
| optim_state = { | |
| 'max_iter': params.num_iterations, | |
| 'tolerance_change': -1, | |
| 'tolerance_grad': -1, | |
| } | |
| if params.lbfgs_num_correction != 100: | |
| optim_state['history_size'] = params.lbfgs_num_correction | |
| optimizer = optim.LBFGS([img], **optim_state) | |
| loopVal = 1 | |
| elif params.optimizer == 'adam': | |
| print("Running optimization with ADAM") | |
| optimizer = optim.Adam([img], lr = params.learning_rate) | |
| loopVal = params.num_iterations - 1 | |
| return optimizer, loopVal | |
| def setup_gpu(): | |
| def setup_cuda(): | |
| if 'cudnn' in params.backend: | |
| torch.backends.cudnn.enabled = True | |
| if params.cudnn_autotune: | |
| torch.backends.cudnn.benchmark = True | |
| else: | |
| torch.backends.cudnn.enabled = False | |
| def setup_cpu(): | |
| if 'mkl' in params.backend and 'mkldnn' not in params.backend: | |
| torch.backends.mkl.enabled = True | |
| elif 'mkldnn' in params.backend: | |
| raise ValueError("MKL-DNN is not supported yet.") | |
| elif 'openmp' in params.backend: | |
| torch.backends.openmp.enabled = True | |
| multidevice = False | |
| if "," in str(params.gpu): | |
| devices = params.gpu.split(',') | |
| multidevice = True | |
| if 'c' in str(devices[0]).lower(): | |
| backward_device = "cpu" | |
| setup_cuda(), setup_cpu() | |
| else: | |
| backward_device = "cuda:" + devices[0] | |
| setup_cuda() | |
| dtype = torch.FloatTensor | |
| #elif "c" not in str(params.gpu).lower(): | |
| #setup_cuda() | |
| #dtype, backward_device = torch.cuda.FloatTensor, "cuda:" + str(params.gpu) | |
| else: | |
| setup_cpu() | |
| dtype, backward_device = torch.FloatTensor, "cpu" | |
| return dtype, multidevice, backward_device | |
| def setup_multi_device(net): | |
| assert len(params.gpu.split(',')) - 1 == len(params.multidevice_strategy.split(',')), \ | |
| "The number of -multidevice_strategy layer indices minus 1, must be equal to the number of -gpu devices." | |
| new_net = ModelParallel(net, params.gpu, params.multidevice_strategy) | |
| return new_net | |
| # Preprocess an image before passing it to a model. | |
| # We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR, | |
| # and subtract the mean pixel. | |
| def preprocess(image_name, image_size): | |
| image = Image.open(image_name).convert('RGB') | |
| if type(image_size) is not tuple: | |
| image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)]) | |
| Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()]) | |
| rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])]) | |
| Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])]) | |
| tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0) | |
| return tensor | |
| # Undo the above preprocessing. | |
| def deprocess(output_tensor): | |
| Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])]) | |
| bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])]) | |
| output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256 | |
| output_tensor.clamp_(0, 1) | |
| Image2PIL = transforms.ToPILImage() | |
| image = Image2PIL(output_tensor.cpu()) | |
| return image | |
| # Combine the Y channel of the generated image and the UV/CbCr channels of the | |
| # content image to perform color-independent style transfer. | |
| def original_colors(content, generated): | |
| content_channels = list(content.convert('YCbCr').split()) | |
| generated_channels = list(generated.convert('YCbCr').split()) | |
| content_channels[0] = generated_channels[0] | |
| return Image.merge('YCbCr', content_channels).convert('RGB') | |
| # Print like Lua/Torch7 | |
| def print_torch(net, multidevice): | |
| if multidevice: | |
| return | |
| simplelist = "" | |
| for i, layer in enumerate(net, 1): | |
| simplelist = simplelist + "(" + str(i) + ") -> " | |
| #print("nn.Sequential ( \n [input -> " + simplelist + "output]") | |
| def strip(x): | |
| return str(x).replace(", ",',').replace("(",'').replace(")",'') + ", " | |
| def n(): | |
| return " (" + str(i) + "): " + "nn." + str(l).split("(", 1)[0] | |
| for i, l in enumerate(net, 1): | |
| if "2d" in str(l): | |
| ks, st, pd = strip(l.kernel_size), strip(l.stride), strip(l.padding) | |
| if "Conv2d" in str(l): | |
| ch = str(l.in_channels) + " -> " + str(l.out_channels) | |
| print(n() + "(" + ch + ", " + (ks).replace(",",'x', 1) + st + pd.replace(", ",')')) | |
| elif "Pool2d" in str(l): | |
| st = st.replace(" ",' ') + st.replace(", ",')') | |
| print(n() + "(" + ((ks).replace(",",'x' + ks, 1) + st).replace(", ",',')) | |
| else: | |
| print(n()) | |
| print(")") | |
| # Divide weights by channel size | |
| def normalize_weights(content_losses, style_losses): | |
| for n, i in enumerate(content_losses): | |
| i.strength = i.strength / max(i.target.size()) | |
| for n, i in enumerate(style_losses): | |
| i.strength = i.strength / max(i.target.size()) | |
| # Define an nn Module to compute content loss | |
| class ContentLoss(nn.Module): | |
| def __init__(self, strength): | |
| super(ContentLoss, self).__init__() | |
| self.strength = strength | |
| self.crit = nn.MSELoss() | |
| self.mode = 'None' | |
| def forward(self, input): | |
| if self.mode == 'loss': | |
| self.loss = self.crit(input, self.target) * self.strength | |
| elif self.mode == 'capture': | |
| self.target = input.detach() | |
| return input | |
| class GramMatrix(nn.Module): | |
| def forward(self, input): | |
| B, C, H, W = input.size() | |
| x_flat = input.view(C, H * W) | |
| return torch.mm(x_flat, x_flat.t()) | |
| # Define an nn Module to compute style loss | |
| class StyleLoss(nn.Module): | |
| def __init__(self, strength): | |
| super(StyleLoss, self).__init__() | |
| self.target = torch.Tensor() | |
| self.strength = strength | |
| self.gram = GramMatrix() | |
| self.crit = nn.MSELoss() | |
| self.mode = 'None' | |
| self.blend_weight = None | |
| def forward(self, input): | |
| self.G = self.gram(input) | |
| self.G = self.G.div(input.nelement()) | |
| if self.mode == 'capture': | |
| if self.blend_weight == None: | |
| self.target = self.G.detach() | |
| elif self.target.nelement() == 0: | |
| self.target = self.G.detach().mul(self.blend_weight) | |
| else: | |
| self.target = self.target.add(self.blend_weight, self.G.detach()) | |
| elif self.mode == 'loss': | |
| self.loss = self.strength * self.crit(self.G, self.target) | |
| return input | |
| class TVLoss(nn.Module): | |
| def __init__(self, strength): | |
| super(TVLoss, self).__init__() | |
| self.strength = strength | |
| def forward(self, input): | |
| self.x_diff = input[:,:,1:,:] - input[:,:,:-1,:] | |
| self.y_diff = input[:,:,:,1:] - input[:,:,:,:-1] | |
| self.loss = self.strength * (torch.sum(torch.abs(self.x_diff)) + torch.sum(torch.abs(self.y_diff))) | |
| return input | |
| if __name__ == "__main__": | |
| main() | |