jhj0517 commited on
Commit
4a16e03
1 Parent(s): 1f6f578

add `init_model()` to musepose_inference.py

Browse files
Files changed (1) hide show
  1. musepose_inference.py +46 -36
musepose_inference.py CHANGED
@@ -4,7 +4,7 @@ from pathlib import Path
4
  import torch
5
  from diffusers import AutoencoderKL, DDIMScheduler
6
  from einops import repeat
7
- from omegaconf import OmegaConf
8
  from PIL import Image
9
  from torchvision import transforms
10
  from transformers import CLIPVisionModelWithProjection
@@ -94,33 +94,9 @@ class MusePoseInference:
94
  else:
95
  weight_dtype = torch.float32
96
 
97
- self.vae = AutoencoderKL.from_pretrained(
98
- self.image_gen_model_paths["pretrained_vae"],
99
- ).to("cuda", dtype=weight_dtype)
100
-
101
- self.reference_unet = UNet2DConditionModel.from_pretrained(
102
- self.image_gen_model_paths["pretrained_base_model"],
103
- subfolder="unet",
104
- ).to(dtype=weight_dtype, device="cuda")
105
-
106
  inference_config_path = self.inference_config_path
107
  infer_config = OmegaConf.load(inference_config_path)
108
 
109
- self.denoising_unet = UNet3DConditionModel.from_pretrained_2d(
110
- Path(self.image_gen_model_paths["pretrained_base_model"]),
111
- Path(self.musepose_model_paths["motion_module"]),
112
- subfolder="unet",
113
- unet_additional_kwargs=infer_config.unet_additional_kwargs,
114
- ).to(dtype=weight_dtype, device="cuda")
115
-
116
- self.pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
117
- dtype=weight_dtype, device="cuda"
118
- )
119
-
120
- self.image_enc = CLIPVisionModelWithProjection.from_pretrained(
121
- self.image_gen_model_paths["image_encoder"]
122
- ).to(dtype=weight_dtype, device="cuda")
123
-
124
  sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
125
  scheduler = DDIMScheduler(**sched_kwargs)
126
 
@@ -128,17 +104,8 @@ class MusePoseInference:
128
 
129
  width, height = W, H
130
 
131
- # load pretrained weights
132
- self.denoising_unet.load_state_dict(
133
- torch.load(self.musepose_model_paths["denoising_unet"], map_location="cpu"),
134
- strict=False,
135
- )
136
- self.reference_unet.load_state_dict(
137
- torch.load(self.musepose_model_paths["reference_unet"], map_location="cpu"),
138
- )
139
- self.pose_guider.load_state_dict(
140
- torch.load(self.musepose_model_paths["pose_guider"], map_location="cpu"),
141
- )
142
  self.pipe = Pose2VideoPipeline(
143
  vae=self.vae,
144
  image_encoder=self.image_enc,
@@ -225,6 +192,49 @@ class MusePoseInference:
225
  self.release_vram()
226
  return output_path, output_path_demo
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  def release_vram(self):
229
  models = [
230
  'vae', 'reference_unet', 'denoising_unet',
 
4
  import torch
5
  from diffusers import AutoencoderKL, DDIMScheduler
6
  from einops import repeat
7
+ from omegaconf import OmegaConf, DictConfig
8
  from PIL import Image
9
  from torchvision import transforms
10
  from transformers import CLIPVisionModelWithProjection
 
94
  else:
95
  weight_dtype = torch.float32
96
 
 
 
 
 
 
 
 
 
 
97
  inference_config_path = self.inference_config_path
98
  infer_config = OmegaConf.load(inference_config_path)
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
101
  scheduler = DDIMScheduler(**sched_kwargs)
102
 
 
104
 
105
  width, height = W, H
106
 
107
+ self.init_model(weight_dtype=weight_dtype, infer_config=infer_config)
108
+
 
 
 
 
 
 
 
 
 
109
  self.pipe = Pose2VideoPipeline(
110
  vae=self.vae,
111
  image_encoder=self.image_enc,
 
192
  self.release_vram()
193
  return output_path, output_path_demo
194
 
195
+ def init_model(self,
196
+ weight_dtype: torch.dtype,
197
+ infer_config: DictConfig
198
+ ):
199
+ if self.vae is None:
200
+ self.vae = AutoencoderKL.from_pretrained(
201
+ self.image_gen_model_paths["pretrained_vae"],
202
+ ).to("cuda", dtype=weight_dtype)
203
+
204
+ if self.reference_unet is None:
205
+ self.reference_unet = UNet2DConditionModel.from_pretrained(
206
+ self.image_gen_model_paths["pretrained_base_model"],
207
+ subfolder="unet",
208
+ ).to(dtype=weight_dtype, device="cuda")
209
+ self.reference_unet.load_state_dict(
210
+ torch.load(self.musepose_model_paths["reference_unet"], map_location="cpu"),
211
+ )
212
+
213
+ if self.denoising_unet is None:
214
+ self.denoising_unet = UNet3DConditionModel.from_pretrained_2d(
215
+ Path(self.image_gen_model_paths["pretrained_base_model"]),
216
+ Path(self.musepose_model_paths["motion_module"]),
217
+ subfolder="unet",
218
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
219
+ ).to(dtype=weight_dtype, device="cuda")
220
+ self.denoising_unet.load_state_dict(
221
+ torch.load(self.musepose_model_paths["denoising_unet"], map_location="cpu"),
222
+ strict=False,
223
+ )
224
+
225
+ if self.pose_guider is None:
226
+ self.pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
227
+ dtype=weight_dtype, device="cuda"
228
+ )
229
+ self.pose_guider.load_state_dict(
230
+ torch.load(self.musepose_model_paths["pose_guider"], map_location="cpu"),
231
+ )
232
+
233
+ if self.image_enc is None:
234
+ self.image_enc = CLIPVisionModelWithProjection.from_pretrained(
235
+ self.image_gen_model_paths["image_encoder"]
236
+ ).to(dtype=weight_dtype, device="cuda")
237
+
238
  def release_vram(self):
239
  models = [
240
  'vae', 'reference_unet', 'denoising_unet',