File size: 1,268 Bytes
a23872f
 
 
 
 
 
 
 
 
 
 
82e6d22
a23872f
e0f92a0
a23872f
e0f92a0
 
a23872f
 
 
 
 
 
 
e0f92a0
a23872f
e0f92a0
 
 
a23872f
 
e0f92a0
 
 
 
 
 
 
 
 
 
 
 
a23872f
 
e0f92a0
 
 
 
 
a23872f
 
e0f92a0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import io

import numpy as np
import PIL
import requests
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from PIL import Image, ImageDraw, ImageFont


def preprocess(img, target_image_size=256):
    s = min(img.size)

    if s < target_image_size:
        raise ValueError(f"min dim for image {s} < {target_image_size}")

    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = torch.unsqueeze(T.ToTensor()(img), 0)
    return img


def preprocess_vqgan(x):
    x = 2.0 * x - 1.0
    return x


def custom_to_pil(x, process=True, mode="RGB"):
    x = x.detach().cpu()
    if process:
        x = torch.clamp(x, -1.0, 1.0)
        x = (x + 1.0) / 2.0
    x = x.permute(1, 2, 0).numpy()
    if process:
        x = (255 * x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == mode:
        x = x.convert(mode)
    return x


def get_pil(x):
    x = torch.clamp(x, -1.0, 1.0)
    x = (x + 1.0) / 2.0
    x = x.permute(1, 2, 0)
    return x


def loop_post_process(x):
    x = get_pil(x.squeeze())
    return x.permute(2, 0, 1).unsqueeze(0)