Spaces:
Runtime error
Runtime error
Commit
•
69f3483
1
Parent(s):
e08f02e
up
Browse files- streamv2v/__init__.py +1 -0
- streamv2v/acceleration/__init__.py +0 -0
- streamv2v/acceleration/sfast/__init__.py +33 -0
- streamv2v/acceleration/tensorrt/__init__.py +188 -0
- streamv2v/acceleration/tensorrt/builder.py +94 -0
- streamv2v/acceleration/tensorrt/engine.py +123 -0
- streamv2v/acceleration/tensorrt/models.py +434 -0
- streamv2v/acceleration/tensorrt/utilities.py +441 -0
- streamv2v/image_filter.py +45 -0
- streamv2v/image_utils.py +173 -0
- streamv2v/models/__init__.py +0 -0
- streamv2v/models/attention_processor.py +352 -0
- streamv2v/models/utils.py +127 -0
- streamv2v/pip_utils.py +52 -0
- streamv2v/pipeline.py +495 -0
- streamv2v/tools/__init__.py +0 -0
- streamv2v/tools/install-tensorrt.py +54 -0
streamv2v/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .pipeline import StreamV2V
|
streamv2v/acceleration/__init__.py
ADDED
File without changes
|
streamv2v/acceleration/sfast/__init__.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from sfast.compilers.stable_diffusion_pipeline_compiler import CompilationConfig, compile
|
4 |
+
|
5 |
+
from ...pipeline import StreamV2V
|
6 |
+
|
7 |
+
|
8 |
+
def accelerate_with_stable_fast(
|
9 |
+
stream: StreamV2V,
|
10 |
+
config: Optional[CompilationConfig] = None,
|
11 |
+
):
|
12 |
+
if config is None:
|
13 |
+
config = CompilationConfig.Default()
|
14 |
+
# xformers and Triton are suggested for achieving best performance.
|
15 |
+
try:
|
16 |
+
import xformers
|
17 |
+
|
18 |
+
config.enable_xformers = True
|
19 |
+
except ImportError:
|
20 |
+
print("xformers not installed, skip")
|
21 |
+
try:
|
22 |
+
import triton
|
23 |
+
|
24 |
+
config.enable_triton = True
|
25 |
+
except ImportError:
|
26 |
+
print("Triton not installed, skip")
|
27 |
+
# CUDA Graph is suggested for small batch sizes and small resolutions to reduce CPU overhead.
|
28 |
+
config.enable_cuda_graph = True
|
29 |
+
stream.pipe = compile(stream.pipe, config)
|
30 |
+
stream.unet = stream.pipe.unet
|
31 |
+
stream.vae = stream.pipe.vae
|
32 |
+
stream.text_encoder = stream.pipe.text_encoder
|
33 |
+
return stream
|
streamv2v/acceleration/tensorrt/__init__.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel
|
6 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
7 |
+
retrieve_latents,
|
8 |
+
)
|
9 |
+
from polygraphy import cuda
|
10 |
+
|
11 |
+
from ...pipeline import StreamV2V
|
12 |
+
from .builder import EngineBuilder, create_onnx_path
|
13 |
+
from .engine import AutoencoderKLEngine, UNet2DConditionModelEngine
|
14 |
+
from .models import VAE, BaseModel, UNet, VAEEncoder
|
15 |
+
|
16 |
+
|
17 |
+
class TorchVAEEncoder(torch.nn.Module):
|
18 |
+
def __init__(self, vae: AutoencoderKL):
|
19 |
+
super().__init__()
|
20 |
+
self.vae = vae
|
21 |
+
|
22 |
+
def forward(self, x: torch.Tensor):
|
23 |
+
return retrieve_latents(self.vae.encode(x))
|
24 |
+
|
25 |
+
|
26 |
+
def compile_vae_encoder(
|
27 |
+
vae: TorchVAEEncoder,
|
28 |
+
model_data: BaseModel,
|
29 |
+
onnx_path: str,
|
30 |
+
onnx_opt_path: str,
|
31 |
+
engine_path: str,
|
32 |
+
opt_batch_size: int = 1,
|
33 |
+
engine_build_options: dict = {},
|
34 |
+
):
|
35 |
+
builder = EngineBuilder(model_data, vae, device=torch.device("cuda"))
|
36 |
+
builder.build(
|
37 |
+
onnx_path,
|
38 |
+
onnx_opt_path,
|
39 |
+
engine_path,
|
40 |
+
opt_batch_size=opt_batch_size,
|
41 |
+
**engine_build_options,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def compile_vae_decoder(
|
46 |
+
vae: AutoencoderKL,
|
47 |
+
model_data: BaseModel,
|
48 |
+
onnx_path: str,
|
49 |
+
onnx_opt_path: str,
|
50 |
+
engine_path: str,
|
51 |
+
opt_batch_size: int = 1,
|
52 |
+
engine_build_options: dict = {},
|
53 |
+
):
|
54 |
+
vae = vae.to(torch.device("cuda"))
|
55 |
+
builder = EngineBuilder(model_data, vae, device=torch.device("cuda"))
|
56 |
+
builder.build(
|
57 |
+
onnx_path,
|
58 |
+
onnx_opt_path,
|
59 |
+
engine_path,
|
60 |
+
opt_batch_size=opt_batch_size,
|
61 |
+
**engine_build_options,
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
def compile_unet(
|
66 |
+
unet: UNet2DConditionModel,
|
67 |
+
model_data: BaseModel,
|
68 |
+
onnx_path: str,
|
69 |
+
onnx_opt_path: str,
|
70 |
+
engine_path: str,
|
71 |
+
opt_batch_size: int = 1,
|
72 |
+
engine_build_options: dict = {},
|
73 |
+
):
|
74 |
+
unet = unet.to(torch.device("cuda"), dtype=torch.float16)
|
75 |
+
builder = EngineBuilder(model_data, unet, device=torch.device("cuda"))
|
76 |
+
builder.build(
|
77 |
+
onnx_path,
|
78 |
+
onnx_opt_path,
|
79 |
+
engine_path,
|
80 |
+
opt_batch_size=opt_batch_size,
|
81 |
+
**engine_build_options,
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
def accelerate_with_tensorrt(
|
86 |
+
stream: StreamV2V,
|
87 |
+
engine_dir: str,
|
88 |
+
max_batch_size: int = 2,
|
89 |
+
min_batch_size: int = 1,
|
90 |
+
use_cuda_graph: bool = False,
|
91 |
+
engine_build_options: dict = {},
|
92 |
+
):
|
93 |
+
if "opt_batch_size" not in engine_build_options or engine_build_options["opt_batch_size"] is None:
|
94 |
+
engine_build_options["opt_batch_size"] = max_batch_size
|
95 |
+
text_encoder = stream.text_encoder
|
96 |
+
unet = stream.unet
|
97 |
+
vae = stream.vae
|
98 |
+
|
99 |
+
del stream.unet, stream.vae, stream.pipe.unet, stream.pipe.vae
|
100 |
+
|
101 |
+
vae_config = vae.config
|
102 |
+
vae_dtype = vae.dtype
|
103 |
+
|
104 |
+
unet.to(torch.device("cpu"))
|
105 |
+
vae.to(torch.device("cpu"))
|
106 |
+
|
107 |
+
gc.collect()
|
108 |
+
torch.cuda.empty_cache()
|
109 |
+
|
110 |
+
onnx_dir = os.path.join(engine_dir, "onnx")
|
111 |
+
os.makedirs(onnx_dir, exist_ok=True)
|
112 |
+
|
113 |
+
unet_engine_path = f"{engine_dir}/unet.engine"
|
114 |
+
vae_encoder_engine_path = f"{engine_dir}/vae_encoder.engine"
|
115 |
+
vae_decoder_engine_path = f"{engine_dir}/vae_decoder.engine"
|
116 |
+
|
117 |
+
unet_model = UNet(
|
118 |
+
fp16=True,
|
119 |
+
device=stream.device,
|
120 |
+
max_batch_size=max_batch_size,
|
121 |
+
min_batch_size=min_batch_size,
|
122 |
+
embedding_dim=text_encoder.config.hidden_size,
|
123 |
+
unet_dim=unet.config.in_channels,
|
124 |
+
)
|
125 |
+
vae_decoder_model = VAE(
|
126 |
+
device=stream.device,
|
127 |
+
max_batch_size=max_batch_size,
|
128 |
+
min_batch_size=min_batch_size,
|
129 |
+
)
|
130 |
+
vae_encoder_model = VAEEncoder(
|
131 |
+
device=stream.device,
|
132 |
+
max_batch_size=max_batch_size,
|
133 |
+
min_batch_size=min_batch_size,
|
134 |
+
)
|
135 |
+
|
136 |
+
if not os.path.exists(unet_engine_path):
|
137 |
+
compile_unet(
|
138 |
+
unet,
|
139 |
+
unet_model,
|
140 |
+
create_onnx_path("unet", onnx_dir, opt=False),
|
141 |
+
create_onnx_path("unet", onnx_dir, opt=True),
|
142 |
+
unet_engine_path,
|
143 |
+
**engine_build_options,
|
144 |
+
)
|
145 |
+
else:
|
146 |
+
del unet
|
147 |
+
|
148 |
+
if not os.path.exists(vae_decoder_engine_path):
|
149 |
+
vae.forward = vae.decode
|
150 |
+
compile_vae_decoder(
|
151 |
+
vae,
|
152 |
+
vae_decoder_model,
|
153 |
+
create_onnx_path("vae_decoder", onnx_dir, opt=False),
|
154 |
+
create_onnx_path("vae_decoder", onnx_dir, opt=True),
|
155 |
+
vae_decoder_engine_path,
|
156 |
+
**engine_build_options,
|
157 |
+
)
|
158 |
+
|
159 |
+
if not os.path.exists(vae_encoder_engine_path):
|
160 |
+
vae_encoder = TorchVAEEncoder(vae).to(torch.device("cuda"))
|
161 |
+
compile_vae_encoder(
|
162 |
+
vae_encoder,
|
163 |
+
vae_encoder_model,
|
164 |
+
create_onnx_path("vae_encoder", onnx_dir, opt=False),
|
165 |
+
create_onnx_path("vae_encoder", onnx_dir, opt=True),
|
166 |
+
vae_encoder_engine_path,
|
167 |
+
**engine_build_options,
|
168 |
+
)
|
169 |
+
|
170 |
+
del vae
|
171 |
+
|
172 |
+
cuda_steram = cuda.Stream()
|
173 |
+
|
174 |
+
stream.unet = UNet2DConditionModelEngine(unet_engine_path, cuda_steram, use_cuda_graph=use_cuda_graph)
|
175 |
+
stream.vae = AutoencoderKLEngine(
|
176 |
+
vae_encoder_engine_path,
|
177 |
+
vae_decoder_engine_path,
|
178 |
+
cuda_steram,
|
179 |
+
stream.pipe.vae_scale_factor,
|
180 |
+
use_cuda_graph=use_cuda_graph,
|
181 |
+
)
|
182 |
+
setattr(stream.vae, "config", vae_config)
|
183 |
+
setattr(stream.vae, "dtype", vae_dtype)
|
184 |
+
|
185 |
+
gc.collect()
|
186 |
+
torch.cuda.empty_cache()
|
187 |
+
|
188 |
+
return stream
|
streamv2v/acceleration/tensorrt/builder.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
from typing import *
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from .models import BaseModel
|
8 |
+
from .utilities import (
|
9 |
+
build_engine,
|
10 |
+
export_onnx,
|
11 |
+
optimize_onnx,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def create_onnx_path(name, onnx_dir, opt=True):
|
16 |
+
return os.path.join(onnx_dir, name + (".opt" if opt else "") + ".onnx")
|
17 |
+
|
18 |
+
|
19 |
+
class EngineBuilder:
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
model: BaseModel,
|
23 |
+
network: Any,
|
24 |
+
device=torch.device("cuda"),
|
25 |
+
):
|
26 |
+
self.device = device
|
27 |
+
|
28 |
+
self.model = model
|
29 |
+
self.network = network
|
30 |
+
|
31 |
+
def build(
|
32 |
+
self,
|
33 |
+
onnx_path: str,
|
34 |
+
onnx_opt_path: str,
|
35 |
+
engine_path: str,
|
36 |
+
opt_image_height: int = 512,
|
37 |
+
opt_image_width: int = 512,
|
38 |
+
opt_batch_size: int = 1,
|
39 |
+
min_image_resolution: int = 256,
|
40 |
+
max_image_resolution: int = 1024,
|
41 |
+
build_enable_refit: bool = False,
|
42 |
+
build_static_batch: bool = False,
|
43 |
+
build_dynamic_shape: bool = False,
|
44 |
+
build_all_tactics: bool = False,
|
45 |
+
onnx_opset: int = 17,
|
46 |
+
force_engine_build: bool = False,
|
47 |
+
force_onnx_export: bool = False,
|
48 |
+
force_onnx_optimize: bool = False,
|
49 |
+
):
|
50 |
+
if not force_onnx_export and os.path.exists(onnx_path):
|
51 |
+
print(f"Found cached model: {onnx_path}")
|
52 |
+
else:
|
53 |
+
print(f"Exporting model: {onnx_path}")
|
54 |
+
export_onnx(
|
55 |
+
self.network,
|
56 |
+
onnx_path=onnx_path,
|
57 |
+
model_data=self.model,
|
58 |
+
opt_image_height=opt_image_height,
|
59 |
+
opt_image_width=opt_image_width,
|
60 |
+
opt_batch_size=opt_batch_size,
|
61 |
+
onnx_opset=onnx_opset,
|
62 |
+
)
|
63 |
+
del self.network
|
64 |
+
gc.collect()
|
65 |
+
torch.cuda.empty_cache()
|
66 |
+
if not force_onnx_optimize and os.path.exists(onnx_opt_path):
|
67 |
+
print(f"Found cached model: {onnx_opt_path}")
|
68 |
+
else:
|
69 |
+
print(f"Generating optimizing model: {onnx_opt_path}")
|
70 |
+
optimize_onnx(
|
71 |
+
onnx_path=onnx_path,
|
72 |
+
onnx_opt_path=onnx_opt_path,
|
73 |
+
model_data=self.model,
|
74 |
+
)
|
75 |
+
self.model.min_latent_shape = min_image_resolution // 8
|
76 |
+
self.model.max_latent_shape = max_image_resolution // 8
|
77 |
+
if not force_engine_build and os.path.exists(engine_path):
|
78 |
+
print(f"Found cached engine: {engine_path}")
|
79 |
+
else:
|
80 |
+
build_engine(
|
81 |
+
engine_path=engine_path,
|
82 |
+
onnx_opt_path=onnx_opt_path,
|
83 |
+
model_data=self.model,
|
84 |
+
opt_image_height=opt_image_height,
|
85 |
+
opt_image_width=opt_image_width,
|
86 |
+
opt_batch_size=opt_batch_size,
|
87 |
+
build_static_batch=build_static_batch,
|
88 |
+
build_dynamic_shape=build_dynamic_shape,
|
89 |
+
build_all_tactics=build_all_tactics,
|
90 |
+
build_enable_refit=build_enable_refit,
|
91 |
+
)
|
92 |
+
|
93 |
+
gc.collect()
|
94 |
+
torch.cuda.empty_cache()
|
streamv2v/acceleration/tensorrt/engine.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers.models.autoencoder_tiny import AutoencoderTinyOutput
|
5 |
+
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
6 |
+
from diffusers.models.vae import DecoderOutput
|
7 |
+
from polygraphy import cuda
|
8 |
+
|
9 |
+
from .utilities import Engine
|
10 |
+
|
11 |
+
|
12 |
+
class UNet2DConditionModelEngine:
|
13 |
+
def __init__(self, filepath: str, stream: cuda.Stream, use_cuda_graph: bool = False):
|
14 |
+
self.engine = Engine(filepath)
|
15 |
+
self.stream = stream
|
16 |
+
self.use_cuda_graph = use_cuda_graph
|
17 |
+
|
18 |
+
self.engine.load()
|
19 |
+
self.engine.activate()
|
20 |
+
|
21 |
+
def __call__(
|
22 |
+
self,
|
23 |
+
latent_model_input: torch.Tensor,
|
24 |
+
timestep: torch.Tensor,
|
25 |
+
encoder_hidden_states: torch.Tensor,
|
26 |
+
**kwargs,
|
27 |
+
) -> Any:
|
28 |
+
if timestep.dtype != torch.float32:
|
29 |
+
timestep = timestep.float()
|
30 |
+
|
31 |
+
self.engine.allocate_buffers(
|
32 |
+
shape_dict={
|
33 |
+
"sample": latent_model_input.shape,
|
34 |
+
"timestep": timestep.shape,
|
35 |
+
"encoder_hidden_states": encoder_hidden_states.shape,
|
36 |
+
"latent": latent_model_input.shape,
|
37 |
+
},
|
38 |
+
device=latent_model_input.device,
|
39 |
+
)
|
40 |
+
|
41 |
+
noise_pred = self.engine.infer(
|
42 |
+
{
|
43 |
+
"sample": latent_model_input,
|
44 |
+
"timestep": timestep,
|
45 |
+
"encoder_hidden_states": encoder_hidden_states,
|
46 |
+
},
|
47 |
+
self.stream,
|
48 |
+
use_cuda_graph=self.use_cuda_graph,
|
49 |
+
)["latent"]
|
50 |
+
return UNet2DConditionOutput(sample=noise_pred)
|
51 |
+
|
52 |
+
def to(self, *args, **kwargs):
|
53 |
+
pass
|
54 |
+
|
55 |
+
def forward(self, *args, **kwargs):
|
56 |
+
pass
|
57 |
+
|
58 |
+
|
59 |
+
class AutoencoderKLEngine:
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
encoder_path: str,
|
63 |
+
decoder_path: str,
|
64 |
+
stream: cuda.Stream,
|
65 |
+
scaling_factor: int,
|
66 |
+
use_cuda_graph: bool = False,
|
67 |
+
):
|
68 |
+
self.encoder = Engine(encoder_path)
|
69 |
+
self.decoder = Engine(decoder_path)
|
70 |
+
self.stream = stream
|
71 |
+
self.vae_scale_factor = scaling_factor
|
72 |
+
self.use_cuda_graph = use_cuda_graph
|
73 |
+
|
74 |
+
self.encoder.load()
|
75 |
+
self.decoder.load()
|
76 |
+
self.encoder.activate()
|
77 |
+
self.decoder.activate()
|
78 |
+
|
79 |
+
def encode(self, images: torch.Tensor, **kwargs):
|
80 |
+
self.encoder.allocate_buffers(
|
81 |
+
shape_dict={
|
82 |
+
"images": images.shape,
|
83 |
+
"latent": (
|
84 |
+
images.shape[0],
|
85 |
+
4,
|
86 |
+
images.shape[2] // self.vae_scale_factor,
|
87 |
+
images.shape[3] // self.vae_scale_factor,
|
88 |
+
),
|
89 |
+
},
|
90 |
+
device=images.device,
|
91 |
+
)
|
92 |
+
latents = self.encoder.infer(
|
93 |
+
{"images": images},
|
94 |
+
self.stream,
|
95 |
+
use_cuda_graph=self.use_cuda_graph,
|
96 |
+
)["latent"]
|
97 |
+
return AutoencoderTinyOutput(latents=latents)
|
98 |
+
|
99 |
+
def decode(self, latent: torch.Tensor, **kwargs):
|
100 |
+
self.decoder.allocate_buffers(
|
101 |
+
shape_dict={
|
102 |
+
"latent": latent.shape,
|
103 |
+
"images": (
|
104 |
+
latent.shape[0],
|
105 |
+
3,
|
106 |
+
latent.shape[2] * self.vae_scale_factor,
|
107 |
+
latent.shape[3] * self.vae_scale_factor,
|
108 |
+
),
|
109 |
+
},
|
110 |
+
device=latent.device,
|
111 |
+
)
|
112 |
+
images = self.decoder.infer(
|
113 |
+
{"latent": latent},
|
114 |
+
self.stream,
|
115 |
+
use_cuda_graph=self.use_cuda_graph,
|
116 |
+
)["images"]
|
117 |
+
return DecoderOutput(sample=images)
|
118 |
+
|
119 |
+
def to(self, *args, **kwargs):
|
120 |
+
pass
|
121 |
+
|
122 |
+
def forward(self, *args, **kwargs):
|
123 |
+
pass
|
streamv2v/acceleration/tensorrt/models.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! fork: https://github.com/NVIDIA/TensorRT/blob/main/demo/Diffusion/models.py
|
2 |
+
|
3 |
+
#
|
4 |
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
5 |
+
# SPDX-License-Identifier: Apache-2.0
|
6 |
+
#
|
7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
# you may not use this file except in compliance with the License.
|
9 |
+
# You may obtain a copy of the License at
|
10 |
+
#
|
11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
#
|
13 |
+
# Unless required by applicable law or agreed to in writing, software
|
14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
# See the License for the specific language governing permissions and
|
17 |
+
# limitations under the License.
|
18 |
+
#
|
19 |
+
|
20 |
+
import onnx_graphsurgeon as gs
|
21 |
+
import torch
|
22 |
+
from onnx import shape_inference
|
23 |
+
from polygraphy.backend.onnx.loader import fold_constants
|
24 |
+
|
25 |
+
|
26 |
+
class Optimizer:
|
27 |
+
def __init__(self, onnx_graph, verbose=False):
|
28 |
+
self.graph = gs.import_onnx(onnx_graph)
|
29 |
+
self.verbose = verbose
|
30 |
+
|
31 |
+
def info(self, prefix):
|
32 |
+
if self.verbose:
|
33 |
+
print(
|
34 |
+
f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs"
|
35 |
+
)
|
36 |
+
|
37 |
+
def cleanup(self, return_onnx=False):
|
38 |
+
self.graph.cleanup().toposort()
|
39 |
+
if return_onnx:
|
40 |
+
return gs.export_onnx(self.graph)
|
41 |
+
|
42 |
+
def select_outputs(self, keep, names=None):
|
43 |
+
self.graph.outputs = [self.graph.outputs[o] for o in keep]
|
44 |
+
if names:
|
45 |
+
for i, name in enumerate(names):
|
46 |
+
self.graph.outputs[i].name = name
|
47 |
+
|
48 |
+
def fold_constants(self, return_onnx=False):
|
49 |
+
onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
|
50 |
+
self.graph = gs.import_onnx(onnx_graph)
|
51 |
+
if return_onnx:
|
52 |
+
return onnx_graph
|
53 |
+
|
54 |
+
def infer_shapes(self, return_onnx=False):
|
55 |
+
onnx_graph = gs.export_onnx(self.graph)
|
56 |
+
if onnx_graph.ByteSize() > 2147483648:
|
57 |
+
raise TypeError("ERROR: model size exceeds supported 2GB limit")
|
58 |
+
else:
|
59 |
+
onnx_graph = shape_inference.infer_shapes(onnx_graph)
|
60 |
+
|
61 |
+
self.graph = gs.import_onnx(onnx_graph)
|
62 |
+
if return_onnx:
|
63 |
+
return onnx_graph
|
64 |
+
|
65 |
+
|
66 |
+
class BaseModel:
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
fp16=False,
|
70 |
+
device="cuda",
|
71 |
+
verbose=True,
|
72 |
+
max_batch_size=16,
|
73 |
+
min_batch_size=1,
|
74 |
+
embedding_dim=768,
|
75 |
+
text_maxlen=77,
|
76 |
+
):
|
77 |
+
self.name = "SD Model"
|
78 |
+
self.fp16 = fp16
|
79 |
+
self.device = device
|
80 |
+
self.verbose = verbose
|
81 |
+
|
82 |
+
self.min_batch = min_batch_size
|
83 |
+
self.max_batch = max_batch_size
|
84 |
+
self.min_image_shape = 256 # min image resolution: 256x256
|
85 |
+
self.max_image_shape = 1024 # max image resolution: 1024x1024
|
86 |
+
self.min_latent_shape = self.min_image_shape // 8
|
87 |
+
self.max_latent_shape = self.max_image_shape // 8
|
88 |
+
|
89 |
+
self.embedding_dim = embedding_dim
|
90 |
+
self.text_maxlen = text_maxlen
|
91 |
+
|
92 |
+
def get_model(self):
|
93 |
+
pass
|
94 |
+
|
95 |
+
def get_input_names(self):
|
96 |
+
pass
|
97 |
+
|
98 |
+
def get_output_names(self):
|
99 |
+
pass
|
100 |
+
|
101 |
+
def get_dynamic_axes(self):
|
102 |
+
return None
|
103 |
+
|
104 |
+
def get_sample_input(self, batch_size, image_height, image_width):
|
105 |
+
pass
|
106 |
+
|
107 |
+
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
|
108 |
+
return None
|
109 |
+
|
110 |
+
def get_shape_dict(self, batch_size, image_height, image_width):
|
111 |
+
return None
|
112 |
+
|
113 |
+
def optimize(self, onnx_graph):
|
114 |
+
opt = Optimizer(onnx_graph, verbose=self.verbose)
|
115 |
+
opt.info(self.name + ": original")
|
116 |
+
opt.cleanup()
|
117 |
+
opt.info(self.name + ": cleanup")
|
118 |
+
opt.fold_constants()
|
119 |
+
opt.info(self.name + ": fold constants")
|
120 |
+
opt.infer_shapes()
|
121 |
+
opt.info(self.name + ": shape inference")
|
122 |
+
onnx_opt_graph = opt.cleanup(return_onnx=True)
|
123 |
+
opt.info(self.name + ": finished")
|
124 |
+
return onnx_opt_graph
|
125 |
+
|
126 |
+
def check_dims(self, batch_size, image_height, image_width):
|
127 |
+
assert batch_size >= self.min_batch and batch_size <= self.max_batch
|
128 |
+
assert image_height % 8 == 0 or image_width % 8 == 0
|
129 |
+
latent_height = image_height // 8
|
130 |
+
latent_width = image_width // 8
|
131 |
+
assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
|
132 |
+
assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
|
133 |
+
return (latent_height, latent_width)
|
134 |
+
|
135 |
+
def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):
|
136 |
+
min_batch = batch_size if static_batch else self.min_batch
|
137 |
+
max_batch = batch_size if static_batch else self.max_batch
|
138 |
+
latent_height = image_height // 8
|
139 |
+
latent_width = image_width // 8
|
140 |
+
min_image_height = image_height if static_shape else self.min_image_shape
|
141 |
+
max_image_height = image_height if static_shape else self.max_image_shape
|
142 |
+
min_image_width = image_width if static_shape else self.min_image_shape
|
143 |
+
max_image_width = image_width if static_shape else self.max_image_shape
|
144 |
+
min_latent_height = latent_height if static_shape else self.min_latent_shape
|
145 |
+
max_latent_height = latent_height if static_shape else self.max_latent_shape
|
146 |
+
min_latent_width = latent_width if static_shape else self.min_latent_shape
|
147 |
+
max_latent_width = latent_width if static_shape else self.max_latent_shape
|
148 |
+
return (
|
149 |
+
min_batch,
|
150 |
+
max_batch,
|
151 |
+
min_image_height,
|
152 |
+
max_image_height,
|
153 |
+
min_image_width,
|
154 |
+
max_image_width,
|
155 |
+
min_latent_height,
|
156 |
+
max_latent_height,
|
157 |
+
min_latent_width,
|
158 |
+
max_latent_width,
|
159 |
+
)
|
160 |
+
|
161 |
+
|
162 |
+
class CLIP(BaseModel):
|
163 |
+
def __init__(self, device, max_batch_size, embedding_dim, min_batch_size=1):
|
164 |
+
super(CLIP, self).__init__(
|
165 |
+
device=device,
|
166 |
+
max_batch_size=max_batch_size,
|
167 |
+
min_batch_size=min_batch_size,
|
168 |
+
embedding_dim=embedding_dim,
|
169 |
+
)
|
170 |
+
self.name = "CLIP"
|
171 |
+
|
172 |
+
def get_input_names(self):
|
173 |
+
return ["input_ids"]
|
174 |
+
|
175 |
+
def get_output_names(self):
|
176 |
+
return ["text_embeddings", "pooler_output"]
|
177 |
+
|
178 |
+
def get_dynamic_axes(self):
|
179 |
+
return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}}
|
180 |
+
|
181 |
+
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
|
182 |
+
self.check_dims(batch_size, image_height, image_width)
|
183 |
+
min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
|
184 |
+
batch_size, image_height, image_width, static_batch, static_shape
|
185 |
+
)
|
186 |
+
return {
|
187 |
+
"input_ids": [
|
188 |
+
(min_batch, self.text_maxlen),
|
189 |
+
(batch_size, self.text_maxlen),
|
190 |
+
(max_batch, self.text_maxlen),
|
191 |
+
]
|
192 |
+
}
|
193 |
+
|
194 |
+
def get_shape_dict(self, batch_size, image_height, image_width):
|
195 |
+
self.check_dims(batch_size, image_height, image_width)
|
196 |
+
return {
|
197 |
+
"input_ids": (batch_size, self.text_maxlen),
|
198 |
+
"text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
|
199 |
+
}
|
200 |
+
|
201 |
+
def get_sample_input(self, batch_size, image_height, image_width):
|
202 |
+
self.check_dims(batch_size, image_height, image_width)
|
203 |
+
return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
|
204 |
+
|
205 |
+
def optimize(self, onnx_graph):
|
206 |
+
opt = Optimizer(onnx_graph)
|
207 |
+
opt.info(self.name + ": original")
|
208 |
+
opt.select_outputs([0]) # delete graph output#1
|
209 |
+
opt.cleanup()
|
210 |
+
opt.info(self.name + ": remove output[1]")
|
211 |
+
opt.fold_constants()
|
212 |
+
opt.info(self.name + ": fold constants")
|
213 |
+
opt.infer_shapes()
|
214 |
+
opt.info(self.name + ": shape inference")
|
215 |
+
opt.select_outputs([0], names=["text_embeddings"]) # rename network output
|
216 |
+
opt.info(self.name + ": remove output[0]")
|
217 |
+
opt_onnx_graph = opt.cleanup(return_onnx=True)
|
218 |
+
opt.info(self.name + ": finished")
|
219 |
+
return opt_onnx_graph
|
220 |
+
|
221 |
+
|
222 |
+
class UNet(BaseModel):
|
223 |
+
def __init__(
|
224 |
+
self,
|
225 |
+
fp16=False,
|
226 |
+
device="cuda",
|
227 |
+
max_batch_size=16,
|
228 |
+
min_batch_size=1,
|
229 |
+
embedding_dim=768,
|
230 |
+
text_maxlen=77,
|
231 |
+
unet_dim=4,
|
232 |
+
):
|
233 |
+
super(UNet, self).__init__(
|
234 |
+
fp16=fp16,
|
235 |
+
device=device,
|
236 |
+
max_batch_size=max_batch_size,
|
237 |
+
min_batch_size=min_batch_size,
|
238 |
+
embedding_dim=embedding_dim,
|
239 |
+
text_maxlen=text_maxlen,
|
240 |
+
)
|
241 |
+
self.unet_dim = unet_dim
|
242 |
+
self.name = "UNet"
|
243 |
+
|
244 |
+
def get_input_names(self):
|
245 |
+
return ["sample", "timestep", "encoder_hidden_states"]
|
246 |
+
|
247 |
+
def get_output_names(self):
|
248 |
+
return ["latent"]
|
249 |
+
|
250 |
+
def get_dynamic_axes(self):
|
251 |
+
return {
|
252 |
+
"sample": {0: "2B", 2: "H", 3: "W"},
|
253 |
+
"timestep": {0: "2B"},
|
254 |
+
"encoder_hidden_states": {0: "2B"},
|
255 |
+
"latent": {0: "2B", 2: "H", 3: "W"},
|
256 |
+
}
|
257 |
+
|
258 |
+
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
|
259 |
+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
260 |
+
(
|
261 |
+
min_batch,
|
262 |
+
max_batch,
|
263 |
+
_,
|
264 |
+
_,
|
265 |
+
_,
|
266 |
+
_,
|
267 |
+
min_latent_height,
|
268 |
+
max_latent_height,
|
269 |
+
min_latent_width,
|
270 |
+
max_latent_width,
|
271 |
+
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
|
272 |
+
return {
|
273 |
+
"sample": [
|
274 |
+
(min_batch, self.unet_dim, min_latent_height, min_latent_width),
|
275 |
+
(batch_size, self.unet_dim, latent_height, latent_width),
|
276 |
+
(max_batch, self.unet_dim, max_latent_height, max_latent_width),
|
277 |
+
],
|
278 |
+
"timestep": [(min_batch,), (batch_size,), (max_batch,)],
|
279 |
+
"encoder_hidden_states": [
|
280 |
+
(min_batch, self.text_maxlen, self.embedding_dim),
|
281 |
+
(batch_size, self.text_maxlen, self.embedding_dim),
|
282 |
+
(max_batch, self.text_maxlen, self.embedding_dim),
|
283 |
+
],
|
284 |
+
}
|
285 |
+
|
286 |
+
def get_shape_dict(self, batch_size, image_height, image_width):
|
287 |
+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
288 |
+
return {
|
289 |
+
"sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
|
290 |
+
"timestep": (2 * batch_size,),
|
291 |
+
"encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
|
292 |
+
"latent": (2 * batch_size, 4, latent_height, latent_width),
|
293 |
+
}
|
294 |
+
|
295 |
+
def get_sample_input(self, batch_size, image_height, image_width):
|
296 |
+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
297 |
+
dtype = torch.float16 if self.fp16 else torch.float32
|
298 |
+
return (
|
299 |
+
torch.randn(
|
300 |
+
2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
|
301 |
+
),
|
302 |
+
torch.ones((2 * batch_size,), dtype=torch.float32, device=self.device),
|
303 |
+
torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
|
304 |
+
)
|
305 |
+
|
306 |
+
|
307 |
+
class VAE(BaseModel):
|
308 |
+
def __init__(self, device, max_batch_size, min_batch_size=1):
|
309 |
+
super(VAE, self).__init__(
|
310 |
+
device=device,
|
311 |
+
max_batch_size=max_batch_size,
|
312 |
+
min_batch_size=min_batch_size,
|
313 |
+
embedding_dim=None,
|
314 |
+
)
|
315 |
+
self.name = "VAE decoder"
|
316 |
+
|
317 |
+
def get_input_names(self):
|
318 |
+
return ["latent"]
|
319 |
+
|
320 |
+
def get_output_names(self):
|
321 |
+
return ["images"]
|
322 |
+
|
323 |
+
def get_dynamic_axes(self):
|
324 |
+
return {
|
325 |
+
"latent": {0: "B", 2: "H", 3: "W"},
|
326 |
+
"images": {0: "B", 2: "8H", 3: "8W"},
|
327 |
+
}
|
328 |
+
|
329 |
+
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
|
330 |
+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
331 |
+
(
|
332 |
+
min_batch,
|
333 |
+
max_batch,
|
334 |
+
_,
|
335 |
+
_,
|
336 |
+
_,
|
337 |
+
_,
|
338 |
+
min_latent_height,
|
339 |
+
max_latent_height,
|
340 |
+
min_latent_width,
|
341 |
+
max_latent_width,
|
342 |
+
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
|
343 |
+
return {
|
344 |
+
"latent": [
|
345 |
+
(min_batch, 4, min_latent_height, min_latent_width),
|
346 |
+
(batch_size, 4, latent_height, latent_width),
|
347 |
+
(max_batch, 4, max_latent_height, max_latent_width),
|
348 |
+
]
|
349 |
+
}
|
350 |
+
|
351 |
+
def get_shape_dict(self, batch_size, image_height, image_width):
|
352 |
+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
353 |
+
return {
|
354 |
+
"latent": (batch_size, 4, latent_height, latent_width),
|
355 |
+
"images": (batch_size, 3, image_height, image_width),
|
356 |
+
}
|
357 |
+
|
358 |
+
def get_sample_input(self, batch_size, image_height, image_width):
|
359 |
+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
360 |
+
return torch.randn(
|
361 |
+
batch_size,
|
362 |
+
4,
|
363 |
+
latent_height,
|
364 |
+
latent_width,
|
365 |
+
dtype=torch.float32,
|
366 |
+
device=self.device,
|
367 |
+
)
|
368 |
+
|
369 |
+
|
370 |
+
class VAEEncoder(BaseModel):
|
371 |
+
def __init__(self, device, max_batch_size, min_batch_size=1):
|
372 |
+
super(VAEEncoder, self).__init__(
|
373 |
+
device=device,
|
374 |
+
max_batch_size=max_batch_size,
|
375 |
+
min_batch_size=min_batch_size,
|
376 |
+
embedding_dim=None,
|
377 |
+
)
|
378 |
+
self.name = "VAE encoder"
|
379 |
+
|
380 |
+
def get_input_names(self):
|
381 |
+
return ["images"]
|
382 |
+
|
383 |
+
def get_output_names(self):
|
384 |
+
return ["latent"]
|
385 |
+
|
386 |
+
def get_dynamic_axes(self):
|
387 |
+
return {
|
388 |
+
"images": {0: "B", 2: "8H", 3: "8W"},
|
389 |
+
"latent": {0: "B", 2: "H", 3: "W"},
|
390 |
+
}
|
391 |
+
|
392 |
+
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
|
393 |
+
assert batch_size >= self.min_batch and batch_size <= self.max_batch
|
394 |
+
min_batch = batch_size if static_batch else self.min_batch
|
395 |
+
max_batch = batch_size if static_batch else self.max_batch
|
396 |
+
self.check_dims(batch_size, image_height, image_width)
|
397 |
+
(
|
398 |
+
min_batch,
|
399 |
+
max_batch,
|
400 |
+
min_image_height,
|
401 |
+
max_image_height,
|
402 |
+
min_image_width,
|
403 |
+
max_image_width,
|
404 |
+
_,
|
405 |
+
_,
|
406 |
+
_,
|
407 |
+
_,
|
408 |
+
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
|
409 |
+
|
410 |
+
return {
|
411 |
+
"images": [
|
412 |
+
(min_batch, 3, min_image_height, min_image_width),
|
413 |
+
(batch_size, 3, image_height, image_width),
|
414 |
+
(max_batch, 3, max_image_height, max_image_width),
|
415 |
+
],
|
416 |
+
}
|
417 |
+
|
418 |
+
def get_shape_dict(self, batch_size, image_height, image_width):
|
419 |
+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
420 |
+
return {
|
421 |
+
"images": (batch_size, 3, image_height, image_width),
|
422 |
+
"latent": (batch_size, 4, latent_height, latent_width),
|
423 |
+
}
|
424 |
+
|
425 |
+
def get_sample_input(self, batch_size, image_height, image_width):
|
426 |
+
self.check_dims(batch_size, image_height, image_width)
|
427 |
+
return torch.randn(
|
428 |
+
batch_size,
|
429 |
+
3,
|
430 |
+
image_height,
|
431 |
+
image_width,
|
432 |
+
dtype=torch.float32,
|
433 |
+
device=self.device,
|
434 |
+
)
|
streamv2v/acceleration/tensorrt/utilities.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! fork: https://github.com/NVIDIA/TensorRT/blob/main/demo/Diffusion/utilities.py
|
2 |
+
|
3 |
+
#
|
4 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
5 |
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
6 |
+
# SPDX-License-Identifier: Apache-2.0
|
7 |
+
#
|
8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
9 |
+
# you may not use this file except in compliance with the License.
|
10 |
+
# You may obtain a copy of the License at
|
11 |
+
#
|
12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
13 |
+
#
|
14 |
+
# Unless required by applicable law or agreed to in writing, software
|
15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
17 |
+
# See the License for the specific language governing permissions and
|
18 |
+
# limitations under the License.
|
19 |
+
#
|
20 |
+
|
21 |
+
import gc
|
22 |
+
from collections import OrderedDict
|
23 |
+
from typing import *
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
import onnx
|
27 |
+
import onnx_graphsurgeon as gs
|
28 |
+
import tensorrt as trt
|
29 |
+
import torch
|
30 |
+
from cuda import cudart
|
31 |
+
from PIL import Image
|
32 |
+
from polygraphy import cuda
|
33 |
+
from polygraphy.backend.common import bytes_from_path
|
34 |
+
from polygraphy.backend.trt import (
|
35 |
+
CreateConfig,
|
36 |
+
Profile,
|
37 |
+
engine_from_bytes,
|
38 |
+
engine_from_network,
|
39 |
+
network_from_onnx_path,
|
40 |
+
save_engine,
|
41 |
+
)
|
42 |
+
from polygraphy.backend.trt import util as trt_util
|
43 |
+
|
44 |
+
from .models import CLIP, VAE, BaseModel, UNet, VAEEncoder
|
45 |
+
|
46 |
+
|
47 |
+
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
|
48 |
+
|
49 |
+
# Map of numpy dtype -> torch dtype
|
50 |
+
numpy_to_torch_dtype_dict = {
|
51 |
+
np.uint8: torch.uint8,
|
52 |
+
np.int8: torch.int8,
|
53 |
+
np.int16: torch.int16,
|
54 |
+
np.int32: torch.int32,
|
55 |
+
np.int64: torch.int64,
|
56 |
+
np.float16: torch.float16,
|
57 |
+
np.float32: torch.float32,
|
58 |
+
np.float64: torch.float64,
|
59 |
+
np.complex64: torch.complex64,
|
60 |
+
np.complex128: torch.complex128,
|
61 |
+
}
|
62 |
+
if np.version.full_version >= "1.24.0":
|
63 |
+
numpy_to_torch_dtype_dict[np.bool_] = torch.bool
|
64 |
+
else:
|
65 |
+
numpy_to_torch_dtype_dict[np.bool] = torch.bool
|
66 |
+
|
67 |
+
# Map of torch dtype -> numpy dtype
|
68 |
+
torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
|
69 |
+
|
70 |
+
|
71 |
+
def CUASSERT(cuda_ret):
|
72 |
+
err = cuda_ret[0]
|
73 |
+
if err != cudart.cudaError_t.cudaSuccess:
|
74 |
+
raise RuntimeError(
|
75 |
+
f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t"
|
76 |
+
)
|
77 |
+
if len(cuda_ret) > 1:
|
78 |
+
return cuda_ret[1]
|
79 |
+
return None
|
80 |
+
|
81 |
+
|
82 |
+
class Engine:
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
engine_path,
|
86 |
+
):
|
87 |
+
self.engine_path = engine_path
|
88 |
+
self.engine = None
|
89 |
+
self.context = None
|
90 |
+
self.buffers = OrderedDict()
|
91 |
+
self.tensors = OrderedDict()
|
92 |
+
self.cuda_graph_instance = None # cuda graph
|
93 |
+
|
94 |
+
def __del__(self):
|
95 |
+
[buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)]
|
96 |
+
del self.engine
|
97 |
+
del self.context
|
98 |
+
del self.buffers
|
99 |
+
del self.tensors
|
100 |
+
|
101 |
+
def refit(self, onnx_path, onnx_refit_path):
|
102 |
+
def convert_int64(arr):
|
103 |
+
# TODO: smarter conversion
|
104 |
+
if len(arr.shape) == 0:
|
105 |
+
return np.int32(arr)
|
106 |
+
return arr
|
107 |
+
|
108 |
+
def add_to_map(refit_dict, name, values):
|
109 |
+
if name in refit_dict:
|
110 |
+
assert refit_dict[name] is None
|
111 |
+
if values.dtype == np.int64:
|
112 |
+
values = convert_int64(values)
|
113 |
+
refit_dict[name] = values
|
114 |
+
|
115 |
+
print(f"Refitting TensorRT engine with {onnx_refit_path} weights")
|
116 |
+
refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes
|
117 |
+
|
118 |
+
# Construct mapping from weight names in refit model -> original model
|
119 |
+
name_map = {}
|
120 |
+
for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes):
|
121 |
+
refit_node = refit_nodes[n]
|
122 |
+
assert node.op == refit_node.op
|
123 |
+
# Constant nodes in ONNX do not have inputs but have a constant output
|
124 |
+
if node.op == "Constant":
|
125 |
+
name_map[refit_node.outputs[0].name] = node.outputs[0].name
|
126 |
+
# Handle scale and bias weights
|
127 |
+
elif node.op == "Conv":
|
128 |
+
if node.inputs[1].__class__ == gs.Constant:
|
129 |
+
name_map[refit_node.name + "_TRTKERNEL"] = node.name + "_TRTKERNEL"
|
130 |
+
if node.inputs[2].__class__ == gs.Constant:
|
131 |
+
name_map[refit_node.name + "_TRTBIAS"] = node.name + "_TRTBIAS"
|
132 |
+
# For all other nodes: find node inputs that are initializers (gs.Constant)
|
133 |
+
else:
|
134 |
+
for i, inp in enumerate(node.inputs):
|
135 |
+
if inp.__class__ == gs.Constant:
|
136 |
+
name_map[refit_node.inputs[i].name] = inp.name
|
137 |
+
|
138 |
+
def map_name(name):
|
139 |
+
if name in name_map:
|
140 |
+
return name_map[name]
|
141 |
+
return name
|
142 |
+
|
143 |
+
# Construct refit dictionary
|
144 |
+
refit_dict = {}
|
145 |
+
refitter = trt.Refitter(self.engine, TRT_LOGGER)
|
146 |
+
all_weights = refitter.get_all()
|
147 |
+
for layer_name, role in zip(all_weights[0], all_weights[1]):
|
148 |
+
# for speciailized roles, use a unique name in the map:
|
149 |
+
if role == trt.WeightsRole.KERNEL:
|
150 |
+
name = layer_name + "_TRTKERNEL"
|
151 |
+
elif role == trt.WeightsRole.BIAS:
|
152 |
+
name = layer_name + "_TRTBIAS"
|
153 |
+
else:
|
154 |
+
name = layer_name
|
155 |
+
|
156 |
+
assert name not in refit_dict, "Found duplicate layer: " + name
|
157 |
+
refit_dict[name] = None
|
158 |
+
|
159 |
+
for n in refit_nodes:
|
160 |
+
# Constant nodes in ONNX do not have inputs but have a constant output
|
161 |
+
if n.op == "Constant":
|
162 |
+
name = map_name(n.outputs[0].name)
|
163 |
+
print(f"Add Constant {name}\n")
|
164 |
+
add_to_map(refit_dict, name, n.outputs[0].values)
|
165 |
+
|
166 |
+
# Handle scale and bias weights
|
167 |
+
elif n.op == "Conv":
|
168 |
+
if n.inputs[1].__class__ == gs.Constant:
|
169 |
+
name = map_name(n.name + "_TRTKERNEL")
|
170 |
+
add_to_map(refit_dict, name, n.inputs[1].values)
|
171 |
+
|
172 |
+
if n.inputs[2].__class__ == gs.Constant:
|
173 |
+
name = map_name(n.name + "_TRTBIAS")
|
174 |
+
add_to_map(refit_dict, name, n.inputs[2].values)
|
175 |
+
|
176 |
+
# For all other nodes: find node inputs that are initializers (AKA gs.Constant)
|
177 |
+
else:
|
178 |
+
for inp in n.inputs:
|
179 |
+
name = map_name(inp.name)
|
180 |
+
if inp.__class__ == gs.Constant:
|
181 |
+
add_to_map(refit_dict, name, inp.values)
|
182 |
+
|
183 |
+
for layer_name, weights_role in zip(all_weights[0], all_weights[1]):
|
184 |
+
if weights_role == trt.WeightsRole.KERNEL:
|
185 |
+
custom_name = layer_name + "_TRTKERNEL"
|
186 |
+
elif weights_role == trt.WeightsRole.BIAS:
|
187 |
+
custom_name = layer_name + "_TRTBIAS"
|
188 |
+
else:
|
189 |
+
custom_name = layer_name
|
190 |
+
|
191 |
+
# Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model
|
192 |
+
if layer_name.startswith("onnx::Trilu"):
|
193 |
+
continue
|
194 |
+
|
195 |
+
if refit_dict[custom_name] is not None:
|
196 |
+
refitter.set_weights(layer_name, weights_role, refit_dict[custom_name])
|
197 |
+
else:
|
198 |
+
print(f"[W] No refit weights for layer: {layer_name}")
|
199 |
+
|
200 |
+
if not refitter.refit_cuda_engine():
|
201 |
+
print("Failed to refit!")
|
202 |
+
exit(0)
|
203 |
+
|
204 |
+
def build(
|
205 |
+
self,
|
206 |
+
onnx_path,
|
207 |
+
fp16,
|
208 |
+
input_profile=None,
|
209 |
+
enable_refit=False,
|
210 |
+
enable_all_tactics=False,
|
211 |
+
timing_cache=None,
|
212 |
+
workspace_size=0,
|
213 |
+
):
|
214 |
+
print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
|
215 |
+
p = Profile()
|
216 |
+
if input_profile:
|
217 |
+
for name, dims in input_profile.items():
|
218 |
+
assert len(dims) == 3
|
219 |
+
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
|
220 |
+
|
221 |
+
config_kwargs = {}
|
222 |
+
|
223 |
+
if workspace_size > 0:
|
224 |
+
config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
|
225 |
+
if not enable_all_tactics:
|
226 |
+
config_kwargs["tactic_sources"] = []
|
227 |
+
|
228 |
+
engine = engine_from_network(
|
229 |
+
network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
|
230 |
+
config=CreateConfig(
|
231 |
+
fp16=fp16, refittable=enable_refit, profiles=[p], load_timing_cache=timing_cache, **config_kwargs
|
232 |
+
),
|
233 |
+
save_timing_cache=timing_cache,
|
234 |
+
)
|
235 |
+
save_engine(engine, path=self.engine_path)
|
236 |
+
|
237 |
+
def load(self):
|
238 |
+
print(f"Loading TensorRT engine: {self.engine_path}")
|
239 |
+
self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
|
240 |
+
|
241 |
+
def activate(self, reuse_device_memory=None):
|
242 |
+
if reuse_device_memory:
|
243 |
+
self.context = self.engine.create_execution_context_without_device_memory()
|
244 |
+
self.context.device_memory = reuse_device_memory
|
245 |
+
else:
|
246 |
+
self.context = self.engine.create_execution_context()
|
247 |
+
|
248 |
+
def allocate_buffers(self, shape_dict=None, device="cuda"):
|
249 |
+
for idx in range(trt_util.get_bindings_per_profile(self.engine)):
|
250 |
+
binding = self.engine[idx]
|
251 |
+
if shape_dict and binding in shape_dict:
|
252 |
+
shape = shape_dict[binding]
|
253 |
+
else:
|
254 |
+
shape = self.engine.get_binding_shape(binding)
|
255 |
+
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
|
256 |
+
if self.engine.binding_is_input(binding):
|
257 |
+
self.context.set_binding_shape(idx, shape)
|
258 |
+
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
|
259 |
+
self.tensors[binding] = tensor
|
260 |
+
|
261 |
+
def infer(self, feed_dict, stream, use_cuda_graph=False):
|
262 |
+
for name, buf in feed_dict.items():
|
263 |
+
self.tensors[name].copy_(buf)
|
264 |
+
|
265 |
+
for name, tensor in self.tensors.items():
|
266 |
+
self.context.set_tensor_address(name, tensor.data_ptr())
|
267 |
+
|
268 |
+
if use_cuda_graph:
|
269 |
+
if self.cuda_graph_instance is not None:
|
270 |
+
CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream.ptr))
|
271 |
+
CUASSERT(cudart.cudaStreamSynchronize(stream.ptr))
|
272 |
+
else:
|
273 |
+
# do inference before CUDA graph capture
|
274 |
+
noerror = self.context.execute_async_v3(stream.ptr)
|
275 |
+
if not noerror:
|
276 |
+
raise ValueError("ERROR: inference failed.")
|
277 |
+
# capture cuda graph
|
278 |
+
CUASSERT(
|
279 |
+
cudart.cudaStreamBeginCapture(stream.ptr, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal)
|
280 |
+
)
|
281 |
+
self.context.execute_async_v3(stream.ptr)
|
282 |
+
self.graph = CUASSERT(cudart.cudaStreamEndCapture(stream.ptr))
|
283 |
+
self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(self.graph, 0))
|
284 |
+
else:
|
285 |
+
noerror = self.context.execute_async_v3(stream.ptr)
|
286 |
+
if not noerror:
|
287 |
+
raise ValueError("ERROR: inference failed.")
|
288 |
+
|
289 |
+
return self.tensors
|
290 |
+
|
291 |
+
|
292 |
+
def decode_images(images: torch.Tensor):
|
293 |
+
images = (
|
294 |
+
((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy()
|
295 |
+
)
|
296 |
+
return [Image.fromarray(x) for x in images]
|
297 |
+
|
298 |
+
|
299 |
+
def preprocess_image(image: Image.Image):
|
300 |
+
w, h = image.size
|
301 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
302 |
+
image = image.resize((w, h))
|
303 |
+
init_image = np.array(image).astype(np.float32) / 255.0
|
304 |
+
init_image = init_image[None].transpose(0, 3, 1, 2)
|
305 |
+
init_image = torch.from_numpy(init_image).contiguous()
|
306 |
+
return 2.0 * init_image - 1.0
|
307 |
+
|
308 |
+
|
309 |
+
def prepare_mask_and_masked_image(image: Image.Image, mask: Image.Image):
|
310 |
+
if isinstance(image, Image.Image):
|
311 |
+
image = np.array(image.convert("RGB"))
|
312 |
+
image = image[None].transpose(0, 3, 1, 2)
|
313 |
+
image = torch.from_numpy(image).to(dtype=torch.float32).contiguous() / 127.5 - 1.0
|
314 |
+
if isinstance(mask, Image.Image):
|
315 |
+
mask = np.array(mask.convert("L"))
|
316 |
+
mask = mask.astype(np.float32) / 255.0
|
317 |
+
mask = mask[None, None]
|
318 |
+
mask[mask < 0.5] = 0
|
319 |
+
mask[mask >= 0.5] = 1
|
320 |
+
mask = torch.from_numpy(mask).to(dtype=torch.float32).contiguous()
|
321 |
+
|
322 |
+
masked_image = image * (mask < 0.5)
|
323 |
+
|
324 |
+
return mask, masked_image
|
325 |
+
|
326 |
+
|
327 |
+
def create_models(
|
328 |
+
model_id: str,
|
329 |
+
use_auth_token: Optional[str],
|
330 |
+
device: Union[str, torch.device],
|
331 |
+
max_batch_size: int,
|
332 |
+
unet_in_channels: int = 4,
|
333 |
+
embedding_dim: int = 768,
|
334 |
+
):
|
335 |
+
models = {
|
336 |
+
"clip": CLIP(
|
337 |
+
hf_token=use_auth_token,
|
338 |
+
device=device,
|
339 |
+
max_batch_size=max_batch_size,
|
340 |
+
embedding_dim=embedding_dim,
|
341 |
+
),
|
342 |
+
"unet": UNet(
|
343 |
+
hf_token=use_auth_token,
|
344 |
+
fp16=True,
|
345 |
+
device=device,
|
346 |
+
max_batch_size=max_batch_size,
|
347 |
+
embedding_dim=embedding_dim,
|
348 |
+
unet_dim=unet_in_channels,
|
349 |
+
),
|
350 |
+
"vae": VAE(
|
351 |
+
hf_token=use_auth_token,
|
352 |
+
device=device,
|
353 |
+
max_batch_size=max_batch_size,
|
354 |
+
embedding_dim=embedding_dim,
|
355 |
+
),
|
356 |
+
"vae_encoder": VAEEncoder(
|
357 |
+
hf_token=use_auth_token,
|
358 |
+
device=device,
|
359 |
+
max_batch_size=max_batch_size,
|
360 |
+
embedding_dim=embedding_dim,
|
361 |
+
),
|
362 |
+
}
|
363 |
+
return models
|
364 |
+
|
365 |
+
|
366 |
+
def build_engine(
|
367 |
+
engine_path: str,
|
368 |
+
onnx_opt_path: str,
|
369 |
+
model_data: BaseModel,
|
370 |
+
opt_image_height: int,
|
371 |
+
opt_image_width: int,
|
372 |
+
opt_batch_size: int,
|
373 |
+
build_static_batch: bool = False,
|
374 |
+
build_dynamic_shape: bool = False,
|
375 |
+
build_all_tactics: bool = False,
|
376 |
+
build_enable_refit: bool = False,
|
377 |
+
):
|
378 |
+
_, free_mem, _ = cudart.cudaMemGetInfo()
|
379 |
+
GiB = 2**30
|
380 |
+
if free_mem > 6 * GiB:
|
381 |
+
activation_carveout = 4 * GiB
|
382 |
+
max_workspace_size = free_mem - activation_carveout
|
383 |
+
else:
|
384 |
+
max_workspace_size = 0
|
385 |
+
engine = Engine(engine_path)
|
386 |
+
input_profile = model_data.get_input_profile(
|
387 |
+
opt_batch_size,
|
388 |
+
opt_image_height,
|
389 |
+
opt_image_width,
|
390 |
+
static_batch=build_static_batch,
|
391 |
+
static_shape=not build_dynamic_shape,
|
392 |
+
)
|
393 |
+
engine.build(
|
394 |
+
onnx_opt_path,
|
395 |
+
fp16=True,
|
396 |
+
input_profile=input_profile,
|
397 |
+
enable_refit=build_enable_refit,
|
398 |
+
enable_all_tactics=build_all_tactics,
|
399 |
+
workspace_size=max_workspace_size,
|
400 |
+
)
|
401 |
+
|
402 |
+
return engine
|
403 |
+
|
404 |
+
|
405 |
+
def export_onnx(
|
406 |
+
model,
|
407 |
+
onnx_path: str,
|
408 |
+
model_data: BaseModel,
|
409 |
+
opt_image_height: int,
|
410 |
+
opt_image_width: int,
|
411 |
+
opt_batch_size: int,
|
412 |
+
onnx_opset: int,
|
413 |
+
):
|
414 |
+
with torch.inference_mode(), torch.autocast("cuda"):
|
415 |
+
inputs = model_data.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
|
416 |
+
torch.onnx.export(
|
417 |
+
model,
|
418 |
+
inputs,
|
419 |
+
onnx_path,
|
420 |
+
export_params=True,
|
421 |
+
opset_version=onnx_opset,
|
422 |
+
do_constant_folding=True,
|
423 |
+
input_names=model_data.get_input_names(),
|
424 |
+
output_names=model_data.get_output_names(),
|
425 |
+
dynamic_axes=model_data.get_dynamic_axes(),
|
426 |
+
)
|
427 |
+
del model
|
428 |
+
gc.collect()
|
429 |
+
torch.cuda.empty_cache()
|
430 |
+
|
431 |
+
|
432 |
+
def optimize_onnx(
|
433 |
+
onnx_path: str,
|
434 |
+
onnx_opt_path: str,
|
435 |
+
model_data: BaseModel,
|
436 |
+
):
|
437 |
+
onnx_opt_graph = model_data.optimize(onnx.load(onnx_path))
|
438 |
+
onnx.save(onnx_opt_graph, onnx_opt_path)
|
439 |
+
del onnx_opt_graph
|
440 |
+
gc.collect()
|
441 |
+
torch.cuda.empty_cache()
|
streamv2v/image_filter.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class SimilarImageFilter:
|
8 |
+
def __init__(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None:
|
9 |
+
self.threshold = threshold
|
10 |
+
self.prev_tensor = None
|
11 |
+
self.cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6)
|
12 |
+
self.max_skip_frame = max_skip_frame
|
13 |
+
self.skip_count = 0
|
14 |
+
|
15 |
+
def __call__(self, x: torch.Tensor) -> Optional[torch.Tensor]:
|
16 |
+
if self.prev_tensor is None:
|
17 |
+
self.prev_tensor = x.detach().clone()
|
18 |
+
return x
|
19 |
+
else:
|
20 |
+
cos_sim = self.cos(self.prev_tensor.reshape(-1), x.reshape(-1)).item()
|
21 |
+
sample = random.uniform(0, 1)
|
22 |
+
if self.threshold >= 1:
|
23 |
+
skip_prob = 0
|
24 |
+
else:
|
25 |
+
skip_prob = max(0, 1 - (1 - cos_sim) / (1 - self.threshold))
|
26 |
+
|
27 |
+
# not skip frame
|
28 |
+
if skip_prob < sample:
|
29 |
+
self.prev_tensor = x.detach().clone()
|
30 |
+
return x
|
31 |
+
# skip frame
|
32 |
+
else:
|
33 |
+
if self.skip_count > self.max_skip_frame:
|
34 |
+
self.skip_count = 0
|
35 |
+
self.prev_tensor = x.detach().clone()
|
36 |
+
return x
|
37 |
+
else:
|
38 |
+
self.skip_count += 1
|
39 |
+
return None
|
40 |
+
|
41 |
+
def set_threshold(self, threshold: float) -> None:
|
42 |
+
self.threshold = threshold
|
43 |
+
|
44 |
+
def set_max_skip_frame(self, max_skip_frame: float) -> None:
|
45 |
+
self.max_skip_frame = max_skip_frame
|
streamv2v/image_utils.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import PIL.Image
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
|
8 |
+
|
9 |
+
def denormalize(images: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
|
10 |
+
"""
|
11 |
+
Denormalize an image array to [0,1].
|
12 |
+
"""
|
13 |
+
return (images / 2 + 0.5).clamp(0, 1)
|
14 |
+
|
15 |
+
|
16 |
+
def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
|
17 |
+
"""
|
18 |
+
Convert a PyTorch tensor to a NumPy image.
|
19 |
+
"""
|
20 |
+
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
21 |
+
return images
|
22 |
+
|
23 |
+
|
24 |
+
def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
|
25 |
+
"""
|
26 |
+
Convert a NumPy image or a batch of images to a PIL image.
|
27 |
+
"""
|
28 |
+
if images.ndim == 3:
|
29 |
+
images = images[None, ...]
|
30 |
+
images = (images * 255).round().astype("uint8")
|
31 |
+
if images.shape[-1] == 1:
|
32 |
+
# special case for grayscale (single channel) images
|
33 |
+
pil_images = [
|
34 |
+
PIL.Image.fromarray(image.squeeze(), mode="L") for image in images
|
35 |
+
]
|
36 |
+
else:
|
37 |
+
pil_images = [PIL.Image.fromarray(image) for image in images]
|
38 |
+
|
39 |
+
return pil_images
|
40 |
+
|
41 |
+
|
42 |
+
def postprocess_image(
|
43 |
+
image: torch.Tensor,
|
44 |
+
output_type: str = "pil",
|
45 |
+
do_denormalize: Optional[List[bool]] = None,
|
46 |
+
) -> Union[torch.Tensor, np.ndarray, PIL.Image.Image]:
|
47 |
+
if not isinstance(image, torch.Tensor):
|
48 |
+
raise ValueError(
|
49 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
50 |
+
)
|
51 |
+
|
52 |
+
if output_type == "latent":
|
53 |
+
return image
|
54 |
+
|
55 |
+
do_normalize_flg = True
|
56 |
+
if do_denormalize is None:
|
57 |
+
do_denormalize = [do_normalize_flg] * image.shape[0]
|
58 |
+
|
59 |
+
image = torch.stack(
|
60 |
+
[
|
61 |
+
denormalize(image[i]) if do_denormalize[i] else image[i]
|
62 |
+
for i in range(image.shape[0])
|
63 |
+
]
|
64 |
+
)
|
65 |
+
|
66 |
+
if output_type == "pt":
|
67 |
+
return image
|
68 |
+
|
69 |
+
image = pt_to_numpy(image)
|
70 |
+
|
71 |
+
if output_type == "np":
|
72 |
+
return image
|
73 |
+
|
74 |
+
if output_type == "pil":
|
75 |
+
return numpy_to_pil(image)
|
76 |
+
|
77 |
+
|
78 |
+
def process_image(
|
79 |
+
image_pil: PIL.Image.Image, range: Tuple[int, int] = (-1, 1)
|
80 |
+
) -> Tuple[torch.Tensor, PIL.Image.Image]:
|
81 |
+
image = torchvision.transforms.ToTensor()(image_pil)
|
82 |
+
r_min, r_max = range[0], range[1]
|
83 |
+
image = image * (r_max - r_min) + r_min
|
84 |
+
return image[None, ...], image_pil
|
85 |
+
|
86 |
+
|
87 |
+
def pil2tensor(image_pil: PIL.Image.Image) -> torch.Tensor:
|
88 |
+
height = image_pil.height
|
89 |
+
width = image_pil.width
|
90 |
+
imgs = []
|
91 |
+
img, _ = process_image(image_pil)
|
92 |
+
imgs.append(img)
|
93 |
+
imgs = torch.vstack(imgs)
|
94 |
+
images = torch.nn.functional.interpolate(
|
95 |
+
imgs, size=(height, width), mode="bilinear"
|
96 |
+
)
|
97 |
+
image_tensors = images.to(torch.float16)
|
98 |
+
return image_tensors
|
99 |
+
|
100 |
+
### Optical flow utils
|
101 |
+
|
102 |
+
def coords_grid(b, h, w, homogeneous=False, device=None):
|
103 |
+
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
|
104 |
+
|
105 |
+
stacks = [x, y]
|
106 |
+
|
107 |
+
if homogeneous:
|
108 |
+
ones = torch.ones_like(x) # [H, W]
|
109 |
+
stacks.append(ones)
|
110 |
+
|
111 |
+
grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
|
112 |
+
|
113 |
+
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
|
114 |
+
|
115 |
+
if device is not None:
|
116 |
+
grid = grid.to(device)
|
117 |
+
|
118 |
+
return grid
|
119 |
+
|
120 |
+
def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
|
121 |
+
b, c, h, w = feature.size()
|
122 |
+
assert flow.size(1) == 2
|
123 |
+
|
124 |
+
grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
|
125 |
+
|
126 |
+
return bilinear_sample(feature, grid, padding_mode=padding_mode,
|
127 |
+
return_mask=mask)
|
128 |
+
|
129 |
+
def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False):
|
130 |
+
# img: [B, C, H, W]
|
131 |
+
# sample_coords: [B, 2, H, W] in image scale
|
132 |
+
if sample_coords.size(1) != 2: # [B, H, W, 2]
|
133 |
+
sample_coords = sample_coords.permute(0, 3, 1, 2)
|
134 |
+
|
135 |
+
b, _, h, w = sample_coords.shape
|
136 |
+
|
137 |
+
# Normalize to [-1, 1]
|
138 |
+
x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
|
139 |
+
y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
|
140 |
+
|
141 |
+
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
|
142 |
+
|
143 |
+
img = torch.nn.functional.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
|
144 |
+
|
145 |
+
if return_mask:
|
146 |
+
mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
|
147 |
+
|
148 |
+
return img, mask
|
149 |
+
|
150 |
+
return img
|
151 |
+
|
152 |
+
def forward_backward_consistency_check(fwd_flow, bwd_flow,
|
153 |
+
alpha=0.1,
|
154 |
+
beta=0.5
|
155 |
+
):
|
156 |
+
# fwd_flow, bwd_flow: [B, 2, H, W]
|
157 |
+
# alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
|
158 |
+
assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
|
159 |
+
assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
|
160 |
+
flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
|
161 |
+
|
162 |
+
warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
|
163 |
+
warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
|
164 |
+
|
165 |
+
diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
|
166 |
+
diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
|
167 |
+
|
168 |
+
threshold = alpha * flow_mag + beta
|
169 |
+
|
170 |
+
fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
|
171 |
+
bwd_occ = (diff_bwd > threshold).float()
|
172 |
+
|
173 |
+
return fwd_occ, bwd_occ
|
streamv2v/models/__init__.py
ADDED
File without changes
|
streamv2v/models/attention_processor.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from importlib import import_module
|
2 |
+
from typing import Callable, Optional, Union
|
3 |
+
from collections import deque
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
from diffusers.models.attention_processor import Attention
|
10 |
+
from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging
|
11 |
+
from diffusers.utils.import_utils import is_xformers_available
|
12 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
13 |
+
from diffusers.models.lora import LoRACompatibleLinear, LoRALinearLayer
|
14 |
+
|
15 |
+
from .utils import get_nn_feats, random_bipartite_soft_matching
|
16 |
+
|
17 |
+
if is_xformers_available():
|
18 |
+
import xformers
|
19 |
+
import xformers.ops
|
20 |
+
else:
|
21 |
+
xformers = None
|
22 |
+
|
23 |
+
class CachedSTAttnProcessor2_0:
|
24 |
+
r"""
|
25 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, name=None, use_feature_injection=False,
|
29 |
+
feature_injection_strength=0.8,
|
30 |
+
feature_similarity_threshold=0.98,
|
31 |
+
interval=4,
|
32 |
+
max_frames=1,
|
33 |
+
use_tome_cache=False,
|
34 |
+
tome_metric="keys",
|
35 |
+
use_grid=False,
|
36 |
+
tome_ratio=0.5):
|
37 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
38 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
39 |
+
self.name = name
|
40 |
+
self.use_feature_injection = use_feature_injection
|
41 |
+
self.fi_strength = feature_injection_strength
|
42 |
+
self.threshold = feature_similarity_threshold
|
43 |
+
self.zero_tensor = torch.tensor(0)
|
44 |
+
self.frame_id = torch.tensor(0)
|
45 |
+
self.interval = torch.tensor(interval)
|
46 |
+
self.max_frames = max_frames
|
47 |
+
self.cached_key = None
|
48 |
+
self.cached_value = None
|
49 |
+
self.cached_output = None
|
50 |
+
self.use_tome_cache = use_tome_cache
|
51 |
+
self.tome_metric = tome_metric
|
52 |
+
self.use_grid = use_grid
|
53 |
+
self.tome_ratio = tome_ratio
|
54 |
+
|
55 |
+
def _tome_step_kvout(self, keys, values, outputs):
|
56 |
+
keys = torch.cat([self.cached_key, keys], dim=1)
|
57 |
+
values = torch.cat([self.cached_value, values], dim=1)
|
58 |
+
outputs = torch.cat([self.cached_output, outputs], dim=1)
|
59 |
+
m_kv_out, _, _= random_bipartite_soft_matching(metric=keys, use_grid=self.use_grid, ratio=self.tome_ratio)
|
60 |
+
compact_keys, compact_values, compact_outputs = m_kv_out(keys, values, outputs)
|
61 |
+
self.cached_key = compact_keys
|
62 |
+
self.cached_value = compact_values
|
63 |
+
self.cached_output = compact_outputs
|
64 |
+
|
65 |
+
def __call__(
|
66 |
+
self,
|
67 |
+
attn: Attention,
|
68 |
+
hidden_states: torch.FloatTensor,
|
69 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
70 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
71 |
+
temb: Optional[torch.FloatTensor] = None,
|
72 |
+
scale: float = 1.0,
|
73 |
+
) -> torch.FloatTensor:
|
74 |
+
residual = hidden_states
|
75 |
+
if attn.spatial_norm is not None:
|
76 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
77 |
+
|
78 |
+
input_ndim = hidden_states.ndim
|
79 |
+
|
80 |
+
if input_ndim == 4:
|
81 |
+
batch_size, channel, height, width = hidden_states.shape
|
82 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
83 |
+
|
84 |
+
batch_size, sequence_length, _ = (
|
85 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
86 |
+
)
|
87 |
+
|
88 |
+
if attention_mask is not None:
|
89 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
90 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
91 |
+
# (batch, heads, source_length, target_length)
|
92 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
93 |
+
|
94 |
+
if attn.group_norm is not None:
|
95 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
96 |
+
|
97 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
98 |
+
query = attn.to_q(hidden_states, *args)
|
99 |
+
|
100 |
+
is_selfattn = False
|
101 |
+
if encoder_hidden_states is None:
|
102 |
+
is_selfattn = True
|
103 |
+
encoder_hidden_states = hidden_states
|
104 |
+
elif attn.norm_cross:
|
105 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
106 |
+
|
107 |
+
key = attn.to_k(encoder_hidden_states, *args)
|
108 |
+
value = attn.to_v(encoder_hidden_states, *args)
|
109 |
+
|
110 |
+
if is_selfattn:
|
111 |
+
cached_key = key.clone()
|
112 |
+
cached_value = value.clone()
|
113 |
+
|
114 |
+
# Avoid if statement -> replace the dynamic graph to static graph
|
115 |
+
if torch.equal(self.frame_id, self.zero_tensor):
|
116 |
+
# ONNX
|
117 |
+
self.cached_key = cached_key
|
118 |
+
self.cached_value = cached_value
|
119 |
+
|
120 |
+
key = torch.cat([key, self.cached_key], dim=1)
|
121 |
+
value = torch.cat([value, self.cached_value], dim=1)
|
122 |
+
|
123 |
+
inner_dim = key.shape[-1]
|
124 |
+
head_dim = inner_dim // attn.heads
|
125 |
+
|
126 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
127 |
+
|
128 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
129 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
130 |
+
|
131 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
132 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
133 |
+
hidden_states = F.scaled_dot_product_attention(
|
134 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
135 |
+
)
|
136 |
+
|
137 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
138 |
+
hidden_states = hidden_states.to(query.dtype)
|
139 |
+
|
140 |
+
# linear proj
|
141 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
142 |
+
# dropout
|
143 |
+
hidden_states = attn.to_out[1](hidden_states)
|
144 |
+
|
145 |
+
if input_ndim == 4:
|
146 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
147 |
+
|
148 |
+
if attn.residual_connection:
|
149 |
+
hidden_states = hidden_states + residual
|
150 |
+
|
151 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
152 |
+
|
153 |
+
if is_selfattn:
|
154 |
+
cached_output = hidden_states.clone()
|
155 |
+
|
156 |
+
if torch.equal(self.frame_id, self.zero_tensor):
|
157 |
+
self.cached_output = cached_output
|
158 |
+
|
159 |
+
if self.use_feature_injection and ("up_blocks.0" in self.name or "up_blocks.1" in self.name or 'mid_block' in self.name):
|
160 |
+
nn_hidden_states = get_nn_feats(hidden_states, self.cached_output, threshold=self.threshold)
|
161 |
+
hidden_states = hidden_states * (1-self.fi_strength) + self.fi_strength * nn_hidden_states
|
162 |
+
|
163 |
+
mod_result = torch.remainder(self.frame_id, self.interval)
|
164 |
+
if torch.equal(mod_result, self.zero_tensor) and is_selfattn:
|
165 |
+
self._tome_step_kvout(cached_key, cached_value, cached_output)
|
166 |
+
|
167 |
+
self.frame_id = self.frame_id + 1
|
168 |
+
|
169 |
+
return hidden_states
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
class CachedSTXFormersAttnProcessor:
|
174 |
+
r"""
|
175 |
+
Processor for implementing memory efficient attention using xFormers.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
179 |
+
The base
|
180 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
181 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
182 |
+
operator.
|
183 |
+
"""
|
184 |
+
|
185 |
+
def __init__(self, attention_op: Optional[Callable] = None, name=None,
|
186 |
+
use_feature_injection=False, feature_injection_strength=0.8, feature_similarity_threshold=0.98,
|
187 |
+
interval=4, max_frames=4, use_tome_cache=False, tome_metric="keys", use_grid=False, tome_ratio=0.5):
|
188 |
+
self.attention_op = attention_op
|
189 |
+
self.name = name
|
190 |
+
self.use_feature_injection = use_feature_injection
|
191 |
+
self.fi_strength = feature_injection_strength
|
192 |
+
self.threshold = feature_similarity_threshold
|
193 |
+
self.frame_id = 0
|
194 |
+
self.interval = interval
|
195 |
+
self.cached_key = deque(maxlen=max_frames)
|
196 |
+
self.cached_value = deque(maxlen=max_frames)
|
197 |
+
self.cached_output = deque(maxlen=max_frames)
|
198 |
+
self.use_tome_cache = use_tome_cache
|
199 |
+
self.tome_metric = tome_metric
|
200 |
+
self.use_grid = use_grid
|
201 |
+
self.tome_ratio = tome_ratio
|
202 |
+
|
203 |
+
def _tome_step_kvout(self, keys, values, outputs):
|
204 |
+
if len(self.cached_value) == 1:
|
205 |
+
keys = torch.cat(list(self.cached_key) + [keys], dim=1)
|
206 |
+
values = torch.cat(list(self.cached_value) + [values], dim=1)
|
207 |
+
outputs = torch.cat(list(self.cached_output) + [outputs], dim=1)
|
208 |
+
m_kv_out, _, _= random_bipartite_soft_matching(metric=eval(self.tome_metric), use_grid=self.use_grid, ratio=self.tome_ratio)
|
209 |
+
compact_keys, compact_values, compact_outputs = m_kv_out(keys, values, outputs)
|
210 |
+
self.cached_key.append(compact_keys)
|
211 |
+
self.cached_value.append(compact_values)
|
212 |
+
self.cached_output.append(compact_outputs)
|
213 |
+
else:
|
214 |
+
self.cached_key.append(keys)
|
215 |
+
self.cached_value.append(values)
|
216 |
+
self.cached_output.append(outputs)
|
217 |
+
|
218 |
+
def _tome_step_kv(self, keys, values):
|
219 |
+
if len(self.cached_value) == 1:
|
220 |
+
keys = torch.cat(list(self.cached_key) + [keys], dim=1)
|
221 |
+
values = torch.cat(list(self.cached_value) + [values], dim=1)
|
222 |
+
_, m_kv, _= random_bipartite_soft_matching(metric=eval(self.tome_metric), use_grid=self.use_grid, ratio=self.tome_ratio)
|
223 |
+
compact_keys, compact_values = m_kv(keys, values)
|
224 |
+
self.cached_key.append(compact_keys)
|
225 |
+
self.cached_value.append(compact_values)
|
226 |
+
else:
|
227 |
+
self.cached_key.append(keys)
|
228 |
+
self.cached_value.append(values)
|
229 |
+
|
230 |
+
def _tome_step_out(self, outputs):
|
231 |
+
if len(self.cached_value) == 1:
|
232 |
+
outputs = torch.cat(list(self.cached_output) + [outputs], dim=1)
|
233 |
+
_, _, m_out= random_bipartite_soft_matching(metric=outputs, use_grid=self.use_grid, ratio=self.tome_ratio)
|
234 |
+
compact_outputs = m_out(outputs)
|
235 |
+
self.cached_output.append(compact_outputs)
|
236 |
+
else:
|
237 |
+
self.cached_output.append(outputs)
|
238 |
+
|
239 |
+
def __call__(
|
240 |
+
self,
|
241 |
+
attn: Attention,
|
242 |
+
hidden_states: torch.FloatTensor,
|
243 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
244 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
245 |
+
temb: Optional[torch.FloatTensor] = None,
|
246 |
+
scale: float = 1.0,
|
247 |
+
) -> torch.FloatTensor:
|
248 |
+
residual = hidden_states
|
249 |
+
|
250 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
251 |
+
|
252 |
+
if attn.spatial_norm is not None:
|
253 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
254 |
+
|
255 |
+
input_ndim = hidden_states.ndim
|
256 |
+
|
257 |
+
if input_ndim == 4:
|
258 |
+
batch_size, channel, height, width = hidden_states.shape
|
259 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
260 |
+
|
261 |
+
batch_size, key_tokens, _ = (
|
262 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
263 |
+
)
|
264 |
+
|
265 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
266 |
+
if attention_mask is not None:
|
267 |
+
# expand our mask's singleton query_tokens dimension:
|
268 |
+
# [batch*heads, 1, key_tokens] ->
|
269 |
+
# [batch*heads, query_tokens, key_tokens]
|
270 |
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
271 |
+
# [batch*heads, query_tokens, key_tokens]
|
272 |
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
273 |
+
_, query_tokens, _ = hidden_states.shape
|
274 |
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
275 |
+
|
276 |
+
if attn.group_norm is not None:
|
277 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
278 |
+
|
279 |
+
query = attn.to_q(hidden_states, *args)
|
280 |
+
|
281 |
+
is_selfattn = False
|
282 |
+
if encoder_hidden_states is None:
|
283 |
+
is_selfattn = True
|
284 |
+
encoder_hidden_states = hidden_states
|
285 |
+
elif attn.norm_cross:
|
286 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
287 |
+
|
288 |
+
key = attn.to_k(encoder_hidden_states, *args)
|
289 |
+
value = attn.to_v(encoder_hidden_states, *args)
|
290 |
+
|
291 |
+
if is_selfattn:
|
292 |
+
cached_key = key.clone()
|
293 |
+
cached_value = value.clone()
|
294 |
+
|
295 |
+
if len(self.cached_key) > 0:
|
296 |
+
key = torch.cat([key] + list(self.cached_key), dim=1)
|
297 |
+
value = torch.cat([value] + list(self.cached_value), dim=1)
|
298 |
+
|
299 |
+
## Code for storing and visualizing features
|
300 |
+
# if self.frame_id % self.interval == 0:
|
301 |
+
# # if "down_blocks.0" in self.name or "up_blocks.3" in self.name:
|
302 |
+
# # feats = {
|
303 |
+
# # "hidden_states": hidden_states.clone().cpu(),
|
304 |
+
# # "query": query.clone().cpu(),
|
305 |
+
# # "key": cached_key.cpu(),
|
306 |
+
# # "value": cached_value.cpu(),
|
307 |
+
# # }
|
308 |
+
# # torch.save(feats, f'./outputs/self_attn_feats_SD/{self.name}.frame{self.frame_id}.pt')
|
309 |
+
# if self.use_tome_cache:
|
310 |
+
# cached_key, cached_value = self._tome_step(cached_key, cached_value)
|
311 |
+
|
312 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
313 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
314 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
315 |
+
|
316 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
317 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
318 |
+
)
|
319 |
+
hidden_states = hidden_states.to(query.dtype)
|
320 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
321 |
+
|
322 |
+
# linear proj
|
323 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
324 |
+
# dropout
|
325 |
+
hidden_states = attn.to_out[1](hidden_states)
|
326 |
+
|
327 |
+
if input_ndim == 4:
|
328 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
329 |
+
|
330 |
+
if attn.residual_connection:
|
331 |
+
hidden_states = hidden_states + residual
|
332 |
+
|
333 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
334 |
+
if is_selfattn:
|
335 |
+
cached_output = hidden_states.clone()
|
336 |
+
if self.use_feature_injection and ("up_blocks.0" in self.name or "up_blocks.1" in self.name or 'mid_block' in self.name):
|
337 |
+
if len(self.cached_output) > 0:
|
338 |
+
nn_hidden_states = get_nn_feats(hidden_states, self.cached_output, threshold=self.threshold)
|
339 |
+
hidden_states = hidden_states * (1-self.fi_strength) + self.fi_strength * nn_hidden_states
|
340 |
+
|
341 |
+
if self.frame_id % self.interval == 0:
|
342 |
+
if is_selfattn:
|
343 |
+
if self.use_tome_cache:
|
344 |
+
self._tome_step_kvout(cached_key, cached_value, cached_output)
|
345 |
+
else:
|
346 |
+
self.cached_key.append(cached_key)
|
347 |
+
self.cached_value.append(cached_value)
|
348 |
+
self.cached_output.append(cached_output)
|
349 |
+
self.frame_id += 1
|
350 |
+
|
351 |
+
return hidden_states
|
352 |
+
|
streamv2v/models/utils.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import deque
|
2 |
+
from typing import Tuple, Callable
|
3 |
+
|
4 |
+
from einops import rearrange
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
def get_nn_feats(x, y, threshold=0.9):
|
9 |
+
|
10 |
+
if type(x) is deque:
|
11 |
+
x = torch.cat(list(x), dim=1)
|
12 |
+
if type(y) is deque:
|
13 |
+
y = torch.cat(list(y), dim=1)
|
14 |
+
|
15 |
+
x_norm = F.normalize(x, p=2, dim=-1)
|
16 |
+
y_norm = F.normalize(y, p=2, dim=-1)
|
17 |
+
|
18 |
+
cosine_similarity = torch.matmul(x_norm, y_norm.transpose(1, 2))
|
19 |
+
|
20 |
+
max_cosine_values, nearest_neighbors_indices = torch.max(cosine_similarity, dim=-1)
|
21 |
+
mask = max_cosine_values < threshold
|
22 |
+
# print('mask ratio', torch.sum(mask)/x.shape[0]/x.shape[1])
|
23 |
+
indices_expanded = nearest_neighbors_indices.unsqueeze(-1).expand(-1, -1, x_norm.size(-1))
|
24 |
+
nearest_neighbor_tensor = torch.gather(y, 1, indices_expanded)
|
25 |
+
selected_tensor = torch.where(mask.unsqueeze(-1), x, nearest_neighbor_tensor)
|
26 |
+
|
27 |
+
return selected_tensor
|
28 |
+
|
29 |
+
def get_nn_latent(x, y, threshold=0.9):
|
30 |
+
|
31 |
+
assert len(x.shape) == 4
|
32 |
+
_, c, h, w = x.shape
|
33 |
+
x_ = rearrange(x, 'n c h w -> n (h w) c')
|
34 |
+
y_ = []
|
35 |
+
for i in range(len(y)):
|
36 |
+
y_.append(rearrange(y[i], 'n c h w -> n (h w) c'))
|
37 |
+
y_ = torch.cat(y_, dim=1)
|
38 |
+
x_norm = F.normalize(x_, p=2, dim=-1)
|
39 |
+
y_norm = F.normalize(y_, p=2, dim=-1)
|
40 |
+
|
41 |
+
cosine_similarity = torch.matmul(x_norm, y_norm.transpose(1, 2))
|
42 |
+
|
43 |
+
max_cosine_values, nearest_neighbors_indices = torch.max(cosine_similarity, dim=-1)
|
44 |
+
mask = max_cosine_values < threshold
|
45 |
+
indices_expanded = nearest_neighbors_indices.unsqueeze(-1).expand(-1, -1, x_norm.size(-1))
|
46 |
+
nearest_neighbor_tensor = torch.gather(y_, 1, indices_expanded)
|
47 |
+
|
48 |
+
# Use values from x where the cosine similarity is below the threshold
|
49 |
+
x_expanded = x_.expand_as(nearest_neighbor_tensor)
|
50 |
+
selected_tensor = torch.where(mask.unsqueeze(-1), x_expanded, nearest_neighbor_tensor)
|
51 |
+
|
52 |
+
selected_tensor = rearrange(selected_tensor, 'n (h w) c -> n c h w', h=h, w=w, c=c)
|
53 |
+
|
54 |
+
return selected_tensor
|
55 |
+
|
56 |
+
|
57 |
+
def random_bipartite_soft_matching(
|
58 |
+
metric: torch.Tensor, use_grid: bool = False, ratio: float = 0.5
|
59 |
+
) -> Tuple[Callable, Callable]:
|
60 |
+
"""
|
61 |
+
Applies ToMe with the two sets as (r chosen randomly, the rest).
|
62 |
+
Input size is [batch, tokens, channels].
|
63 |
+
|
64 |
+
This will reduce the number of tokens by a ratio of ratio/2.
|
65 |
+
"""
|
66 |
+
|
67 |
+
with torch.no_grad():
|
68 |
+
B, N, _ = metric.shape
|
69 |
+
if use_grid:
|
70 |
+
assert ratio == 0.5
|
71 |
+
sample = torch.randint(2, size=(B, N//2, 1), device=metric.device)
|
72 |
+
sample_alternate = 1 - sample
|
73 |
+
grid = torch.arange(0, N, 2).view(1, N//2, 1).to(device=metric.device)
|
74 |
+
grid = grid.repeat(4, 1, 1)
|
75 |
+
rand_idx = torch.cat([sample + grid, sample_alternate + grid], dim = 1)
|
76 |
+
else:
|
77 |
+
rand_idx = torch.rand(B, N, 1, device=metric.device).argsort(dim=1)
|
78 |
+
r = int(ratio * N)
|
79 |
+
a_idx = rand_idx[:, :r, :]
|
80 |
+
b_idx = rand_idx[:, r:, :]
|
81 |
+
def split(x):
|
82 |
+
C = x.shape[-1]
|
83 |
+
a = x.gather(dim=1, index=a_idx.expand(B, r, C))
|
84 |
+
b = x.gather(dim=1, index=b_idx.expand(B, N - r, C))
|
85 |
+
return a, b
|
86 |
+
|
87 |
+
metric = metric / metric.norm(dim=-1, keepdim=True)
|
88 |
+
a, b = split(metric)
|
89 |
+
scores = a @ b.transpose(-1, -2)
|
90 |
+
|
91 |
+
_, dst_idx = scores.max(dim=-1)
|
92 |
+
dst_idx = dst_idx[..., None]
|
93 |
+
|
94 |
+
def merge_kv_out(keys: torch.Tensor, values: torch.Tensor, outputs: torch.Tensor, mode="mean") -> torch.Tensor:
|
95 |
+
src_keys, dst_keys = split(keys)
|
96 |
+
C_keys = src_keys.shape[-1]
|
97 |
+
dst_keys = dst_keys.scatter_reduce(-2, dst_idx.expand(B, r, C_keys), src_keys, reduce=mode)
|
98 |
+
|
99 |
+
src_values, dst_values = split(values)
|
100 |
+
C_values = src_values.shape[-1]
|
101 |
+
dst_values = dst_values.scatter_reduce(-2, dst_idx.expand(B, r, C_values), src_values, reduce=mode)
|
102 |
+
|
103 |
+
src_outputs, dst_outputs = split(outputs)
|
104 |
+
C_outputs = src_outputs.shape[-1]
|
105 |
+
dst_outputs = dst_outputs.scatter_reduce(-2, dst_idx.expand(B, r, C_outputs), src_outputs, reduce=mode)
|
106 |
+
|
107 |
+
return dst_keys, dst_values, dst_outputs
|
108 |
+
|
109 |
+
def merge_kv(keys: torch.Tensor, values: torch.Tensor, mode="mean") -> torch.Tensor:
|
110 |
+
src_keys, dst_keys = split(keys)
|
111 |
+
C_keys = src_keys.shape[-1]
|
112 |
+
dst_keys = dst_keys.scatter_reduce(-2, dst_idx.expand(B, r, C_keys), src_keys, reduce=mode)
|
113 |
+
|
114 |
+
src_values, dst_values = split(values)
|
115 |
+
C_values = src_values.shape[-1]
|
116 |
+
dst_values = dst_values.scatter_reduce(-2, dst_idx.expand(B, r, C_values), src_values, reduce=mode)
|
117 |
+
|
118 |
+
return dst_keys, dst_values
|
119 |
+
|
120 |
+
def merge_out(outputs: torch.Tensor, mode="mean") -> torch.Tensor:
|
121 |
+
src_outputs, dst_outputs = split(outputs)
|
122 |
+
C_outputs = src_outputs.shape[-1]
|
123 |
+
dst_outputs = dst_outputs.scatter_reduce(-2, dst_idx.expand(B, r, C_outputs), src_outputs, reduce=mode)
|
124 |
+
|
125 |
+
return dst_outputs
|
126 |
+
|
127 |
+
return merge_kv_out, merge_kv, merge_out
|
streamv2v/pip_utils.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import importlib.util
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
import sys
|
6 |
+
from typing import Dict, Optional
|
7 |
+
|
8 |
+
from packaging.version import Version
|
9 |
+
|
10 |
+
|
11 |
+
python = sys.executable
|
12 |
+
index_url = os.environ.get("INDEX_URL", "")
|
13 |
+
|
14 |
+
|
15 |
+
def version(package: str) -> Optional[Version]:
|
16 |
+
try:
|
17 |
+
return Version(importlib.import_module(package).__version__)
|
18 |
+
except ModuleNotFoundError:
|
19 |
+
return None
|
20 |
+
|
21 |
+
|
22 |
+
def is_installed(package: str) -> bool:
|
23 |
+
try:
|
24 |
+
spec = importlib.util.find_spec(package)
|
25 |
+
except ModuleNotFoundError:
|
26 |
+
return False
|
27 |
+
|
28 |
+
return spec is not None
|
29 |
+
|
30 |
+
|
31 |
+
def run_python(command: str, env: Dict[str, str] = None) -> str:
|
32 |
+
run_kwargs = {
|
33 |
+
"args": f"\"{python}\" {command}",
|
34 |
+
"shell": True,
|
35 |
+
"env": os.environ if env is None else env,
|
36 |
+
"encoding": "utf8",
|
37 |
+
"errors": "ignore",
|
38 |
+
}
|
39 |
+
|
40 |
+
print(run_kwargs["args"])
|
41 |
+
|
42 |
+
result = subprocess.run(**run_kwargs)
|
43 |
+
|
44 |
+
if result.returncode != 0:
|
45 |
+
print(f"Error running command: {command}", file=sys.stderr)
|
46 |
+
raise RuntimeError(f"Error running command: {command}")
|
47 |
+
|
48 |
+
return result.stdout or ""
|
49 |
+
|
50 |
+
|
51 |
+
def run_pip(command: str, env: Dict[str, str] = None) -> str:
|
52 |
+
return run_python(f"-m pip {command}", env)
|
streamv2v/pipeline.py
ADDED
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from typing import List, Optional, Union, Any, Dict, Tuple, Literal
|
5 |
+
from collections import deque
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import PIL.Image
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torchvision.models.optical_flow import raft_small
|
12 |
+
|
13 |
+
from diffusers import LCMScheduler, StableDiffusionPipeline
|
14 |
+
from diffusers.image_processor import VaeImageProcessor
|
15 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
16 |
+
retrieve_latents,
|
17 |
+
)
|
18 |
+
from .image_utils import postprocess_image, forward_backward_consistency_check
|
19 |
+
from .models.utils import get_nn_latent
|
20 |
+
from .image_filter import SimilarImageFilter
|
21 |
+
|
22 |
+
|
23 |
+
class StreamV2V:
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
pipe: StableDiffusionPipeline,
|
27 |
+
t_index_list: List[int],
|
28 |
+
torch_dtype: torch.dtype = torch.float16,
|
29 |
+
width: int = 512,
|
30 |
+
height: int = 512,
|
31 |
+
do_add_noise: bool = True,
|
32 |
+
use_denoising_batch: bool = True,
|
33 |
+
frame_buffer_size: int = 1,
|
34 |
+
cfg_type: Literal["none", "full", "self", "initialize"] = "self",
|
35 |
+
) -> None:
|
36 |
+
self.device = pipe.device
|
37 |
+
self.dtype = torch_dtype
|
38 |
+
self.generator = None
|
39 |
+
|
40 |
+
self.height = height
|
41 |
+
self.width = width
|
42 |
+
|
43 |
+
self.latent_height = int(height // pipe.vae_scale_factor)
|
44 |
+
self.latent_width = int(width // pipe.vae_scale_factor)
|
45 |
+
|
46 |
+
self.frame_bff_size = frame_buffer_size
|
47 |
+
self.denoising_steps_num = len(t_index_list)
|
48 |
+
|
49 |
+
self.cfg_type = cfg_type
|
50 |
+
|
51 |
+
if use_denoising_batch:
|
52 |
+
self.batch_size = self.denoising_steps_num * frame_buffer_size
|
53 |
+
if self.cfg_type == "initialize":
|
54 |
+
self.trt_unet_batch_size = (
|
55 |
+
self.denoising_steps_num + 1
|
56 |
+
) * self.frame_bff_size
|
57 |
+
elif self.cfg_type == "full":
|
58 |
+
self.trt_unet_batch_size = (
|
59 |
+
2 * self.denoising_steps_num * self.frame_bff_size
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
self.trt_unet_batch_size = self.denoising_steps_num * frame_buffer_size
|
63 |
+
else:
|
64 |
+
self.trt_unet_batch_size = self.frame_bff_size
|
65 |
+
self.batch_size = frame_buffer_size
|
66 |
+
|
67 |
+
self.t_list = t_index_list
|
68 |
+
|
69 |
+
self.do_add_noise = do_add_noise
|
70 |
+
self.use_denoising_batch = use_denoising_batch
|
71 |
+
|
72 |
+
self.similar_image_filter = False
|
73 |
+
self.similar_filter = SimilarImageFilter()
|
74 |
+
self.prev_image_tensor = None
|
75 |
+
self.prev_x_t_latent = None
|
76 |
+
self.prev_image_result = None
|
77 |
+
|
78 |
+
self.pipe = pipe
|
79 |
+
self.image_processor = VaeImageProcessor(pipe.vae_scale_factor)
|
80 |
+
|
81 |
+
self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
|
82 |
+
self.text_encoder = pipe.text_encoder
|
83 |
+
self.unet = pipe.unet
|
84 |
+
self.vae = pipe.vae
|
85 |
+
|
86 |
+
self.flow_model = raft_small(pretrained=True, progress=False).to(device=pipe.device).eval()
|
87 |
+
|
88 |
+
self.cached_x_t_latent = deque(maxlen=4)
|
89 |
+
|
90 |
+
self.inference_time_ema = 0
|
91 |
+
|
92 |
+
def load_lcm_lora(
|
93 |
+
self,
|
94 |
+
pretrained_model_name_or_path_or_dict: Union[
|
95 |
+
str, Dict[str, torch.Tensor]
|
96 |
+
] = "latent-consistency/lcm-lora-sdv1-5",
|
97 |
+
adapter_name: Optional[Any] = 'lcm',
|
98 |
+
**kwargs,
|
99 |
+
) -> None:
|
100 |
+
self.pipe.load_lora_weights(
|
101 |
+
pretrained_model_name_or_path_or_dict, adapter_name, **kwargs
|
102 |
+
)
|
103 |
+
|
104 |
+
def load_lora(
|
105 |
+
self,
|
106 |
+
pretrained_lora_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
107 |
+
adapter_name: Optional[Any] = None,
|
108 |
+
**kwargs,
|
109 |
+
) -> None:
|
110 |
+
self.pipe.load_lora_weights(
|
111 |
+
pretrained_lora_model_name_or_path_or_dict, adapter_name, **kwargs
|
112 |
+
)
|
113 |
+
|
114 |
+
def fuse_lora(
|
115 |
+
self,
|
116 |
+
fuse_unet: bool = True,
|
117 |
+
fuse_text_encoder: bool = True,
|
118 |
+
lora_scale: float = 1.0,
|
119 |
+
safe_fusing: bool = False,
|
120 |
+
) -> None:
|
121 |
+
self.pipe.fuse_lora(
|
122 |
+
fuse_unet=fuse_unet,
|
123 |
+
fuse_text_encoder=fuse_text_encoder,
|
124 |
+
lora_scale=lora_scale,
|
125 |
+
safe_fusing=safe_fusing,
|
126 |
+
)
|
127 |
+
|
128 |
+
def enable_similar_image_filter(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None:
|
129 |
+
self.similar_image_filter = True
|
130 |
+
self.similar_filter.set_threshold(threshold)
|
131 |
+
self.similar_filter.set_max_skip_frame(max_skip_frame)
|
132 |
+
|
133 |
+
def disable_similar_image_filter(self) -> None:
|
134 |
+
self.similar_image_filter = False
|
135 |
+
|
136 |
+
@torch.no_grad()
|
137 |
+
def prepare(
|
138 |
+
self,
|
139 |
+
prompt: str,
|
140 |
+
negative_prompt: str = "",
|
141 |
+
num_inference_steps: int = 50,
|
142 |
+
guidance_scale: float = 1.2,
|
143 |
+
delta: float = 1.0,
|
144 |
+
generator: Optional[torch.Generator] = torch.Generator(),
|
145 |
+
seed: int = 2,
|
146 |
+
) -> None:
|
147 |
+
self.generator = generator
|
148 |
+
self.generator.manual_seed(seed)
|
149 |
+
# initialize x_t_latent (it can be any random tensor)
|
150 |
+
if self.denoising_steps_num > 1:
|
151 |
+
self.x_t_latent_buffer = torch.zeros(
|
152 |
+
(
|
153 |
+
(self.denoising_steps_num - 1) * self.frame_bff_size,
|
154 |
+
4,
|
155 |
+
self.latent_height,
|
156 |
+
self.latent_width,
|
157 |
+
),
|
158 |
+
dtype=self.dtype,
|
159 |
+
device=self.device,
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
self.x_t_latent_buffer = None
|
163 |
+
|
164 |
+
if self.cfg_type == "none":
|
165 |
+
self.guidance_scale = 1.0
|
166 |
+
else:
|
167 |
+
self.guidance_scale = guidance_scale
|
168 |
+
self.delta = delta
|
169 |
+
|
170 |
+
do_classifier_free_guidance = False
|
171 |
+
if self.guidance_scale > 1.0:
|
172 |
+
do_classifier_free_guidance = True
|
173 |
+
|
174 |
+
encoder_output = self.pipe.encode_prompt(
|
175 |
+
prompt=prompt,
|
176 |
+
device=self.device,
|
177 |
+
num_images_per_prompt=1,
|
178 |
+
do_classifier_free_guidance=True,
|
179 |
+
negative_prompt=negative_prompt,
|
180 |
+
)
|
181 |
+
|
182 |
+
self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1)
|
183 |
+
self.null_prompt_embeds = encoder_output[1]
|
184 |
+
|
185 |
+
if self.use_denoising_batch and self.cfg_type == "full":
|
186 |
+
uncond_prompt_embeds = encoder_output[1].repeat(self.batch_size, 1, 1)
|
187 |
+
elif self.cfg_type == "initialize":
|
188 |
+
uncond_prompt_embeds = encoder_output[1].repeat(self.frame_bff_size, 1, 1)
|
189 |
+
|
190 |
+
if self.guidance_scale > 1.0 and (
|
191 |
+
self.cfg_type == "initialize" or self.cfg_type == "full"
|
192 |
+
):
|
193 |
+
self.prompt_embeds = torch.cat(
|
194 |
+
[uncond_prompt_embeds, self.prompt_embeds], dim=0
|
195 |
+
)
|
196 |
+
|
197 |
+
self.scheduler.set_timesteps(num_inference_steps, self.device)
|
198 |
+
self.timesteps = self.scheduler.timesteps.to(self.device)
|
199 |
+
|
200 |
+
# make sub timesteps list based on the indices in the t_list list and the values in the timesteps list
|
201 |
+
self.sub_timesteps = []
|
202 |
+
for t in self.t_list:
|
203 |
+
self.sub_timesteps.append(self.timesteps[t])
|
204 |
+
|
205 |
+
sub_timesteps_tensor = torch.tensor(
|
206 |
+
self.sub_timesteps, dtype=torch.long, device=self.device
|
207 |
+
)
|
208 |
+
self.sub_timesteps_tensor = torch.repeat_interleave(
|
209 |
+
sub_timesteps_tensor,
|
210 |
+
repeats=self.frame_bff_size if self.use_denoising_batch else 1,
|
211 |
+
dim=0,
|
212 |
+
)
|
213 |
+
|
214 |
+
self.init_noise = torch.randn(
|
215 |
+
(self.batch_size, 4, self.latent_height, self.latent_width),
|
216 |
+
generator=generator,
|
217 |
+
).to(device=self.device, dtype=self.dtype)
|
218 |
+
|
219 |
+
self.randn_noise = self.init_noise[:1].clone()
|
220 |
+
self.warp_noise = self.init_noise[:1].clone()
|
221 |
+
|
222 |
+
self.stock_noise = torch.zeros_like(self.init_noise)
|
223 |
+
|
224 |
+
c_skip_list = []
|
225 |
+
c_out_list = []
|
226 |
+
for timestep in self.sub_timesteps:
|
227 |
+
c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(
|
228 |
+
timestep
|
229 |
+
)
|
230 |
+
c_skip_list.append(c_skip)
|
231 |
+
c_out_list.append(c_out)
|
232 |
+
|
233 |
+
self.c_skip = (
|
234 |
+
torch.stack(c_skip_list)
|
235 |
+
.view(len(self.t_list), 1, 1, 1)
|
236 |
+
.to(dtype=self.dtype, device=self.device)
|
237 |
+
)
|
238 |
+
self.c_out = (
|
239 |
+
torch.stack(c_out_list)
|
240 |
+
.view(len(self.t_list), 1, 1, 1)
|
241 |
+
.to(dtype=self.dtype, device=self.device)
|
242 |
+
)
|
243 |
+
|
244 |
+
alpha_prod_t_sqrt_list = []
|
245 |
+
beta_prod_t_sqrt_list = []
|
246 |
+
for timestep in self.sub_timesteps:
|
247 |
+
alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt()
|
248 |
+
beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt()
|
249 |
+
alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt)
|
250 |
+
beta_prod_t_sqrt_list.append(beta_prod_t_sqrt)
|
251 |
+
alpha_prod_t_sqrt = (
|
252 |
+
torch.stack(alpha_prod_t_sqrt_list)
|
253 |
+
.view(len(self.t_list), 1, 1, 1)
|
254 |
+
.to(dtype=self.dtype, device=self.device)
|
255 |
+
)
|
256 |
+
beta_prod_t_sqrt = (
|
257 |
+
torch.stack(beta_prod_t_sqrt_list)
|
258 |
+
.view(len(self.t_list), 1, 1, 1)
|
259 |
+
.to(dtype=self.dtype, device=self.device)
|
260 |
+
)
|
261 |
+
self.alpha_prod_t_sqrt = torch.repeat_interleave(
|
262 |
+
alpha_prod_t_sqrt,
|
263 |
+
repeats=self.frame_bff_size if self.use_denoising_batch else 1,
|
264 |
+
dim=0,
|
265 |
+
)
|
266 |
+
self.beta_prod_t_sqrt = torch.repeat_interleave(
|
267 |
+
beta_prod_t_sqrt,
|
268 |
+
repeats=self.frame_bff_size if self.use_denoising_batch else 1,
|
269 |
+
dim=0,
|
270 |
+
)
|
271 |
+
|
272 |
+
@torch.no_grad()
|
273 |
+
def update_prompt(self, prompt: str) -> None:
|
274 |
+
encoder_output = self.pipe.encode_prompt(
|
275 |
+
prompt=prompt,
|
276 |
+
device=self.device,
|
277 |
+
num_images_per_prompt=1,
|
278 |
+
do_classifier_free_guidance=False,
|
279 |
+
)
|
280 |
+
self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1)
|
281 |
+
|
282 |
+
def add_noise(
|
283 |
+
self,
|
284 |
+
original_samples: torch.Tensor,
|
285 |
+
noise: torch.Tensor,
|
286 |
+
t_index: int,
|
287 |
+
) -> torch.Tensor:
|
288 |
+
noisy_samples = (
|
289 |
+
self.alpha_prod_t_sqrt[t_index] * original_samples
|
290 |
+
+ self.beta_prod_t_sqrt[t_index] * noise
|
291 |
+
)
|
292 |
+
return noisy_samples
|
293 |
+
|
294 |
+
def scheduler_step_batch(
|
295 |
+
self,
|
296 |
+
model_pred_batch: torch.Tensor,
|
297 |
+
x_t_latent_batch: torch.Tensor,
|
298 |
+
idx: Optional[int] = None,
|
299 |
+
) -> torch.Tensor:
|
300 |
+
# TODO: use t_list to select beta_prod_t_sqrt
|
301 |
+
if idx is None:
|
302 |
+
F_theta = (
|
303 |
+
x_t_latent_batch - self.beta_prod_t_sqrt * model_pred_batch
|
304 |
+
) / self.alpha_prod_t_sqrt
|
305 |
+
denoised_batch = self.c_out * F_theta + self.c_skip * x_t_latent_batch
|
306 |
+
else:
|
307 |
+
F_theta = (
|
308 |
+
x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch
|
309 |
+
) / self.alpha_prod_t_sqrt[idx]
|
310 |
+
denoised_batch = (
|
311 |
+
self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch
|
312 |
+
)
|
313 |
+
|
314 |
+
return denoised_batch
|
315 |
+
|
316 |
+
def unet_step(
|
317 |
+
self,
|
318 |
+
x_t_latent: torch.Tensor,
|
319 |
+
t_list: Union[torch.Tensor, list[int]],
|
320 |
+
idx: Optional[int] = None,
|
321 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
322 |
+
if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"):
|
323 |
+
x_t_latent_plus_uc = torch.concat([x_t_latent[0:1], x_t_latent], dim=0)
|
324 |
+
t_list = torch.concat([t_list[0:1], t_list], dim=0)
|
325 |
+
elif self.guidance_scale > 1.0 and (self.cfg_type == "full"):
|
326 |
+
x_t_latent_plus_uc = torch.concat([x_t_latent, x_t_latent], dim=0)
|
327 |
+
t_list = torch.concat([t_list, t_list], dim=0)
|
328 |
+
else:
|
329 |
+
x_t_latent_plus_uc = x_t_latent
|
330 |
+
|
331 |
+
model_pred = self.unet(
|
332 |
+
x_t_latent_plus_uc,
|
333 |
+
t_list,
|
334 |
+
encoder_hidden_states=self.prompt_embeds,
|
335 |
+
return_dict=False,
|
336 |
+
)[0]
|
337 |
+
|
338 |
+
if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"):
|
339 |
+
noise_pred_text = model_pred[1:]
|
340 |
+
self.stock_noise = torch.concat(
|
341 |
+
[model_pred[0:1], self.stock_noise[1:]], dim=0
|
342 |
+
) # ここコメントアウトでself out cfg
|
343 |
+
elif self.guidance_scale > 1.0 and (self.cfg_type == "full"):
|
344 |
+
noise_pred_uncond, noise_pred_text = model_pred.chunk(2)
|
345 |
+
else:
|
346 |
+
noise_pred_text = model_pred
|
347 |
+
if self.guidance_scale > 1.0 and (
|
348 |
+
self.cfg_type == "self" or self.cfg_type == "initialize"
|
349 |
+
):
|
350 |
+
noise_pred_uncond = self.stock_noise * self.delta
|
351 |
+
if self.guidance_scale > 1.0 and self.cfg_type != "none":
|
352 |
+
model_pred = noise_pred_uncond + self.guidance_scale * (
|
353 |
+
noise_pred_text - noise_pred_uncond
|
354 |
+
)
|
355 |
+
else:
|
356 |
+
model_pred = noise_pred_text
|
357 |
+
|
358 |
+
# compute the previous noisy sample x_t -> x_t-1
|
359 |
+
if self.use_denoising_batch:
|
360 |
+
denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx)
|
361 |
+
if self.cfg_type == "self" or self.cfg_type == "initialize":
|
362 |
+
scaled_noise = self.beta_prod_t_sqrt * self.stock_noise
|
363 |
+
delta_x = self.scheduler_step_batch(model_pred, scaled_noise, idx)
|
364 |
+
alpha_next = torch.concat(
|
365 |
+
[
|
366 |
+
self.alpha_prod_t_sqrt[1:],
|
367 |
+
torch.ones_like(self.alpha_prod_t_sqrt[0:1]),
|
368 |
+
],
|
369 |
+
dim=0,
|
370 |
+
)
|
371 |
+
delta_x = alpha_next * delta_x
|
372 |
+
beta_next = torch.concat(
|
373 |
+
[
|
374 |
+
self.beta_prod_t_sqrt[1:],
|
375 |
+
torch.ones_like(self.beta_prod_t_sqrt[0:1]),
|
376 |
+
],
|
377 |
+
dim=0,
|
378 |
+
)
|
379 |
+
delta_x = delta_x / beta_next
|
380 |
+
init_noise = torch.concat(
|
381 |
+
[self.init_noise[1:], self.init_noise[0:1]], dim=0
|
382 |
+
)
|
383 |
+
self.stock_noise = init_noise + delta_x
|
384 |
+
|
385 |
+
else:
|
386 |
+
# denoised_batch = self.scheduler.step(model_pred, t_list[0], x_t_latent).denoised
|
387 |
+
denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx)
|
388 |
+
|
389 |
+
return denoised_batch, model_pred
|
390 |
+
|
391 |
+
|
392 |
+
def norm_noise(self, noise):
|
393 |
+
# Compute mean and std of blended_noise
|
394 |
+
mean = noise.mean()
|
395 |
+
std = noise.std()
|
396 |
+
|
397 |
+
# Normalize blended_noise to have mean=0 and std=1
|
398 |
+
normalized_noise = (noise - mean) / std
|
399 |
+
return normalized_noise
|
400 |
+
|
401 |
+
def encode_image(self, image_tensors: torch.Tensor) -> torch.Tensor:
|
402 |
+
image_tensors = image_tensors.to(
|
403 |
+
device=self.device,
|
404 |
+
dtype=self.vae.dtype,
|
405 |
+
)
|
406 |
+
img_latent = retrieve_latents(self.vae.encode(image_tensors), self.generator)
|
407 |
+
img_latent = img_latent * self.vae.config.scaling_factor
|
408 |
+
x_t_latent = self.add_noise(img_latent, self.init_noise[0], 0)
|
409 |
+
return x_t_latent
|
410 |
+
|
411 |
+
def decode_image(self, x_0_pred_out: torch.Tensor) -> torch.Tensor:
|
412 |
+
output_latent = self.vae.decode(
|
413 |
+
x_0_pred_out / self.vae.config.scaling_factor, return_dict=False
|
414 |
+
)[0]
|
415 |
+
return output_latent
|
416 |
+
|
417 |
+
def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor:
|
418 |
+
prev_latent_batch = self.x_t_latent_buffer
|
419 |
+
if self.use_denoising_batch:
|
420 |
+
t_list = self.sub_timesteps_tensor
|
421 |
+
if self.denoising_steps_num > 1:
|
422 |
+
x_t_latent = torch.cat((x_t_latent, prev_latent_batch), dim=0)
|
423 |
+
self.stock_noise = torch.cat(
|
424 |
+
(self.init_noise[0:1], self.stock_noise[:-1]), dim=0
|
425 |
+
)
|
426 |
+
x_0_pred_batch, model_pred = self.unet_step(x_t_latent, t_list)
|
427 |
+
|
428 |
+
if self.denoising_steps_num > 1:
|
429 |
+
x_0_pred_out = x_0_pred_batch[-1].unsqueeze(0)
|
430 |
+
if self.do_add_noise:
|
431 |
+
self.x_t_latent_buffer = (
|
432 |
+
self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]
|
433 |
+
+ self.beta_prod_t_sqrt[1:] * self.init_noise[1:]
|
434 |
+
)
|
435 |
+
else:
|
436 |
+
self.x_t_latent_buffer = (
|
437 |
+
self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]
|
438 |
+
)
|
439 |
+
else:
|
440 |
+
x_0_pred_out = x_0_pred_batch
|
441 |
+
self.x_t_latent_buffer = None
|
442 |
+
else:
|
443 |
+
self.init_noise = x_t_latent
|
444 |
+
for idx, t in enumerate(self.sub_timesteps_tensor):
|
445 |
+
t = t.view(
|
446 |
+
1,
|
447 |
+
).repeat(
|
448 |
+
self.frame_bff_size,
|
449 |
+
)
|
450 |
+
x_0_pred, model_pred = self.unet_step(x_t_latent, t, idx)
|
451 |
+
if idx < len(self.sub_timesteps_tensor) - 1:
|
452 |
+
if self.do_add_noise:
|
453 |
+
x_t_latent = self.alpha_prod_t_sqrt[
|
454 |
+
idx + 1
|
455 |
+
] * x_0_pred + self.beta_prod_t_sqrt[
|
456 |
+
idx + 1
|
457 |
+
] * torch.randn_like(
|
458 |
+
x_0_pred, device=self.device, dtype=self.dtype
|
459 |
+
)
|
460 |
+
else:
|
461 |
+
x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred
|
462 |
+
x_0_pred_out = x_0_pred
|
463 |
+
return x_0_pred_out
|
464 |
+
|
465 |
+
@torch.no_grad()
|
466 |
+
def __call__(
|
467 |
+
self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None
|
468 |
+
) -> torch.Tensor:
|
469 |
+
start = torch.cuda.Event(enable_timing=True)
|
470 |
+
end = torch.cuda.Event(enable_timing=True)
|
471 |
+
start.record()
|
472 |
+
if x is not None:
|
473 |
+
x = self.image_processor.preprocess(x, self.height, self.width).to(
|
474 |
+
device=self.device, dtype=self.dtype
|
475 |
+
)
|
476 |
+
if self.similar_image_filter:
|
477 |
+
x = self.similar_filter(x)
|
478 |
+
if x is None:
|
479 |
+
time.sleep(self.inference_time_ema)
|
480 |
+
return self.prev_image_result
|
481 |
+
x_t_latent = self.encode_image(x)
|
482 |
+
else:
|
483 |
+
# TODO: check the dimension of x_t_latent
|
484 |
+
x_t_latent = torch.randn((1, 4, self.latent_height, self.latent_width)).to(
|
485 |
+
device=self.device, dtype=self.dtype
|
486 |
+
)
|
487 |
+
x_0_pred_out = self.predict_x0_batch(x_t_latent)
|
488 |
+
x_output = self.decode_image(x_0_pred_out).detach().clone()
|
489 |
+
|
490 |
+
self.prev_image_result = x_output
|
491 |
+
end.record()
|
492 |
+
torch.cuda.synchronize()
|
493 |
+
inference_time = start.elapsed_time(end) / 1000
|
494 |
+
self.inference_time_ema = 0.9 * self.inference_time_ema + 0.1 * inference_time
|
495 |
+
return x_output
|
streamv2v/tools/__init__.py
ADDED
File without changes
|
streamv2v/tools/install-tensorrt.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Optional
|
2 |
+
|
3 |
+
import fire
|
4 |
+
from packaging.version import Version
|
5 |
+
|
6 |
+
from ..pip_utils import is_installed, run_pip, version
|
7 |
+
import platform
|
8 |
+
|
9 |
+
|
10 |
+
def get_cuda_version_from_torch() -> Optional[Literal["11", "12"]]:
|
11 |
+
try:
|
12 |
+
import torch
|
13 |
+
except ImportError:
|
14 |
+
return None
|
15 |
+
|
16 |
+
return torch.version.cuda.split(".")[0]
|
17 |
+
|
18 |
+
|
19 |
+
def install(cu: Optional[Literal["11", "12"]] = get_cuda_version_from_torch()):
|
20 |
+
if cu is None or cu not in ["11", "12"]:
|
21 |
+
print("Could not detect CUDA version. Please specify manually.")
|
22 |
+
return
|
23 |
+
print("Installing TensorRT requirements...")
|
24 |
+
|
25 |
+
if is_installed("tensorrt"):
|
26 |
+
if version("tensorrt") < Version("9.0.0"):
|
27 |
+
run_pip("uninstall -y tensorrt")
|
28 |
+
|
29 |
+
cudnn_name = f"nvidia-cudnn-cu{cu}==8.9.4.25"
|
30 |
+
|
31 |
+
if not is_installed("tensorrt"):
|
32 |
+
run_pip(f"install {cudnn_name} --no-cache-dir")
|
33 |
+
run_pip(
|
34 |
+
"install --pre --extra-index-url https://pypi.nvidia.com tensorrt==9.0.1.post11.dev4 --no-cache-dir"
|
35 |
+
)
|
36 |
+
|
37 |
+
if not is_installed("polygraphy"):
|
38 |
+
run_pip(
|
39 |
+
"install polygraphy==0.47.1 --extra-index-url https://pypi.ngc.nvidia.com"
|
40 |
+
)
|
41 |
+
if not is_installed("onnx_graphsurgeon"):
|
42 |
+
run_pip(
|
43 |
+
"install onnx-graphsurgeon==0.3.26 --extra-index-url https://pypi.ngc.nvidia.com"
|
44 |
+
)
|
45 |
+
# if platform.system() == 'Windows' and not is_installed("pywin32"):
|
46 |
+
# run_pip(
|
47 |
+
# "install pywin32"
|
48 |
+
# )
|
49 |
+
|
50 |
+
pass
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
fire.Fire(install)
|