File size: 3,484 Bytes
b762e56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86e64e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn.functional as F
import cairosvg
from data_utils.common_utils import trans2_white_bg
from PIL import Image
import numpy as np

def select_imgs(images_of_onefont, selected_cls, opts):
    # given selected char classes, return selected imgs
    # images_of_onefont: [bs, 52, opts.img_size, opts.img_size]
    # selected_cls: [bs, nshot]
    nums = selected_cls.size(1)
    selected_cls_ = selected_cls.unsqueeze(2)
    selected_cls_ = selected_cls_.unsqueeze(3)
    selected_cls_ = selected_cls_.expand(images_of_onefont.size(0), nums, opts.img_size, opts.img_size)         
    selected_img = torch.gather(images_of_onefont, 1, selected_cls_)
    return selected_img

def select_seqs(seqs_of_onefont, selected_cls, opts, seq_dim):

    nums = selected_cls.size(1)
    selected_cls_ = selected_cls.unsqueeze(2)
    selected_cls_ = selected_cls_.unsqueeze(3)
    selected_cls_ = selected_cls_.expand(seqs_of_onefont.size(0), nums, opts.max_seq_len, seq_dim) 
    selected_seqs = torch.gather(seqs_of_onefont, 1, selected_cls_)
    return selected_seqs

def select_seqlens(seqlens_of_onefont, selected_cls, opts):

    nums = selected_cls.size(1)
    selected_cls_ = selected_cls.unsqueeze(2)
    selected_cls_ = selected_cls_.expand(seqlens_of_onefont.size(0), nums, 1)     # 64, nums, 1    
    selected_seqlens = torch.gather(seqlens_of_onefont, 1, selected_cls_)
    return selected_seqlens

def trgcls_to_onehot(trg_cls, opts):
    trg_char = F.one_hot(trg_cls, num_classes=opts.char_num).squeeze(dim=1)
    return trg_char


def shift_right(x, pad_value=None):
    if pad_value is None:
        shifted = F.pad(x, (0, 0, 0, 0, 1, 0))[:-1, :, :]
    else:
        shifted = torch.cat([pad_value, x], axis=0)[:-1, :, :]
    return shifted


def length_form_embedding(emb):
    """Compute the length of each sequence in the batch
    Args:
        emb: [seq_len, batch, depth]
    Returns:
        a 0/1 tensor: [batch]
    """
    absed = torch.abs(emb)
    sum_last = torch.sum(absed, dim=2, keepdim=True)
    mask = sum_last != 0
    sum_except_batch = torch.sum(mask, dim=(0, 2), dtype=torch.long)
    return sum_except_batch


def lognormal(y, mean, logstd, logsqrttwopi):
    y_mean = y - mean # NOTE y:[b*51*6, 1]   mean:  [b*51*6, 50]
    logstd_exp = logstd.exp() # NOTE  [b*51*6, 50]
    y_mean_divide_exp = y_mean / logstd_exp
    return -0.5 * (y_mean_divide_exp) ** 2 - logstd - logsqrttwopi

def sequence_mask(lengths, max_len=None):
    batch_size=lengths.numel()
    max_len=max_len or lengths.max()
    return (torch.arange(0, max_len, device=lengths.device)
    .type_as(lengths)
    .unsqueeze(0).expand(batch_size,max_len)
    .lt(lengths.unsqueeze(1)))

def svg2img(path_svg, path_img, img_size):
    cairosvg.svg2png(url=path_svg, write_to=path_img, output_width=img_size, output_height=img_size)
    img_arr = trans2_white_bg(path_img)
    return img_arr

def cal_img_l1_dist(path_img1, path_img2):
    img1 = np.array(Image.open(path_img1))
    img2 = np.array(Image.open(path_img2))
    dist = np.mean(np.abs(img1 - img2[:, :, 0]))
    return dist

def cal_iou(path_img1, path_img2):

    img1 = np.array(Image.open(path_img1))
    img2 = np.array(Image.open(path_img2))[:, :, 0]
    mask_img1 = img1 < (255 * 3 / 4)
    mask_img2 = img2 < (255 * 3 / 4)
    iou = np.sum(mask_img1 * mask_img2) / (np.sum(mask_img1 + mask_img2))
    l1_dist = np.mean(np.abs(mask_img1.astype(float) - mask_img2.astype(float)))
    return iou, l1_dist