Spanicin commited on
Commit
e863356
·
verified ·
1 Parent(s): 677bc30

Upload 7 files

Browse files
scripts/app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script is a gradio web ui.
3
+
4
+ The script takes an image and an audio clip, and lets you configure all the
5
+ variables such as cfg_scale, pose_weight, face_weight, lip_weight, etc.
6
+
7
+ Usage:
8
+ This script can be run from the command line with the following command:
9
+
10
+ python scripts/app.py
11
+ """
12
+ import argparse
13
+
14
+ import gradio as gr
15
+ from inference import inference_process
16
+
17
+
18
+ def predict(image, audio, pose_weight, face_weight, lip_weight, face_expand_ratio, progress=gr.Progress(track_tqdm=True)):
19
+ """
20
+ Create a gradio interface with the configs.
21
+ """
22
+ _ = progress
23
+ config = {
24
+ 'source_image': image,
25
+ 'driving_audio': audio,
26
+ 'pose_weight': pose_weight,
27
+ 'face_weight': face_weight,
28
+ 'lip_weight': lip_weight,
29
+ 'face_expand_ratio': face_expand_ratio,
30
+ 'config': 'configs/inference/default.yaml',
31
+ 'checkpoint': None,
32
+ 'output': ".cache/output.mp4"
33
+ }
34
+ args = argparse.Namespace()
35
+ for key, value in config.items():
36
+ setattr(args, key, value)
37
+ return inference_process(args)
38
+
39
+ app = gr.Interface(
40
+ fn=predict,
41
+ inputs=[
42
+ gr.Image(label="source image (no webp)", type="filepath", format="jpeg"),
43
+ gr.Audio(label="source audio", type="filepath"),
44
+ gr.Number(label="pose weight", value=1.0),
45
+ gr.Number(label="face weight", value=1.0),
46
+ gr.Number(label="lip weight", value=1.0),
47
+ gr.Number(label="face expand ratio", value=1.2),
48
+ ],
49
+ outputs=[gr.Video()],
50
+ )
51
+ app.launch()
scripts/data_preprocess.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=W1203,W0718
2
+ """
3
+ This module is used to process videos to prepare data for training. It utilizes various libraries and models
4
+ to perform tasks such as video frame extraction, audio extraction, face mask generation, and face embedding extraction.
5
+ The script takes in command-line arguments to specify the input and output directories, GPU status, level of parallelism,
6
+ and rank for distributed processing.
7
+
8
+ Usage:
9
+ python -m scripts.data_preprocess --input_dir /path/to/video_dir --dataset_name dataset_name --gpu_status --parallelism 4 --rank 0
10
+
11
+ Example:
12
+ python -m scripts.data_preprocess -i data/videos -o data/output -g -p 4 -r 0
13
+ """
14
+ import argparse
15
+ import logging
16
+ import os
17
+ from pathlib import Path
18
+ from typing import List
19
+
20
+ import cv2
21
+ import torch
22
+ from tqdm import tqdm
23
+
24
+ from hallo.datasets.audio_processor import AudioProcessor
25
+ from hallo.datasets.image_processor import ImageProcessorForDataProcessing
26
+ from hallo.utils.util import convert_video_to_images, extract_audio_from_videos
27
+
28
+ # Configure logging
29
+ logging.basicConfig(level=logging.INFO,
30
+ format='%(asctime)s - %(levelname)s - %(message)s')
31
+
32
+
33
+ def setup_directories(video_path: Path) -> dict:
34
+ """
35
+ Setup directories for storing processed files.
36
+
37
+ Args:
38
+ video_path (Path): Path to the video file.
39
+
40
+ Returns:
41
+ dict: A dictionary containing paths for various directories.
42
+ """
43
+ base_dir = video_path.parent.parent
44
+ dirs = {
45
+ "face_mask": base_dir / "face_mask",
46
+ "sep_pose_mask": base_dir / "sep_pose_mask",
47
+ "sep_face_mask": base_dir / "sep_face_mask",
48
+ "sep_lip_mask": base_dir / "sep_lip_mask",
49
+ "face_emb": base_dir / "face_emb",
50
+ "audio_emb": base_dir / "audio_emb"
51
+ }
52
+
53
+ for path in dirs.values():
54
+ path.mkdir(parents=True, exist_ok=True)
55
+
56
+ return dirs
57
+
58
+
59
+ def process_single_video(video_path: Path,
60
+ output_dir: Path,
61
+ image_processor: ImageProcessorForDataProcessing,
62
+ audio_processor: AudioProcessor,
63
+ step: int) -> None:
64
+ """
65
+ Process a single video file.
66
+
67
+ Args:
68
+ video_path (Path): Path to the video file.
69
+ output_dir (Path): Directory to save the output.
70
+ image_processor (ImageProcessorForDataProcessing): Image processor object.
71
+ audio_processor (AudioProcessor): Audio processor object.
72
+ gpu_status (bool): Whether to use GPU for processing.
73
+ """
74
+ assert video_path.exists(), f"Video path {video_path} does not exist"
75
+ dirs = setup_directories(video_path)
76
+ logging.info(f"Processing video: {video_path}")
77
+
78
+ try:
79
+ if step == 1:
80
+ images_output_dir = output_dir / 'images' / video_path.stem
81
+ images_output_dir.mkdir(parents=True, exist_ok=True)
82
+ images_output_dir = convert_video_to_images(
83
+ video_path, images_output_dir)
84
+ logging.info(f"Images saved to: {images_output_dir}")
85
+
86
+ audio_output_dir = output_dir / 'audios'
87
+ audio_output_dir.mkdir(parents=True, exist_ok=True)
88
+ audio_output_path = audio_output_dir / f'{video_path.stem}.wav'
89
+ audio_output_path = extract_audio_from_videos(
90
+ video_path, audio_output_path)
91
+ logging.info(f"Audio extracted to: {audio_output_path}")
92
+
93
+ face_mask, _, sep_pose_mask, sep_face_mask, sep_lip_mask = image_processor.preprocess(
94
+ images_output_dir)
95
+ cv2.imwrite(
96
+ str(dirs["face_mask"] / f"{video_path.stem}.png"), face_mask)
97
+ cv2.imwrite(str(dirs["sep_pose_mask"] /
98
+ f"{video_path.stem}.png"), sep_pose_mask)
99
+ cv2.imwrite(str(dirs["sep_face_mask"] /
100
+ f"{video_path.stem}.png"), sep_face_mask)
101
+ cv2.imwrite(str(dirs["sep_lip_mask"] /
102
+ f"{video_path.stem}.png"), sep_lip_mask)
103
+ else:
104
+ images_dir = output_dir / "images" / video_path.stem
105
+ audio_path = output_dir / "audios" / f"{video_path.stem}.wav"
106
+ _, face_emb, _, _, _ = image_processor.preprocess(images_dir)
107
+ torch.save(face_emb, str(
108
+ dirs["face_emb"] / f"{video_path.stem}.pt"))
109
+ audio_emb, _ = audio_processor.preprocess(audio_path)
110
+ torch.save(audio_emb, str(
111
+ dirs["audio_emb"] / f"{video_path.stem}.pt"))
112
+ except Exception as e:
113
+ logging.error(f"Failed to process video {video_path}: {e}")
114
+
115
+
116
+ def process_all_videos(input_video_list: List[Path], output_dir: Path, step: int) -> None:
117
+ """
118
+ Process all videos in the input list.
119
+
120
+ Args:
121
+ input_video_list (List[Path]): List of video paths to process.
122
+ output_dir (Path): Directory to save the output.
123
+ gpu_status (bool): Whether to use GPU for processing.
124
+ """
125
+ face_analysis_model_path = "pretrained_models/face_analysis"
126
+ landmark_model_path = "pretrained_models/face_analysis/models/face_landmarker_v2_with_blendshapes.task"
127
+ audio_separator_model_file = "pretrained_models/audio_separator/Kim_Vocal_2.onnx"
128
+ wav2vec_model_path = 'pretrained_models/wav2vec/wav2vec2-base-960h'
129
+
130
+ audio_processor = AudioProcessor(
131
+ 16000,
132
+ 25,
133
+ wav2vec_model_path,
134
+ False,
135
+ os.path.dirname(audio_separator_model_file),
136
+ os.path.basename(audio_separator_model_file),
137
+ os.path.join(output_dir, "vocals"),
138
+ ) if step==2 else None
139
+
140
+ image_processor = ImageProcessorForDataProcessing(
141
+ face_analysis_model_path, landmark_model_path, step)
142
+
143
+ for video_path in tqdm(input_video_list, desc="Processing videos"):
144
+ process_single_video(video_path, output_dir,
145
+ image_processor, audio_processor, step)
146
+
147
+
148
+ def get_video_paths(source_dir: Path, parallelism: int, rank: int) -> List[Path]:
149
+ """
150
+ Get paths of videos to process, partitioned for parallel processing.
151
+
152
+ Args:
153
+ source_dir (Path): Source directory containing videos.
154
+ parallelism (int): Level of parallelism.
155
+ rank (int): Rank for distributed processing.
156
+
157
+ Returns:
158
+ List[Path]: List of video paths to process.
159
+ """
160
+ video_paths = [item for item in sorted(
161
+ source_dir.iterdir()) if item.is_file() and item.suffix == '.mp4']
162
+ return [video_paths[i] for i in range(len(video_paths)) if i % parallelism == rank]
163
+
164
+
165
+ if __name__ == "__main__":
166
+ parser = argparse.ArgumentParser(
167
+ description="Process videos to prepare data for training. Run this script twice with different GPU status parameters."
168
+ )
169
+ parser.add_argument("-i", "--input_dir", type=Path,
170
+ required=True, help="Directory containing videos")
171
+ parser.add_argument("-o", "--output_dir", type=Path,
172
+ help="Directory to save results, default is parent dir of input dir")
173
+ parser.add_argument("-s", "--step", type=int, default=1,
174
+ help="Specify data processing step 1 or 2, you should run 1 and 2 sequently")
175
+ parser.add_argument("-p", "--parallelism", default=1,
176
+ type=int, help="Level of parallelism")
177
+ parser.add_argument("-r", "--rank", default=0, type=int,
178
+ help="Rank for distributed processing")
179
+
180
+ args = parser.parse_args()
181
+
182
+ if args.output_dir is None:
183
+ args.output_dir = args.input_dir.parent
184
+
185
+ video_path_list = get_video_paths(
186
+ args.input_dir, args.parallelism, args.rank)
187
+
188
+ if not video_path_list:
189
+ logging.warning("No videos to process.")
190
+ else:
191
+ process_all_videos(video_path_list, args.output_dir, args.step)
scripts/extract_meta_info_stage1.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=R0801
2
+ """
3
+ This module is used to extract meta information from video directories.
4
+
5
+ It takes in two command-line arguments: `root_path` and `dataset_name`. The `root_path`
6
+ specifies the path to the video directory, while the `dataset_name` specifies the name
7
+ of the dataset. The module then collects all the video folder paths, and for each video
8
+ folder, it checks if a mask path and a face embedding path exist. If they do, it appends
9
+ a dictionary containing the image path, mask path, and face embedding path to a list.
10
+
11
+ Finally, the module writes the list of dictionaries to a JSON file with the filename
12
+ constructed using the `dataset_name`.
13
+
14
+ Usage:
15
+ python tools/extract_meta_info_stage1.py --root_path /path/to/video_dir --dataset_name hdtf
16
+
17
+ """
18
+
19
+ import argparse
20
+ import json
21
+ import os
22
+ from pathlib import Path
23
+
24
+ import torch
25
+
26
+
27
+ def collect_video_folder_paths(root_path: Path) -> list:
28
+ """
29
+ Collect all video folder paths from the root path.
30
+
31
+ Args:
32
+ root_path (Path): The root directory containing video folders.
33
+
34
+ Returns:
35
+ list: List of video folder paths.
36
+ """
37
+ return [frames_dir.resolve() for frames_dir in root_path.iterdir() if frames_dir.is_dir()]
38
+
39
+
40
+ def construct_meta_info(frames_dir_path: Path) -> dict:
41
+ """
42
+ Construct meta information for a given frames directory.
43
+
44
+ Args:
45
+ frames_dir_path (Path): The path to the frames directory.
46
+
47
+ Returns:
48
+ dict: A dictionary containing the meta information for the frames directory, or None if the required files do not exist.
49
+ """
50
+ mask_path = str(frames_dir_path).replace("images", "face_mask") + ".png"
51
+ face_emb_path = str(frames_dir_path).replace("images", "face_emb") + ".pt"
52
+
53
+ if not os.path.exists(mask_path):
54
+ print(f"Mask path not found: {mask_path}")
55
+ return None
56
+
57
+ if torch.load(face_emb_path) is None:
58
+ print(f"Face emb is None: {face_emb_path}")
59
+ return None
60
+
61
+ return {
62
+ "image_path": str(frames_dir_path),
63
+ "mask_path": mask_path,
64
+ "face_emb": face_emb_path,
65
+ }
66
+
67
+
68
+ def main():
69
+ """
70
+ Main function to extract meta info for training.
71
+ """
72
+ parser = argparse.ArgumentParser()
73
+ parser.add_argument("-r", "--root_path", type=str,
74
+ required=True, help="Root path of the video directories")
75
+ parser.add_argument("-n", "--dataset_name", type=str,
76
+ required=True, help="Name of the dataset")
77
+ parser.add_argument("--meta_info_name", type=str,
78
+ help="Name of the meta information file")
79
+
80
+ args = parser.parse_args()
81
+
82
+ if args.meta_info_name is None:
83
+ args.meta_info_name = args.dataset_name
84
+
85
+ image_dir = Path(args.root_path) / "images"
86
+ output_dir = Path("./data")
87
+ output_dir.mkdir(exist_ok=True)
88
+
89
+ # Collect all video folder paths
90
+ frames_dir_paths = collect_video_folder_paths(image_dir)
91
+
92
+ meta_infos = []
93
+ for frames_dir_path in frames_dir_paths:
94
+ meta_info = construct_meta_info(frames_dir_path)
95
+ if meta_info:
96
+ meta_infos.append(meta_info)
97
+
98
+ output_file = output_dir / f"{args.meta_info_name}_stage1.json"
99
+ with output_file.open("w", encoding="utf-8") as f:
100
+ json.dump(meta_infos, f, indent=4)
101
+
102
+ print(f"Final data count: {len(meta_infos)}")
103
+
104
+
105
+ if __name__ == "__main__":
106
+ main()
scripts/extract_meta_info_stage2.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=R0801
2
+ """
3
+ This module is used to extract meta information from video files and store them in a JSON file.
4
+
5
+ The script takes in command line arguments to specify the root path of the video files,
6
+ the dataset name, and the name of the meta information file. It then generates a list of
7
+ dictionaries containing the meta information for each video file and writes it to a JSON
8
+ file with the specified name.
9
+
10
+ The meta information includes the path to the video file, the mask path, the face mask
11
+ path, the face mask union path, the face mask gaussian path, the lip mask path, the lip
12
+ mask union path, the lip mask gaussian path, the separate mask border, the separate mask
13
+ face, the separate mask lip, the face embedding path, the audio path, the vocals embedding
14
+ base last path, the vocals embedding base all path, the vocals embedding base average
15
+ path, the vocals embedding large last path, the vocals embedding large all path, and the
16
+ vocals embedding large average path.
17
+
18
+ The script checks if the mask path exists before adding the information to the list.
19
+
20
+ Usage:
21
+ python tools/extract_meta_info_stage2.py --root_path <root_path> --dataset_name <dataset_name> --meta_info_name <meta_info_name>
22
+
23
+ Example:
24
+ python tools/extract_meta_info_stage2.py --root_path data/videos_25fps --dataset_name my_dataset --meta_info_name my_meta_info
25
+ """
26
+
27
+ import argparse
28
+ import json
29
+ import os
30
+ from pathlib import Path
31
+
32
+ import torch
33
+ from decord import VideoReader, cpu
34
+ from tqdm import tqdm
35
+
36
+
37
+ def get_video_paths(root_path: Path, extensions: list) -> list:
38
+ """
39
+ Get a list of video paths from the root path with the specified extensions.
40
+
41
+ Args:
42
+ root_path (Path): The root directory containing video files.
43
+ extensions (list): List of file extensions to include.
44
+
45
+ Returns:
46
+ list: List of video file paths.
47
+ """
48
+ return [str(path.resolve()) for path in root_path.iterdir() if path.suffix in extensions]
49
+
50
+
51
+ def file_exists(file_path: str) -> bool:
52
+ """
53
+ Check if a file exists.
54
+
55
+ Args:
56
+ file_path (str): The path to the file.
57
+
58
+ Returns:
59
+ bool: True if the file exists, False otherwise.
60
+ """
61
+ return os.path.exists(file_path)
62
+
63
+
64
+ def construct_paths(video_path: str, base_dir: str, new_dir: str, new_ext: str) -> str:
65
+ """
66
+ Construct a new path by replacing the base directory and extension in the original path.
67
+
68
+ Args:
69
+ video_path (str): The original video path.
70
+ base_dir (str): The base directory to be replaced.
71
+ new_dir (str): The new directory to replace the base directory.
72
+ new_ext (str): The new file extension.
73
+
74
+ Returns:
75
+ str: The constructed path.
76
+ """
77
+ return str(video_path).replace(base_dir, new_dir).replace(".mp4", new_ext)
78
+
79
+
80
+ def extract_meta_info(video_path: str) -> dict:
81
+ """
82
+ Extract meta information for a given video file.
83
+
84
+ Args:
85
+ video_path (str): The path to the video file.
86
+
87
+ Returns:
88
+ dict: A dictionary containing the meta information for the video.
89
+ """
90
+ mask_path = construct_paths(
91
+ video_path, "videos", "face_mask", ".png")
92
+ sep_mask_border = construct_paths(
93
+ video_path, "videos", "sep_pose_mask", ".png")
94
+ sep_mask_face = construct_paths(
95
+ video_path, "videos", "sep_face_mask", ".png")
96
+ sep_mask_lip = construct_paths(
97
+ video_path, "videos", "sep_lip_mask", ".png")
98
+ face_emb_path = construct_paths(
99
+ video_path, "videos", "face_emb", ".pt")
100
+ audio_path = construct_paths(video_path, "videos", "audios", ".wav")
101
+ vocal_emb_base_all = construct_paths(
102
+ video_path, "videos", "audio_emb", ".pt")
103
+
104
+ assert_flag = True
105
+
106
+ if not file_exists(mask_path):
107
+ print(f"Mask path not found: {mask_path}")
108
+ assert_flag = False
109
+ if not file_exists(sep_mask_border):
110
+ print(f"Separate mask border not found: {sep_mask_border}")
111
+ assert_flag = False
112
+ if not file_exists(sep_mask_face):
113
+ print(f"Separate mask face not found: {sep_mask_face}")
114
+ assert_flag = False
115
+ if not file_exists(sep_mask_lip):
116
+ print(f"Separate mask lip not found: {sep_mask_lip}")
117
+ assert_flag = False
118
+ if not file_exists(face_emb_path):
119
+ print(f"Face embedding path not found: {face_emb_path}")
120
+ assert_flag = False
121
+ if not file_exists(audio_path):
122
+ print(f"Audio path not found: {audio_path}")
123
+ assert_flag = False
124
+ if not file_exists(vocal_emb_base_all):
125
+ print(f"Vocal embedding base all not found: {vocal_emb_base_all}")
126
+ assert_flag = False
127
+
128
+ video_frames = VideoReader(video_path, ctx=cpu(0))
129
+ audio_emb = torch.load(vocal_emb_base_all)
130
+ if abs(len(video_frames) - audio_emb.shape[0]) > 3:
131
+ print(f"Frame count mismatch for video: {video_path}")
132
+ assert_flag = False
133
+
134
+ face_emb = torch.load(face_emb_path)
135
+ if face_emb is None:
136
+ print(f"Face embedding is None for video: {video_path}")
137
+ assert_flag = False
138
+
139
+ del video_frames, audio_emb
140
+
141
+ if assert_flag:
142
+ return {
143
+ "video_path": str(video_path),
144
+ "mask_path": mask_path,
145
+ "sep_mask_border": sep_mask_border,
146
+ "sep_mask_face": sep_mask_face,
147
+ "sep_mask_lip": sep_mask_lip,
148
+ "face_emb_path": face_emb_path,
149
+ "audio_path": audio_path,
150
+ "vocals_emb_base_all": vocal_emb_base_all,
151
+ }
152
+ return None
153
+
154
+
155
+ def main():
156
+ """
157
+ Main function to extract meta info for training.
158
+ """
159
+ parser = argparse.ArgumentParser()
160
+ parser.add_argument("-r", "--root_path", type=str,
161
+ required=True, help="Root path of the video files")
162
+ parser.add_argument("-n", "--dataset_name", type=str,
163
+ required=True, help="Name of the dataset")
164
+ parser.add_argument("--meta_info_name", type=str,
165
+ help="Name of the meta information file")
166
+
167
+ args = parser.parse_args()
168
+
169
+ if args.meta_info_name is None:
170
+ args.meta_info_name = args.dataset_name
171
+
172
+ video_dir = Path(args.root_path) / "videos"
173
+ video_paths = get_video_paths(video_dir, [".mp4"])
174
+
175
+ meta_infos = []
176
+
177
+ for video_path in tqdm(video_paths, desc="Extracting meta info"):
178
+ meta_info = extract_meta_info(video_path)
179
+ if meta_info:
180
+ meta_infos.append(meta_info)
181
+
182
+ print(f"Final data count: {len(meta_infos)}")
183
+
184
+ output_file = Path(f"./data/{args.meta_info_name}_stage2.json")
185
+ output_file.parent.mkdir(parents=True, exist_ok=True)
186
+
187
+ with output_file.open("w", encoding="utf-8") as f:
188
+ json.dump(meta_infos, f, indent=4)
189
+
190
+
191
+ if __name__ == "__main__":
192
+ main()
scripts/inference.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=E1101
2
+ # scripts/inference.py
3
+
4
+ """
5
+ This script contains the main inference pipeline for processing audio and image inputs to generate a video output.
6
+
7
+ The script imports necessary packages and classes, defines a neural network model,
8
+ and contains functions for processing audio embeddings and performing inference.
9
+
10
+ The main inference process is outlined in the following steps:
11
+ 1. Initialize the configuration.
12
+ 2. Set up runtime variables.
13
+ 3. Prepare the input data for inference (source image, face mask, and face embeddings).
14
+ 4. Process the audio embeddings.
15
+ 5. Build and freeze the model and scheduler.
16
+ 6. Run the inference loop and save the result.
17
+
18
+ Usage:
19
+ This script can be run from the command line with the following arguments:
20
+ - audio_path: Path to the audio file.
21
+ - image_path: Path to the source image.
22
+ - face_mask_path: Path to the face mask image.
23
+ - face_emb_path: Path to the face embeddings file.
24
+ - output_path: Path to save the output video.
25
+
26
+ Example:
27
+ python scripts/inference.py --audio_path audio.wav --image_path image.jpg
28
+ --face_mask_path face_mask.png --face_emb_path face_emb.pt --output_path output.mp4
29
+ """
30
+
31
+ import argparse
32
+ import os
33
+
34
+ import torch
35
+ from diffusers import AutoencoderKL, DDIMScheduler
36
+ from omegaconf import OmegaConf
37
+ from torch import nn
38
+
39
+ from hallo.animate.face_animate import FaceAnimatePipeline
40
+ from hallo.datasets.audio_processor import AudioProcessor
41
+ from hallo.datasets.image_processor import ImageProcessor
42
+ from hallo.models.audio_proj import AudioProjModel
43
+ from hallo.models.face_locator import FaceLocator
44
+ from hallo.models.image_proj import ImageProjModel
45
+ from hallo.models.unet_2d_condition import UNet2DConditionModel
46
+ from hallo.models.unet_3d import UNet3DConditionModel
47
+ from hallo.utils.config import filter_non_none
48
+ from hallo.utils.util import tensor_to_video
49
+
50
+
51
+ class Net(nn.Module):
52
+ """
53
+ The Net class combines all the necessary modules for the inference process.
54
+
55
+ Args:
56
+ reference_unet (UNet2DConditionModel): The UNet2DConditionModel used as a reference for inference.
57
+ denoising_unet (UNet3DConditionModel): The UNet3DConditionModel used for denoising the input audio.
58
+ face_locator (FaceLocator): The FaceLocator model used to locate the face in the input image.
59
+ imageproj (nn.Module): The ImageProjector model used to project the source image onto the face.
60
+ audioproj (nn.Module): The AudioProjector model used to project the audio embeddings onto the face.
61
+ """
62
+ def __init__(
63
+ self,
64
+ reference_unet: UNet2DConditionModel,
65
+ denoising_unet: UNet3DConditionModel,
66
+ face_locator: FaceLocator,
67
+ imageproj,
68
+ audioproj,
69
+ ):
70
+ super().__init__()
71
+ self.reference_unet = reference_unet
72
+ self.denoising_unet = denoising_unet
73
+ self.face_locator = face_locator
74
+ self.imageproj = imageproj
75
+ self.audioproj = audioproj
76
+
77
+ def forward(self,):
78
+ """
79
+ empty function to override abstract function of nn Module
80
+ """
81
+
82
+ def get_modules(self):
83
+ """
84
+ Simple method to avoid too-few-public-methods pylint error
85
+ """
86
+ return {
87
+ "reference_unet": self.reference_unet,
88
+ "denoising_unet": self.denoising_unet,
89
+ "face_locator": self.face_locator,
90
+ "imageproj": self.imageproj,
91
+ "audioproj": self.audioproj,
92
+ }
93
+
94
+
95
+ def process_audio_emb(audio_emb):
96
+ """
97
+ Process the audio embedding to concatenate with other tensors.
98
+
99
+ Parameters:
100
+ audio_emb (torch.Tensor): The audio embedding tensor to process.
101
+
102
+ Returns:
103
+ concatenated_tensors (List[torch.Tensor]): The concatenated tensor list.
104
+ """
105
+ concatenated_tensors = []
106
+
107
+ for i in range(audio_emb.shape[0]):
108
+ vectors_to_concat = [
109
+ audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)]
110
+ concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
111
+
112
+ audio_emb = torch.stack(concatenated_tensors, dim=0)
113
+
114
+ return audio_emb
115
+
116
+
117
+
118
+ def inference_process(args: argparse.Namespace):
119
+ """
120
+ Perform inference processing.
121
+
122
+ Args:
123
+ args (argparse.Namespace): Command-line arguments.
124
+
125
+ This function initializes the configuration for the inference process. It sets up the necessary
126
+ modules and variables to prepare for the upcoming inference steps.
127
+ """
128
+ # 1. init config
129
+ cli_args = filter_non_none(vars(args))
130
+ config = OmegaConf.load(args.config)
131
+ config = OmegaConf.merge(config, cli_args)
132
+ source_image_path = config.source_image
133
+ driving_audio_path = config.driving_audio
134
+ save_path = config.save_path
135
+ if not os.path.exists(save_path):
136
+ os.makedirs(save_path)
137
+ motion_scale = [config.pose_weight, config.face_weight, config.lip_weight]
138
+
139
+ # 2. runtime variables
140
+ device = torch.device(
141
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
142
+ if config.weight_dtype == "fp16":
143
+ weight_dtype = torch.float16
144
+ elif config.weight_dtype == "bf16":
145
+ weight_dtype = torch.bfloat16
146
+ elif config.weight_dtype == "fp32":
147
+ weight_dtype = torch.float32
148
+ else:
149
+ weight_dtype = torch.float32
150
+
151
+ # 3. prepare inference data
152
+ # 3.1 prepare source image, face mask, face embeddings
153
+ img_size = (config.data.source_image.width,
154
+ config.data.source_image.height)
155
+ clip_length = config.data.n_sample_frames
156
+ face_analysis_model_path = config.face_analysis.model_path
157
+ with ImageProcessor(img_size, face_analysis_model_path) as image_processor:
158
+ source_image_pixels, \
159
+ source_image_face_region, \
160
+ source_image_face_emb, \
161
+ source_image_full_mask, \
162
+ source_image_face_mask, \
163
+ source_image_lip_mask = image_processor.preprocess(
164
+ source_image_path, save_path, config.face_expand_ratio)
165
+
166
+ # 3.2 prepare audio embeddings
167
+ sample_rate = config.data.driving_audio.sample_rate
168
+ assert sample_rate == 16000, "audio sample rate must be 16000"
169
+ fps = config.data.export_video.fps
170
+ wav2vec_model_path = config.wav2vec.model_path
171
+ wav2vec_only_last_features = config.wav2vec.features == "last"
172
+ audio_separator_model_file = config.audio_separator.model_path
173
+ with AudioProcessor(
174
+ sample_rate,
175
+ fps,
176
+ wav2vec_model_path,
177
+ wav2vec_only_last_features,
178
+ os.path.dirname(audio_separator_model_file),
179
+ os.path.basename(audio_separator_model_file),
180
+ os.path.join(save_path, "audio_preprocess")
181
+ ) as audio_processor:
182
+ audio_emb, audio_length = audio_processor.preprocess(driving_audio_path, clip_length)
183
+
184
+ # 4. build modules
185
+ sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs)
186
+ if config.enable_zero_snr:
187
+ sched_kwargs.update(
188
+ rescale_betas_zero_snr=True,
189
+ timestep_spacing="trailing",
190
+ prediction_type="v_prediction",
191
+ )
192
+ val_noise_scheduler = DDIMScheduler(**sched_kwargs)
193
+ sched_kwargs.update({"beta_schedule": "scaled_linear"})
194
+
195
+ vae = AutoencoderKL.from_pretrained(config.vae.model_path)
196
+ reference_unet = UNet2DConditionModel.from_pretrained(
197
+ config.base_model_path, subfolder="unet")
198
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
199
+ config.base_model_path,
200
+ config.motion_module_path,
201
+ subfolder="unet",
202
+ unet_additional_kwargs=OmegaConf.to_container(
203
+ config.unet_additional_kwargs),
204
+ use_landmark=False,
205
+ )
206
+ face_locator = FaceLocator(conditioning_embedding_channels=320)
207
+ image_proj = ImageProjModel(
208
+ cross_attention_dim=denoising_unet.config.cross_attention_dim,
209
+ clip_embeddings_dim=512,
210
+ clip_extra_context_tokens=4,
211
+ )
212
+
213
+ audio_proj = AudioProjModel(
214
+ seq_len=5,
215
+ blocks=12, # use 12 layers' hidden states of wav2vec
216
+ channels=768, # audio embedding channel
217
+ intermediate_dim=512,
218
+ output_dim=768,
219
+ context_tokens=32,
220
+ ).to(device=device, dtype=weight_dtype)
221
+
222
+ audio_ckpt_dir = config.audio_ckpt_dir
223
+
224
+
225
+ # Freeze
226
+ vae.requires_grad_(False)
227
+ image_proj.requires_grad_(False)
228
+ reference_unet.requires_grad_(False)
229
+ denoising_unet.requires_grad_(False)
230
+ face_locator.requires_grad_(False)
231
+ audio_proj.requires_grad_(False)
232
+
233
+ reference_unet.enable_gradient_checkpointing()
234
+ denoising_unet.enable_gradient_checkpointing()
235
+
236
+ net = Net(
237
+ reference_unet,
238
+ denoising_unet,
239
+ face_locator,
240
+ image_proj,
241
+ audio_proj,
242
+ )
243
+
244
+ m,u = net.load_state_dict(
245
+ torch.load(
246
+ os.path.join(audio_ckpt_dir, "net.pth"),
247
+ map_location="cpu",
248
+ ),
249
+ )
250
+ assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint."
251
+ print("loaded weight from ", os.path.join(audio_ckpt_dir, "net.pth"))
252
+
253
+ # 5. inference
254
+ pipeline = FaceAnimatePipeline(
255
+ vae=vae,
256
+ reference_unet=net.reference_unet,
257
+ denoising_unet=net.denoising_unet,
258
+ face_locator=net.face_locator,
259
+ scheduler=val_noise_scheduler,
260
+ image_proj=net.imageproj,
261
+ )
262
+ pipeline.to(device=device, dtype=weight_dtype)
263
+
264
+ audio_emb = process_audio_emb(audio_emb)
265
+
266
+ source_image_pixels = source_image_pixels.unsqueeze(0)
267
+ source_image_face_region = source_image_face_region.unsqueeze(0)
268
+ source_image_face_emb = source_image_face_emb.reshape(1, -1)
269
+ source_image_face_emb = torch.tensor(source_image_face_emb)
270
+
271
+ source_image_full_mask = [
272
+ (mask.repeat(clip_length, 1))
273
+ for mask in source_image_full_mask
274
+ ]
275
+ source_image_face_mask = [
276
+ (mask.repeat(clip_length, 1))
277
+ for mask in source_image_face_mask
278
+ ]
279
+ source_image_lip_mask = [
280
+ (mask.repeat(clip_length, 1))
281
+ for mask in source_image_lip_mask
282
+ ]
283
+
284
+
285
+ times = audio_emb.shape[0] // clip_length
286
+
287
+ tensor_result = []
288
+
289
+ generator = torch.manual_seed(42)
290
+
291
+ for t in range(times):
292
+ print(f"[{t+1}/{times}]")
293
+
294
+ if len(tensor_result) == 0:
295
+ # The first iteration
296
+ motion_zeros = source_image_pixels.repeat(
297
+ config.data.n_motion_frames, 1, 1, 1)
298
+ motion_zeros = motion_zeros.to(
299
+ dtype=source_image_pixels.dtype, device=source_image_pixels.device)
300
+ pixel_values_ref_img = torch.cat(
301
+ [source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames
302
+ else:
303
+ motion_frames = tensor_result[-1][0]
304
+ motion_frames = motion_frames.permute(1, 0, 2, 3)
305
+ motion_frames = motion_frames[0-config.data.n_motion_frames:]
306
+ motion_frames = motion_frames * 2.0 - 1.0
307
+ motion_frames = motion_frames.to(
308
+ dtype=source_image_pixels.dtype, device=source_image_pixels.device)
309
+ pixel_values_ref_img = torch.cat(
310
+ [source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames
311
+
312
+ pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
313
+
314
+ audio_tensor = audio_emb[
315
+ t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
316
+ ]
317
+ audio_tensor = audio_tensor.unsqueeze(0)
318
+ audio_tensor = audio_tensor.to(
319
+ device=net.audioproj.device, dtype=net.audioproj.dtype)
320
+ audio_tensor = net.audioproj(audio_tensor)
321
+
322
+ pipeline_output = pipeline(
323
+ ref_image=pixel_values_ref_img,
324
+ audio_tensor=audio_tensor,
325
+ face_emb=source_image_face_emb,
326
+ face_mask=source_image_face_region,
327
+ pixel_values_full_mask=source_image_full_mask,
328
+ pixel_values_face_mask=source_image_face_mask,
329
+ pixel_values_lip_mask=source_image_lip_mask,
330
+ width=img_size[0],
331
+ height=img_size[1],
332
+ video_length=clip_length,
333
+ num_inference_steps=config.inference_steps,
334
+ guidance_scale=config.cfg_scale,
335
+ generator=generator,
336
+ motion_scale=motion_scale,
337
+ )
338
+
339
+ tensor_result.append(pipeline_output.videos)
340
+
341
+ tensor_result = torch.cat(tensor_result, dim=2)
342
+ tensor_result = tensor_result.squeeze(0)
343
+ tensor_result = tensor_result[:, :audio_length]
344
+
345
+ output_file = config.output
346
+ # save the result after all iteration
347
+ tensor_to_video(tensor_result, output_file, driving_audio_path)
348
+ return output_file
349
+
350
+
351
+ if __name__ == "__main__":
352
+ parser = argparse.ArgumentParser()
353
+
354
+ parser.add_argument(
355
+ "-c", "--config", default="configs/inference/default.yaml")
356
+ parser.add_argument("--source_image", type=str, required=False,
357
+ help="source image")
358
+ parser.add_argument("--driving_audio", type=str, required=False,
359
+ help="driving audio")
360
+ parser.add_argument(
361
+ "--output", type=str, help="output video file name", default=".cache/output.mp4")
362
+ parser.add_argument(
363
+ "--pose_weight", type=float, help="weight of pose", required=False)
364
+ parser.add_argument(
365
+ "--face_weight", type=float, help="weight of face", required=False)
366
+ parser.add_argument(
367
+ "--lip_weight", type=float, help="weight of lip", required=False)
368
+ parser.add_argument(
369
+ "--face_expand_ratio", type=float, help="face region", required=False)
370
+ parser.add_argument(
371
+ "--audio_ckpt_dir", "--checkpoint", type=str, help="specific checkpoint dir", required=False)
372
+
373
+
374
+ command_line_args = parser.parse_args()
375
+
376
+ inference_process(command_line_args)
scripts/train_stage1.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=E1101,C0415,W0718,R0801
2
+ # scripts/train_stage1.py
3
+ """
4
+ This is the main training script for stage 1 of the project.
5
+ It imports necessary packages, defines necessary classes and functions, and trains the model using the provided configuration.
6
+
7
+ The script includes the following classes and functions:
8
+
9
+ 1. Net: A PyTorch model that takes noisy latents, timesteps, reference image latents, face embeddings,
10
+ and face masks as input and returns the denoised latents.
11
+ 3. log_validation: A function that logs the validation information using the given VAE, image encoder,
12
+ network, scheduler, accelerator, width, height, and configuration.
13
+ 4. train_stage1_process: A function that processes the training stage 1 using the given configuration.
14
+
15
+ The script also includes the necessary imports and a brief description of the purpose of the file.
16
+ """
17
+
18
+ import argparse
19
+ import copy
20
+ import logging
21
+ import math
22
+ import os
23
+ import random
24
+ import warnings
25
+ from datetime import datetime
26
+
27
+ import cv2
28
+ import diffusers
29
+ import mlflow
30
+ import numpy as np
31
+ import torch
32
+ import torch.nn.functional as F
33
+ import torch.utils.checkpoint
34
+ import transformers
35
+ from accelerate import Accelerator
36
+ from accelerate.logging import get_logger
37
+ from accelerate.utils import DistributedDataParallelKwargs
38
+ from diffusers import AutoencoderKL, DDIMScheduler
39
+ from diffusers.optimization import get_scheduler
40
+ from diffusers.utils import check_min_version
41
+ from diffusers.utils.import_utils import is_xformers_available
42
+ from insightface.app import FaceAnalysis
43
+ from omegaconf import OmegaConf
44
+ from PIL import Image
45
+ from torch import nn
46
+ from tqdm.auto import tqdm
47
+
48
+ from hallo.animate.face_animate_static import StaticPipeline
49
+ from hallo.datasets.mask_image import FaceMaskDataset
50
+ from hallo.models.face_locator import FaceLocator
51
+ from hallo.models.image_proj import ImageProjModel
52
+ from hallo.models.mutual_self_attention import ReferenceAttentionControl
53
+ from hallo.models.unet_2d_condition import UNet2DConditionModel
54
+ from hallo.models.unet_3d import UNet3DConditionModel
55
+ from hallo.utils.util import (compute_snr, delete_additional_ckpt,
56
+ import_filename, init_output_dir,
57
+ load_checkpoint, move_final_checkpoint,
58
+ save_checkpoint, seed_everything)
59
+
60
+ warnings.filterwarnings("ignore")
61
+
62
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
63
+ check_min_version("0.10.0.dev0")
64
+
65
+ logger = get_logger(__name__, log_level="INFO")
66
+
67
+
68
+ class Net(nn.Module):
69
+ """
70
+ The Net class defines a neural network model that combines a reference UNet2DConditionModel,
71
+ a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image.
72
+
73
+ Args:
74
+ reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation.
75
+ denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation.
76
+ face_locator (FaceLocator): The face locator model used for face animation.
77
+ reference_control_writer: The reference control writer component.
78
+ reference_control_reader: The reference control reader component.
79
+ imageproj: The image projection model.
80
+
81
+ Forward method:
82
+ noisy_latents (torch.Tensor): The noisy latents tensor.
83
+ timesteps (torch.Tensor): The timesteps tensor.
84
+ ref_image_latents (torch.Tensor): The reference image latents tensor.
85
+ face_emb (torch.Tensor): The face embeddings tensor.
86
+ face_mask (torch.Tensor): The face mask tensor.
87
+ uncond_fwd (bool): A flag indicating whether to perform unconditional forward pass.
88
+
89
+ Returns:
90
+ torch.Tensor: The output tensor of the neural network model.
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ reference_unet: UNet2DConditionModel,
96
+ denoising_unet: UNet3DConditionModel,
97
+ face_locator: FaceLocator,
98
+ reference_control_writer: ReferenceAttentionControl,
99
+ reference_control_reader: ReferenceAttentionControl,
100
+ imageproj: ImageProjModel,
101
+ ):
102
+ super().__init__()
103
+ self.reference_unet = reference_unet
104
+ self.denoising_unet = denoising_unet
105
+ self.face_locator = face_locator
106
+ self.reference_control_writer = reference_control_writer
107
+ self.reference_control_reader = reference_control_reader
108
+ self.imageproj = imageproj
109
+
110
+ def forward(
111
+ self,
112
+ noisy_latents,
113
+ timesteps,
114
+ ref_image_latents,
115
+ face_emb,
116
+ face_mask,
117
+ uncond_fwd: bool = False,
118
+ ):
119
+ """
120
+ Forward pass of the model.
121
+ Args:
122
+ self (Net): The model instance.
123
+ noisy_latents (torch.Tensor): Noisy latents.
124
+ timesteps (torch.Tensor): Timesteps.
125
+ ref_image_latents (torch.Tensor): Reference image latents.
126
+ face_emb (torch.Tensor): Face embedding.
127
+ face_mask (torch.Tensor): Face mask.
128
+ uncond_fwd (bool, optional): Unconditional forward pass. Defaults to False.
129
+
130
+ Returns:
131
+ torch.Tensor: Model prediction.
132
+ """
133
+
134
+ face_emb = self.imageproj(face_emb)
135
+ face_mask = face_mask.to(device="cuda")
136
+ face_mask_feature = self.face_locator(face_mask)
137
+
138
+ if not uncond_fwd:
139
+ ref_timesteps = torch.zeros_like(timesteps)
140
+ self.reference_unet(
141
+ ref_image_latents,
142
+ ref_timesteps,
143
+ encoder_hidden_states=face_emb,
144
+ return_dict=False,
145
+ )
146
+ self.reference_control_reader.update(self.reference_control_writer)
147
+ model_pred = self.denoising_unet(
148
+ noisy_latents,
149
+ timesteps,
150
+ mask_cond_fea=face_mask_feature,
151
+ encoder_hidden_states=face_emb,
152
+ ).sample
153
+
154
+ return model_pred
155
+
156
+
157
+ def get_noise_scheduler(cfg: argparse.Namespace):
158
+ """
159
+ Create noise scheduler for training
160
+
161
+ Args:
162
+ cfg (omegaconf.dictconfig.DictConfig): Configuration object.
163
+
164
+ Returns:
165
+ train noise scheduler and val noise scheduler
166
+ """
167
+ sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
168
+ if cfg.enable_zero_snr:
169
+ sched_kwargs.update(
170
+ rescale_betas_zero_snr=True,
171
+ timestep_spacing="trailing",
172
+ prediction_type="v_prediction",
173
+ )
174
+ val_noise_scheduler = DDIMScheduler(**sched_kwargs)
175
+ sched_kwargs.update({"beta_schedule": "scaled_linear"})
176
+ train_noise_scheduler = DDIMScheduler(**sched_kwargs)
177
+
178
+ return train_noise_scheduler, val_noise_scheduler
179
+
180
+
181
+ def log_validation(
182
+ vae,
183
+ net,
184
+ scheduler,
185
+ accelerator,
186
+ width,
187
+ height,
188
+ imageproj,
189
+ cfg,
190
+ save_dir,
191
+ global_step,
192
+ face_analysis_model_path,
193
+ ):
194
+ """
195
+ Log validation generation image.
196
+
197
+ Args:
198
+ vae (nn.Module): Variational Autoencoder model.
199
+ net (Net): Main model.
200
+ scheduler (diffusers.SchedulerMixin): Noise scheduler.
201
+ accelerator (accelerate.Accelerator): Accelerator for training.
202
+ width (int): Width of the input images.
203
+ height (int): Height of the input images.
204
+ imageproj (nn.Module): Image projection model.
205
+ cfg (omegaconf.dictconfig.DictConfig): Configuration object.
206
+ save_dir (str): directory path to save log result.
207
+ global_step (int): Global step number.
208
+
209
+ Returns:
210
+ None
211
+ """
212
+ logger.info("Running validation... ")
213
+
214
+ ori_net = accelerator.unwrap_model(net)
215
+ ori_net = copy.deepcopy(ori_net)
216
+ reference_unet = ori_net.reference_unet
217
+ denoising_unet = ori_net.denoising_unet
218
+ face_locator = ori_net.face_locator
219
+
220
+ generator = torch.manual_seed(42)
221
+ image_enc = FaceAnalysis(
222
+ name="",
223
+ root=face_analysis_model_path,
224
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
225
+ )
226
+ image_enc.prepare(ctx_id=0, det_size=(640, 640))
227
+
228
+ pipe = StaticPipeline(
229
+ vae=vae,
230
+ reference_unet=reference_unet,
231
+ denoising_unet=denoising_unet,
232
+ face_locator=face_locator,
233
+ scheduler=scheduler,
234
+ imageproj=imageproj,
235
+ )
236
+
237
+ pil_images = []
238
+ for ref_image_path, mask_image_path in zip(cfg.ref_image_paths, cfg.mask_image_paths):
239
+ # for mask_image_path in mask_image_paths:
240
+ mask_name = os.path.splitext(
241
+ os.path.basename(mask_image_path))[0]
242
+ ref_name = os.path.splitext(
243
+ os.path.basename(ref_image_path))[0]
244
+ ref_image_pil = Image.open(ref_image_path).convert("RGB")
245
+ mask_image_pil = Image.open(mask_image_path).convert("RGB")
246
+
247
+ # Prepare face embeds
248
+ face_info = image_enc.get(
249
+ cv2.cvtColor(np.array(ref_image_pil), cv2.COLOR_RGB2BGR))
250
+ face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (
251
+ x['bbox'][3] - x['bbox'][1]))[-1] # only use the maximum face
252
+ face_emb = torch.tensor(face_info['embedding'])
253
+ face_emb = face_emb.to(
254
+ imageproj.device, imageproj.dtype)
255
+
256
+ image = pipe(
257
+ ref_image_pil,
258
+ mask_image_pil,
259
+ width,
260
+ height,
261
+ 20,
262
+ 3.5,
263
+ face_emb,
264
+ generator=generator,
265
+ ).images
266
+ image = image[0, :, 0].permute(1, 2, 0).cpu().numpy() # (3, 512, 512)
267
+ res_image_pil = Image.fromarray((image * 255).astype(np.uint8))
268
+ # Save ref_image, src_image and the generated_image
269
+ w, h = res_image_pil.size
270
+ canvas = Image.new("RGB", (w * 3, h), "white")
271
+ ref_image_pil = ref_image_pil.resize((w, h))
272
+ mask_image_pil = mask_image_pil.resize((w, h))
273
+ canvas.paste(ref_image_pil, (0, 0))
274
+ canvas.paste(mask_image_pil, (w, 0))
275
+ canvas.paste(res_image_pil, (w * 2, 0))
276
+
277
+ out_file = os.path.join(
278
+ save_dir, f"{global_step:06d}-{ref_name}_{mask_name}.jpg"
279
+ )
280
+ canvas.save(out_file)
281
+
282
+ del pipe
283
+ del ori_net
284
+ torch.cuda.empty_cache()
285
+
286
+ return pil_images
287
+
288
+
289
+ def train_stage1_process(cfg: argparse.Namespace) -> None:
290
+ """
291
+ Trains the model using the given configuration (cfg).
292
+
293
+ Args:
294
+ cfg (dict): The configuration dictionary containing the parameters for training.
295
+
296
+ Notes:
297
+ - This function trains the model using the given configuration.
298
+ - It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler.
299
+ - The training progress is logged and tracked using the accelerator.
300
+ - The trained model is saved after the training is completed.
301
+ """
302
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
303
+ accelerator = Accelerator(
304
+ gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
305
+ mixed_precision=cfg.solver.mixed_precision,
306
+ log_with="mlflow",
307
+ project_dir="./mlruns",
308
+ kwargs_handlers=[kwargs],
309
+ )
310
+
311
+ # Make one log on every process with the configuration for debugging.
312
+ logging.basicConfig(
313
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
314
+ datefmt="%m/%d/%Y %H:%M:%S",
315
+ level=logging.INFO,
316
+ )
317
+
318
+ logger.info(accelerator.state, main_process_only=False)
319
+ if accelerator.is_local_main_process:
320
+ transformers.utils.logging.set_verbosity_warning()
321
+ diffusers.utils.logging.set_verbosity_info()
322
+ else:
323
+ transformers.utils.logging.set_verbosity_error()
324
+ diffusers.utils.logging.set_verbosity_error()
325
+
326
+ # If passed along, set the training seed now.
327
+ if cfg.seed is not None:
328
+ seed_everything(cfg.seed)
329
+
330
+ # create output dir for training
331
+ exp_name = cfg.exp_name
332
+ save_dir = f"{cfg.output_dir}/{exp_name}"
333
+ checkpoint_dir = os.path.join(save_dir, "checkpoints")
334
+ module_dir = os.path.join(save_dir, "modules")
335
+ validation_dir = os.path.join(save_dir, "validation")
336
+
337
+ if accelerator.is_main_process:
338
+ init_output_dir([save_dir, checkpoint_dir, module_dir, validation_dir])
339
+
340
+ accelerator.wait_for_everyone()
341
+
342
+ # create model
343
+ if cfg.weight_dtype == "fp16":
344
+ weight_dtype = torch.float16
345
+ elif cfg.weight_dtype == "bf16":
346
+ weight_dtype = torch.bfloat16
347
+ elif cfg.weight_dtype == "fp32":
348
+ weight_dtype = torch.float32
349
+ else:
350
+ raise ValueError(
351
+ f"Do not support weight dtype: {cfg.weight_dtype} during training"
352
+ )
353
+
354
+ # create model
355
+ vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
356
+ "cuda", dtype=weight_dtype
357
+ )
358
+ reference_unet = UNet2DConditionModel.from_pretrained(
359
+ cfg.base_model_path,
360
+ subfolder="unet",
361
+ ).to(device="cuda", dtype=weight_dtype)
362
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
363
+ cfg.base_model_path,
364
+ "",
365
+ subfolder="unet",
366
+ unet_additional_kwargs={
367
+ "use_motion_module": False,
368
+ "unet_use_temporal_attention": False,
369
+ },
370
+ use_landmark=False
371
+ ).to(device="cuda", dtype=weight_dtype)
372
+ imageproj = ImageProjModel(
373
+ cross_attention_dim=denoising_unet.config.cross_attention_dim,
374
+ clip_embeddings_dim=512,
375
+ clip_extra_context_tokens=4,
376
+ ).to(device="cuda", dtype=weight_dtype)
377
+
378
+ if cfg.face_locator_pretrained:
379
+ face_locator = FaceLocator(
380
+ conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256)
381
+ ).to(device="cuda", dtype=weight_dtype)
382
+ miss, _ = face_locator.load_state_dict(
383
+ cfg.face_state_dict_path, strict=False)
384
+ logger.info(f"Missing key for face locator: {len(miss)}")
385
+ else:
386
+ face_locator = FaceLocator(
387
+ conditioning_embedding_channels=320,
388
+ ).to(device="cuda", dtype=weight_dtype)
389
+ # Freeze
390
+ vae.requires_grad_(False)
391
+ denoising_unet.requires_grad_(True)
392
+ reference_unet.requires_grad_(True)
393
+ imageproj.requires_grad_(True)
394
+ face_locator.requires_grad_(True)
395
+
396
+ reference_control_writer = ReferenceAttentionControl(
397
+ reference_unet,
398
+ do_classifier_free_guidance=False,
399
+ mode="write",
400
+ fusion_blocks="full",
401
+ )
402
+ reference_control_reader = ReferenceAttentionControl(
403
+ denoising_unet,
404
+ do_classifier_free_guidance=False,
405
+ mode="read",
406
+ fusion_blocks="full",
407
+ )
408
+
409
+ net = Net(
410
+ reference_unet,
411
+ denoising_unet,
412
+ face_locator,
413
+ reference_control_writer,
414
+ reference_control_reader,
415
+ imageproj,
416
+ ).to(dtype=weight_dtype)
417
+
418
+ # get noise scheduler
419
+ train_noise_scheduler, val_noise_scheduler = get_noise_scheduler(cfg)
420
+
421
+ # init optimizer
422
+ if cfg.solver.enable_xformers_memory_efficient_attention:
423
+ if is_xformers_available():
424
+ reference_unet.enable_xformers_memory_efficient_attention()
425
+ denoising_unet.enable_xformers_memory_efficient_attention()
426
+ else:
427
+ raise ValueError(
428
+ "xformers is not available. Make sure it is installed correctly"
429
+ )
430
+
431
+ if cfg.solver.gradient_checkpointing:
432
+ reference_unet.enable_gradient_checkpointing()
433
+ denoising_unet.enable_gradient_checkpointing()
434
+
435
+ if cfg.solver.scale_lr:
436
+ learning_rate = (
437
+ cfg.solver.learning_rate
438
+ * cfg.solver.gradient_accumulation_steps
439
+ * cfg.data.train_bs
440
+ * accelerator.num_processes
441
+ )
442
+ else:
443
+ learning_rate = cfg.solver.learning_rate
444
+
445
+ # Initialize the optimizer
446
+ if cfg.solver.use_8bit_adam:
447
+ try:
448
+ import bitsandbytes as bnb
449
+ except ImportError as exc:
450
+ raise ImportError(
451
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
452
+ ) from exc
453
+
454
+ optimizer_cls = bnb.optim.AdamW8bit
455
+ else:
456
+ optimizer_cls = torch.optim.AdamW
457
+
458
+ trainable_params = list(
459
+ filter(lambda p: p.requires_grad, net.parameters()))
460
+ optimizer = optimizer_cls(
461
+ trainable_params,
462
+ lr=learning_rate,
463
+ betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
464
+ weight_decay=cfg.solver.adam_weight_decay,
465
+ eps=cfg.solver.adam_epsilon,
466
+ )
467
+
468
+ # init scheduler
469
+ lr_scheduler = get_scheduler(
470
+ cfg.solver.lr_scheduler,
471
+ optimizer=optimizer,
472
+ num_warmup_steps=cfg.solver.lr_warmup_steps
473
+ * cfg.solver.gradient_accumulation_steps,
474
+ num_training_steps=cfg.solver.max_train_steps
475
+ * cfg.solver.gradient_accumulation_steps,
476
+ )
477
+
478
+ # get data loader
479
+ train_dataset = FaceMaskDataset(
480
+ img_size=(cfg.data.train_width, cfg.data.train_height),
481
+ data_meta_paths=cfg.data.meta_paths,
482
+ sample_margin=cfg.data.sample_margin,
483
+ )
484
+ train_dataloader = torch.utils.data.DataLoader(
485
+ train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=4
486
+ )
487
+
488
+ # Prepare everything with our `accelerator`.
489
+ (
490
+ net,
491
+ optimizer,
492
+ train_dataloader,
493
+ lr_scheduler,
494
+ ) = accelerator.prepare(
495
+ net,
496
+ optimizer,
497
+ train_dataloader,
498
+ lr_scheduler,
499
+ )
500
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
501
+ num_update_steps_per_epoch = math.ceil(
502
+ len(train_dataloader) / cfg.solver.gradient_accumulation_steps
503
+ )
504
+ # Afterwards we recalculate our number of training epochs
505
+ num_train_epochs = math.ceil(
506
+ cfg.solver.max_train_steps / num_update_steps_per_epoch
507
+ )
508
+
509
+ # We need to initialize the trackers we use, and also store our configuration.
510
+ # The trackers initializes automatically on the main process.
511
+ if accelerator.is_main_process:
512
+ run_time = datetime.now().strftime("%Y%m%d-%H%M")
513
+ accelerator.init_trackers(
514
+ cfg.exp_name,
515
+ init_kwargs={"mlflow": {"run_name": run_time}},
516
+ )
517
+ # dump config file
518
+ mlflow.log_dict(OmegaConf.to_container(cfg), "config.yaml")
519
+
520
+ logger.info(f"save config to {save_dir}")
521
+ OmegaConf.save(
522
+ cfg, os.path.join(save_dir, "config.yaml")
523
+ )
524
+ # Train!
525
+ total_batch_size = (
526
+ cfg.data.train_bs
527
+ * accelerator.num_processes
528
+ * cfg.solver.gradient_accumulation_steps
529
+ )
530
+
531
+ logger.info("***** Running training *****")
532
+ logger.info(f" Num examples = {len(train_dataset)}")
533
+ logger.info(f" Num Epochs = {num_train_epochs}")
534
+ logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}")
535
+ logger.info(
536
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
537
+ )
538
+ logger.info(
539
+ f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}"
540
+ )
541
+ logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}")
542
+ global_step = 0
543
+ first_epoch = 0
544
+
545
+ # load checkpoint
546
+ # Potentially load in the weights and states from a previous save
547
+ if cfg.resume_from_checkpoint:
548
+ logger.info(f"Loading checkpoint from {checkpoint_dir}")
549
+ global_step = load_checkpoint(cfg, checkpoint_dir, accelerator)
550
+ first_epoch = global_step // num_update_steps_per_epoch
551
+
552
+ # Only show the progress bar once on each machine.
553
+ progress_bar = tqdm(
554
+ range(global_step, cfg.solver.max_train_steps),
555
+ disable=not accelerator.is_main_process,
556
+ )
557
+ progress_bar.set_description("Steps")
558
+ net.train()
559
+ for _ in range(first_epoch, num_train_epochs):
560
+ train_loss = 0.0
561
+ for _, batch in enumerate(train_dataloader):
562
+ with accelerator.accumulate(net):
563
+ # Convert videos to latent space
564
+ pixel_values = batch["img"].to(weight_dtype)
565
+ with torch.no_grad():
566
+ latents = vae.encode(pixel_values).latent_dist.sample()
567
+ latents = latents.unsqueeze(2) # (b, c, 1, h, w)
568
+ latents = latents * 0.18215
569
+
570
+ noise = torch.randn_like(latents)
571
+ if cfg.noise_offset > 0.0:
572
+ noise += cfg.noise_offset * torch.randn(
573
+ (noise.shape[0], noise.shape[1], 1, 1, 1),
574
+ device=noise.device,
575
+ )
576
+
577
+ bsz = latents.shape[0]
578
+ # Sample a random timestep for each video
579
+ timesteps = torch.randint(
580
+ 0,
581
+ train_noise_scheduler.num_train_timesteps,
582
+ (bsz,),
583
+ device=latents.device,
584
+ )
585
+ timesteps = timesteps.long()
586
+
587
+ face_mask_img = batch["tgt_mask"]
588
+ face_mask_img = face_mask_img.unsqueeze(
589
+ 2)
590
+ face_mask_img = face_mask_img.to(weight_dtype)
591
+
592
+ uncond_fwd = random.random() < cfg.uncond_ratio
593
+ face_emb_list = []
594
+ ref_image_list = []
595
+ for _, (ref_img, face_emb) in enumerate(
596
+ zip(batch["ref_img"], batch["face_emb"])
597
+ ):
598
+ if uncond_fwd:
599
+ face_emb_list.append(torch.zeros_like(face_emb))
600
+ else:
601
+ face_emb_list.append(face_emb)
602
+ ref_image_list.append(ref_img)
603
+
604
+ with torch.no_grad():
605
+ ref_img = torch.stack(ref_image_list, dim=0).to(
606
+ dtype=vae.dtype, device=vae.device
607
+ )
608
+ ref_image_latents = vae.encode(
609
+ ref_img
610
+ ).latent_dist.sample()
611
+ ref_image_latents = ref_image_latents * 0.18215
612
+
613
+ face_emb = torch.stack(face_emb_list, dim=0).to(
614
+ dtype=imageproj.dtype, device=imageproj.device
615
+ )
616
+
617
+ # add noise
618
+ noisy_latents = train_noise_scheduler.add_noise(
619
+ latents, noise, timesteps
620
+ )
621
+
622
+ # Get the target for loss depending on the prediction type
623
+ if train_noise_scheduler.prediction_type == "epsilon":
624
+ target = noise
625
+ elif train_noise_scheduler.prediction_type == "v_prediction":
626
+ target = train_noise_scheduler.get_velocity(
627
+ latents, noise, timesteps
628
+ )
629
+ else:
630
+ raise ValueError(
631
+ f"Unknown prediction type {train_noise_scheduler.prediction_type}"
632
+ )
633
+ model_pred = net(
634
+ noisy_latents,
635
+ timesteps,
636
+ ref_image_latents,
637
+ face_emb,
638
+ face_mask_img,
639
+ uncond_fwd,
640
+ )
641
+
642
+ if cfg.snr_gamma == 0:
643
+ loss = F.mse_loss(
644
+ model_pred.float(), target.float(), reduction="mean"
645
+ )
646
+ else:
647
+ snr = compute_snr(train_noise_scheduler, timesteps)
648
+ if train_noise_scheduler.config.prediction_type == "v_prediction":
649
+ # Velocity objective requires that we add one to SNR values before we divide by them.
650
+ snr = snr + 1
651
+ mse_loss_weights = (
652
+ torch.stack(
653
+ [snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1
654
+ ).min(dim=1)[0]
655
+ / snr
656
+ )
657
+ loss = F.mse_loss(
658
+ model_pred.float(), target.float(), reduction="none"
659
+ )
660
+ loss = (
661
+ loss.mean(dim=list(range(1, len(loss.shape))))
662
+ * mse_loss_weights
663
+ )
664
+ loss = loss.mean()
665
+
666
+ # Gather the losses across all processes for logging (if we use distributed training).
667
+ avg_loss = accelerator.gather(
668
+ loss.repeat(cfg.data.train_bs)).mean()
669
+ train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps
670
+
671
+ # Backpropagate
672
+ accelerator.backward(loss)
673
+ if accelerator.sync_gradients:
674
+ accelerator.clip_grad_norm_(
675
+ trainable_params,
676
+ cfg.solver.max_grad_norm,
677
+ )
678
+ optimizer.step()
679
+ lr_scheduler.step()
680
+ optimizer.zero_grad()
681
+
682
+ if accelerator.sync_gradients:
683
+ reference_control_reader.clear()
684
+ reference_control_writer.clear()
685
+ progress_bar.update(1)
686
+ global_step += 1
687
+ accelerator.log({"train_loss": train_loss}, step=global_step)
688
+ train_loss = 0.0
689
+ if global_step % cfg.checkpointing_steps == 0 or global_step == cfg.solver.max_train_steps:
690
+ accelerator.wait_for_everyone()
691
+ save_path = os.path.join(
692
+ checkpoint_dir, f"checkpoint-{global_step}")
693
+ if accelerator.is_main_process:
694
+ delete_additional_ckpt(checkpoint_dir, 3)
695
+ accelerator.save_state(save_path)
696
+ accelerator.wait_for_everyone()
697
+ unwrap_net = accelerator.unwrap_model(net)
698
+ if accelerator.is_main_process:
699
+ save_checkpoint(
700
+ unwrap_net.reference_unet,
701
+ module_dir,
702
+ "reference_unet",
703
+ global_step,
704
+ total_limit=3,
705
+ )
706
+ save_checkpoint(
707
+ unwrap_net.imageproj,
708
+ module_dir,
709
+ "imageproj",
710
+ global_step,
711
+ total_limit=3,
712
+ )
713
+ save_checkpoint(
714
+ unwrap_net.denoising_unet,
715
+ module_dir,
716
+ "denoising_unet",
717
+ global_step,
718
+ total_limit=3,
719
+ )
720
+ save_checkpoint(
721
+ unwrap_net.face_locator,
722
+ module_dir,
723
+ "face_locator",
724
+ global_step,
725
+ total_limit=3,
726
+ )
727
+
728
+ if global_step % cfg.val.validation_steps == 0 or global_step == 1:
729
+ if accelerator.is_main_process:
730
+ generator = torch.Generator(device=accelerator.device)
731
+ generator.manual_seed(cfg.seed)
732
+ log_validation(
733
+ vae=vae,
734
+ net=net,
735
+ scheduler=val_noise_scheduler,
736
+ accelerator=accelerator,
737
+ width=cfg.data.train_width,
738
+ height=cfg.data.train_height,
739
+ imageproj=imageproj,
740
+ cfg=cfg,
741
+ save_dir=validation_dir,
742
+ global_step=global_step,
743
+ face_analysis_model_path=cfg.face_analysis_model_path
744
+ )
745
+
746
+ logs = {
747
+ "step_loss": loss.detach().item(),
748
+ "lr": lr_scheduler.get_last_lr()[0],
749
+ }
750
+ progress_bar.set_postfix(**logs)
751
+
752
+ if global_step >= cfg.solver.max_train_steps:
753
+ # process final module weight for stage2
754
+ if accelerator.is_main_process:
755
+ move_final_checkpoint(save_dir, module_dir, "reference_unet")
756
+ move_final_checkpoint(save_dir, module_dir, "imageproj")
757
+ move_final_checkpoint(save_dir, module_dir, "denoising_unet")
758
+ move_final_checkpoint(save_dir, module_dir, "face_locator")
759
+ break
760
+
761
+ accelerator.wait_for_everyone()
762
+ accelerator.end_training()
763
+
764
+
765
+ def load_config(config_path: str) -> dict:
766
+ """
767
+ Loads the configuration file.
768
+
769
+ Args:
770
+ config_path (str): Path to the configuration file.
771
+
772
+ Returns:
773
+ dict: The configuration dictionary.
774
+ """
775
+
776
+ if config_path.endswith(".yaml"):
777
+ return OmegaConf.load(config_path)
778
+ if config_path.endswith(".py"):
779
+ return import_filename(config_path).cfg
780
+ raise ValueError("Unsupported format for config file")
781
+
782
+
783
+ if __name__ == "__main__":
784
+ parser = argparse.ArgumentParser()
785
+ parser.add_argument("--config", type=str,
786
+ default="./configs/train/stage1.yaml")
787
+ args = parser.parse_args()
788
+
789
+ try:
790
+ config = load_config(args.config)
791
+ train_stage1_process(config)
792
+ except Exception as e:
793
+ logging.error("Failed to execute the training process: %s", e)
scripts/train_stage2.py ADDED
@@ -0,0 +1,991 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=E1101,C0415,W0718,R0801
2
+ # scripts/train_stage2.py
3
+ """
4
+ This is the main training script for stage 2 of the project.
5
+ It imports necessary packages, defines necessary classes and functions, and trains the model using the provided configuration.
6
+
7
+ The script includes the following classes and functions:
8
+
9
+ 1. Net: A PyTorch model that takes noisy latents, timesteps, reference image latents, face embeddings,
10
+ and face masks as input and returns the denoised latents.
11
+ 2. get_attention_mask: A function that rearranges the mask tensors to the required format.
12
+ 3. get_noise_scheduler: A function that creates and returns the noise schedulers for training and validation.
13
+ 4. process_audio_emb: A function that processes the audio embeddings to concatenate with other tensors.
14
+ 5. log_validation: A function that logs the validation information using the given VAE, image encoder,
15
+ network, scheduler, accelerator, width, height, and configuration.
16
+ 6. train_stage2_process: A function that processes the training stage 2 using the given configuration.
17
+ 7. load_config: A function that loads the configuration file from the given path.
18
+
19
+ The script also includes the necessary imports and a brief description of the purpose of the file.
20
+ """
21
+
22
+ import argparse
23
+ import copy
24
+ import logging
25
+ import math
26
+ import os
27
+ import random
28
+ import time
29
+ import warnings
30
+ from datetime import datetime
31
+ from typing import List, Tuple
32
+
33
+ import diffusers
34
+ import mlflow
35
+ import torch
36
+ import torch.nn.functional as F
37
+ import torch.utils.checkpoint
38
+ import transformers
39
+ from accelerate import Accelerator
40
+ from accelerate.logging import get_logger
41
+ from accelerate.utils import DistributedDataParallelKwargs
42
+ from diffusers import AutoencoderKL, DDIMScheduler
43
+ from diffusers.optimization import get_scheduler
44
+ from diffusers.utils import check_min_version
45
+ from diffusers.utils.import_utils import is_xformers_available
46
+ from einops import rearrange, repeat
47
+ from omegaconf import OmegaConf
48
+ from torch import nn
49
+ from tqdm.auto import tqdm
50
+
51
+ from hallo.animate.face_animate import FaceAnimatePipeline
52
+ from hallo.datasets.audio_processor import AudioProcessor
53
+ from hallo.datasets.image_processor import ImageProcessor
54
+ from hallo.datasets.talk_video import TalkingVideoDataset
55
+ from hallo.models.audio_proj import AudioProjModel
56
+ from hallo.models.face_locator import FaceLocator
57
+ from hallo.models.image_proj import ImageProjModel
58
+ from hallo.models.mutual_self_attention import ReferenceAttentionControl
59
+ from hallo.models.unet_2d_condition import UNet2DConditionModel
60
+ from hallo.models.unet_3d import UNet3DConditionModel
61
+ from hallo.utils.util import (compute_snr, delete_additional_ckpt,
62
+ import_filename, init_output_dir,
63
+ load_checkpoint, save_checkpoint,
64
+ seed_everything, tensor_to_video)
65
+
66
+ warnings.filterwarnings("ignore")
67
+
68
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
69
+ check_min_version("0.10.0.dev0")
70
+
71
+ logger = get_logger(__name__, log_level="INFO")
72
+
73
+
74
+ class Net(nn.Module):
75
+ """
76
+ The Net class defines a neural network model that combines a reference UNet2DConditionModel,
77
+ a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image.
78
+
79
+ Args:
80
+ reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation.
81
+ denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation.
82
+ face_locator (FaceLocator): The face locator model used for face animation.
83
+ reference_control_writer: The reference control writer component.
84
+ reference_control_reader: The reference control reader component.
85
+ imageproj: The image projection model.
86
+ audioproj: The audio projection model.
87
+
88
+ Forward method:
89
+ noisy_latents (torch.Tensor): The noisy latents tensor.
90
+ timesteps (torch.Tensor): The timesteps tensor.
91
+ ref_image_latents (torch.Tensor): The reference image latents tensor.
92
+ face_emb (torch.Tensor): The face embeddings tensor.
93
+ audio_emb (torch.Tensor): The audio embeddings tensor.
94
+ mask (torch.Tensor): Hard face mask for face locator.
95
+ full_mask (torch.Tensor): Pose Mask.
96
+ face_mask (torch.Tensor): Face Mask
97
+ lip_mask (torch.Tensor): Lip Mask
98
+ uncond_img_fwd (bool): A flag indicating whether to perform reference image unconditional forward pass.
99
+ uncond_audio_fwd (bool): A flag indicating whether to perform audio unconditional forward pass.
100
+
101
+ Returns:
102
+ torch.Tensor: The output tensor of the neural network model.
103
+ """
104
+ def __init__(
105
+ self,
106
+ reference_unet: UNet2DConditionModel,
107
+ denoising_unet: UNet3DConditionModel,
108
+ face_locator: FaceLocator,
109
+ reference_control_writer,
110
+ reference_control_reader,
111
+ imageproj,
112
+ audioproj,
113
+ ):
114
+ super().__init__()
115
+ self.reference_unet = reference_unet
116
+ self.denoising_unet = denoising_unet
117
+ self.face_locator = face_locator
118
+ self.reference_control_writer = reference_control_writer
119
+ self.reference_control_reader = reference_control_reader
120
+ self.imageproj = imageproj
121
+ self.audioproj = audioproj
122
+
123
+ def forward(
124
+ self,
125
+ noisy_latents: torch.Tensor,
126
+ timesteps: torch.Tensor,
127
+ ref_image_latents: torch.Tensor,
128
+ face_emb: torch.Tensor,
129
+ audio_emb: torch.Tensor,
130
+ mask: torch.Tensor,
131
+ full_mask: torch.Tensor,
132
+ face_mask: torch.Tensor,
133
+ lip_mask: torch.Tensor,
134
+ uncond_img_fwd: bool = False,
135
+ uncond_audio_fwd: bool = False,
136
+ ):
137
+ """
138
+ simple docstring to prevent pylint error
139
+ """
140
+ face_emb = self.imageproj(face_emb)
141
+ mask = mask.to(device="cuda")
142
+ mask_feature = self.face_locator(mask)
143
+ audio_emb = audio_emb.to(
144
+ device=self.audioproj.device, dtype=self.audioproj.dtype)
145
+ audio_emb = self.audioproj(audio_emb)
146
+
147
+ # condition forward
148
+ if not uncond_img_fwd:
149
+ ref_timesteps = torch.zeros_like(timesteps)
150
+ ref_timesteps = repeat(
151
+ ref_timesteps,
152
+ "b -> (repeat b)",
153
+ repeat=ref_image_latents.size(0) // ref_timesteps.size(0),
154
+ )
155
+ self.reference_unet(
156
+ ref_image_latents,
157
+ ref_timesteps,
158
+ encoder_hidden_states=face_emb,
159
+ return_dict=False,
160
+ )
161
+ self.reference_control_reader.update(self.reference_control_writer)
162
+
163
+ if uncond_audio_fwd:
164
+ audio_emb = torch.zeros_like(audio_emb).to(
165
+ device=audio_emb.device, dtype=audio_emb.dtype
166
+ )
167
+
168
+ model_pred = self.denoising_unet(
169
+ noisy_latents,
170
+ timesteps,
171
+ mask_cond_fea=mask_feature,
172
+ encoder_hidden_states=face_emb,
173
+ audio_embedding=audio_emb,
174
+ full_mask=full_mask,
175
+ face_mask=face_mask,
176
+ lip_mask=lip_mask
177
+ ).sample
178
+
179
+ return model_pred
180
+
181
+
182
+ def get_attention_mask(mask: torch.Tensor, weight_dtype: torch.dtype) -> torch.Tensor:
183
+ """
184
+ Rearrange the mask tensors to the required format.
185
+
186
+ Args:
187
+ mask (torch.Tensor): The input mask tensor.
188
+ weight_dtype (torch.dtype): The data type for the mask tensor.
189
+
190
+ Returns:
191
+ torch.Tensor: The rearranged mask tensor.
192
+ """
193
+ if isinstance(mask, List):
194
+ _mask = []
195
+ for m in mask:
196
+ _mask.append(
197
+ rearrange(m, "b f 1 h w -> (b f) (h w)").to(weight_dtype))
198
+ return _mask
199
+ mask = rearrange(mask, "b f 1 h w -> (b f) (h w)").to(weight_dtype)
200
+ return mask
201
+
202
+
203
+ def get_noise_scheduler(cfg: argparse.Namespace) -> Tuple[DDIMScheduler, DDIMScheduler]:
204
+ """
205
+ Create noise scheduler for training.
206
+
207
+ Args:
208
+ cfg (argparse.Namespace): Configuration object.
209
+
210
+ Returns:
211
+ Tuple[DDIMScheduler, DDIMScheduler]: Train noise scheduler and validation noise scheduler.
212
+ """
213
+
214
+ sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
215
+ if cfg.enable_zero_snr:
216
+ sched_kwargs.update(
217
+ rescale_betas_zero_snr=True,
218
+ timestep_spacing="trailing",
219
+ prediction_type="v_prediction",
220
+ )
221
+ val_noise_scheduler = DDIMScheduler(**sched_kwargs)
222
+ sched_kwargs.update({"beta_schedule": "scaled_linear"})
223
+ train_noise_scheduler = DDIMScheduler(**sched_kwargs)
224
+
225
+ return train_noise_scheduler, val_noise_scheduler
226
+
227
+
228
+ def process_audio_emb(audio_emb: torch.Tensor) -> torch.Tensor:
229
+ """
230
+ Process the audio embedding to concatenate with other tensors.
231
+
232
+ Parameters:
233
+ audio_emb (torch.Tensor): The audio embedding tensor to process.
234
+
235
+ Returns:
236
+ concatenated_tensors (List[torch.Tensor]): The concatenated tensor list.
237
+ """
238
+ concatenated_tensors = []
239
+
240
+ for i in range(audio_emb.shape[0]):
241
+ vectors_to_concat = [
242
+ audio_emb[max(min(i + j, audio_emb.shape[0] - 1), 0)]for j in range(-2, 3)]
243
+ concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
244
+
245
+ audio_emb = torch.stack(concatenated_tensors, dim=0)
246
+
247
+ return audio_emb
248
+
249
+
250
+ def log_validation(
251
+ accelerator: Accelerator,
252
+ vae: AutoencoderKL,
253
+ net: Net,
254
+ scheduler: DDIMScheduler,
255
+ width: int,
256
+ height: int,
257
+ clip_length: int = 24,
258
+ generator: torch.Generator = None,
259
+ cfg: dict = None,
260
+ save_dir: str = None,
261
+ global_step: int = 0,
262
+ times: int = None,
263
+ face_analysis_model_path: str = "",
264
+ ) -> None:
265
+ """
266
+ Log validation video during the training process.
267
+
268
+ Args:
269
+ accelerator (Accelerator): The accelerator for distributed training.
270
+ vae (AutoencoderKL): The autoencoder model.
271
+ net (Net): The main neural network model.
272
+ scheduler (DDIMScheduler): The scheduler for noise.
273
+ width (int): The width of the input images.
274
+ height (int): The height of the input images.
275
+ clip_length (int): The length of the video clips. Defaults to 24.
276
+ generator (torch.Generator): The random number generator. Defaults to None.
277
+ cfg (dict): The configuration dictionary. Defaults to None.
278
+ save_dir (str): The directory to save validation results. Defaults to None.
279
+ global_step (int): The current global step in training. Defaults to 0.
280
+ times (int): The number of inference times. Defaults to None.
281
+ face_analysis_model_path (str): The path to the face analysis model. Defaults to "".
282
+
283
+ Returns:
284
+ torch.Tensor: The tensor result of the validation.
285
+ """
286
+ ori_net = accelerator.unwrap_model(net)
287
+ reference_unet = ori_net.reference_unet
288
+ denoising_unet = ori_net.denoising_unet
289
+ face_locator = ori_net.face_locator
290
+ imageproj = ori_net.imageproj
291
+ audioproj = ori_net.audioproj
292
+
293
+ generator = torch.manual_seed(42)
294
+ tmp_denoising_unet = copy.deepcopy(denoising_unet)
295
+
296
+ pipeline = FaceAnimatePipeline(
297
+ vae=vae,
298
+ reference_unet=reference_unet,
299
+ denoising_unet=tmp_denoising_unet,
300
+ face_locator=face_locator,
301
+ image_proj=imageproj,
302
+ scheduler=scheduler,
303
+ )
304
+ pipeline = pipeline.to("cuda")
305
+
306
+ image_processor = ImageProcessor((width, height), face_analysis_model_path)
307
+ audio_processor = AudioProcessor(
308
+ cfg.data.sample_rate,
309
+ cfg.data.fps,
310
+ cfg.wav2vec_config.model_path,
311
+ cfg.wav2vec_config.features == "last",
312
+ os.path.dirname(cfg.audio_separator.model_path),
313
+ os.path.basename(cfg.audio_separator.model_path),
314
+ os.path.join(save_dir, '.cache', "audio_preprocess")
315
+ )
316
+
317
+ for idx, ref_img_path in enumerate(cfg.ref_img_path):
318
+ audio_path = cfg.audio_path[idx]
319
+ source_image_pixels, \
320
+ source_image_face_region, \
321
+ source_image_face_emb, \
322
+ source_image_full_mask, \
323
+ source_image_face_mask, \
324
+ source_image_lip_mask = image_processor.preprocess(
325
+ ref_img_path, os.path.join(save_dir, '.cache'), cfg.face_expand_ratio)
326
+ audio_emb, audio_length = audio_processor.preprocess(
327
+ audio_path, clip_length)
328
+
329
+ audio_emb = process_audio_emb(audio_emb)
330
+
331
+ source_image_pixels = source_image_pixels.unsqueeze(0)
332
+ source_image_face_region = source_image_face_region.unsqueeze(0)
333
+ source_image_face_emb = source_image_face_emb.reshape(1, -1)
334
+ source_image_face_emb = torch.tensor(source_image_face_emb)
335
+
336
+ source_image_full_mask = [
337
+ (mask.repeat(clip_length, 1))
338
+ for mask in source_image_full_mask
339
+ ]
340
+ source_image_face_mask = [
341
+ (mask.repeat(clip_length, 1))
342
+ for mask in source_image_face_mask
343
+ ]
344
+ source_image_lip_mask = [
345
+ (mask.repeat(clip_length, 1))
346
+ for mask in source_image_lip_mask
347
+ ]
348
+
349
+ times = audio_emb.shape[0] // clip_length
350
+ tensor_result = []
351
+ generator = torch.manual_seed(42)
352
+ for t in range(times):
353
+ print(f"[{t+1}/{times}]")
354
+
355
+ if len(tensor_result) == 0:
356
+ # The first iteration
357
+ motion_zeros = source_image_pixels.repeat(
358
+ cfg.data.n_motion_frames, 1, 1, 1)
359
+ motion_zeros = motion_zeros.to(
360
+ dtype=source_image_pixels.dtype, device=source_image_pixels.device)
361
+ pixel_values_ref_img = torch.cat(
362
+ [source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames
363
+ else:
364
+ motion_frames = tensor_result[-1][0]
365
+ motion_frames = motion_frames.permute(1, 0, 2, 3)
366
+ motion_frames = motion_frames[0 - cfg.data.n_motion_frames:]
367
+ motion_frames = motion_frames * 2.0 - 1.0
368
+ motion_frames = motion_frames.to(
369
+ dtype=source_image_pixels.dtype, device=source_image_pixels.device)
370
+ pixel_values_ref_img = torch.cat(
371
+ [source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames
372
+
373
+ pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
374
+
375
+ audio_tensor = audio_emb[
376
+ t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
377
+ ]
378
+ audio_tensor = audio_tensor.unsqueeze(0)
379
+ audio_tensor = audio_tensor.to(
380
+ device=audioproj.device, dtype=audioproj.dtype)
381
+ audio_tensor = audioproj(audio_tensor)
382
+
383
+ pipeline_output = pipeline(
384
+ ref_image=pixel_values_ref_img,
385
+ audio_tensor=audio_tensor,
386
+ face_emb=source_image_face_emb,
387
+ face_mask=source_image_face_region,
388
+ pixel_values_full_mask=source_image_full_mask,
389
+ pixel_values_face_mask=source_image_face_mask,
390
+ pixel_values_lip_mask=source_image_lip_mask,
391
+ width=cfg.data.train_width,
392
+ height=cfg.data.train_height,
393
+ video_length=clip_length,
394
+ num_inference_steps=cfg.inference_steps,
395
+ guidance_scale=cfg.cfg_scale,
396
+ generator=generator,
397
+ )
398
+
399
+ tensor_result.append(pipeline_output.videos)
400
+
401
+ tensor_result = torch.cat(tensor_result, dim=2)
402
+ tensor_result = tensor_result.squeeze(0)
403
+ tensor_result = tensor_result[:, :audio_length]
404
+ audio_name = os.path.basename(audio_path).split('.')[0]
405
+ ref_name = os.path.basename(ref_img_path).split('.')[0]
406
+ output_file = os.path.join(save_dir,f"{global_step}_{ref_name}_{audio_name}.mp4")
407
+ # save the result after all iteration
408
+ tensor_to_video(tensor_result, output_file, audio_path)
409
+
410
+
411
+ # clean up
412
+ del tmp_denoising_unet
413
+ del pipeline
414
+ del image_processor
415
+ del audio_processor
416
+ torch.cuda.empty_cache()
417
+
418
+ return tensor_result
419
+
420
+
421
+ def train_stage2_process(cfg: argparse.Namespace) -> None:
422
+ """
423
+ Trains the model using the given configuration (cfg).
424
+
425
+ Args:
426
+ cfg (dict): The configuration dictionary containing the parameters for training.
427
+
428
+ Notes:
429
+ - This function trains the model using the given configuration.
430
+ - It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler.
431
+ - The training progress is logged and tracked using the accelerator.
432
+ - The trained model is saved after the training is completed.
433
+ """
434
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
435
+ accelerator = Accelerator(
436
+ gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
437
+ mixed_precision=cfg.solver.mixed_precision,
438
+ log_with="mlflow",
439
+ project_dir="./mlruns",
440
+ kwargs_handlers=[kwargs],
441
+ )
442
+
443
+ # Make one log on every process with the configuration for debugging.
444
+ logging.basicConfig(
445
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
446
+ datefmt="%m/%d/%Y %H:%M:%S",
447
+ level=logging.INFO,
448
+ )
449
+ logger.info(accelerator.state, main_process_only=False)
450
+ if accelerator.is_local_main_process:
451
+ transformers.utils.logging.set_verbosity_warning()
452
+ diffusers.utils.logging.set_verbosity_info()
453
+ else:
454
+ transformers.utils.logging.set_verbosity_error()
455
+ diffusers.utils.logging.set_verbosity_error()
456
+
457
+ # If passed along, set the training seed now.
458
+ if cfg.seed is not None:
459
+ seed_everything(cfg.seed)
460
+
461
+ # create output dir for training
462
+ exp_name = cfg.exp_name
463
+ save_dir = f"{cfg.output_dir}/{exp_name}"
464
+ checkpoint_dir = os.path.join(save_dir, "checkpoints")
465
+ module_dir = os.path.join(save_dir, "modules")
466
+ validation_dir = os.path.join(save_dir, "validation")
467
+ if accelerator.is_main_process:
468
+ init_output_dir([save_dir, checkpoint_dir, module_dir, validation_dir])
469
+
470
+ accelerator.wait_for_everyone()
471
+
472
+ if cfg.weight_dtype == "fp16":
473
+ weight_dtype = torch.float16
474
+ elif cfg.weight_dtype == "bf16":
475
+ weight_dtype = torch.bfloat16
476
+ elif cfg.weight_dtype == "fp32":
477
+ weight_dtype = torch.float32
478
+ else:
479
+ raise ValueError(
480
+ f"Do not support weight dtype: {cfg.weight_dtype} during training"
481
+ )
482
+
483
+ # Create Models
484
+ vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
485
+ "cuda", dtype=weight_dtype
486
+ )
487
+ reference_unet = UNet2DConditionModel.from_pretrained(
488
+ cfg.base_model_path,
489
+ subfolder="unet",
490
+ ).to(device="cuda", dtype=weight_dtype)
491
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
492
+ cfg.base_model_path,
493
+ cfg.mm_path,
494
+ subfolder="unet",
495
+ unet_additional_kwargs=OmegaConf.to_container(
496
+ cfg.unet_additional_kwargs),
497
+ use_landmark=False
498
+ ).to(device="cuda", dtype=weight_dtype)
499
+ imageproj = ImageProjModel(
500
+ cross_attention_dim=denoising_unet.config.cross_attention_dim,
501
+ clip_embeddings_dim=512,
502
+ clip_extra_context_tokens=4,
503
+ ).to(device="cuda", dtype=weight_dtype)
504
+ face_locator = FaceLocator(
505
+ conditioning_embedding_channels=320,
506
+ ).to(device="cuda", dtype=weight_dtype)
507
+ audioproj = AudioProjModel(
508
+ seq_len=5,
509
+ blocks=12,
510
+ channels=768,
511
+ intermediate_dim=512,
512
+ output_dim=768,
513
+ context_tokens=32,
514
+ ).to(device="cuda", dtype=weight_dtype)
515
+
516
+ # load module weight from stage 1
517
+ stage1_ckpt_dir = cfg.stage1_ckpt_dir
518
+ denoising_unet.load_state_dict(
519
+ torch.load(
520
+ os.path.join(stage1_ckpt_dir, "denoising_unet.pth"),
521
+ map_location="cpu",
522
+ ),
523
+ strict=False,
524
+ )
525
+ reference_unet.load_state_dict(
526
+ torch.load(
527
+ os.path.join(stage1_ckpt_dir, "reference_unet.pth"),
528
+ map_location="cpu",
529
+ ),
530
+ strict=False,
531
+ )
532
+ face_locator.load_state_dict(
533
+ torch.load(
534
+ os.path.join(stage1_ckpt_dir, "face_locator.pth"),
535
+ map_location="cpu",
536
+ ),
537
+ strict=False,
538
+ )
539
+ imageproj.load_state_dict(
540
+ torch.load(
541
+ os.path.join(stage1_ckpt_dir, "imageproj.pth"),
542
+ map_location="cpu",
543
+ ),
544
+ strict=False,
545
+ )
546
+
547
+ # Freeze
548
+ vae.requires_grad_(False)
549
+ imageproj.requires_grad_(False)
550
+ reference_unet.requires_grad_(False)
551
+ denoising_unet.requires_grad_(False)
552
+ face_locator.requires_grad_(False)
553
+ audioproj.requires_grad_(True)
554
+
555
+ # Set motion module learnable
556
+ trainable_modules = cfg.trainable_para
557
+ for name, module in denoising_unet.named_modules():
558
+ if any(trainable_mod in name for trainable_mod in trainable_modules):
559
+ for params in module.parameters():
560
+ params.requires_grad_(True)
561
+
562
+ reference_control_writer = ReferenceAttentionControl(
563
+ reference_unet,
564
+ do_classifier_free_guidance=False,
565
+ mode="write",
566
+ fusion_blocks="full",
567
+ )
568
+ reference_control_reader = ReferenceAttentionControl(
569
+ denoising_unet,
570
+ do_classifier_free_guidance=False,
571
+ mode="read",
572
+ fusion_blocks="full",
573
+ )
574
+
575
+ net = Net(
576
+ reference_unet,
577
+ denoising_unet,
578
+ face_locator,
579
+ reference_control_writer,
580
+ reference_control_reader,
581
+ imageproj,
582
+ audioproj,
583
+ ).to(dtype=weight_dtype)
584
+
585
+ # get noise scheduler
586
+ train_noise_scheduler, val_noise_scheduler = get_noise_scheduler(cfg)
587
+
588
+ if cfg.solver.enable_xformers_memory_efficient_attention:
589
+ if is_xformers_available():
590
+ reference_unet.enable_xformers_memory_efficient_attention()
591
+ denoising_unet.enable_xformers_memory_efficient_attention()
592
+
593
+ else:
594
+ raise ValueError(
595
+ "xformers is not available. Make sure it is installed correctly"
596
+ )
597
+
598
+ if cfg.solver.gradient_checkpointing:
599
+ reference_unet.enable_gradient_checkpointing()
600
+ denoising_unet.enable_gradient_checkpointing()
601
+
602
+ if cfg.solver.scale_lr:
603
+ learning_rate = (
604
+ cfg.solver.learning_rate
605
+ * cfg.solver.gradient_accumulation_steps
606
+ * cfg.data.train_bs
607
+ * accelerator.num_processes
608
+ )
609
+ else:
610
+ learning_rate = cfg.solver.learning_rate
611
+
612
+ # Initialize the optimizer
613
+ if cfg.solver.use_8bit_adam:
614
+ try:
615
+ import bitsandbytes as bnb
616
+ except ImportError as exc:
617
+ raise ImportError(
618
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
619
+ ) from exc
620
+ optimizer_cls = bnb.optim.AdamW8bit
621
+ else:
622
+ optimizer_cls = torch.optim.AdamW
623
+
624
+ trainable_params = list(
625
+ filter(lambda p: p.requires_grad, net.parameters()))
626
+ logger.info(f"Total trainable params {len(trainable_params)}")
627
+ optimizer = optimizer_cls(
628
+ trainable_params,
629
+ lr=learning_rate,
630
+ betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
631
+ weight_decay=cfg.solver.adam_weight_decay,
632
+ eps=cfg.solver.adam_epsilon,
633
+ )
634
+
635
+ # Scheduler
636
+ lr_scheduler = get_scheduler(
637
+ cfg.solver.lr_scheduler,
638
+ optimizer=optimizer,
639
+ num_warmup_steps=cfg.solver.lr_warmup_steps
640
+ * cfg.solver.gradient_accumulation_steps,
641
+ num_training_steps=cfg.solver.max_train_steps
642
+ * cfg.solver.gradient_accumulation_steps,
643
+ )
644
+
645
+ # get data loader
646
+ train_dataset = TalkingVideoDataset(
647
+ img_size=(cfg.data.train_width, cfg.data.train_height),
648
+ sample_rate=cfg.data.sample_rate,
649
+ n_sample_frames=cfg.data.n_sample_frames,
650
+ n_motion_frames=cfg.data.n_motion_frames,
651
+ audio_margin=cfg.data.audio_margin,
652
+ data_meta_paths=cfg.data.train_meta_paths,
653
+ wav2vec_cfg=cfg.wav2vec_config,
654
+ )
655
+ train_dataloader = torch.utils.data.DataLoader(
656
+ train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=16
657
+ )
658
+
659
+ # Prepare everything with our `accelerator`.
660
+ (
661
+ net,
662
+ optimizer,
663
+ train_dataloader,
664
+ lr_scheduler,
665
+ ) = accelerator.prepare(
666
+ net,
667
+ optimizer,
668
+ train_dataloader,
669
+ lr_scheduler,
670
+ )
671
+
672
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
673
+ num_update_steps_per_epoch = math.ceil(
674
+ len(train_dataloader) / cfg.solver.gradient_accumulation_steps
675
+ )
676
+ # Afterwards we recalculate our number of training epochs
677
+ num_train_epochs = math.ceil(
678
+ cfg.solver.max_train_steps / num_update_steps_per_epoch
679
+ )
680
+
681
+ # We need to initialize the trackers we use, and also store our configuration.
682
+ # The trackers initializes automatically on the main process.
683
+ if accelerator.is_main_process:
684
+ run_time = datetime.now().strftime("%Y%m%d-%H%M")
685
+ accelerator.init_trackers(
686
+ exp_name,
687
+ init_kwargs={"mlflow": {"run_name": run_time}},
688
+ )
689
+ # dump config file
690
+ mlflow.log_dict(
691
+ OmegaConf.to_container(
692
+ cfg), "config.yaml"
693
+ )
694
+ logger.info(f"save config to {save_dir}")
695
+ OmegaConf.save(
696
+ cfg, os.path.join(save_dir, "config.yaml")
697
+ )
698
+
699
+ # Train!
700
+ total_batch_size = (
701
+ cfg.data.train_bs
702
+ * accelerator.num_processes
703
+ * cfg.solver.gradient_accumulation_steps
704
+ )
705
+
706
+ logger.info("***** Running training *****")
707
+ logger.info(f" Num examples = {len(train_dataset)}")
708
+ logger.info(f" Num Epochs = {num_train_epochs}")
709
+ logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}")
710
+ logger.info(
711
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
712
+ )
713
+ logger.info(
714
+ f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}"
715
+ )
716
+ logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}")
717
+ global_step = 0
718
+ first_epoch = 0
719
+
720
+ # # Potentially load in the weights and states from a previous save
721
+ if cfg.resume_from_checkpoint:
722
+ logger.info(f"Loading checkpoint from {checkpoint_dir}")
723
+ global_step = load_checkpoint(cfg, checkpoint_dir, accelerator)
724
+ first_epoch = global_step // num_update_steps_per_epoch
725
+
726
+ # Only show the progress bar once on each machine.
727
+ progress_bar = tqdm(
728
+ range(global_step, cfg.solver.max_train_steps),
729
+ disable=not accelerator.is_local_main_process,
730
+ )
731
+ progress_bar.set_description("Steps")
732
+
733
+ for _ in range(first_epoch, num_train_epochs):
734
+ train_loss = 0.0
735
+ t_data_start = time.time()
736
+ for _, batch in enumerate(train_dataloader):
737
+ t_data = time.time() - t_data_start
738
+ with accelerator.accumulate(net):
739
+ # Convert videos to latent space
740
+ pixel_values_vid = batch["pixel_values_vid"].to(weight_dtype)
741
+
742
+ pixel_values_face_mask = batch["pixel_values_face_mask"]
743
+ pixel_values_face_mask = get_attention_mask(
744
+ pixel_values_face_mask, weight_dtype
745
+ )
746
+ pixel_values_lip_mask = batch["pixel_values_lip_mask"]
747
+ pixel_values_lip_mask = get_attention_mask(
748
+ pixel_values_lip_mask, weight_dtype
749
+ )
750
+ pixel_values_full_mask = batch["pixel_values_full_mask"]
751
+ pixel_values_full_mask = get_attention_mask(
752
+ pixel_values_full_mask, weight_dtype
753
+ )
754
+
755
+ with torch.no_grad():
756
+ video_length = pixel_values_vid.shape[1]
757
+ pixel_values_vid = rearrange(
758
+ pixel_values_vid, "b f c h w -> (b f) c h w"
759
+ )
760
+ latents = vae.encode(pixel_values_vid).latent_dist.sample()
761
+ latents = rearrange(
762
+ latents, "(b f) c h w -> b c f h w", f=video_length
763
+ )
764
+ latents = latents * 0.18215
765
+
766
+ noise = torch.randn_like(latents)
767
+ if cfg.noise_offset > 0:
768
+ noise += cfg.noise_offset * torch.randn(
769
+ (latents.shape[0], latents.shape[1], 1, 1, 1),
770
+ device=latents.device,
771
+ )
772
+
773
+ bsz = latents.shape[0]
774
+ # Sample a random timestep for each video
775
+ timesteps = torch.randint(
776
+ 0,
777
+ train_noise_scheduler.num_train_timesteps,
778
+ (bsz,),
779
+ device=latents.device,
780
+ )
781
+ timesteps = timesteps.long()
782
+
783
+ # mask for face locator
784
+ pixel_values_mask = (
785
+ batch["pixel_values_mask"].unsqueeze(
786
+ 1).to(dtype=weight_dtype)
787
+ )
788
+ pixel_values_mask = repeat(
789
+ pixel_values_mask,
790
+ "b f c h w -> b (repeat f) c h w",
791
+ repeat=video_length,
792
+ )
793
+ pixel_values_mask = pixel_values_mask.transpose(
794
+ 1, 2)
795
+
796
+ uncond_img_fwd = random.random() < cfg.uncond_img_ratio
797
+ uncond_audio_fwd = random.random() < cfg.uncond_audio_ratio
798
+
799
+ start_frame = random.random() < cfg.start_ratio
800
+ pixel_values_ref_img = batch["pixel_values_ref_img"].to(
801
+ dtype=weight_dtype
802
+ )
803
+ # initialize the motion frames as zero maps
804
+ if start_frame:
805
+ pixel_values_ref_img[:, 1:] = 0.0
806
+
807
+ ref_img_and_motion = rearrange(
808
+ pixel_values_ref_img, "b f c h w -> (b f) c h w"
809
+ )
810
+
811
+ with torch.no_grad():
812
+ ref_image_latents = vae.encode(
813
+ ref_img_and_motion
814
+ ).latent_dist.sample()
815
+ ref_image_latents = ref_image_latents * 0.18215
816
+ image_prompt_embeds = batch["face_emb"].to(
817
+ dtype=imageproj.dtype, device=imageproj.device
818
+ )
819
+
820
+ # add noise
821
+ noisy_latents = train_noise_scheduler.add_noise(
822
+ latents, noise, timesteps
823
+ )
824
+
825
+ # Get the target for loss depending on the prediction type
826
+ if train_noise_scheduler.prediction_type == "epsilon":
827
+ target = noise
828
+ elif train_noise_scheduler.prediction_type == "v_prediction":
829
+ target = train_noise_scheduler.get_velocity(
830
+ latents, noise, timesteps
831
+ )
832
+ else:
833
+ raise ValueError(
834
+ f"Unknown prediction type {train_noise_scheduler.prediction_type}"
835
+ )
836
+
837
+ # ---- Forward!!! -----
838
+ model_pred = net(
839
+ noisy_latents=noisy_latents,
840
+ timesteps=timesteps,
841
+ ref_image_latents=ref_image_latents,
842
+ face_emb=image_prompt_embeds,
843
+ mask=pixel_values_mask,
844
+ full_mask=pixel_values_full_mask,
845
+ face_mask=pixel_values_face_mask,
846
+ lip_mask=pixel_values_lip_mask,
847
+ audio_emb=batch["audio_tensor"].to(
848
+ dtype=weight_dtype),
849
+ uncond_img_fwd=uncond_img_fwd,
850
+ uncond_audio_fwd=uncond_audio_fwd,
851
+ )
852
+
853
+ if cfg.snr_gamma == 0:
854
+ loss = F.mse_loss(
855
+ model_pred.float(),
856
+ target.float(),
857
+ reduction="mean",
858
+ )
859
+ else:
860
+ snr = compute_snr(train_noise_scheduler, timesteps)
861
+ if train_noise_scheduler.config.prediction_type == "v_prediction":
862
+ # Velocity objective requires that we add one to SNR values before we divide by them.
863
+ snr = snr + 1
864
+ mse_loss_weights = (
865
+ torch.stack(
866
+ [snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1
867
+ ).min(dim=1)[0]
868
+ / snr
869
+ )
870
+ loss = F.mse_loss(
871
+ model_pred.float(),
872
+ target.float(),
873
+ reduction="mean",
874
+ )
875
+ loss = (
876
+ loss.mean(dim=list(range(1, len(loss.shape))))
877
+ * mse_loss_weights
878
+ ).mean()
879
+
880
+ # Gather the losses across all processes for logging (if we use distributed training).
881
+ avg_loss = accelerator.gather(
882
+ loss.repeat(cfg.data.train_bs)).mean()
883
+ train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps
884
+
885
+ # Backpropagate
886
+ accelerator.backward(loss)
887
+ if accelerator.sync_gradients:
888
+ accelerator.clip_grad_norm_(
889
+ trainable_params,
890
+ cfg.solver.max_grad_norm,
891
+ )
892
+ optimizer.step()
893
+ lr_scheduler.step()
894
+ optimizer.zero_grad()
895
+
896
+ if accelerator.sync_gradients:
897
+ reference_control_reader.clear()
898
+ reference_control_writer.clear()
899
+ progress_bar.update(1)
900
+ global_step += 1
901
+ accelerator.log({"train_loss": train_loss}, step=global_step)
902
+ train_loss = 0.0
903
+
904
+ if global_step % cfg.val.validation_steps == 0 or global_step==1:
905
+ if accelerator.is_main_process:
906
+ generator = torch.Generator(device=accelerator.device)
907
+ generator.manual_seed(cfg.seed)
908
+
909
+ log_validation(
910
+ accelerator=accelerator,
911
+ vae=vae,
912
+ net=net,
913
+ scheduler=val_noise_scheduler,
914
+ width=cfg.data.train_width,
915
+ height=cfg.data.train_height,
916
+ clip_length=cfg.data.n_sample_frames,
917
+ cfg=cfg,
918
+ save_dir=validation_dir,
919
+ global_step=global_step,
920
+ times=cfg.single_inference_times if cfg.single_inference_times is not None else None,
921
+ face_analysis_model_path=cfg.face_analysis_model_path
922
+ )
923
+
924
+ logs = {
925
+ "step_loss": loss.detach().item(),
926
+ "lr": lr_scheduler.get_last_lr()[0],
927
+ "td": f"{t_data:.2f}s",
928
+ }
929
+ t_data_start = time.time()
930
+ progress_bar.set_postfix(**logs)
931
+
932
+ if (
933
+ global_step % cfg.checkpointing_steps == 0
934
+ or global_step == cfg.solver.max_train_steps
935
+ ):
936
+ # save model
937
+ save_path = os.path.join(
938
+ checkpoint_dir, f"checkpoint-{global_step}")
939
+ if accelerator.is_main_process:
940
+ delete_additional_ckpt(checkpoint_dir, 30)
941
+ accelerator.wait_for_everyone()
942
+ accelerator.save_state(save_path)
943
+
944
+ # save model weight
945
+ unwrap_net = accelerator.unwrap_model(net)
946
+ if accelerator.is_main_process:
947
+ save_checkpoint(
948
+ unwrap_net,
949
+ module_dir,
950
+ "net",
951
+ global_step,
952
+ total_limit=30,
953
+ )
954
+ if global_step >= cfg.solver.max_train_steps:
955
+ break
956
+
957
+ # Create the pipeline using the trained modules and save it.
958
+ accelerator.wait_for_everyone()
959
+ accelerator.end_training()
960
+
961
+
962
+ def load_config(config_path: str) -> dict:
963
+ """
964
+ Loads the configuration file.
965
+
966
+ Args:
967
+ config_path (str): Path to the configuration file.
968
+
969
+ Returns:
970
+ dict: The configuration dictionary.
971
+ """
972
+
973
+ if config_path.endswith(".yaml"):
974
+ return OmegaConf.load(config_path)
975
+ if config_path.endswith(".py"):
976
+ return import_filename(config_path).cfg
977
+ raise ValueError("Unsupported format for config file")
978
+
979
+
980
+ if __name__ == "__main__":
981
+ parser = argparse.ArgumentParser()
982
+ parser.add_argument(
983
+ "--config", type=str, default="./configs/train/stage2.yaml"
984
+ )
985
+ args = parser.parse_args()
986
+
987
+ try:
988
+ config = load_config(args.config)
989
+ train_stage2_process(config)
990
+ except Exception as e:
991
+ logging.error("Failed to execute the training process: %s", e)