Spaces:
Sleeping
Sleeping
File size: 3,484 Bytes
b762e56 86e64e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import torch
import torch.nn.functional as F
import cairosvg
from data_utils.common_utils import trans2_white_bg
from PIL import Image
import numpy as np
def select_imgs(images_of_onefont, selected_cls, opts):
# given selected char classes, return selected imgs
# images_of_onefont: [bs, 52, opts.img_size, opts.img_size]
# selected_cls: [bs, nshot]
nums = selected_cls.size(1)
selected_cls_ = selected_cls.unsqueeze(2)
selected_cls_ = selected_cls_.unsqueeze(3)
selected_cls_ = selected_cls_.expand(images_of_onefont.size(0), nums, opts.img_size, opts.img_size)
selected_img = torch.gather(images_of_onefont, 1, selected_cls_)
return selected_img
def select_seqs(seqs_of_onefont, selected_cls, opts, seq_dim):
nums = selected_cls.size(1)
selected_cls_ = selected_cls.unsqueeze(2)
selected_cls_ = selected_cls_.unsqueeze(3)
selected_cls_ = selected_cls_.expand(seqs_of_onefont.size(0), nums, opts.max_seq_len, seq_dim)
selected_seqs = torch.gather(seqs_of_onefont, 1, selected_cls_)
return selected_seqs
def select_seqlens(seqlens_of_onefont, selected_cls, opts):
nums = selected_cls.size(1)
selected_cls_ = selected_cls.unsqueeze(2)
selected_cls_ = selected_cls_.expand(seqlens_of_onefont.size(0), nums, 1) # 64, nums, 1
selected_seqlens = torch.gather(seqlens_of_onefont, 1, selected_cls_)
return selected_seqlens
def trgcls_to_onehot(trg_cls, opts):
trg_char = F.one_hot(trg_cls, num_classes=opts.char_num).squeeze(dim=1)
return trg_char
def shift_right(x, pad_value=None):
if pad_value is None:
shifted = F.pad(x, (0, 0, 0, 0, 1, 0))[:-1, :, :]
else:
shifted = torch.cat([pad_value, x], axis=0)[:-1, :, :]
return shifted
def length_form_embedding(emb):
"""Compute the length of each sequence in the batch
Args:
emb: [seq_len, batch, depth]
Returns:
a 0/1 tensor: [batch]
"""
absed = torch.abs(emb)
sum_last = torch.sum(absed, dim=2, keepdim=True)
mask = sum_last != 0
sum_except_batch = torch.sum(mask, dim=(0, 2), dtype=torch.long)
return sum_except_batch
def lognormal(y, mean, logstd, logsqrttwopi):
y_mean = y - mean # NOTE y:[b*51*6, 1] mean: [b*51*6, 50]
logstd_exp = logstd.exp() # NOTE [b*51*6, 50]
y_mean_divide_exp = y_mean / logstd_exp
return -0.5 * (y_mean_divide_exp) ** 2 - logstd - logsqrttwopi
def sequence_mask(lengths, max_len=None):
batch_size=lengths.numel()
max_len=max_len or lengths.max()
return (torch.arange(0, max_len, device=lengths.device)
.type_as(lengths)
.unsqueeze(0).expand(batch_size,max_len)
.lt(lengths.unsqueeze(1)))
def svg2img(path_svg, path_img, img_size):
cairosvg.svg2png(url=path_svg, write_to=path_img, output_width=img_size, output_height=img_size)
img_arr = trans2_white_bg(path_img)
return img_arr
def cal_img_l1_dist(path_img1, path_img2):
img1 = np.array(Image.open(path_img1))
img2 = np.array(Image.open(path_img2))
dist = np.mean(np.abs(img1 - img2[:, :, 0]))
return dist
def cal_iou(path_img1, path_img2):
img1 = np.array(Image.open(path_img1))
img2 = np.array(Image.open(path_img2))[:, :, 0]
mask_img1 = img1 < (255 * 3 / 4)
mask_img2 = img2 < (255 * 3 / 4)
iou = np.sum(mask_img1 * mask_img2) / (np.sum(mask_img1 + mask_img2))
l1_dist = np.mean(np.abs(mask_img1.astype(float) - mask_img2.astype(float)))
return iou, l1_dist |