Follow-Your-Emoji / models /__init__.py
Mildclimate's picture
upload models module
8bc810e verified
import torch
import torch.nn as nn
import torchvision.transforms as T
class LVDM(nn.Module):
def __init__(self, referencenet, unet, pose_guider):
super().__init__()
self.referencenet = referencenet
self.unet = unet
self.pose_guider = pose_guider
def forward(
self,
noisy_latents,
timesteps,
ref_image_latents,
clip_image_embeds,
pose_img,
uncond_fwd: bool = False,
):
# noisy_latents.shape = torch.Size([4, 4, 1, 112, 80])
# timesteps = tensor([426, 277, 802, 784], device='cuda:5')
# ref_image_latents.shape = torch.Size([4, 4, 112, 80])
# clip_image_embeds.shape = torch.Size([4, 1, 768])
# pose_img.shape = torch.Size([4, 3, 1, 896, 640])
# uncond_fwd = False
pose_cond_tensor = pose_img.to(device="cuda")
pose_fea = self.pose_guider(pose_cond_tensor)
# pose_fea.shape = torch.Size([4, 320, 1, 112, 80])
# not uncond_fwd = True
if not uncond_fwd:
ref_timesteps = torch.zeros_like(timesteps)
reference_down_block_res_samples, reference_mid_block_res_sample, reference_up_block_res_samples = \
self.referencenet(ref_image_latents,
ref_timesteps,
encoder_hidden_states=clip_image_embeds,
return_dict=False)
self.unet.set_do_classifier_free_guidance(do_classifier_free_guidance=False)
model_pred = self.unet(noisy_latents,
timesteps,
pose_cond_fea=pose_fea,
encoder_hidden_states=clip_image_embeds,
reference_down_block_res_samples=reference_down_block_res_samples if not uncond_fwd else None,
reference_mid_block_res_sample=reference_mid_block_res_sample if not uncond_fwd else None,
reference_up_block_res_samples=reference_up_block_res_samples if not uncond_fwd else None).sample
return model_pred