File size: 3,450 Bytes
16d007c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from madgrad import MADGRAD
from torchvision import transforms


def get_optimizer(cfg, params):
    if cfg["optimizer"] == "adam":
        optimizer = torch.optim.Adam(params, lr=cfg["lr"])
    elif cfg["optimizer"] == "radam":
        optimizer = torch.optim.RAdam(params, lr=cfg["lr"])
    elif cfg["optimizer"] == "madgrad":
        optimizer = MADGRAD(params, lr=cfg["lr"], weight_decay=0.01, momentum=0.9)
    elif cfg["optimizer"] == "rmsprop":
        optimizer = torch.optim.RMSprop(params, lr=cfg["lr"], weight_decay=0.01)
    elif cfg["optimizer"] == "sgd":
        optimizer = torch.optim.SGD(params, lr=cfg["lr"])
    else:
        return NotImplementedError("optimizer [%s] is not implemented", cfg["optimizer"])
    return optimizer


def get_text_criterion(cfg):
    if cfg["text_criterion"] == "spherical":
        text_criterion = spherical_dist_loss
    elif cfg["text_criterion"] == "cosine":
        text_criterion = cosine_loss
    else:
        return NotImplementedError("text criterion [%s] is not implemented", cfg["text_criterion"])
    return text_criterion


def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return ((x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)).mean()


def cosine_loss(x, y, scaling=1.2):
    return scaling * (1 - F.cosine_similarity(x, y).mean())


def tensor2im(input_image, imtype=np.uint8):
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor):  # get the data from a variable
            image_tensor = input_image.data
        else:
            return input_image
        image_numpy = image_tensor[0].clamp(0.0, 1.0).cpu().float().numpy()  # convert it into a numpy array
        image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0  # post-processing: tranpose and scaling
    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
    return image_numpy.astype(imtype)


def get_screen_template():
    return [
        "{} over a green screen.",
        "{} in front of a green screen.",
    ]


def get_augmentations_template():
    templates = [
        "photo of {}.",
        "high quality photo of {}.",
        "a photo of {}.",
        "the photo of {}.",
        "image of {}.",
        "an image of {}.",
        "high quality image of {}.",
        "a high quality image of {}.",
        "the {}.",
        "a {}.",
        "{}.",
        "{}",
        "{}!",
        "{}...",
    ]
    return templates


def compose_text_with_templates(text: str, templates) -> list:
    return [template.format(text) for template in templates]


def get_mask_boundary(img, mask):
    mask = mask.squeeze()  # mask.shape -> (H, W)
    if torch.sum(mask) > 0:
        y, x = torch.where(mask)
        y0, x0 = y.min(), x.min()
        y1, x1 = y.max(), x.max()
        return img[:, :, y0:y1, x0:x1]
    else:
        return img


def load_video(folder: str, resize=(432, 768), num_frames=70):
    resy, resx = resize
    folder = Path(folder)
    input_files = sorted(list(folder.glob("*.jpg")) + list(folder.glob("*.png")))[:num_frames]
    video = torch.zeros((len(input_files), 3, resy, resx))

    for i, file in enumerate(input_files):
        video[i] = transforms.ToTensor()(Image.open(str(file)).resize((resx, resy), Image.LANCZOS))

    return video