File size: 6,731 Bytes
d16b52d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
from typing import *

import torch
from polygraphy import cuda

from live2diff.animatediff.models.unet_depth_streaming import UNet3DConditionStreamingOutput

from .utilities import Engine


try:
    from diffusers.models.autoencoder_tiny import AutoencoderTinyOutput
except ImportError:
    from dataclasses import dataclass

    from diffusers.utils import BaseOutput

    @dataclass
    class AutoencoderTinyOutput(BaseOutput):
        """
        Output of AutoencoderTiny encoding method.

        Args:
            latents (`torch.Tensor`): Encoded outputs of the `Encoder`.

        """

        latents: torch.Tensor


try:
    from diffusers.models.vae import DecoderOutput
except ImportError:
    from dataclasses import dataclass

    from diffusers.utils import BaseOutput

    @dataclass
    class DecoderOutput(BaseOutput):
        r"""
        Output of decoding method.

        Args:
            sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
                The decoded output sample from the last layer of the model.
        """

        sample: torch.FloatTensor


class AutoencoderKLEngine:
    def __init__(
        self,
        encoder_path: str,
        decoder_path: str,
        stream: cuda.Stream,
        scaling_factor: int,
        use_cuda_graph: bool = False,
    ):
        self.encoder = Engine(encoder_path)
        self.decoder = Engine(decoder_path)
        self.stream = stream
        self.vae_scale_factor = scaling_factor
        self.use_cuda_graph = use_cuda_graph

        self.encoder.load()
        self.decoder.load()
        self.encoder.activate()
        self.decoder.activate()

    def encode(self, images: torch.Tensor, **kwargs):
        self.encoder.allocate_buffers(
            shape_dict={
                "images": images.shape,
                "latent": (
                    images.shape[0],
                    4,
                    images.shape[2] // self.vae_scale_factor,
                    images.shape[3] // self.vae_scale_factor,
                ),
            },
            device=images.device,
        )
        latents = self.encoder.infer(
            {"images": images},
            self.stream,
            use_cuda_graph=self.use_cuda_graph,
        )["latent"]
        return AutoencoderTinyOutput(latents=latents)

    def decode(self, latent: torch.Tensor, **kwargs):
        self.decoder.allocate_buffers(
            shape_dict={
                "latent": latent.shape,
                "images": (
                    latent.shape[0],
                    3,
                    latent.shape[2] * self.vae_scale_factor,
                    latent.shape[3] * self.vae_scale_factor,
                ),
            },
            device=latent.device,
        )
        images = self.decoder.infer(
            {"latent": latent},
            self.stream,
            use_cuda_graph=self.use_cuda_graph,
        )["images"]
        return DecoderOutput(sample=images)

    def to(self, *args, **kwargs):
        pass

    def forward(self, *args, **kwargs):
        pass


class UNet2DConditionModelDepthEngine:
    def __init__(self, filepath: str, stream: cuda.Stream, use_cuda_graph: bool = False):
        self.engine = Engine(filepath)
        self.stream = stream
        self.use_cuda_graph = use_cuda_graph

        self.init_profiler()

        self.engine.load()
        self.engine.activate(profiler=self.profiler)
        self.has_allocated = False

    def init_profiler(self):
        import tensorrt

        class Profiler(tensorrt.IProfiler):
            def __init__(self):
                tensorrt.IProfiler.__init__(self)

            def report_layer_time(self, layer_name, ms):
                print(f"{layer_name}: {ms} ms")

        self.profiler = Profiler()

    def __call__(
        self,
        latent_model_input: torch.Tensor,
        timestep: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        temporal_attention_mask: torch.Tensor,
        depth_sample: torch.Tensor,
        kv_cache: List[torch.Tensor],
        pe_idx: torch.Tensor,
        update_idx: torch.Tensor,
        **kwargs,
    ) -> Any:
        if timestep.dtype != torch.float32:
            timestep = timestep.float()

        feed_dict = {
            "sample": latent_model_input,
            "timestep": timestep,
            "encoder_hidden_states": encoder_hidden_states,
            "temporal_attention_mask": temporal_attention_mask,
            "depth_sample": depth_sample,
            "pe_idx": pe_idx,
            "update_idx": update_idx,
        }
        for idx, cache in enumerate(kv_cache):
            feed_dict[f"kv_cache_{idx}"] = cache
        shape_dict = {k: v.shape for k, v in feed_dict.items()}

        if not self.has_allocated:
            self.engine.allocate_buffers(
                shape_dict=shape_dict,
                device=latent_model_input.device,
            )
            self.has_allocated = True

        output = self.engine.infer(
            feed_dict,
            self.stream,
            use_cuda_graph=self.use_cuda_graph,
        )

        noise_pred = output["latent"]
        kv_cache = [output[f"kv_cache_out_{idx}"] for idx in range(len(kv_cache))]
        return UNet3DConditionStreamingOutput(sample=noise_pred, kv_cache=kv_cache)

    def to(self, *args, **kwargs):
        pass

    def forward(self, *args, **kwargs):
        pass


class MidasEngine:
    def __init__(self, filepath: str, stream: cuda.Stream, use_cuda_graph: bool = False):
        self.engine = Engine(filepath)
        self.stream = stream
        self.use_cuda_graph = use_cuda_graph

        self.engine.load()
        self.engine.activate()
        self.has_allocated = False
        self.default_batch_size = 1

    def __call__(
        self,
        images: torch.Tensor,
        **kwargs,
    ) -> Any:
        if not self.has_allocated or images.shape[0] != self.default_batch_size:
            bz = images.shape[0]
            self.engine.allocate_buffers(
                shape_dict={
                    "images": (bz, 3, 384, 384),
                    "depth_map": (bz, 384, 384),
                },
                device=images.device,
            )
            self.has_allocated = True
            self.default_batch_size = bz

        depth_map = self.engine.infer(
            {
                "images": images,
            },
            self.stream,
            use_cuda_graph=self.use_cuda_graph,
        )["depth_map"]  #  (1, 384, 384)

        return depth_map

    def norm(self, x):
        return (x - x.min()) / (x.max() - x.min())

    def to(self, *args, **kwargs):
        pass

    def forward(self, *args, **kwargs):
        pass