SupermanxKiaski's picture
Upload 356 files
16d007c
raw
history blame
No virus
3.45 kB
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from madgrad import MADGRAD
from torchvision import transforms
def get_optimizer(cfg, params):
if cfg["optimizer"] == "adam":
optimizer = torch.optim.Adam(params, lr=cfg["lr"])
elif cfg["optimizer"] == "radam":
optimizer = torch.optim.RAdam(params, lr=cfg["lr"])
elif cfg["optimizer"] == "madgrad":
optimizer = MADGRAD(params, lr=cfg["lr"], weight_decay=0.01, momentum=0.9)
elif cfg["optimizer"] == "rmsprop":
optimizer = torch.optim.RMSprop(params, lr=cfg["lr"], weight_decay=0.01)
elif cfg["optimizer"] == "sgd":
optimizer = torch.optim.SGD(params, lr=cfg["lr"])
else:
return NotImplementedError("optimizer [%s] is not implemented", cfg["optimizer"])
return optimizer
def get_text_criterion(cfg):
if cfg["text_criterion"] == "spherical":
text_criterion = spherical_dist_loss
elif cfg["text_criterion"] == "cosine":
text_criterion = cosine_loss
else:
return NotImplementedError("text criterion [%s] is not implemented", cfg["text_criterion"])
return text_criterion
def spherical_dist_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return ((x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)).mean()
def cosine_loss(x, y, scaling=1.2):
return scaling * (1 - F.cosine_similarity(x, y).mean())
def tensor2im(input_image, imtype=np.uint8):
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor[0].clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
def get_screen_template():
return [
"{} over a green screen.",
"{} in front of a green screen.",
]
def get_augmentations_template():
templates = [
"photo of {}.",
"high quality photo of {}.",
"a photo of {}.",
"the photo of {}.",
"image of {}.",
"an image of {}.",
"high quality image of {}.",
"a high quality image of {}.",
"the {}.",
"a {}.",
"{}.",
"{}",
"{}!",
"{}...",
]
return templates
def compose_text_with_templates(text: str, templates) -> list:
return [template.format(text) for template in templates]
def get_mask_boundary(img, mask):
mask = mask.squeeze() # mask.shape -> (H, W)
if torch.sum(mask) > 0:
y, x = torch.where(mask)
y0, x0 = y.min(), x.min()
y1, x1 = y.max(), x.max()
return img[:, :, y0:y1, x0:x1]
else:
return img
def load_video(folder: str, resize=(432, 768), num_frames=70):
resy, resx = resize
folder = Path(folder)
input_files = sorted(list(folder.glob("*.jpg")) + list(folder.glob("*.png")))[:num_frames]
video = torch.zeros((len(input_files), 3, resy, resx))
for i, file in enumerate(input_files):
video[i] = transforms.ToTensor()(Image.open(str(file)).resize((resx, resy), Image.LANCZOS))
return video