import os import imageio import matplotlib.pyplot as plt import numpy as np import pandas as pd import pydiffvg import skimage import skimage.io import torch import wandb import PIL from PIL import Image from torchvision import transforms from torchvision.utils import make_grid from skimage.transform import resize from U2Net_.model import U2NET def imwrite(img, filename, gamma=2.2, normalize=False, use_wandb=False, wandb_name="", step=0, input_im=None): directory = os.path.dirname(filename) if directory != '' and not os.path.exists(directory): os.makedirs(directory) if not isinstance(img, np.ndarray): img = img.data.numpy() if normalize: img_rng = np.max(img) - np.min(img) if img_rng > 0: img = (img - np.min(img)) / img_rng img = np.clip(img, 0.0, 1.0) if img.ndim == 2: # repeat along the third dimension img = np.expand_dims(img, 2) img[:, :, :3] = np.power(img[:, :, :3], 1.0/gamma) img = (img * 255).astype(np.uint8) skimage.io.imsave(filename, img, check_contrast=False) images = [wandb.Image(Image.fromarray(img), caption="output")] if input_im is not None and step == 0: images.append(wandb.Image(input_im, caption="input")) if use_wandb: wandb.log({wandb_name + "_": images}, step=step) def plot_batch(inputs, outputs, output_dir, step, use_wandb, title): plt.figure() plt.subplot(2, 1, 1) grid = make_grid(inputs.clone().detach(), normalize=True, pad_value=2) npgrid = grid.cpu().numpy() plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest') plt.axis("off") plt.title("inputs") plt.subplot(2, 1, 2) grid = make_grid(outputs, normalize=False, pad_value=2) npgrid = grid.detach().cpu().numpy() plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest') plt.axis("off") plt.title("outputs") plt.tight_layout() if use_wandb: wandb.log({"output": wandb.Image(plt)}, step=step) plt.savefig("{}/{}".format(output_dir, title)) plt.close() def log_input(use_wandb, epoch, inputs, output_dir): grid = make_grid(inputs.clone().detach(), normalize=True, pad_value=2) npgrid = grid.cpu().numpy() plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest') plt.axis("off") plt.tight_layout() if use_wandb: wandb.log({"input": wandb.Image(plt)}, step=epoch) plt.close() input_ = inputs[0].cpu().clone().detach().permute(1, 2, 0).numpy() input_ = (input_ - input_.min()) / (input_.max() - input_.min()) input_ = (input_ * 255).astype(np.uint8) imageio.imwrite("{}/{}.png".format(output_dir, "input"), input_) def log_sketch_summary_final(path_svg, use_wandb, device, epoch, loss, title): canvas_width, canvas_height, shapes, shape_groups = load_svg(path_svg) _render = pydiffvg.RenderFunction.apply scene_args = pydiffvg.RenderFunction.serialize_scene( canvas_width, canvas_height, shapes, shape_groups) img = _render(canvas_width, # width canvas_height, # height 2, # num_samples_x 2, # num_samples_y 0, # seed None, *scene_args) img = img[:, :, 3:4] * img[:, :, :3] + \ torch.ones(img.shape[0], img.shape[1], 3, device=device) * (1 - img[:, :, 3:4]) img = img[:, :, :3] plt.imshow(img.cpu().numpy()) plt.axis("off") plt.title(f"{title} best res [{epoch}] [{loss}.]") if use_wandb: wandb.log({title: wandb.Image(plt)}) plt.close() def log_sketch_summary(sketch, title, use_wandb): plt.figure() grid = make_grid(sketch.clone().detach(), normalize=True, pad_value=2) npgrid = grid.cpu().numpy() plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest') plt.axis("off") plt.title(title) plt.tight_layout() if use_wandb: wandb.run.summary["best_loss_im"] = wandb.Image(plt) plt.close() def load_svg(path_svg): svg = os.path.join(path_svg) canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene( svg) return canvas_width, canvas_height, shapes, shape_groups def read_svg(path_svg, device, multiply=False): canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene( path_svg) if multiply: canvas_width *= 2 canvas_height *= 2 for path in shapes: path.points *= 2 path.stroke_width *= 2 _render = pydiffvg.RenderFunction.apply scene_args = pydiffvg.RenderFunction.serialize_scene( canvas_width, canvas_height, shapes, shape_groups) img = _render(canvas_width, # width canvas_height, # height 2, # num_samples_x 2, # num_samples_y 0, # seed None, *scene_args) img = img[:, :, 3:4] * img[:, :, :3] + \ torch.ones(img.shape[0], img.shape[1], 3, device=device) * (1 - img[:, :, 3:4]) img = img[:, :, :3] return img def plot_attn_dino(attn, threshold_map, inputs, inds, use_wandb, output_path): # currently supports one image (and not a batch) plt.figure(figsize=(10, 5)) plt.subplot(2, attn.shape[0] + 2, 1) main_im = make_grid(inputs, normalize=True, pad_value=2) main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) plt.imshow(main_im, interpolation='nearest') plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') plt.title("input im") plt.axis("off") plt.subplot(2, attn.shape[0] + 2, 2) plt.imshow(attn.sum(0).numpy(), interpolation='nearest') plt.title("atn map sum") plt.axis("off") plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 3) plt.imshow(threshold_map[-1].numpy(), interpolation='nearest') plt.title("prob sum") plt.axis("off") plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 4) plt.imshow(threshold_map[:-1].sum(0).numpy(), interpolation='nearest') plt.title("thresh sum") plt.axis("off") for i in range(attn.shape[0]): plt.subplot(2, attn.shape[0] + 2, i + 3) plt.imshow(attn[i].numpy()) plt.axis("off") plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 1 + i + 4) plt.imshow(threshold_map[i].numpy()) plt.axis("off") plt.tight_layout() if use_wandb: wandb.log({"attention_map": wandb.Image(plt)}) plt.savefig(output_path) plt.close() def plot_attn_clip(attn, threshold_map, inputs, inds, use_wandb, output_path, display_logs): # currently supports one image (and not a batch) plt.figure(figsize=(10, 5)) plt.subplot(1, 3, 1) main_im = make_grid(inputs, normalize=True, pad_value=2) main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) plt.imshow(main_im, interpolation='nearest') plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') plt.title("input im") plt.axis("off") plt.subplot(1, 3, 2) plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1) plt.title("atn map") plt.axis("off") plt.subplot(1, 3, 3) threshold_map_ = (threshold_map - threshold_map.min()) / \ (threshold_map.max() - threshold_map.min()) plt.imshow(threshold_map_, interpolation='nearest', vmin=0, vmax=1) plt.title("prob softmax") plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') plt.axis("off") plt.tight_layout() if use_wandb: wandb.log({"attention_map": wandb.Image(plt)}) plt.savefig(output_path) plt.close() def plot_atten(attn, threshold_map, inputs, inds, use_wandb, output_path, saliency_model, display_logs): if saliency_model == "dino": plot_attn_dino(attn, threshold_map, inputs, inds, use_wandb, output_path) elif saliency_model == "clip": plot_attn_clip(attn, threshold_map, inputs, inds, use_wandb, output_path, display_logs) def fix_image_scale(im): im_np = np.array(im) / 255 height, width = im_np.shape[0], im_np.shape[1] max_len = max(height, width) + 20 new_background = np.ones((max_len, max_len, 3)) y, x = max_len // 2 - height // 2, max_len // 2 - width // 2 new_background[y: y + height, x: x + width] = im_np new_background = (new_background / new_background.max() * 255).astype(np.uint8) new_im = Image.fromarray(new_background) return new_im def get_mask_u2net(args, pil_im): w, h = pil_im.size[0], pil_im.size[1] im_size = min(w, h) data_transforms = transforms.Compose([ transforms.Resize(min(320, im_size), interpolation=PIL.Image.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=( 0.26862954, 0.26130258, 0.27577711)), ]) input_im_trans = data_transforms(pil_im).unsqueeze(0).to(args.device) model_dir = os.path.join("./U2Net_/saved_models/u2net.pth") net = U2NET(3, 1) if torch.cuda.is_available() and args.use_gpu: net.load_state_dict(torch.load(model_dir)) net.to(args.device) else: net.load_state_dict(torch.load(model_dir, map_location='cpu')) net.eval() with torch.no_grad(): d1, d2, d3, d4, d5, d6, d7 = net(input_im_trans.detach()) pred = d1[:, 0, :, :] pred = (pred - pred.min()) / (pred.max() - pred.min()) predict = pred predict[predict < 0.5] = 0 predict[predict >= 0.5] = 1 mask = torch.cat([predict, predict, predict], axis=0).permute(1, 2, 0) mask = mask.cpu().numpy() mask = resize(mask, (h, w), anti_aliasing=False) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 # predict_np = predict.clone().cpu().data.numpy() im = Image.fromarray((mask[:, :, 0]*255).astype(np.uint8)).convert('RGB') im.save(f"{args.output_dir}/mask.png") im_np = np.array(pil_im) im_np = im_np / im_np.max() im_np = mask * im_np im_np[mask == 0] = 1 im_final = (im_np / im_np.max() * 255).astype(np.uint8) im_final = Image.fromarray(im_final) return im_final, predict