import sys sys.path.append('miniminiai/miniminiai') import torchvision, torch import fastcore.all as fc import gradio as gr from miniminiai import * import numpy as np from PIL import Image, ImageOps, ImageDraw from torch import nn, tensor from torch.utils.data import DataLoader from torch.nn import functional as F from torchvision import models, transforms class LengthDataset(): def __init__(self, length=1): self.length=length def __len__(self): return self.length def __getitem__(self, idx): return 0,0 def get_dummy_dls(length=100): return DataLoaders(DataLoader(LengthDataset(length), batch_size=1), # Train DataLoader(LengthDataset(1), batch_size=1)) # Valid (length 1) class TensorModel(nn.Module): def __init__(self, t): super().__init__() self.t = nn.Parameter(t.clone()) def forward(self, x=0): return self.t class ImageOptCB(TrainCB): def predict(self, learn): learn.preds = learn.model() def get_loss(self, learn): learn.loss = learn.loss_func(learn.preds) def calc_features(imgs, target_layers=(18, 25)): normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) x = normalize(imgs) feats = [] for i, layer in enumerate(vgg16[:max(target_layers)+1]): x = layer(x) if i in target_layers: feats.append(x.clone()) return feats class ContentLossToTarget(): def __init__(self, target_im, target_layers=(18, 25)): fc.store_attr() with torch.no_grad(): self.target_features = calc_features(target_im, target_layers) def __call__(self, input_im): return sum((f1-f2).pow(2).mean() for f1, f2 in zip(calc_features(input_im, self.target_layers), self.target_features)) def calc_grams(img, target_layers=(1, 6, 11, 18, 25)): return fc.L(torch.einsum('chw, dhw -> cd', x, x) / (x.shape[-2]*x.shape[-1]) # 'bchw, bdhw -> bcd' if batched for x in calc_features(img, target_layers)) class StyleLossToTarget(): def __init__(self, target_im, target_layers=(1, 6, 11, 18, 25), size=394): fc.store_attr() with torch.no_grad(): self.target_grams = calc_grams(target_im, target_layers) def __call__(self, input_im): return sum((f1-f2).pow(2).mean() for f1, f2 in zip(calc_grams(input_im, self.target_layers), self.target_grams)) class OTStyleLossToTarget(nn.Module): def __init__(self, target, size=128, style_layers = [1, 6, 11, 18, 25], scale_factor=2e-5): super(OTStyleLossToTarget, self).__init__() self.device = device self.resize = transforms.Compose([transforms.Resize(size), transforms.CenterCrop(size)]) self.target = self.resize(target) # resize target image to size self.style_layers = style_layers self.scale_factor = scale_factor # Defaults tend to be very large, we scale to make them easier to work with with torch.no_grad(): self.target_features = calc_features(self.target, self.style_layers) def project_sort(self, x, proj): return torch.einsum('bcn,cp->bpn', x, proj).sort()[0] def ot_loss(self, source, target, proj_n=32): ch, n = source.shape[-2:] projs = F.normalize(torch.randn(ch, proj_n).to(self.device), dim=0) source_proj = self.project_sort(source, projs) target_proj = self.project_sort(target, projs) target_interp = F.interpolate(target_proj, n, mode='nearest') return (source_proj-target_interp).square().sum() def forward(self, input): input = self.resize(input) # set size (assumes square images) input_features = calc_features(input, self.style_layers) l = 0 # Run through all features and take l1 loss (mean error) between them return sum(self.ot_loss(x, y) for x, y in zip(input_features, self.target_features)) * self.scale_factor class VincentStyleLossToTarget(nn.Module): def __init__(self, target, size=128, style_layers = [1, 6, 11, 18, 25], scale_factor=1e-5): super(VincentStyleLossToTarget, self).__init__() self.resize = transforms.Compose([transforms.Resize(size), transforms.CenterCrop(size)]) self.target = self.resize(target) # resize target image to size self.style_layers = style_layers self.scale_factor = scale_factor # Defaults tend to be very large, we scale to make them easier to work with with torch.no_grad(): self.target_features = calc_features(self.target, self.style_layers) def calc_2_moments(self, x): c, w, h = x.shape x = x.reshape(1, c, w*h) # b, c, n mu = x.mean(dim=-1, keepdim=True) # b, c, 1 cov = torch.matmul(x-mu, torch.transpose(x-mu, -1, -2)) return mu, cov def matrix_diag(self, diagonal): N = diagonal.shape[-1] shape = diagonal.shape[:-1] + (N, N) device, dtype = diagonal.device, diagonal.dtype result = torch.zeros(shape, dtype=dtype, device=device) indices = torch.arange(result.numel(), device=device).reshape(shape) indices = indices.diagonal(dim1=-2, dim2=-1) result.view(-1)[indices] = diagonal return result def l2wass_dist(self, mean_stl, cov_stl, mean_synth, cov_synth): # Calculate tr_cov and root_cov from mean_stl and cov_stl eigvals,eigvects = torch.linalg.eigh(cov_stl) # eig returns complex tensors, I think eigh matches tf self_adjoint_eig eigroot_mat = self.matrix_diag(torch.sqrt(eigvals.clip(0))) root_cov_stl = torch.matmul(torch.matmul(eigvects, eigroot_mat),torch.transpose(eigvects, -1, -2)) tr_cov_stl = torch.sum(eigvals.clip(0), dim=1, keepdim=True) tr_cov_synth = torch.sum(torch.linalg.eigvalsh(cov_synth).clip(0), dim=1, keepdim=True) mean_diff_squared = torch.mean((mean_synth - mean_stl)**2) cov_prod = torch.matmul(torch.matmul(root_cov_stl,cov_synth),root_cov_stl) var_overlap = torch.sum(torch.sqrt(torch.linalg.eigvalsh(cov_prod).clip(0.1)), dim=1, keepdim=True) # .clip(0) meant errors getting eigvals dist = mean_diff_squared+tr_cov_stl+tr_cov_synth-2*var_overlap return dist def forward(self, input): input = self.resize(input) # set size (assumes square images, center crops otherwise) input_features = calc_features(input, self.style_layers) # get features l = 0 for x, y in zip(input_features, self.target_features): mean_synth, cov_synth = self.calc_2_moments(x) # input mean and cov mean_stl, cov_stl = self.calc_2_moments(y) # target mean and cov l += self.l2wass_dist(mean_stl, cov_stl, mean_synth, cov_synth) return l.mean() * self.scale_factor def image_grid(imgs, rows, cols): assert len(imgs) == rows*cols w, h = imgs[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) grid_w, grid_h = grid.size for i, img in enumerate(imgs): grid.paste(img.resize((w, h)), box=(i%cols*w, i//cols*h)) grid = ImageOps.expand(grid, border=20, fill=(255,255,255)) draw = ImageDraw.Draw(grid) # # fnt = ImageFont.truetype("Pillow/Tests/fonts/FreeMono.ttf", ) # draw.text((0,0),"Sample Text",(0,0,0)) return grid def style_image(content_image, style_image, style_losses): data = [] content_image = content_image.resize((384, 384)) style_image = style_image.resize((384, 384)) output = [content_image] content_image = torch.tensor(np.array(content_image).astype(np.float32) / 255.).permute(2, 0, 1) style_image = torch.tensor(np.array(style_image).astype(np.float32) / 255.).permute(2, 0, 1) content_loss = ContentLossToTarget(content_image.to(device)) sim = style_image.to(device) for style_loss in style_losses: style_loss = map_style_losses[style_loss](sim, size=384) model = TensorModel(content_image) def combined_loss(x): return style_loss(x) + content_loss(x) learn = Learner(model, get_dummy_dls(150), combined_loss, lr=1e-2, cbs=[ImageOptCB(), DeviceCB()], opt_func=torch.optim.Adam) learn.fit(1) im = to_cpu(learn.preds.clip(0, 1)) output.append(Image.fromarray((im.permute(1, 2, 0).numpy()* 255).astype(np.uint8))) return image_grid(output, 1, len(style_losses) + 1) def run(): with gr.Blocks() as demo: # gr.Markdown("Start typing below and then click **Run** to see the output.") with gr.Row(): with gr.Column(scale=1): content_im = gr.Image(shape=(318, 318), type='pil', label="Content image") style_img = gr.Image(shape=(318, 318), type='pil', label="Style image") style_losses = gr.CheckboxGroup(["Gram Matrix", "OT-Based", "Vincent's"], value=["Gram Matrix", "OT-Based", "Vincent's"], label="Style Loss") btn = gr.Button("Generate") with gr.Column(scale=1): out = gr.Image(shape=(384, 384)) btn.click(fn=style_image, inputs=[content_im, style_img, style_losses], outputs=out) demo.launch(server_name="0.0.0.0", server_port=7860) map_style_losses = { "Gram Matrix": StyleLossToTarget, "OT-Based": OTStyleLossToTarget, "Vincent's": VincentStyleLossToTarget } if __name__ == "__main__": device = 'cuda' if torch.cuda.is_available() else 'cpu' vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).to(device) vgg16.eval() vgg16 = vgg16.features run()