Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torchvision.transforms as T | |
class LVDM(nn.Module): | |
def __init__(self, referencenet, unet, pose_guider): | |
super().__init__() | |
self.referencenet = referencenet | |
self.unet = unet | |
self.pose_guider = pose_guider | |
def forward( | |
self, | |
noisy_latents, | |
timesteps, | |
ref_image_latents, | |
clip_image_embeds, | |
pose_img, | |
uncond_fwd: bool = False, | |
): | |
# noisy_latents.shape = torch.Size([4, 4, 1, 112, 80]) | |
# timesteps = tensor([426, 277, 802, 784], device='cuda:5') | |
# ref_image_latents.shape = torch.Size([4, 4, 112, 80]) | |
# clip_image_embeds.shape = torch.Size([4, 1, 768]) | |
# pose_img.shape = torch.Size([4, 3, 1, 896, 640]) | |
# uncond_fwd = False | |
pose_cond_tensor = pose_img.to(device="cuda") | |
pose_fea = self.pose_guider(pose_cond_tensor) | |
# pose_fea.shape = torch.Size([4, 320, 1, 112, 80]) | |
# not uncond_fwd = True | |
if not uncond_fwd: | |
ref_timesteps = torch.zeros_like(timesteps) | |
reference_down_block_res_samples, reference_mid_block_res_sample, reference_up_block_res_samples = \ | |
self.referencenet(ref_image_latents, | |
ref_timesteps, | |
encoder_hidden_states=clip_image_embeds, | |
return_dict=False) | |
self.unet.set_do_classifier_free_guidance(do_classifier_free_guidance=False) | |
model_pred = self.unet(noisy_latents, | |
timesteps, | |
pose_cond_fea=pose_fea, | |
encoder_hidden_states=clip_image_embeds, | |
reference_down_block_res_samples=reference_down_block_res_samples if not uncond_fwd else None, | |
reference_mid_block_res_sample=reference_mid_block_res_sample if not uncond_fwd else None, | |
reference_up_block_res_samples=reference_up_block_res_samples if not uncond_fwd else None).sample | |
return model_pred |