Spaces:
Paused
Paused
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "False" | |
| import torch | |
| from tqdm import tqdm | |
| from .models.utils import fast_sta_nabla | |
| def get_sparse_params(conf, batch_embeds, device): | |
| assert conf.model.dit_params.patch_size[0] == 1 | |
| T, H, W, _ = batch_embeds["visual"].shape | |
| T, H, W = ( | |
| T // conf.model.dit_params.patch_size[0], | |
| H // conf.model.dit_params.patch_size[1], | |
| W // conf.model.dit_params.patch_size[2], | |
| ) | |
| if conf.model.attention.type == "nabla": | |
| sta_mask = fast_sta_nabla(T, H // 8, W // 8, conf.model.attention.wT, | |
| conf.model.attention.wH, conf.model.attention.wW, device=device) | |
| sparse_params = { | |
| "sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0), | |
| "attention_type": conf.model.attention.type, | |
| "to_fractal": True, | |
| "P": conf.model.attention.P, | |
| "wT": conf.model.attention.wT, | |
| "wW": conf.model.attention.wW, | |
| "wH": conf.model.attention.wH, | |
| "add_sta": conf.model.attention.add_sta, | |
| "visual_shape": (T, H, W), | |
| "method": getattr(conf.model.attention, "method", "topcdf"), | |
| } | |
| else: | |
| sparse_params = None | |
| return sparse_params | |
| def get_velocity( | |
| dit, | |
| x, | |
| t, | |
| text_embeds, | |
| null_text_embeds, | |
| visual_rope_pos, | |
| text_rope_pos, | |
| null_text_rope_pos, | |
| guidance_weight, | |
| conf, | |
| sparse_params=None, | |
| ): | |
| pred_velocity = dit( | |
| x, | |
| text_embeds["text_embeds"], | |
| text_embeds["pooled_embed"], | |
| t * 1000, | |
| visual_rope_pos, | |
| text_rope_pos, | |
| scale_factor=conf.metrics.scale_factor, | |
| sparse_params=sparse_params, | |
| ) | |
| if abs(guidance_weight - 1.0) > 1e-6: | |
| uncond_pred_velocity = dit( | |
| x, | |
| null_text_embeds["text_embeds"], | |
| null_text_embeds["pooled_embed"], | |
| t * 1000, | |
| visual_rope_pos, | |
| null_text_rope_pos, | |
| scale_factor=conf.metrics.scale_factor, | |
| sparse_params=sparse_params, | |
| ) | |
| pred_velocity = uncond_pred_velocity + guidance_weight * ( | |
| pred_velocity - uncond_pred_velocity | |
| ) | |
| return pred_velocity | |
| def generate( | |
| model, | |
| device, | |
| shape, | |
| num_steps, | |
| text_embeds, | |
| null_text_embeds, | |
| visual_rope_pos, | |
| text_rope_pos, | |
| null_text_rope_pos, | |
| guidance_weight, | |
| scheduler_scale, | |
| conf, | |
| progress=False, | |
| seed=6554, | |
| ): | |
| g = torch.Generator(device="cuda") | |
| g.manual_seed(seed) | |
| img = torch.randn(*shape, device=device, generator=g) | |
| sparse_params = get_sparse_params(conf, {"visual": img}, device) | |
| timesteps = torch.linspace(1, 0, num_steps + 1, device=device) | |
| timesteps = scheduler_scale * timesteps / (1 + (scheduler_scale - 1) * timesteps) | |
| for timestep, timestep_diff in tqdm(list(zip(timesteps[:-1], torch.diff(timesteps)))): | |
| time = timestep.unsqueeze(0) | |
| if model.visual_cond: | |
| visual_cond = torch.zeros_like(img) | |
| visual_cond_mask = torch.zeros( | |
| [*img.shape[:-1], 1], dtype=img.dtype, device=img.device | |
| ) | |
| model_input = torch.cat([img, visual_cond, visual_cond_mask], dim=-1) | |
| else: | |
| model_input = img | |
| pred_velocity = get_velocity( | |
| model, | |
| model_input, | |
| time, | |
| text_embeds, | |
| null_text_embeds, | |
| visual_rope_pos, | |
| text_rope_pos, | |
| null_text_rope_pos, | |
| guidance_weight, | |
| conf, | |
| sparse_params=sparse_params, | |
| ) | |
| img = img + timestep_diff * pred_velocity | |
| return img | |
| def generate_sample( | |
| shape, | |
| caption, | |
| dit, | |
| vae, | |
| conf, | |
| text_embedder, | |
| num_steps=25, | |
| guidance_weight=5.0, | |
| scheduler_scale=1, | |
| negative_caption="", | |
| seed=6554, | |
| device="cuda", | |
| vae_device="cuda", | |
| text_embedder_device="cuda", | |
| progress=True, | |
| offload=False, | |
| ): | |
| bs, duration, height, width, dim = shape | |
| if duration == 1: | |
| type_of_content = "image" | |
| else: | |
| type_of_content = "video" | |
| with torch.no_grad(): | |
| bs_text_embed, text_cu_seqlens = text_embedder.encode( | |
| [caption], type_of_content=type_of_content | |
| ) | |
| bs_null_text_embed, null_text_cu_seqlens = text_embedder.encode( | |
| [negative_caption], type_of_content=type_of_content | |
| ) | |
| if offload: | |
| text_embedder = text_embedder.to('cpu') | |
| for key in bs_text_embed: | |
| bs_text_embed[key] = bs_text_embed[key].to(device=device) | |
| bs_null_text_embed[key] = bs_null_text_embed[key].to(device=device) | |
| text_cu_seqlens = text_cu_seqlens.to(device=device)[-1].item() | |
| null_text_cu_seqlens = null_text_cu_seqlens.to(device=device)[-1].item() | |
| visual_rope_pos = [ | |
| torch.arange(duration), | |
| torch.arange(shape[-3] // conf.model.dit_params.patch_size[1]), | |
| torch.arange(shape[-2] // conf.model.dit_params.patch_size[2]), | |
| ] | |
| text_rope_pos = torch.arange(text_cu_seqlens) | |
| null_text_rope_pos = torch.arange(null_text_cu_seqlens) | |
| if offload: | |
| dit.to(device, non_blocking=True) | |
| with torch.no_grad(): | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| latent_visual = generate( | |
| dit, | |
| device, | |
| (bs * duration, height, width, dim), | |
| num_steps, | |
| bs_text_embed, | |
| bs_null_text_embed, | |
| visual_rope_pos, | |
| text_rope_pos, | |
| null_text_rope_pos, | |
| guidance_weight, | |
| scheduler_scale, | |
| conf, | |
| seed=seed, | |
| progress=progress, | |
| ) | |
| if offload: | |
| dit = dit.to('cpu', non_blocking=True) | |
| torch.cuda.empty_cache() | |
| if offload: | |
| vae = vae.to(vae_device, non_blocking=True) | |
| with torch.no_grad(): | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| images = latent_visual.reshape( | |
| bs, | |
| -1, | |
| latent_visual.shape[-3], | |
| latent_visual.shape[-2], | |
| latent_visual.shape[-1], | |
| ) | |
| images = images.to(device=vae_device) | |
| images = (images / vae.config.scaling_factor).permute(0, 4, 1, 2, 3) | |
| images = vae.decode(images).sample | |
| images = ((images.clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8) | |
| if offload: | |
| vae = vae.to('cpu', non_blocking=True) | |
| torch.cuda.empty_cache() | |
| return images | |