John6666 commited on
Commit
31934a1
1 Parent(s): 18c7f8b

Upload animate.py

Browse files
Files changed (1) hide show
  1. demo/animate.py +188 -194
demo/animate.py CHANGED
@@ -1,195 +1,189 @@
1
- # Copyright 2023 ByteDance and/or its affiliates.
2
- #
3
- # Copyright (2023) MagicAnimate Authors
4
- #
5
- # ByteDance, its affiliates and licensors retain all intellectual
6
- # property and proprietary rights in and to this material, related
7
- # documentation and any modifications thereto. Any use, reproduction,
8
- # disclosure or distribution of this material and related documentation
9
- # without an express license agreement from ByteDance or
10
- # its affiliates is strictly prohibited.
11
- import argparse
12
- import argparse
13
- import datetime
14
- import inspect
15
- import os
16
- import numpy as np
17
- from PIL import Image
18
- from omegaconf import OmegaConf
19
- from collections import OrderedDict
20
-
21
- import torch
22
-
23
- from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler
24
-
25
- from tqdm import tqdm
26
- from transformers import CLIPTextModel, CLIPTokenizer
27
-
28
- from magicanimate.models.unet_controlnet import UNet3DConditionModel
29
- from magicanimate.models.controlnet import ControlNetModel
30
- from magicanimate.models.appearance_encoder import AppearanceEncoderModel
31
- from magicanimate.models.mutual_self_attention import ReferenceAttentionControl
32
- from magicanimate.pipelines.pipeline_animation import AnimationPipeline
33
- from magicanimate.utils.util import save_videos_grid
34
- from accelerate.utils import set_seed
35
-
36
- from magicanimate.utils.videoreader import VideoReader
37
-
38
- from einops import rearrange, repeat
39
-
40
- import csv, pdb, glob
41
- from safetensors import safe_open
42
- import math
43
- from pathlib import Path
44
-
45
- class MagicAnimate():
46
- def __init__(self, config="configs/prompts/animation.yaml") -> None:
47
- print("Initializing MagicAnimate Pipeline...")
48
- *_, func_args = inspect.getargvalues(inspect.currentframe())
49
- func_args = dict(func_args)
50
-
51
- config = OmegaConf.load(config)
52
-
53
- inference_config = OmegaConf.load(config.inference_config)
54
-
55
- motion_module = config.motion_module
56
-
57
- ### >>> create animation pipeline >>> ###
58
- tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer")
59
- text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder")
60
- if config.pretrained_unet_path:
61
- unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
62
- else:
63
- unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
64
- self.appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").cuda()
65
- self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks)
66
- self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks)
67
- if config.pretrained_vae_path is not None:
68
- vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)
69
- else:
70
- vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae")
71
-
72
- ### Load controlnet
73
- controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)
74
-
75
- vae.to(torch.float16)
76
- unet.to(torch.float16)
77
- text_encoder.to(torch.float16)
78
- controlnet.to(torch.float16)
79
- self.appearance_encoder.to(torch.float16)
80
-
81
- unet.enable_xformers_memory_efficient_attention()
82
- self.appearance_encoder.enable_xformers_memory_efficient_attention()
83
- controlnet.enable_xformers_memory_efficient_attention()
84
-
85
- self.pipeline = AnimationPipeline(
86
- vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet,
87
- scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
88
- # NOTE: UniPCMultistepScheduler
89
- ).to("cuda")
90
-
91
- # 1. unet ckpt
92
- # 1.1 motion module
93
- motion_module_state_dict = torch.load(motion_module, map_location="cpu")
94
- if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
95
- motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict
96
- try:
97
- # extra steps for self-trained models
98
- state_dict = OrderedDict()
99
- for key in motion_module_state_dict.keys():
100
- if key.startswith("module."):
101
- _key = key.split("module.")[-1]
102
- state_dict[_key] = motion_module_state_dict[key]
103
- else:
104
- state_dict[key] = motion_module_state_dict[key]
105
- motion_module_state_dict = state_dict
106
- del state_dict
107
- missing, unexpected = self.pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
108
- assert len(unexpected) == 0
109
- except:
110
- _tmp_ = OrderedDict()
111
- for key in motion_module_state_dict.keys():
112
- if "motion_modules" in key:
113
- if key.startswith("unet."):
114
- _key = key.split('unet.')[-1]
115
- _tmp_[_key] = motion_module_state_dict[key]
116
- else:
117
- _tmp_[key] = motion_module_state_dict[key]
118
- missing, unexpected = unet.load_state_dict(_tmp_, strict=False)
119
- assert len(unexpected) == 0
120
- del _tmp_
121
- del motion_module_state_dict
122
-
123
- self.pipeline.to("cuda")
124
- self.L = config.L
125
-
126
- print("Initialization Done!")
127
-
128
- def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512):
129
- prompt = n_prompt = ""
130
- random_seed = int(random_seed)
131
- step = int(step)
132
- guidance_scale = float(guidance_scale)
133
- samples_per_video = []
134
- # manually set random seed for reproduction
135
- if random_seed != -1:
136
- torch.manual_seed(random_seed)
137
- set_seed(random_seed)
138
- else:
139
- torch.seed()
140
-
141
- if motion_sequence.endswith('.mp4'):
142
- control = VideoReader(motion_sequence).read()
143
- if control[0].shape[0] != size:
144
- control = [np.array(Image.fromarray(c).resize((size, size))) for c in control]
145
- control = np.array(control)
146
-
147
- if source_image.shape[0] != size:
148
- source_image = np.array(Image.fromarray(source_image).resize((size, size)))
149
- H, W, C = source_image.shape
150
-
151
- init_latents = None
152
- original_length = control.shape[0]
153
- if control.shape[0] % self.L > 0:
154
- control = np.pad(control, ((0, self.L-control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge')
155
- generator = torch.Generator(device=torch.device("cuda:0"))
156
- generator.manual_seed(torch.initial_seed())
157
- sample = self.pipeline(
158
- prompt,
159
- negative_prompt = n_prompt,
160
- num_inference_steps = step,
161
- guidance_scale = guidance_scale,
162
- width = W,
163
- height = H,
164
- video_length = len(control),
165
- controlnet_condition = control,
166
- init_latents = init_latents,
167
- generator = generator,
168
- appearance_encoder = self.appearance_encoder,
169
- reference_control_writer = self.reference_control_writer,
170
- reference_control_reader = self.reference_control_reader,
171
- source_image = source_image,
172
- ).videos
173
-
174
- source_images = np.array([source_image] * original_length)
175
- source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0
176
- samples_per_video.append(source_images)
177
-
178
- control = control / 255.0
179
- control = rearrange(control, "t h w c -> 1 c t h w")
180
- control = torch.from_numpy(control)
181
- samples_per_video.append(control[:, :, :original_length])
182
-
183
- samples_per_video.append(sample[:, :, :original_length])
184
-
185
- samples_per_video = torch.cat(samples_per_video)
186
-
187
- time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
188
- savedir = f"demo/outputs"
189
- animation_path = f"{savedir}/{time_str}.mp4"
190
-
191
- os.makedirs(savedir, exist_ok=True)
192
- save_videos_grid(samples_per_video, animation_path)
193
-
194
- return animation_path
195
 
 
1
+ # Copyright 2023 ByteDance and/or its affiliates.
2
+ #
3
+ # Copyright (2023) MagicAnimate Authors
4
+ #
5
+ # ByteDance, its affiliates and licensors retain all intellectual
6
+ # property and proprietary rights in and to this material, related
7
+ # documentation and any modifications thereto. Any use, reproduction,
8
+ # disclosure or distribution of this material and related documentation
9
+ # without an express license agreement from ByteDance or
10
+ # its affiliates is strictly prohibited.
11
+ import datetime
12
+ import inspect
13
+ import os
14
+ import numpy as np
15
+ from PIL import Image
16
+ from omegaconf import OmegaConf
17
+ from collections import OrderedDict
18
+
19
+ import torch
20
+
21
+ from diffusers import AutoencoderKL, DDIMScheduler
22
+
23
+ from transformers import CLIPTextModel, CLIPTokenizer
24
+
25
+ from magicanimate.models.unet_controlnet import UNet3DConditionModel
26
+ from magicanimate.models.controlnet import ControlNetModel
27
+ from magicanimate.models.appearance_encoder import AppearanceEncoderModel
28
+ from magicanimate.models.mutual_self_attention import ReferenceAttentionControl
29
+ from magicanimate.pipelines.pipeline_animation import AnimationPipeline
30
+ from magicanimate.utils.util import save_videos_grid
31
+ from accelerate.utils import set_seed
32
+
33
+ from magicanimate.utils.videoreader import VideoReader
34
+
35
+ from einops import rearrange
36
+
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+
39
+ class MagicAnimate():
40
+ def __init__(self, config="configs/prompts/animation.yaml") -> None:
41
+ print("Initializing MagicAnimate Pipeline...")
42
+ *_, func_args = inspect.getargvalues(inspect.currentframe())
43
+ func_args = dict(func_args)
44
+
45
+ config = OmegaConf.load(config)
46
+
47
+ inference_config = OmegaConf.load(config.inference_config)
48
+
49
+ motion_module = config.motion_module
50
+
51
+ ### >>> create animation pipeline >>> ###
52
+ tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer")
53
+ text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder")
54
+ if config.pretrained_unet_path:
55
+ unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
56
+ else:
57
+ unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
58
+ self.appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").to(device)
59
+ self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks)
60
+ self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks)
61
+ if config.pretrained_vae_path is not None:
62
+ vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)
63
+ else:
64
+ vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae")
65
+
66
+ ### Load controlnet
67
+ controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)
68
+
69
+ vae.to(torch.float16)
70
+ unet.to(torch.float16)
71
+ text_encoder.to(torch.float16)
72
+ controlnet.to(torch.float16)
73
+ self.appearance_encoder.to(torch.float16)
74
+
75
+ unet.enable_xformers_memory_efficient_attention()
76
+ self.appearance_encoder.enable_xformers_memory_efficient_attention()
77
+ controlnet.enable_xformers_memory_efficient_attention()
78
+
79
+ self.pipeline = AnimationPipeline(
80
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet,
81
+ scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
82
+ # NOTE: UniPCMultistepScheduler
83
+ ).to(device)
84
+
85
+ # 1. unet ckpt
86
+ # 1.1 motion module
87
+ motion_module_state_dict = torch.load(motion_module, map_location="cpu")
88
+ if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
89
+ motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict
90
+ try:
91
+ # extra steps for self-trained models
92
+ state_dict = OrderedDict()
93
+ for key in motion_module_state_dict.keys():
94
+ if key.startswith("module."):
95
+ _key = key.split("module.")[-1]
96
+ state_dict[_key] = motion_module_state_dict[key]
97
+ else:
98
+ state_dict[key] = motion_module_state_dict[key]
99
+ motion_module_state_dict = state_dict
100
+ del state_dict
101
+ missing, unexpected = self.pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
102
+ assert len(unexpected) == 0
103
+ except:
104
+ _tmp_ = OrderedDict()
105
+ for key in motion_module_state_dict.keys():
106
+ if "motion_modules" in key:
107
+ if key.startswith("unet."):
108
+ _key = key.split('unet.')[-1]
109
+ _tmp_[_key] = motion_module_state_dict[key]
110
+ else:
111
+ _tmp_[key] = motion_module_state_dict[key]
112
+ missing, unexpected = unet.load_state_dict(_tmp_, strict=False)
113
+ assert len(unexpected) == 0
114
+ del _tmp_
115
+ del motion_module_state_dict
116
+
117
+ self.pipeline.to(device)
118
+ self.L = config.L
119
+
120
+ print("Initialization Done!")
121
+
122
+ def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512):
123
+ prompt = n_prompt = ""
124
+ random_seed = int(random_seed)
125
+ step = int(step)
126
+ guidance_scale = float(guidance_scale)
127
+ samples_per_video = []
128
+ # manually set random seed for reproduction
129
+ if random_seed != -1:
130
+ torch.manual_seed(random_seed)
131
+ set_seed(random_seed)
132
+ else:
133
+ torch.seed()
134
+
135
+ if motion_sequence.endswith('.mp4'):
136
+ control = VideoReader(motion_sequence).read()
137
+ if control[0].shape[0] != size:
138
+ control = [np.array(Image.fromarray(c).resize((size, size))) for c in control]
139
+ control = np.array(control)
140
+
141
+ if source_image.shape[0] != size:
142
+ source_image = np.array(Image.fromarray(source_image).resize((size, size)))
143
+ H, W, C = source_image.shape
144
+
145
+ init_latents = None
146
+ original_length = control.shape[0]
147
+ if control.shape[0] % self.L > 0:
148
+ control = np.pad(control, ((0, self.L-control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge')
149
+ generator = torch.Generator(device=torch.device(device))
150
+ generator.manual_seed(torch.initial_seed())
151
+ sample = self.pipeline(
152
+ prompt,
153
+ negative_prompt = n_prompt,
154
+ num_inference_steps = step,
155
+ guidance_scale = guidance_scale,
156
+ width = W,
157
+ height = H,
158
+ video_length = len(control),
159
+ controlnet_condition = control,
160
+ init_latents = init_latents,
161
+ generator = generator,
162
+ appearance_encoder = self.appearance_encoder,
163
+ reference_control_writer = self.reference_control_writer,
164
+ reference_control_reader = self.reference_control_reader,
165
+ source_image = source_image,
166
+ ).videos
167
+
168
+ source_images = np.array([source_image] * original_length)
169
+ source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0
170
+ samples_per_video.append(source_images)
171
+
172
+ control = control / 255.0
173
+ control = rearrange(control, "t h w c -> 1 c t h w")
174
+ control = torch.from_numpy(control)
175
+ samples_per_video.append(control[:, :, :original_length])
176
+
177
+ samples_per_video.append(sample[:, :, :original_length])
178
+
179
+ samples_per_video = torch.cat(samples_per_video)
180
+
181
+ time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
182
+ savedir = f"demo/outputs"
183
+ animation_path = f"{savedir}/{time_str}.mp4"
184
+
185
+ os.makedirs(savedir, exist_ok=True)
186
+ save_videos_grid(samples_per_video, animation_path)
187
+
188
+ return animation_path
 
 
 
 
 
 
189