Spaces:
Runtime error
Runtime error
Commit
•
96b91d1
1
Parent(s):
8a943d8
Upload inference.py
Browse files- inference.py +231 -0
inference.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Natural Synthetics Inc. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import sys
|
16 |
+
|
17 |
+
sys.path.append("/")
|
18 |
+
import os
|
19 |
+
import argparse
|
20 |
+
import torch
|
21 |
+
from hotshot_xl.pipelines.hotshot_xl_pipeline import HotshotXLPipeline
|
22 |
+
from hotshot_xl.pipelines.hotshot_xl_controlnet_pipeline import HotshotXLControlNetPipeline
|
23 |
+
from hotshot_xl.models.unet import UNet3DConditionModel
|
24 |
+
import torchvision.transforms as transforms
|
25 |
+
from einops import rearrange
|
26 |
+
from hotshot_xl.utils import save_as_gif, save_as_mp4, extract_gif_frames_from_midpoint, scale_aspect_fill
|
27 |
+
from torch import autocast
|
28 |
+
from diffusers import ControlNetModel
|
29 |
+
from contextlib import contextmanager
|
30 |
+
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
|
31 |
+
from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
|
32 |
+
|
33 |
+
SCHEDULERS = {
|
34 |
+
'EulerAncestralDiscreteScheduler': EulerAncestralDiscreteScheduler,
|
35 |
+
'EulerDiscreteScheduler': EulerDiscreteScheduler,
|
36 |
+
'default': None,
|
37 |
+
# add more here
|
38 |
+
}
|
39 |
+
|
40 |
+
def parse_args():
|
41 |
+
parser = argparse.ArgumentParser(description="Hotshot-XL inference")
|
42 |
+
parser.add_argument("--pretrained_path", type=str, default="hotshotco/Hotshot-XL")
|
43 |
+
parser.add_argument("--xformers", action="store_true")
|
44 |
+
parser.add_argument("--spatial_unet_base", type=str)
|
45 |
+
parser.add_argument("--lora", type=str)
|
46 |
+
parser.add_argument("--output", type=str, required=True)
|
47 |
+
parser.add_argument("--steps", type=int, default=30)
|
48 |
+
parser.add_argument("--prompt", type=str,
|
49 |
+
default="a bulldog in the captains chair of a spaceship, hd, high quality")
|
50 |
+
parser.add_argument("--negative_prompt", type=str, default="blurry")
|
51 |
+
parser.add_argument("--seed", type=int, default=455)
|
52 |
+
parser.add_argument("--width", type=int, default=672)
|
53 |
+
parser.add_argument("--height", type=int, default=384)
|
54 |
+
parser.add_argument("--target_width", type=int, default=512)
|
55 |
+
parser.add_argument("--target_height", type=int, default=512)
|
56 |
+
parser.add_argument("--og_width", type=int, default=1920)
|
57 |
+
parser.add_argument("--og_height", type=int, default=1080)
|
58 |
+
parser.add_argument("--video_length", type=int, default=8)
|
59 |
+
parser.add_argument("--video_duration", type=int, default=1000)
|
60 |
+
parser.add_argument("--low_vram_mode", action="store_true")
|
61 |
+
parser.add_argument('--scheduler', type=str, default='EulerAncestralDiscreteScheduler',
|
62 |
+
help='Name of the scheduler to use')
|
63 |
+
|
64 |
+
parser.add_argument("--control_type", type=str, default=None, choices=["depth", "canny"])
|
65 |
+
parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
|
66 |
+
parser.add_argument("--control_guidance_start", type=float, default=0.0)
|
67 |
+
parser.add_argument("--control_guidance_end", type=float, default=1.0)
|
68 |
+
parser.add_argument("--gif", type=str, default=None)
|
69 |
+
parser.add_argument("--precision", type=str, default='f16', choices=[
|
70 |
+
'f16', 'f32', 'bf16'
|
71 |
+
])
|
72 |
+
parser.add_argument("--autocast", type=str, default=None, choices=[
|
73 |
+
'f16', 'bf16'
|
74 |
+
])
|
75 |
+
|
76 |
+
return parser.parse_args()
|
77 |
+
|
78 |
+
|
79 |
+
to_pil = transforms.ToPILImage()
|
80 |
+
|
81 |
+
|
82 |
+
def to_pil_images(video_frames: torch.Tensor, output_type='pil'):
|
83 |
+
video_frames = rearrange(video_frames, "b c f w h -> b f c w h")
|
84 |
+
bsz = video_frames.shape[0]
|
85 |
+
images = []
|
86 |
+
for i in range(bsz):
|
87 |
+
video = video_frames[i]
|
88 |
+
for j in range(video.shape[0]):
|
89 |
+
if output_type == "pil":
|
90 |
+
images.append(to_pil(video[j]))
|
91 |
+
else:
|
92 |
+
images.append(video[j])
|
93 |
+
return images
|
94 |
+
|
95 |
+
@contextmanager
|
96 |
+
def maybe_auto_cast(data_type):
|
97 |
+
if data_type:
|
98 |
+
with autocast("cuda", dtype=data_type):
|
99 |
+
yield
|
100 |
+
else:
|
101 |
+
yield
|
102 |
+
|
103 |
+
|
104 |
+
def main():
|
105 |
+
args = parse_args()
|
106 |
+
|
107 |
+
if args.control_type and not args.gif:
|
108 |
+
raise ValueError("Controlnet specified but you didn't specify a gif!")
|
109 |
+
|
110 |
+
if args.gif and not args.control_type:
|
111 |
+
print("warning: gif was specified but no control type was specified. gif will be ignored.")
|
112 |
+
|
113 |
+
output_dir = os.path.dirname(args.output)
|
114 |
+
if output_dir:
|
115 |
+
os.makedirs(output_dir, exist_ok=True)
|
116 |
+
|
117 |
+
device = torch.device("cuda")
|
118 |
+
|
119 |
+
control_net_model_pretrained_path = None
|
120 |
+
if args.control_type:
|
121 |
+
control_type_to_model_map = {
|
122 |
+
"canny": "diffusers/controlnet-canny-sdxl-1.0",
|
123 |
+
"depth": "diffusers/controlnet-depth-sdxl-1.0",
|
124 |
+
}
|
125 |
+
control_net_model_pretrained_path = control_type_to_model_map[args.control_type]
|
126 |
+
|
127 |
+
data_type = torch.float32
|
128 |
+
|
129 |
+
if args.precision == 'f16':
|
130 |
+
data_type = torch.half
|
131 |
+
elif args.precision == 'f32':
|
132 |
+
data_type = torch.float32
|
133 |
+
elif args.precision == 'bf16':
|
134 |
+
data_type = torch.bfloat16
|
135 |
+
|
136 |
+
pipe_line_args = {
|
137 |
+
"torch_dtype": data_type,
|
138 |
+
"use_safetensors": True
|
139 |
+
}
|
140 |
+
|
141 |
+
PipelineClass = HotshotXLPipeline
|
142 |
+
|
143 |
+
if control_net_model_pretrained_path:
|
144 |
+
PipelineClass = HotshotXLControlNetPipeline
|
145 |
+
pipe_line_args['controlnet'] = \
|
146 |
+
ControlNetModel.from_pretrained(control_net_model_pretrained_path, torch_dtype=data_type)
|
147 |
+
|
148 |
+
if args.spatial_unet_base:
|
149 |
+
|
150 |
+
unet_3d = UNet3DConditionModel.from_pretrained(args.pretrained_path, subfolder="unet", torch_dtype=data_type).to(device)
|
151 |
+
|
152 |
+
unet = UNet3DConditionModel.from_pretrained_spatial(args.spatial_unet_base).to(device, dtype=data_type)
|
153 |
+
|
154 |
+
temporal_layers = {}
|
155 |
+
unet_3d_sd = unet_3d.state_dict()
|
156 |
+
|
157 |
+
for k, v in unet_3d_sd.items():
|
158 |
+
if 'temporal' in k:
|
159 |
+
temporal_layers[k] = v
|
160 |
+
|
161 |
+
unet.load_state_dict(temporal_layers, strict=False)
|
162 |
+
|
163 |
+
pipe_line_args['unet'] = unet
|
164 |
+
|
165 |
+
del unet_3d_sd
|
166 |
+
del unet_3d
|
167 |
+
del temporal_layers
|
168 |
+
|
169 |
+
pipe = PipelineClass.from_pretrained(args.pretrained_path, **pipe_line_args).to(device)
|
170 |
+
|
171 |
+
if args.lora:
|
172 |
+
pipe.load_lora_weights(args.lora)
|
173 |
+
|
174 |
+
SchedulerClass = SCHEDULERS[args.scheduler]
|
175 |
+
if SchedulerClass is not None:
|
176 |
+
pipe.scheduler = SchedulerClass.from_config(pipe.scheduler.config)
|
177 |
+
|
178 |
+
if args.xformers:
|
179 |
+
pipe.enable_xformers_memory_efficient_attention()
|
180 |
+
|
181 |
+
generator = torch.Generator().manual_seed(args.seed) if args.seed else None
|
182 |
+
|
183 |
+
autocast_type = None
|
184 |
+
if args.autocast == 'f16':
|
185 |
+
autocast_type = torch.half
|
186 |
+
elif args.autocast == 'bf16':
|
187 |
+
autocast_type = torch.bfloat16
|
188 |
+
|
189 |
+
if type(pipe) is HotshotXLControlNetPipeline:
|
190 |
+
kwargs = {}
|
191 |
+
else:
|
192 |
+
kwargs = {
|
193 |
+
"low_vram_mode": args.low_vram_mode
|
194 |
+
}
|
195 |
+
|
196 |
+
if args.gif and type(pipe) is HotshotXLControlNetPipeline:
|
197 |
+
kwargs['control_images'] = [
|
198 |
+
scale_aspect_fill(img, args.width, args.height).convert("RGB") \
|
199 |
+
for img in
|
200 |
+
extract_gif_frames_from_midpoint(args.gif, fps=args.video_length, target_duration=args.video_duration)
|
201 |
+
]
|
202 |
+
kwargs['controlnet_conditioning_scale'] = args.controlnet_conditioning_scale
|
203 |
+
kwargs['control_guidance_start'] = args.control_guidance_start
|
204 |
+
kwargs['control_guidance_end'] = args.control_guidance_end
|
205 |
+
|
206 |
+
with maybe_auto_cast(autocast_type):
|
207 |
+
|
208 |
+
images = pipe(args.prompt,
|
209 |
+
negative_prompt=args.negative_prompt,
|
210 |
+
width=args.width,
|
211 |
+
height=args.height,
|
212 |
+
original_size=(args.og_width, args.og_height),
|
213 |
+
target_size=(args.target_width, args.target_height),
|
214 |
+
num_inference_steps=args.steps,
|
215 |
+
video_length=args.video_length,
|
216 |
+
generator=generator,
|
217 |
+
output_type="tensor", **kwargs).videos
|
218 |
+
|
219 |
+
images = to_pil_images(images, output_type="pil")
|
220 |
+
|
221 |
+
if args.video_length > 1:
|
222 |
+
if args.output.split(".")[-1] == "gif":
|
223 |
+
save_as_gif(images, args.output, duration=args.video_duration // args.video_length)
|
224 |
+
else:
|
225 |
+
save_as_mp4(images, args.output, duration=args.video_duration // args.video_length)
|
226 |
+
else:
|
227 |
+
images[0].save(args.output, format='JPEG', quality=95)
|
228 |
+
|
229 |
+
|
230 |
+
if __name__ == "__main__":
|
231 |
+
main()
|