alvan
Added gradio space for domain expansion
560a1b9
raw
history blame
3.82 kB
# Copyright 2023 Adobe Research. All rights reserved.
# To view a copy of the license, visit LICENSE.md.
import json
from pathlib import Path
import torch
from torch.utils.data import DataLoader
import torchvision
from PIL import Image, ImageDraw, ImageFont
from expansion_utils import consts
def max_num_not_in_list(max_num, lst):
for i in range(max_num, 0, -1):
if i not in lst:
return i
def process_config(expansion_cfg_path):
with Path(expansion_cfg_path).open() as fp:
expansion_cfg = json.load(fp)
assert ("tasks" in expansion_cfg.keys()) and ("tasks_losses" in expansion_cfg.keys())
curr_max_dim = consts.LATENT_DIM - 1
used_dims = [x.get("dimension") for x in expansion_cfg["tasks"]]
for dim in used_dims:
if dim is not None and used_dims.count(dim) > 1:
raise ValueError(f"Config tries to repurpose the same dim {dim} more than once, unsupported...")
for task in expansion_cfg["tasks"]:
if task.get("dimension") is None:
curr_max_dim = max_num_not_in_list(curr_max_dim, used_dims)
if curr_max_dim is None:
raise ValueError("No available dimension was found")
task["dimension"] = curr_max_dim
used_dims.append(curr_max_dim)
print(f"Parsed config successfuly! Repurposing {len(used_dims)} dims, good luck!")
return expansion_cfg
def label_image(img: torch.Tensor, label: str = None):
batch_size = img.shape[0]
img = torchvision.utils.make_grid(img, batch_size) # concat over W
if label is not None:
H, W = img.shape[-2:]
W = W // (4 * batch_size)
H, W = W, H # will be rotated, so H is W
font = ImageFont.truetype("DejaVuSans.ttf", 60) # TODO: use different font sizes for different resolutions.
label_img = Image.new('RGB', (W ,H), color='white')
draw = ImageDraw.Draw(label_img)
w, h = draw.textsize(label, font=font)
draw.text(((W - w) / 2, (H - h) / 2), label, font=font, fill=(0, 0, 0))
label_img = torchvision.transforms.functional.pil_to_tensor(label_img.rotate(90, expand=True))
label_img = label_img.to(torch.float32) / 127.5 - 1
img = torch.cat([label_img, img], dim=-1)
return img
def save_batched_images(images: torch.Tensor, output_path: Path, labels: list = None, max_row_in_img=5):
num_rows = images.shape[0]
if labels is not None:
if num_rows != len(labels):
raise ValueError('Number of labels should match number of batches')
else:
labels = [None] * num_rows
images = [label_image(image, label) for image, label in zip(images, labels)]
images = torch.stack(images)
batched_iter = DataLoader(images, batch_size=max_row_in_img)
for batch_idx, images_slice in enumerate(batched_iter):
save_images(
images_slice,
output_path.with_name(f"{output_path.stem}_batch_{batch_idx}"),
1,
)
def save_images(frames: torch.Tensor, output_path: Path, nrow=None, size=None, separate=False):
parent_dir = output_path.parent
parent_dir.mkdir(exist_ok=True, parents=True)
if size:
frames = torch.nn.functional.interpolate(frames, size)
if separate:
base_name = output_path.stem
for i, frame in enumerate(frames):
torchvision.utils.save_image(
frame,
output_path.with_name(f"{i:05d}_{base_name}.jpg"),
nrow=len(frame),
normalize=True,
range=(-1, 1),
)
else:
torchvision.utils.save_image(
frames,
output_path.with_suffix(".jpg"),
nrow=nrow if nrow else len(frames),
normalize=True,
range=(-1, 1),
)