File size: 7,606 Bytes
a891a57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# coding: utf-8

"""
Benchmark the inference speed of each module in LivePortrait.

TODO: heavy GPT style, need to refactor
"""

import yaml
import torch
import time
import numpy as np
from src.utils.helper import load_model, concat_feat
from src.config.inference_config import InferenceConfig


def initialize_inputs(batch_size=1):
    """
    Generate random input tensors and move them to GPU
    """
    feature_3d = torch.randn(batch_size, 32, 16, 64, 64).cuda().half()
    kp_source = torch.randn(batch_size, 21, 3).cuda().half()
    kp_driving = torch.randn(batch_size, 21, 3).cuda().half()
    source_image = torch.randn(batch_size, 3, 256, 256).cuda().half()
    generator_input = torch.randn(batch_size, 256, 64, 64).cuda().half()
    eye_close_ratio = torch.randn(batch_size, 3).cuda().half()
    lip_close_ratio = torch.randn(batch_size, 2).cuda().half()
    feat_stitching = concat_feat(kp_source, kp_driving).half()
    feat_eye = concat_feat(kp_source, eye_close_ratio).half()
    feat_lip = concat_feat(kp_source, lip_close_ratio).half()

    inputs = {
        'feature_3d': feature_3d,
        'kp_source': kp_source,
        'kp_driving': kp_driving,
        'source_image': source_image,
        'generator_input': generator_input,
        'feat_stitching': feat_stitching,
        'feat_eye': feat_eye,
        'feat_lip': feat_lip
    }

    return inputs


def load_and_compile_models(cfg, model_config):
    """
    Load and compile models for inference
    """
    appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
    motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
    warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
    spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
    stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')

    models_with_params = [
        ('Appearance Feature Extractor', appearance_feature_extractor),
        ('Motion Extractor', motion_extractor),
        ('Warping Network', warping_module),
        ('SPADE Decoder', spade_generator)
    ]

    compiled_models = {}
    for name, model in models_with_params:
        model = model.half()
        model = torch.compile(model, mode='max-autotune')  # Optimize for inference
        model.eval()  # Switch to evaluation mode
        compiled_models[name] = model

    retargeting_models = ['stitching', 'eye', 'lip']
    for retarget in retargeting_models:
        module = stitching_retargeting_module[retarget].half()
        module = torch.compile(module, mode='max-autotune')  # Optimize for inference
        module.eval()  # Switch to evaluation mode
        stitching_retargeting_module[retarget] = module

    return compiled_models, stitching_retargeting_module


def warm_up_models(compiled_models, stitching_retargeting_module, inputs):
    """
    Warm up models to prepare them for benchmarking
    """
    print("Warm up start!")
    with torch.no_grad():
        for _ in range(10):
            compiled_models['Appearance Feature Extractor'](inputs['source_image'])
            compiled_models['Motion Extractor'](inputs['source_image'])
            compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
            compiled_models['SPADE Decoder'](inputs['generator_input'])  # Adjust input as required
            stitching_retargeting_module['stitching'](inputs['feat_stitching'])
            stitching_retargeting_module['eye'](inputs['feat_eye'])
            stitching_retargeting_module['lip'](inputs['feat_lip'])
    print("Warm up end!")


def measure_inference_times(compiled_models, stitching_retargeting_module, inputs):
    """
    Measure inference times for each model
    """
    times = {name: [] for name in compiled_models.keys()}
    times['Retargeting Models'] = []

    overall_times = []

    with torch.no_grad():
        for _ in range(100):
            torch.cuda.synchronize()
            overall_start = time.time()

            start = time.time()
            compiled_models['Appearance Feature Extractor'](inputs['source_image'])
            torch.cuda.synchronize()
            times['Appearance Feature Extractor'].append(time.time() - start)

            start = time.time()
            compiled_models['Motion Extractor'](inputs['source_image'])
            torch.cuda.synchronize()
            times['Motion Extractor'].append(time.time() - start)

            start = time.time()
            compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
            torch.cuda.synchronize()
            times['Warping Network'].append(time.time() - start)

            start = time.time()
            compiled_models['SPADE Decoder'](inputs['generator_input'])  # Adjust input as required
            torch.cuda.synchronize()
            times['SPADE Decoder'].append(time.time() - start)

            start = time.time()
            stitching_retargeting_module['stitching'](inputs['feat_stitching'])
            stitching_retargeting_module['eye'](inputs['feat_eye'])
            stitching_retargeting_module['lip'](inputs['feat_lip'])
            torch.cuda.synchronize()
            times['Retargeting Models'].append(time.time() - start)

            overall_times.append(time.time() - overall_start)

    return times, overall_times


def print_benchmark_results(compiled_models, stitching_retargeting_module, retargeting_models, times, overall_times):
    """
    Print benchmark results with average and standard deviation of inference times
    """
    average_times = {name: np.mean(times[name]) * 1000 for name in times.keys()}
    std_times = {name: np.std(times[name]) * 1000 for name in times.keys()}

    for name, model in compiled_models.items():
        num_params = sum(p.numel() for p in model.parameters())
        num_params_in_millions = num_params / 1e6
        print(f"Number of parameters for {name}: {num_params_in_millions:.2f} M")

    for index, retarget in enumerate(retargeting_models):
        num_params = sum(p.numel() for p in stitching_retargeting_module[retarget].parameters())
        num_params_in_millions = num_params / 1e6
        print(f"Number of parameters for part_{index} in Stitching and Retargeting Modules: {num_params_in_millions:.2f} M")

    for name, avg_time in average_times.items():
        std_time = std_times[name]
        print(f"Average inference time for {name} over 100 runs: {avg_time:.2f} ms (std: {std_time:.2f} ms)")


def main():
    """
    Main function to benchmark speed and model parameters
    """
    # Sample input tensors
    inputs = initialize_inputs()

    # Load configuration
    cfg = InferenceConfig(device_id=0)
    model_config_path = cfg.models_config
    with open(model_config_path, 'r') as file:
        model_config = yaml.safe_load(file)

    # Load and compile models
    compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config)

    # Warm up models
    warm_up_models(compiled_models, stitching_retargeting_module, inputs)

    # Measure inference times
    times, overall_times = measure_inference_times(compiled_models, stitching_retargeting_module, inputs)

    # Print benchmark results
    print_benchmark_results(compiled_models, stitching_retargeting_module, ['stitching', 'eye', 'lip'], times, overall_times)


if __name__ == "__main__":
    main()