Spaces:
Runtime error
Runtime error
| import copy | |
| from typing import Optional | |
| import torch.nn as nn | |
| import torch | |
| from einops import rearrange | |
| import math | |
| import numpy as np | |
| import torch.nn.functional as F | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, dropout=0.1, max_len=5000): | |
| super(PositionalEncoding, self).__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros(max_len, d_model) # (5000, 128) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (5000, 1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0).transpose(0, 1) | |
| self.register_buffer('pe', pe) | |
| def forward(self, x): | |
| # not used in the final model | |
| x = x + self.pe[:x.shape[0], :] | |
| return self.dropout(x) | |
| class TimestepEmbedding(nn.Module): | |
| def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str, | |
| out_dim = None, post_act_fn = None, | |
| cond_proj_dim = None, zero_init_cond: bool = True) -> None: | |
| super(TimestepEmbedding, self).__init__() | |
| self.linear_1 = nn.Linear(in_channels, time_embed_dim) | |
| if cond_proj_dim is not None: | |
| self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) | |
| if zero_init_cond: | |
| self.cond_proj.weight.data.fill_(0.0) | |
| else: | |
| self.cond_proj = None | |
| # gelu | |
| self.act = torch.nn.GELU() if act_fn == 'gelu' else torch.nn.SiLU() | |
| if out_dim is not None: | |
| time_embed_dim_out = out_dim | |
| else: | |
| time_embed_dim_out = time_embed_dim | |
| self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) | |
| if post_act_fn is None: | |
| self.post_act = None | |
| else: | |
| self.post_act = torch.nn.GELU() if post_act_fn == 'gelu' else torch.nn.SiLU() | |
| def forward(self, sample: torch.Tensor, timestep_cond = None) -> torch.Tensor: | |
| if timestep_cond is not None: | |
| sample = sample + self.cond_proj(timestep_cond) | |
| sample = self.linear_1(sample) | |
| sample = self.act(sample) | |
| sample = self.linear_2(sample) | |
| if self.post_act is not None: | |
| sample = self.post_act(sample) | |
| return sample | |
| class TimestepEmbedder(nn.Module): | |
| def __init__(self, latent_dim, sequence_pos_encoder): | |
| super().__init__() | |
| self.latent_dim = latent_dim | |
| self.sequence_pos_encoder = sequence_pos_encoder | |
| time_embed_dim = self.latent_dim | |
| self.time_embed = nn.Sequential( | |
| nn.Linear(self.latent_dim, time_embed_dim), | |
| nn.SiLU(), | |
| nn.Linear(time_embed_dim, time_embed_dim), | |
| ) | |
| def forward(self, timesteps): | |
| return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) | |
| class InputProcess(nn.Module): | |
| def __init__(self, input_feats, latent_dim): | |
| super().__init__() | |
| self.input_feats = input_feats | |
| self.latent_dim = latent_dim | |
| self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim) | |
| def forward(self, x): | |
| x = x.permute((0, 1, 3, 2)) | |
| x = self.poseEmbedding(x) # [seqlen, bs, d] | |
| return x | |
| class OutputProcess(nn.Module): | |
| def __init__(self, input_feats, latent_dim): | |
| super().__init__() | |
| self.input_feats = input_feats | |
| self.latent_dim = latent_dim | |
| self.poseFinal = nn.Linear(self.latent_dim, self.input_feats) | |
| def forward(self, output): | |
| bs, n_joints, nframes, d = output.shape | |
| output = self.poseFinal(output) | |
| output = output.permute(0, 1, 3, 2) # [bs, njoints, nfeats, nframes] | |
| output = output.reshape(bs, n_joints * 128, 1, nframes) | |
| return output | |
| class SinusoidalEmbeddings(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.register_buffer('inv_freq', inv_freq) | |
| def forward(self, x): | |
| n = x.shape[-2] | |
| t = torch.arange(n, device = x.device).type_as(self.inv_freq) | |
| freqs = torch.einsum('i , j -> i j', t, self.inv_freq) | |
| return torch.cat((freqs, freqs), dim=-1) | |
| def rotate_half(x): | |
| x = rearrange(x, 'b ... (r d) -> b (...) r d', r = 2) | |
| x1, x2 = x.unbind(dim = -2) | |
| return torch.cat((-x2, x1), dim = -1) | |
| def apply_rotary_pos_emb(q, k, freqs): | |
| q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) | |
| return q, k | |
| class Timesteps(nn.Module): | |
| def __init__(self, num_channels: int, flip_sin_to_cos: bool, | |
| downscale_freq_shift: float) -> None: | |
| super().__init__() | |
| self.num_channels = num_channels | |
| self.flip_sin_to_cos = flip_sin_to_cos | |
| self.downscale_freq_shift = downscale_freq_shift | |
| def forward(self, timesteps: torch.Tensor) -> torch.Tensor: | |
| t_emb = get_timestep_embedding( | |
| timesteps, | |
| self.num_channels, | |
| flip_sin_to_cos=self.flip_sin_to_cos, | |
| downscale_freq_shift=self.downscale_freq_shift) | |
| return t_emb | |
| def get_timestep_embedding( | |
| timesteps: torch.Tensor, | |
| embedding_dim: int, | |
| flip_sin_to_cos: bool = False, | |
| downscale_freq_shift: float = 1, | |
| scale: float = 1, | |
| max_period: int = 10000, | |
| ) -> torch.Tensor: | |
| # assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" | |
| half_dim = embedding_dim // 2 | |
| exponent = -math.log(max_period) * torch.arange( | |
| start=0, end=half_dim, dtype=torch.float32, device=timesteps.device | |
| ) | |
| exponent = exponent / (half_dim - downscale_freq_shift) | |
| emb = torch.exp(exponent) | |
| emb = timesteps[:, None].float() * emb[None, :] | |
| # scale embeddings | |
| emb = scale * emb | |
| # concat sine and cosine embeddings | |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) | |
| # flip sine and cosine embeddings | |
| if flip_sin_to_cos: | |
| emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) | |
| # zero pad | |
| if embedding_dim % 2 == 1: | |
| emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) | |
| return emb | |
| def reparameterize(mu, logvar): | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return mu + eps * std | |
| def init_weight(m): | |
| if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): | |
| nn.init.xavier_normal_(m.weight) | |
| # m.bias.data.fill_(0.01) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def init_weight_skcnn(m): | |
| if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): | |
| nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) | |
| # m.bias.data.fill_(0.01) | |
| if m.bias is not None: | |
| #nn.init.constant_(m.bias, 0) | |
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) | |
| bound = 1 / math.sqrt(fan_in) | |
| nn.init.uniform_(m.bias, -bound, bound) | |
| def sample(logits, temperature: float = 1.0, top_k: int = 0, top_p: float = 1.0, sample_logits=True): | |
| logits = logits[:, -1, :] / max(temperature, 1e-5) | |
| if top_k > 0 or top_p < 1.0: | |
| logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) | |
| probs = F.softmax(logits, dim=-1) | |
| if sample_logits: | |
| idx = torch.multinomial(probs, num_samples=1) | |
| else: | |
| _, idx = torch.topk(probs, k=1, dim=-1) | |
| return idx, probs | |
| ### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html | |
| def top_k_top_p_filtering( | |
| logits, | |
| top_k: int = 0, | |
| top_p: float = 1.0, | |
| filter_value: float = -float("Inf"), | |
| min_tokens_to_keep: int = 1, | |
| ): | |
| """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering | |
| Args: | |
| logits: logits distribution shape (batch size, vocabulary size) | |
| if top_k > 0: keep only top k tokens with highest probability (top-k filtering). | |
| if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). | |
| Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) | |
| Make sure we keep at least min_tokens_to_keep per batch example in the output | |
| From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 | |
| """ | |
| if top_k > 0: | |
| top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check | |
| # Remove all tokens with a probability less than the last token of the top-k | |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
| logits[indices_to_remove] = filter_value | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| # Remove tokens with cumulative probability above the threshold (token with 0 are kept) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| if min_tokens_to_keep > 1: | |
| # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) | |
| sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 | |
| # Shift the indices to the right to keep also the first token above the threshold | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| # scatter sorted tensors to original indexing | |
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |
| logits[indices_to_remove] = filter_value | |
| return logits | |
| class FlowMatchScheduler(): | |
| def __init__(self, num_inference_steps=20, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003 / 1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False): | |
| self.num_train_timesteps = num_train_timesteps | |
| self.shift = shift | |
| self.sigma_max = sigma_max | |
| self.sigma_min = sigma_min | |
| self.inverse_timesteps = inverse_timesteps | |
| self.extra_one_step = extra_one_step | |
| self.reverse_sigmas = reverse_sigmas | |
| self.set_timesteps(num_inference_steps, training=True) | |
| def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False): | |
| sigma_start = self.sigma_min + \ | |
| (self.sigma_max - self.sigma_min) * denoising_strength | |
| if self.extra_one_step: | |
| self.sigmas = torch.linspace( | |
| sigma_start, self.sigma_min, num_inference_steps + 1)[:-1] | |
| else: | |
| self.sigmas = torch.linspace( | |
| sigma_start, self.sigma_min, num_inference_steps) | |
| if self.inverse_timesteps: | |
| self.sigmas = torch.flip(self.sigmas, dims=[0]) | |
| self.sigmas = self.shift * self.sigmas / \ | |
| (1 + (self.shift - 1) * self.sigmas) | |
| if self.reverse_sigmas: | |
| self.sigmas = 1 - self.sigmas | |
| self.timesteps = self.sigmas * self.num_train_timesteps | |
| if training: | |
| x = self.timesteps | |
| y = torch.exp(-2 * ((x - num_inference_steps / 2) / | |
| num_inference_steps) ** 2) | |
| y_shifted = y - y.min() | |
| bsmntw_weighing = y_shifted * \ | |
| (num_inference_steps / y_shifted.sum()) | |
| self.linear_timesteps_weights = bsmntw_weighing | |
| def step(self, model_output, timestep, sample, to_final=False): | |
| if timestep.ndim == 2: | |
| timestep = timestep.flatten(0, 1) | |
| self.sigmas = self.sigmas.to(model_output.device) | |
| self.timesteps = self.timesteps.to(model_output.device) | |
| timestep_id = torch.argmin( | |
| (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) | |
| sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1) | |
| if to_final or (timestep_id + 1 >= len(self.timesteps)).any(): | |
| sigma_ = 1 if ( | |
| self.inverse_timesteps or self.reverse_sigmas) else 0 | |
| else: | |
| sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1) | |
| prev_sample = sample + model_output * (sigma_ - sigma) | |
| return prev_sample | |
| def add_noise(self, original_samples, noise, timestep): | |
| """ | |
| Diffusion forward corruption process. | |
| Input: | |
| - clean_latent: the clean latent with shape [B*T, C, H, W] | |
| - noise: the noise with shape [B*T, C, H, W] | |
| - timestep: the timestep with shape [B*T] | |
| Output: the corrupted latent with shape [B*T, C, H, W] | |
| """ | |
| if timestep.ndim == 2: | |
| timestep = timestep.flatten(0, 1) | |
| self.sigmas = self.sigmas.to(noise.device) | |
| self.timesteps = self.timesteps.to(noise.device) | |
| timestep_id = torch.argmin( | |
| (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) | |
| sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1) | |
| sample = (1 - sigma) * original_samples + sigma * noise | |
| return sample.type_as(noise) | |
| def training_target(self, sample, noise, timestep): | |
| target = noise - sample | |
| return target | |
| def training_weight(self, timestep): | |
| """ | |
| Input: | |
| - timestep: the timestep with shape [B*T] | |
| Output: the corresponding weighting [B*T] | |
| """ | |
| if timestep.ndim == 2: | |
| timestep = timestep.flatten(0, 1) | |
| self.linear_timesteps_weights = self.linear_timesteps_weights.to(timestep.device) | |
| timestep_id = torch.argmin( | |
| (self.timesteps.unsqueeze(1) - timestep.unsqueeze(0)).abs(), dim=0) | |
| weights = self.linear_timesteps_weights[timestep_id] | |
| return weights |