import numpy as np import torch import math from einops import rearrange from torch.nn import functional as F def add_gumbel_noise(t, temperature, device): return (t + torch.Tensor(temperature * np.random.gumbel(size=t.shape)).to(device)) class MUSE(object): def __init__(self, codebook_size, device, ignore_ind=-1, smoothing=0., gen_temp=4.5): self.mask_ind = codebook_size # for input masking self.ignore_ind = ignore_ind # for ce loss, excluding visible self.device = device self.smoothing = smoothing self.gen_temp = gen_temp @staticmethod def cosine_schedule(t): return torch.cos(t * math.pi * 0.5) def sample(self, x0): N, L, device = *x0.shape, self.device timesteps = torch.zeros((N,), device=device).float().uniform_(0, 1) rand_mask_probs = self.cosine_schedule(timesteps) # cosine schedule num_token_masked = (L * rand_mask_probs).round().clamp(min=1) batch_randperm = torch.rand(N, L, device=device).argsort(dim=-1) mask = batch_randperm < rearrange(num_token_masked, 'b -> b 1') masked_ids = torch.where(mask, self.mask_ind, x0) labels = torch.where(mask, x0, self.ignore_ind) return labels, masked_ids def loss(self, pred, label): return F.cross_entropy(pred.transpose(1, 2), label.long(), ignore_index=self.ignore_ind, label_smoothing=self.smoothing) @torch.no_grad() def generate(self, config, _n_samples, nnet, decode_fn, is_eval=False, **kwargs): fmap_size, _sample_steps, device = config.z_shape[-1], config.sample.sample_steps, self.device seq_len = fmap_size ** 2 ids = torch.full((_n_samples, seq_len), self.mask_ind, dtype=torch.long, device=device) cfg_scale = 0. for step in range(_sample_steps): ratio = 1. * (step + 1) / _sample_steps annealed_temp = self.gen_temp * (1 - ratio) is_mask = (ids == self.mask_ind) logits = nnet(ids, **kwargs, scale=cfg_scale) # sampling & scoring sampled_ids = add_gumbel_noise(logits, annealed_temp, device).argmax(dim=-1) sampled_logits = torch.squeeze( torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1) sampled_ids = torch.where(is_mask, sampled_ids, ids) sampled_logits = torch.where(is_mask, sampled_logits, +np.inf).float() # masking mask_ratio = np.cos(ratio * math.pi * 0.5) mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).to(device) mask_len = torch.maximum(torch.Tensor([1]).to(device), torch.minimum(torch.sum(is_mask, dim=-1, keepdims=True) - 1, mask_len))[0].squeeze() confidence = add_gumbel_noise(sampled_logits, annealed_temp, device) sorted_confidence, _ = torch.sort(confidence, axis=-1) cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()] masking = (confidence <= cut_off) ids = torch.where(masking, self.mask_ind, sampled_ids) cfg_scale = ratio * config.sample.scale _z1 = rearrange(sampled_ids, 'b (i j) -> b i j', i=fmap_size, j=fmap_size) # with adapter ids = torch.full((_n_samples, seq_len), self.mask_ind, dtype=torch.long, device=device) cfg_scale = 0. lambdaA=0. lambdaB=0. for step in range(_sample_steps): ratio = 1. * (step + 1) / _sample_steps annealed_temp = self.gen_temp * (1 - ratio) is_mask = (ids == self.mask_ind) # 尝试使用 *ratio logits = nnet(ids, **kwargs, scale=cfg_scale,lambdaA=lambdaA,lambdaB=lambdaB) # sampling & scoring sampled_ids = add_gumbel_noise(logits, annealed_temp, device).argmax(dim=-1) sampled_logits = torch.squeeze( torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1) sampled_ids = torch.where(is_mask, sampled_ids, ids) sampled_logits = torch.where(is_mask, sampled_logits, +np.inf).float() # masking mask_ratio = np.cos(ratio * math.pi * 0.5) mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).to(device) mask_len = torch.maximum(torch.Tensor([1]).to(device), torch.minimum(torch.sum(is_mask, dim=-1, keepdims=True) - 1, mask_len))[0].squeeze() confidence = add_gumbel_noise(sampled_logits, annealed_temp, device) sorted_confidence, _ = torch.sort(confidence, axis=-1) cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()] masking = (confidence <= cut_off) ids = torch.where(masking, self.mask_ind, sampled_ids) cfg_scale = ratio * config.sample.scale lambdaA = config.sample.lambdaA lambdaB = config.sample.lambdaB _z2 = rearrange(sampled_ids, 'b (i j) -> b i j', i=fmap_size, j=fmap_size) _z = _z2 if is_eval else torch.cat([_z1,_z2],dim=0) out = decode_fn(_z) return out