musim / pipeline.py
juliagsy's picture
Update pipeline.py
8bce792 verified
raw
history blame contribute delete
No virus
3.78 kB
import numpy as np
import torch
import torch.nn as nn
from diffusers import DiffusionPipeline
class MusImPipeline(DiffusionPipeline):
# modeling_file = "pipeline.py"
def __init__(self, vae, unet, scheduler, muvis, linear):
super().__init__()
self.register_modules(vae=vae, unet=unet, scheduler=scheduler, muvis=muvis, linear=linear)
self.loss_mse = nn.MSELoss()
def forward(self, wav, img):
# uncond_t = torch.zeros_like(wav.input_values, device=self.device)
wav_t = self.muvis(**wav)["last_hidden_state"]
latents = self.vae.encode(img)
latents = latents.latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
noise = torch.rand_like(latents, device=self.device, dtype=self.dtype)
t = torch.randint(
0, self.scheduler.config.num_train_timesteps, (img.size(0),), device=self.device
).long()
latents = self.scheduler.add_noise(latents, noise, t)
# uncond_embeddings = self.muvis(uncond_t)["last_hidden_state"]
# wav_embs = torch.cat([uncond_embeddings, wav_t])
# print(wav_embs.size())
wav_embs = wav_t
wav_embs = self.linear(wav_embs)
# latent_model_input = torch.cat([latents] * 2)
# print(latent_model_input.size())
latent_model_input = latents
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=wav_embs).sample
loss = self.loss_mse(noise_pred, noise)
return loss
@torch.no_grad()
def __call__(self, wav, img=None, batch_size=1, num_inference_steps=25, guidance_scale=7.5):
wav_t = self.muvis(**wav)["last_hidden_state"]
if img is not None:
latents = self.vae.encode(img)
latents = latents.latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
latents = latents + torch.rand_like(latents).to(self.device).to(self.dtype)
else:
# latents = self.vae.encode(torch.randn(
# (batch_size, self.unet.config.in_channels - 1, self.unet.config.sample_size * 8, self.unet.config.sample_size * 8), device=self.device
# ))
# latents = latents.latent_dist.sample()
# latents = latents * self.vae.config.scaling_factor
latents = torch.randn(
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size)
)
# latents = latents * self.vae.config.scaling_factor
latents = latents.to(self.device).to(self.dtype)
latents = latents * self.scheduler.init_noise_sigma
uncond_embeddings = self.muvis(torch.zeros(1, 1024, 128).to(self.device).to(self.dtype))["last_hidden_state"]
wav_embs = torch.cat([uncond_embeddings, wav_t])
wav_embs = self.linear(wav_embs)
self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep=t)
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=wav_embs).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
latents = latents / self.vae.config.scaling_factor
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1).squeeze()
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
return image