Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import math | |
from einops import rearrange | |
from torch.nn import functional as F | |
def add_gumbel_noise(t, temperature, device): | |
return (t + torch.Tensor(temperature * np.random.gumbel(size=t.shape)).to(device)) | |
class MUSE(object): | |
def __init__(self, codebook_size, device, ignore_ind=-1, smoothing=0., gen_temp=4.5): | |
self.mask_ind = codebook_size # for input masking | |
self.ignore_ind = ignore_ind # for ce loss, excluding visible | |
self.device = device | |
self.smoothing = smoothing | |
self.gen_temp = gen_temp | |
def cosine_schedule(t): | |
return torch.cos(t * math.pi * 0.5) | |
def sample(self, x0): | |
N, L, device = *x0.shape, self.device | |
timesteps = torch.zeros((N,), device=device).float().uniform_(0, 1) | |
rand_mask_probs = self.cosine_schedule(timesteps) # cosine schedule | |
num_token_masked = (L * rand_mask_probs).round().clamp(min=1) | |
batch_randperm = torch.rand(N, L, device=device).argsort(dim=-1) | |
mask = batch_randperm < rearrange(num_token_masked, 'b -> b 1') | |
masked_ids = torch.where(mask, self.mask_ind, x0) | |
labels = torch.where(mask, x0, self.ignore_ind) | |
return labels, masked_ids | |
def loss(self, pred, label): | |
return F.cross_entropy(pred.transpose(1, 2), label.long(), | |
ignore_index=self.ignore_ind, label_smoothing=self.smoothing) | |
def generate(self, config, _n_samples, nnet, decode_fn, is_eval=False, **kwargs): | |
fmap_size, _sample_steps, device = config.z_shape[-1], config.sample.sample_steps, self.device | |
seq_len = fmap_size ** 2 | |
ids = torch.full((_n_samples, seq_len), self.mask_ind, dtype=torch.long, device=device) | |
cfg_scale = 0. | |
for step in range(_sample_steps): | |
ratio = 1. * (step + 1) / _sample_steps | |
annealed_temp = self.gen_temp * (1 - ratio) | |
is_mask = (ids == self.mask_ind) | |
logits = nnet(ids, **kwargs, scale=cfg_scale) | |
# sampling & scoring | |
sampled_ids = add_gumbel_noise(logits, annealed_temp, device).argmax(dim=-1) | |
sampled_logits = torch.squeeze( | |
torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1) | |
sampled_ids = torch.where(is_mask, sampled_ids, ids) | |
sampled_logits = torch.where(is_mask, sampled_logits, +np.inf).float() | |
# masking | |
mask_ratio = np.cos(ratio * math.pi * 0.5) | |
mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).to(device) | |
mask_len = torch.maximum(torch.Tensor([1]).to(device), | |
torch.minimum(torch.sum(is_mask, dim=-1, keepdims=True) - 1, | |
mask_len))[0].squeeze() | |
confidence = add_gumbel_noise(sampled_logits, annealed_temp, device) | |
sorted_confidence, _ = torch.sort(confidence, axis=-1) | |
cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()] | |
masking = (confidence <= cut_off) | |
ids = torch.where(masking, self.mask_ind, sampled_ids) | |
cfg_scale = ratio * config.sample.scale | |
_z1 = rearrange(sampled_ids, 'b (i j) -> b i j', i=fmap_size, j=fmap_size) | |
# with adapter | |
ids = torch.full((_n_samples, seq_len), self.mask_ind, dtype=torch.long, device=device) | |
cfg_scale = 0. | |
lambdaA=0. | |
lambdaB=0. | |
for step in range(_sample_steps): | |
ratio = 1. * (step + 1) / _sample_steps | |
annealed_temp = self.gen_temp * (1 - ratio) | |
is_mask = (ids == self.mask_ind) | |
# 尝试使用 *ratio | |
logits = nnet(ids, **kwargs, scale=cfg_scale,lambdaA=lambdaA,lambdaB=lambdaB) | |
# sampling & scoring | |
sampled_ids = add_gumbel_noise(logits, annealed_temp, device).argmax(dim=-1) | |
sampled_logits = torch.squeeze( | |
torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1) | |
sampled_ids = torch.where(is_mask, sampled_ids, ids) | |
sampled_logits = torch.where(is_mask, sampled_logits, +np.inf).float() | |
# masking | |
mask_ratio = np.cos(ratio * math.pi * 0.5) | |
mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).to(device) | |
mask_len = torch.maximum(torch.Tensor([1]).to(device), | |
torch.minimum(torch.sum(is_mask, dim=-1, keepdims=True) - 1, | |
mask_len))[0].squeeze() | |
confidence = add_gumbel_noise(sampled_logits, annealed_temp, device) | |
sorted_confidence, _ = torch.sort(confidence, axis=-1) | |
cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()] | |
masking = (confidence <= cut_off) | |
ids = torch.where(masking, self.mask_ind, sampled_ids) | |
cfg_scale = ratio * config.sample.scale | |
lambdaA = config.sample.lambdaA | |
lambdaB = config.sample.lambdaB | |
_z2 = rearrange(sampled_ids, 'b (i j) -> b i j', i=fmap_size, j=fmap_size) | |
_z = _z2 if is_eval else torch.cat([_z1,_z2],dim=0) | |
out = decode_fn(_z) | |
return out | |