styledrop / libs /muse.py
zideliu's picture
StyleDrop init
28c6826
import numpy as np
import torch
import math
from einops import rearrange
from torch.nn import functional as F
def add_gumbel_noise(t, temperature, device):
return (t + torch.Tensor(temperature * np.random.gumbel(size=t.shape)).to(device))
class MUSE(object):
def __init__(self, codebook_size, device, ignore_ind=-1, smoothing=0., gen_temp=4.5):
self.mask_ind = codebook_size # for input masking
self.ignore_ind = ignore_ind # for ce loss, excluding visible
self.device = device
self.smoothing = smoothing
self.gen_temp = gen_temp
@staticmethod
def cosine_schedule(t):
return torch.cos(t * math.pi * 0.5)
def sample(self, x0):
N, L, device = *x0.shape, self.device
timesteps = torch.zeros((N,), device=device).float().uniform_(0, 1)
rand_mask_probs = self.cosine_schedule(timesteps) # cosine schedule
num_token_masked = (L * rand_mask_probs).round().clamp(min=1)
batch_randperm = torch.rand(N, L, device=device).argsort(dim=-1)
mask = batch_randperm < rearrange(num_token_masked, 'b -> b 1')
masked_ids = torch.where(mask, self.mask_ind, x0)
labels = torch.where(mask, x0, self.ignore_ind)
return labels, masked_ids
def loss(self, pred, label):
return F.cross_entropy(pred.transpose(1, 2), label.long(),
ignore_index=self.ignore_ind, label_smoothing=self.smoothing)
@torch.no_grad()
def generate(self, config, _n_samples, nnet, decode_fn, is_eval=False, **kwargs):
fmap_size, _sample_steps, device = config.z_shape[-1], config.sample.sample_steps, self.device
seq_len = fmap_size ** 2
ids = torch.full((_n_samples, seq_len), self.mask_ind, dtype=torch.long, device=device)
cfg_scale = 0.
for step in range(_sample_steps):
ratio = 1. * (step + 1) / _sample_steps
annealed_temp = self.gen_temp * (1 - ratio)
is_mask = (ids == self.mask_ind)
logits = nnet(ids, **kwargs, scale=cfg_scale)
# sampling & scoring
sampled_ids = add_gumbel_noise(logits, annealed_temp, device).argmax(dim=-1)
sampled_logits = torch.squeeze(
torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
sampled_ids = torch.where(is_mask, sampled_ids, ids)
sampled_logits = torch.where(is_mask, sampled_logits, +np.inf).float()
# masking
mask_ratio = np.cos(ratio * math.pi * 0.5)
mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).to(device)
mask_len = torch.maximum(torch.Tensor([1]).to(device),
torch.minimum(torch.sum(is_mask, dim=-1, keepdims=True) - 1,
mask_len))[0].squeeze()
confidence = add_gumbel_noise(sampled_logits, annealed_temp, device)
sorted_confidence, _ = torch.sort(confidence, axis=-1)
cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()]
masking = (confidence <= cut_off)
ids = torch.where(masking, self.mask_ind, sampled_ids)
cfg_scale = ratio * config.sample.scale
_z1 = rearrange(sampled_ids, 'b (i j) -> b i j', i=fmap_size, j=fmap_size)
# with adapter
ids = torch.full((_n_samples, seq_len), self.mask_ind, dtype=torch.long, device=device)
cfg_scale = 0.
lambdaA=0.
lambdaB=0.
for step in range(_sample_steps):
ratio = 1. * (step + 1) / _sample_steps
annealed_temp = self.gen_temp * (1 - ratio)
is_mask = (ids == self.mask_ind)
# 尝试使用 *ratio
logits = nnet(ids, **kwargs, scale=cfg_scale,lambdaA=lambdaA,lambdaB=lambdaB)
# sampling & scoring
sampled_ids = add_gumbel_noise(logits, annealed_temp, device).argmax(dim=-1)
sampled_logits = torch.squeeze(
torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
sampled_ids = torch.where(is_mask, sampled_ids, ids)
sampled_logits = torch.where(is_mask, sampled_logits, +np.inf).float()
# masking
mask_ratio = np.cos(ratio * math.pi * 0.5)
mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).to(device)
mask_len = torch.maximum(torch.Tensor([1]).to(device),
torch.minimum(torch.sum(is_mask, dim=-1, keepdims=True) - 1,
mask_len))[0].squeeze()
confidence = add_gumbel_noise(sampled_logits, annealed_temp, device)
sorted_confidence, _ = torch.sort(confidence, axis=-1)
cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()]
masking = (confidence <= cut_off)
ids = torch.where(masking, self.mask_ind, sampled_ids)
cfg_scale = ratio * config.sample.scale
lambdaA = config.sample.lambdaA
lambdaB = config.sample.lambdaB
_z2 = rearrange(sampled_ids, 'b (i j) -> b i j', i=fmap_size, j=fmap_size)
_z = _z2 if is_eval else torch.cat([_z1,_z2],dim=0)
out = decode_fn(_z)
return out