structured-prompted-gif / inference.py
voidDescriptor's picture
Upload inference.py
96b91d1 verified
raw
history blame contribute delete
No virus
8.67 kB
# Copyright 2023 Natural Synthetics Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("/")
import os
import argparse
import torch
from hotshot_xl.pipelines.hotshot_xl_pipeline import HotshotXLPipeline
from hotshot_xl.pipelines.hotshot_xl_controlnet_pipeline import HotshotXLControlNetPipeline
from hotshot_xl.models.unet import UNet3DConditionModel
import torchvision.transforms as transforms
from einops import rearrange
from hotshot_xl.utils import save_as_gif, save_as_mp4, extract_gif_frames_from_midpoint, scale_aspect_fill
from torch import autocast
from diffusers import ControlNetModel
from contextlib import contextmanager
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
SCHEDULERS = {
'EulerAncestralDiscreteScheduler': EulerAncestralDiscreteScheduler,
'EulerDiscreteScheduler': EulerDiscreteScheduler,
'default': None,
# add more here
}
def parse_args():
parser = argparse.ArgumentParser(description="Hotshot-XL inference")
parser.add_argument("--pretrained_path", type=str, default="hotshotco/Hotshot-XL")
parser.add_argument("--xformers", action="store_true")
parser.add_argument("--spatial_unet_base", type=str)
parser.add_argument("--lora", type=str)
parser.add_argument("--output", type=str, required=True)
parser.add_argument("--steps", type=int, default=30)
parser.add_argument("--prompt", type=str,
default="a bulldog in the captains chair of a spaceship, hd, high quality")
parser.add_argument("--negative_prompt", type=str, default="blurry")
parser.add_argument("--seed", type=int, default=455)
parser.add_argument("--width", type=int, default=672)
parser.add_argument("--height", type=int, default=384)
parser.add_argument("--target_width", type=int, default=512)
parser.add_argument("--target_height", type=int, default=512)
parser.add_argument("--og_width", type=int, default=1920)
parser.add_argument("--og_height", type=int, default=1080)
parser.add_argument("--video_length", type=int, default=8)
parser.add_argument("--video_duration", type=int, default=1000)
parser.add_argument("--low_vram_mode", action="store_true")
parser.add_argument('--scheduler', type=str, default='EulerAncestralDiscreteScheduler',
help='Name of the scheduler to use')
parser.add_argument("--control_type", type=str, default=None, choices=["depth", "canny"])
parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
parser.add_argument("--control_guidance_start", type=float, default=0.0)
parser.add_argument("--control_guidance_end", type=float, default=1.0)
parser.add_argument("--gif", type=str, default=None)
parser.add_argument("--precision", type=str, default='f16', choices=[
'f16', 'f32', 'bf16'
])
parser.add_argument("--autocast", type=str, default=None, choices=[
'f16', 'bf16'
])
return parser.parse_args()
to_pil = transforms.ToPILImage()
def to_pil_images(video_frames: torch.Tensor, output_type='pil'):
video_frames = rearrange(video_frames, "b c f w h -> b f c w h")
bsz = video_frames.shape[0]
images = []
for i in range(bsz):
video = video_frames[i]
for j in range(video.shape[0]):
if output_type == "pil":
images.append(to_pil(video[j]))
else:
images.append(video[j])
return images
@contextmanager
def maybe_auto_cast(data_type):
if data_type:
with autocast("cuda", dtype=data_type):
yield
else:
yield
def main():
args = parse_args()
if args.control_type and not args.gif:
raise ValueError("Controlnet specified but you didn't specify a gif!")
if args.gif and not args.control_type:
print("warning: gif was specified but no control type was specified. gif will be ignored.")
output_dir = os.path.dirname(args.output)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
device = torch.device("cuda")
control_net_model_pretrained_path = None
if args.control_type:
control_type_to_model_map = {
"canny": "diffusers/controlnet-canny-sdxl-1.0",
"depth": "diffusers/controlnet-depth-sdxl-1.0",
}
control_net_model_pretrained_path = control_type_to_model_map[args.control_type]
data_type = torch.float32
if args.precision == 'f16':
data_type = torch.half
elif args.precision == 'f32':
data_type = torch.float32
elif args.precision == 'bf16':
data_type = torch.bfloat16
pipe_line_args = {
"torch_dtype": data_type,
"use_safetensors": True
}
PipelineClass = HotshotXLPipeline
if control_net_model_pretrained_path:
PipelineClass = HotshotXLControlNetPipeline
pipe_line_args['controlnet'] = \
ControlNetModel.from_pretrained(control_net_model_pretrained_path, torch_dtype=data_type)
if args.spatial_unet_base:
unet_3d = UNet3DConditionModel.from_pretrained(args.pretrained_path, subfolder="unet", torch_dtype=data_type).to(device)
unet = UNet3DConditionModel.from_pretrained_spatial(args.spatial_unet_base).to(device, dtype=data_type)
temporal_layers = {}
unet_3d_sd = unet_3d.state_dict()
for k, v in unet_3d_sd.items():
if 'temporal' in k:
temporal_layers[k] = v
unet.load_state_dict(temporal_layers, strict=False)
pipe_line_args['unet'] = unet
del unet_3d_sd
del unet_3d
del temporal_layers
pipe = PipelineClass.from_pretrained(args.pretrained_path, **pipe_line_args).to(device)
if args.lora:
pipe.load_lora_weights(args.lora)
SchedulerClass = SCHEDULERS[args.scheduler]
if SchedulerClass is not None:
pipe.scheduler = SchedulerClass.from_config(pipe.scheduler.config)
if args.xformers:
pipe.enable_xformers_memory_efficient_attention()
generator = torch.Generator().manual_seed(args.seed) if args.seed else None
autocast_type = None
if args.autocast == 'f16':
autocast_type = torch.half
elif args.autocast == 'bf16':
autocast_type = torch.bfloat16
if type(pipe) is HotshotXLControlNetPipeline:
kwargs = {}
else:
kwargs = {
"low_vram_mode": args.low_vram_mode
}
if args.gif and type(pipe) is HotshotXLControlNetPipeline:
kwargs['control_images'] = [
scale_aspect_fill(img, args.width, args.height).convert("RGB") \
for img in
extract_gif_frames_from_midpoint(args.gif, fps=args.video_length, target_duration=args.video_duration)
]
kwargs['controlnet_conditioning_scale'] = args.controlnet_conditioning_scale
kwargs['control_guidance_start'] = args.control_guidance_start
kwargs['control_guidance_end'] = args.control_guidance_end
with maybe_auto_cast(autocast_type):
images = pipe(args.prompt,
negative_prompt=args.negative_prompt,
width=args.width,
height=args.height,
original_size=(args.og_width, args.og_height),
target_size=(args.target_width, args.target_height),
num_inference_steps=args.steps,
video_length=args.video_length,
generator=generator,
output_type="tensor", **kwargs).videos
images = to_pil_images(images, output_type="pil")
if args.video_length > 1:
if args.output.split(".")[-1] == "gif":
save_as_gif(images, args.output, duration=args.video_duration // args.video_length)
else:
save_as_mp4(images, args.output, duration=args.video_duration // args.video_length)
else:
images[0].save(args.output, format='JPEG', quality=95)
if __name__ == "__main__":
main()