Spaces:
Paused
Paused
File size: 6,615 Bytes
135b069 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
def gumbel_noise(t, generator=None):
device = generator.device if generator is not None else t.device
noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))
def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
sorted_confidence = torch.sort(confidence, dim=-1).values
cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
masking = confidence < cut_off
return masking
@dataclass
class AmusedSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: torch.FloatTensor = None
class AmusedScheduler(SchedulerMixin, ConfigMixin):
order = 1
temperatures: torch.Tensor
@register_to_config
def __init__(
self,
mask_token_id: int,
masking_schedule: str = "cosine",
):
self.temperatures = None
self.timesteps = None
def set_timesteps(
self,
num_inference_steps: int,
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
device: Union[str, torch.device] = None,
):
self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
if isinstance(temperature, (tuple, list)):
self.temperatures = torch.linspace(temperature[0], temperature[1], num_inference_steps, device=device)
else:
self.temperatures = torch.linspace(temperature, 0.01, num_inference_steps, device=device)
def step(
self,
model_output: torch.FloatTensor,
timestep: torch.long,
sample: torch.LongTensor,
starting_mask_ratio: int = 1,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[AmusedSchedulerOutput, Tuple]:
two_dim_input = sample.ndim == 3 and model_output.ndim == 4
if two_dim_input:
batch_size, codebook_size, height, width = model_output.shape
sample = sample.reshape(batch_size, height * width)
model_output = model_output.reshape(batch_size, codebook_size, height * width).permute(0, 2, 1)
unknown_map = sample == self.config.mask_token_id
probs = model_output.softmax(dim=-1)
device = probs.device
probs_ = probs.to(generator.device) if generator is not None else probs # handles when generator is on CPU
if probs_.device.type == "cpu" and probs_.dtype != torch.float32:
probs_ = probs_.float() # multinomial is not implemented for cpu half precision
probs_ = probs_.reshape(-1, probs.size(-1))
pred_original_sample = torch.multinomial(probs_, 1, generator=generator).to(device=device)
pred_original_sample = pred_original_sample[:, 0].view(*probs.shape[:-1])
pred_original_sample = torch.where(unknown_map, pred_original_sample, sample)
if timestep == 0:
prev_sample = pred_original_sample
else:
seq_len = sample.shape[1]
step_idx = (self.timesteps == timestep).nonzero()
ratio = (step_idx + 1) / len(self.timesteps)
if self.config.masking_schedule == "cosine":
mask_ratio = torch.cos(ratio * math.pi / 2)
elif self.config.masking_schedule == "linear":
mask_ratio = 1 - ratio
else:
raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
mask_ratio = starting_mask_ratio * mask_ratio
mask_len = (seq_len * mask_ratio).floor()
# do not mask more than amount previously masked
mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
# mask at least one
mask_len = torch.max(torch.tensor([1], device=model_output.device), mask_len)
selected_probs = torch.gather(probs, -1, pred_original_sample[:, :, None])[:, :, 0]
# Ignores the tokens given in the input by overwriting their confidence.
selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
masking = mask_by_random_topk(mask_len, selected_probs, self.temperatures[step_idx], generator)
# Masks tokens with lower confidence.
prev_sample = torch.where(masking, self.config.mask_token_id, pred_original_sample)
if two_dim_input:
prev_sample = prev_sample.reshape(batch_size, height, width)
pred_original_sample = pred_original_sample.reshape(batch_size, height, width)
if not return_dict:
return (prev_sample, pred_original_sample)
return AmusedSchedulerOutput(prev_sample, pred_original_sample)
def add_noise(self, sample, timesteps, generator=None):
step_idx = (self.timesteps == timesteps).nonzero()
ratio = (step_idx + 1) / len(self.timesteps)
if self.config.masking_schedule == "cosine":
mask_ratio = torch.cos(ratio * math.pi / 2)
elif self.config.masking_schedule == "linear":
mask_ratio = 1 - ratio
else:
raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
mask_indices = (
torch.rand(
sample.shape, device=generator.device if generator is not None else sample.device, generator=generator
).to(sample.device)
< mask_ratio
)
masked_sample = sample.clone()
masked_sample[mask_indices] = self.config.mask_token_id
return masked_sample
|