voidDescriptor commited on
Commit
96b91d1
1 Parent(s): 8a943d8

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +231 -0
inference.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Natural Synthetics Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import sys
16
+
17
+ sys.path.append("/")
18
+ import os
19
+ import argparse
20
+ import torch
21
+ from hotshot_xl.pipelines.hotshot_xl_pipeline import HotshotXLPipeline
22
+ from hotshot_xl.pipelines.hotshot_xl_controlnet_pipeline import HotshotXLControlNetPipeline
23
+ from hotshot_xl.models.unet import UNet3DConditionModel
24
+ import torchvision.transforms as transforms
25
+ from einops import rearrange
26
+ from hotshot_xl.utils import save_as_gif, save_as_mp4, extract_gif_frames_from_midpoint, scale_aspect_fill
27
+ from torch import autocast
28
+ from diffusers import ControlNetModel
29
+ from contextlib import contextmanager
30
+ from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
31
+ from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
32
+
33
+ SCHEDULERS = {
34
+ 'EulerAncestralDiscreteScheduler': EulerAncestralDiscreteScheduler,
35
+ 'EulerDiscreteScheduler': EulerDiscreteScheduler,
36
+ 'default': None,
37
+ # add more here
38
+ }
39
+
40
+ def parse_args():
41
+ parser = argparse.ArgumentParser(description="Hotshot-XL inference")
42
+ parser.add_argument("--pretrained_path", type=str, default="hotshotco/Hotshot-XL")
43
+ parser.add_argument("--xformers", action="store_true")
44
+ parser.add_argument("--spatial_unet_base", type=str)
45
+ parser.add_argument("--lora", type=str)
46
+ parser.add_argument("--output", type=str, required=True)
47
+ parser.add_argument("--steps", type=int, default=30)
48
+ parser.add_argument("--prompt", type=str,
49
+ default="a bulldog in the captains chair of a spaceship, hd, high quality")
50
+ parser.add_argument("--negative_prompt", type=str, default="blurry")
51
+ parser.add_argument("--seed", type=int, default=455)
52
+ parser.add_argument("--width", type=int, default=672)
53
+ parser.add_argument("--height", type=int, default=384)
54
+ parser.add_argument("--target_width", type=int, default=512)
55
+ parser.add_argument("--target_height", type=int, default=512)
56
+ parser.add_argument("--og_width", type=int, default=1920)
57
+ parser.add_argument("--og_height", type=int, default=1080)
58
+ parser.add_argument("--video_length", type=int, default=8)
59
+ parser.add_argument("--video_duration", type=int, default=1000)
60
+ parser.add_argument("--low_vram_mode", action="store_true")
61
+ parser.add_argument('--scheduler', type=str, default='EulerAncestralDiscreteScheduler',
62
+ help='Name of the scheduler to use')
63
+
64
+ parser.add_argument("--control_type", type=str, default=None, choices=["depth", "canny"])
65
+ parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
66
+ parser.add_argument("--control_guidance_start", type=float, default=0.0)
67
+ parser.add_argument("--control_guidance_end", type=float, default=1.0)
68
+ parser.add_argument("--gif", type=str, default=None)
69
+ parser.add_argument("--precision", type=str, default='f16', choices=[
70
+ 'f16', 'f32', 'bf16'
71
+ ])
72
+ parser.add_argument("--autocast", type=str, default=None, choices=[
73
+ 'f16', 'bf16'
74
+ ])
75
+
76
+ return parser.parse_args()
77
+
78
+
79
+ to_pil = transforms.ToPILImage()
80
+
81
+
82
+ def to_pil_images(video_frames: torch.Tensor, output_type='pil'):
83
+ video_frames = rearrange(video_frames, "b c f w h -> b f c w h")
84
+ bsz = video_frames.shape[0]
85
+ images = []
86
+ for i in range(bsz):
87
+ video = video_frames[i]
88
+ for j in range(video.shape[0]):
89
+ if output_type == "pil":
90
+ images.append(to_pil(video[j]))
91
+ else:
92
+ images.append(video[j])
93
+ return images
94
+
95
+ @contextmanager
96
+ def maybe_auto_cast(data_type):
97
+ if data_type:
98
+ with autocast("cuda", dtype=data_type):
99
+ yield
100
+ else:
101
+ yield
102
+
103
+
104
+ def main():
105
+ args = parse_args()
106
+
107
+ if args.control_type and not args.gif:
108
+ raise ValueError("Controlnet specified but you didn't specify a gif!")
109
+
110
+ if args.gif and not args.control_type:
111
+ print("warning: gif was specified but no control type was specified. gif will be ignored.")
112
+
113
+ output_dir = os.path.dirname(args.output)
114
+ if output_dir:
115
+ os.makedirs(output_dir, exist_ok=True)
116
+
117
+ device = torch.device("cuda")
118
+
119
+ control_net_model_pretrained_path = None
120
+ if args.control_type:
121
+ control_type_to_model_map = {
122
+ "canny": "diffusers/controlnet-canny-sdxl-1.0",
123
+ "depth": "diffusers/controlnet-depth-sdxl-1.0",
124
+ }
125
+ control_net_model_pretrained_path = control_type_to_model_map[args.control_type]
126
+
127
+ data_type = torch.float32
128
+
129
+ if args.precision == 'f16':
130
+ data_type = torch.half
131
+ elif args.precision == 'f32':
132
+ data_type = torch.float32
133
+ elif args.precision == 'bf16':
134
+ data_type = torch.bfloat16
135
+
136
+ pipe_line_args = {
137
+ "torch_dtype": data_type,
138
+ "use_safetensors": True
139
+ }
140
+
141
+ PipelineClass = HotshotXLPipeline
142
+
143
+ if control_net_model_pretrained_path:
144
+ PipelineClass = HotshotXLControlNetPipeline
145
+ pipe_line_args['controlnet'] = \
146
+ ControlNetModel.from_pretrained(control_net_model_pretrained_path, torch_dtype=data_type)
147
+
148
+ if args.spatial_unet_base:
149
+
150
+ unet_3d = UNet3DConditionModel.from_pretrained(args.pretrained_path, subfolder="unet", torch_dtype=data_type).to(device)
151
+
152
+ unet = UNet3DConditionModel.from_pretrained_spatial(args.spatial_unet_base).to(device, dtype=data_type)
153
+
154
+ temporal_layers = {}
155
+ unet_3d_sd = unet_3d.state_dict()
156
+
157
+ for k, v in unet_3d_sd.items():
158
+ if 'temporal' in k:
159
+ temporal_layers[k] = v
160
+
161
+ unet.load_state_dict(temporal_layers, strict=False)
162
+
163
+ pipe_line_args['unet'] = unet
164
+
165
+ del unet_3d_sd
166
+ del unet_3d
167
+ del temporal_layers
168
+
169
+ pipe = PipelineClass.from_pretrained(args.pretrained_path, **pipe_line_args).to(device)
170
+
171
+ if args.lora:
172
+ pipe.load_lora_weights(args.lora)
173
+
174
+ SchedulerClass = SCHEDULERS[args.scheduler]
175
+ if SchedulerClass is not None:
176
+ pipe.scheduler = SchedulerClass.from_config(pipe.scheduler.config)
177
+
178
+ if args.xformers:
179
+ pipe.enable_xformers_memory_efficient_attention()
180
+
181
+ generator = torch.Generator().manual_seed(args.seed) if args.seed else None
182
+
183
+ autocast_type = None
184
+ if args.autocast == 'f16':
185
+ autocast_type = torch.half
186
+ elif args.autocast == 'bf16':
187
+ autocast_type = torch.bfloat16
188
+
189
+ if type(pipe) is HotshotXLControlNetPipeline:
190
+ kwargs = {}
191
+ else:
192
+ kwargs = {
193
+ "low_vram_mode": args.low_vram_mode
194
+ }
195
+
196
+ if args.gif and type(pipe) is HotshotXLControlNetPipeline:
197
+ kwargs['control_images'] = [
198
+ scale_aspect_fill(img, args.width, args.height).convert("RGB") \
199
+ for img in
200
+ extract_gif_frames_from_midpoint(args.gif, fps=args.video_length, target_duration=args.video_duration)
201
+ ]
202
+ kwargs['controlnet_conditioning_scale'] = args.controlnet_conditioning_scale
203
+ kwargs['control_guidance_start'] = args.control_guidance_start
204
+ kwargs['control_guidance_end'] = args.control_guidance_end
205
+
206
+ with maybe_auto_cast(autocast_type):
207
+
208
+ images = pipe(args.prompt,
209
+ negative_prompt=args.negative_prompt,
210
+ width=args.width,
211
+ height=args.height,
212
+ original_size=(args.og_width, args.og_height),
213
+ target_size=(args.target_width, args.target_height),
214
+ num_inference_steps=args.steps,
215
+ video_length=args.video_length,
216
+ generator=generator,
217
+ output_type="tensor", **kwargs).videos
218
+
219
+ images = to_pil_images(images, output_type="pil")
220
+
221
+ if args.video_length > 1:
222
+ if args.output.split(".")[-1] == "gif":
223
+ save_as_gif(images, args.output, duration=args.video_duration // args.video_length)
224
+ else:
225
+ save_as_mp4(images, args.output, duration=args.video_duration // args.video_length)
226
+ else:
227
+ images[0].save(args.output, format='JPEG', quality=95)
228
+
229
+
230
+ if __name__ == "__main__":
231
+ main()