Spaces:
Runtime error
Runtime error
| import typing as tp | |
| import torch | |
| from einops import rearrange | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from x_transformers import ContinuousTransformerWrapper, Encoder | |
| from .blocks import FourierFeatures | |
| from .transformer import ContinuousTransformer | |
| from model.stable import transformer_use_mask | |
| class DiffusionTransformerV2(nn.Module): | |
| def __init__(self, | |
| io_channels=32, | |
| patch_size=1, | |
| embed_dim=768, | |
| cond_token_dim=0, | |
| project_cond_tokens=True, | |
| global_cond_dim=0, | |
| project_global_cond=True, | |
| input_concat_dim=0, | |
| prepend_cond_dim=0, | |
| depth=12, | |
| num_heads=8, | |
| transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers", | |
| global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", | |
| **kwargs): | |
| super().__init__() | |
| d_model = embed_dim | |
| n_head = num_heads | |
| n_layers = depth | |
| encoder_layer = torch.nn.TransformerEncoderLayer(batch_first=True, | |
| norm_first=True, | |
| d_model=d_model, | |
| nhead=n_head) | |
| self.transformer = torch.nn.TransformerEncoder(encoder_layer, num_layers=n_layers) | |
| # ===================================== timestep embedding | |
| timestep_features_dim = 256 | |
| self.timestep_features = FourierFeatures(1, timestep_features_dim) | |
| self.to_timestep_embed = nn.Sequential( | |
| nn.Linear(timestep_features_dim, embed_dim, bias=True), | |
| nn.SiLU(), | |
| nn.Linear(embed_dim, embed_dim, bias=True), | |
| ) | |
| def _forward( | |
| self, | |
| Xt_btd, | |
| t, #(1d) | |
| mu_btd, | |
| ): | |
| timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) | |
| cated_input = torch.cat([t,mu,x_t]) | |
| ### 1. ιθ¦ιζ°εθΏδ»₯ιεΊδΈειΏεΊ¦ηcon | |
| if cross_attn_cond is not None: | |
| cross_attn_cond = self.to_cond_embed(cross_attn_cond) | |
| if global_embed is not None: | |
| # Project the global conditioning to the embedding dimension | |
| global_embed = self.to_global_embed(global_embed) | |
| prepend_inputs = None | |
| prepend_mask = None | |
| prepend_length = 0 | |
| if prepend_cond is not None: | |
| # Project the prepend conditioning to the embedding dimension | |
| prepend_cond = self.to_prepend_embed(prepend_cond) | |
| prepend_inputs = prepend_cond | |
| if prepend_cond_mask is not None: | |
| prepend_mask = prepend_cond_mask | |
| if input_concat_cond is not None: | |
| # Interpolate input_concat_cond to the same length as x | |
| if input_concat_cond.shape[2] != x.shape[2]: | |
| input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2],), mode='nearest') | |
| x = torch.cat([x, input_concat_cond], dim=1) | |
| # Get the batch of timestep embeddings | |
| try: | |
| timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) | |
| except Exception as e: | |
| print("t.shape:", t.shape, "x.shape", x.shape) | |
| print("t:", t) | |
| raise e | |
| # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists | |
| if global_embed is not None: | |
| global_embed = global_embed + timestep_embed | |
| else: | |
| global_embed = timestep_embed | |
| # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer | |
| if self.global_cond_type == "prepend": | |
| if prepend_inputs is None: | |
| # Prepend inputs are just the global embed, and the mask is all ones | |
| prepend_inputs = global_embed.unsqueeze(1) | |
| prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool) | |
| else: | |
| # Prepend inputs are the prepend conditioning + the global embed | |
| prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1) | |
| prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], | |
| dim=1) | |
| prepend_length = prepend_inputs.shape[1] | |
| x = self.preprocess_conv(x) + x | |
| x = rearrange(x, "b c t -> b t c") | |
| extra_args = {} | |
| if self.global_cond_type == "adaLN": | |
| extra_args["global_cond"] = global_embed | |
| if self.patch_size > 1: | |
| x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size) | |
| if self.transformer_type == "x-transformers": | |
| output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, | |
| context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, | |
| **extra_args, **kwargs) | |
| elif self.transformer_type in ["continuous_transformer", "continuous_transformer_with_mask"]: | |
| output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, | |
| context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, | |
| return_info=return_info, **extra_args, **kwargs) | |
| if return_info: | |
| output, info = output | |
| elif self.transformer_type == "mm_transformer": | |
| output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, | |
| **extra_args, **kwargs) | |
| output = rearrange(output, "b t c -> b c t")[:, :, prepend_length:] | |
| if self.patch_size > 1: | |
| output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size) | |
| output = self.postprocess_conv(output) + output | |
| if return_info: | |
| return output, info | |
| return output | |
| def forward( | |
| self, | |
| x, | |
| t, | |
| cross_attn_cond=None, | |
| cross_attn_cond_mask=None, | |
| negative_cross_attn_cond=None, | |
| negative_cross_attn_mask=None, | |
| input_concat_cond=None, | |
| global_embed=None, | |
| negative_global_embed=None, | |
| prepend_cond=None, | |
| prepend_cond_mask=None, | |
| cfg_scale=1.0, | |
| cfg_dropout_prob=0.0, | |
| causal=False, | |
| scale_phi=0.0, | |
| mask=None, | |
| return_info=False, | |
| **kwargs): | |
| assert causal == False, "Causal mode is not supported for DiffusionTransformer" | |
| if cross_attn_cond_mask is not None: | |
| cross_attn_cond_mask = cross_attn_cond_mask.bool() | |
| cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention | |
| if prepend_cond_mask is not None: | |
| prepend_cond_mask = prepend_cond_mask.bool() | |
| # CFG dropout | |
| if cfg_dropout_prob > 0.0: | |
| if cross_attn_cond is not None: | |
| null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) | |
| dropout_mask = torch.bernoulli( | |
| torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to( | |
| torch.bool) | |
| cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) | |
| if prepend_cond is not None: | |
| null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) | |
| dropout_mask = torch.bernoulli( | |
| torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to( | |
| torch.bool) | |
| prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) | |
| if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None): | |
| # Classifier-free guidance | |
| # Concatenate conditioned and unconditioned inputs on the batch dimension | |
| batch_inputs = torch.cat([x, x], dim=0) | |
| batch_timestep = torch.cat([t, t], dim=0) | |
| if global_embed is not None: | |
| batch_global_cond = torch.cat([global_embed, global_embed], dim=0) | |
| else: | |
| batch_global_cond = None | |
| if input_concat_cond is not None: | |
| batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0) | |
| else: | |
| batch_input_concat_cond = None | |
| batch_cond = None | |
| batch_cond_masks = None | |
| # Handle CFG for cross-attention conditioning | |
| if cross_attn_cond is not None: | |
| null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) | |
| # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning | |
| if negative_cross_attn_cond is not None: | |
| # If there's a negative cross-attention mask, set the masked tokens to the null embed | |
| if negative_cross_attn_mask is not None: | |
| negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2) | |
| negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, | |
| null_embed) | |
| batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0) | |
| else: | |
| batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0) | |
| if cross_attn_cond_mask is not None: | |
| batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0) | |
| batch_prepend_cond = None | |
| batch_prepend_cond_mask = None | |
| if prepend_cond is not None: | |
| null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) | |
| batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) | |
| if prepend_cond_mask is not None: | |
| batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) | |
| if mask is not None: | |
| batch_masks = torch.cat([mask, mask], dim=0) | |
| else: | |
| batch_masks = None | |
| batch_output = self._forward( | |
| batch_inputs, | |
| batch_timestep, | |
| cross_attn_cond=batch_cond, | |
| cross_attn_cond_mask=batch_cond_masks, | |
| mask=batch_masks, | |
| input_concat_cond=batch_input_concat_cond, | |
| global_embed=batch_global_cond, | |
| prepend_cond=batch_prepend_cond, | |
| prepend_cond_mask=batch_prepend_cond_mask, | |
| return_info=return_info, | |
| **kwargs) | |
| if return_info: | |
| batch_output, info = batch_output | |
| cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0) | |
| cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale | |
| # CFG Rescale | |
| if scale_phi != 0.0: | |
| cond_out_std = cond_output.std(dim=1, keepdim=True) | |
| out_cfg_std = cfg_output.std(dim=1, keepdim=True) | |
| output = scale_phi * (cfg_output * (cond_out_std / out_cfg_std)) + (1 - scale_phi) * cfg_output | |
| else: | |
| output = cfg_output | |
| if return_info: | |
| return output, info | |
| return output | |
| else: | |
| return self._forward( | |
| x, | |
| t, | |
| cross_attn_cond=cross_attn_cond, | |
| cross_attn_cond_mask=cross_attn_cond_mask, | |
| input_concat_cond=input_concat_cond, | |
| global_embed=global_embed, | |
| prepend_cond=prepend_cond, | |
| prepend_cond_mask=prepend_cond_mask, | |
| mask=mask, | |
| return_info=return_info, | |
| **kwargs | |
| ) | |