File size: 1,998 Bytes
266ec16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright (c) SenseTime Research. All rights reserved.

import torch
import cv2
from torchvision import transforms
import numpy as np
import math


def visual(output, out_path):
    output = (output + 1)/2
    output = torch.clamp(output, 0, 1)
    if output.shape[1] == 1:
        output = torch.cat([output, output, output], 1)
    output = output[0].detach().cpu().permute(1,2,0).numpy()
    output = (output*255).astype(np.uint8)
    output = output[:,:,::-1]
    cv2.imwrite(out_path, output)
    
    
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
    
    lr_ramp = min(1, (1 - t) / rampdown)
    lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
    lr_ramp = lr_ramp * min(1, t / rampup)
    return initial_lr * lr_ramp



def latent_noise(latent, strength):
    noise = torch.randn_like(latent) * strength

    return latent + noise

def noise_regularize_(noises):
    loss = 0

    for noise in noises:
        size = noise.shape[2]

        while True:
            loss = (
                loss
                + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
                + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
            )

            if size <= 8:
                break

            noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
            noise = noise.mean([3, 5])
            size //= 2

    return loss


def noise_normalize_(noises):
    for noise in noises:
        mean = noise.mean()
        std = noise.std()

        noise.data.add_(-mean).div_(std)
        

def tensor_to_numpy(x):
    x = x[0].permute(1, 2, 0)
    x = torch.clamp(x, -1 ,1)
    x = (x+1) * 127.5
    x = x.cpu().detach().numpy().astype(np.uint8)
    return x

def numpy_to_tensor(x):
    x = (x / 255 - 0.5) * 2
    x = torch.from_numpy(x).unsqueeze(0).permute(0, 3, 1, 2)
    x = x.cuda().float()
    return x

def tensor_to_pil(x):
    x = torch.clamp(x, -1 ,1)
    x = (x+1) * 127.5
    return transforms.ToPILImage()(x.squeeze_(0))