Spaces:
Runtime error
Runtime error
# General | |
import os | |
from os.path import join as opj | |
import datetime | |
import torch | |
from einops import rearrange, repeat | |
# Utilities | |
from t2v_enhanced.inference_utils import * | |
from modelscope.outputs import OutputKeys | |
import imageio | |
from PIL import Image | |
import numpy as np | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from diffusers.utils import load_image | |
transform = transforms.Compose([ | |
transforms.PILToTensor() | |
]) | |
def ms_short_gen(prompt, ms_model, inference_generator, t=50, device="cuda"): | |
frames = ms_model(prompt, | |
num_inference_steps=t, | |
generator=inference_generator, | |
eta=1.0, | |
height=256, | |
width=256, | |
latents=None).frames | |
frames = torch.stack([torch.from_numpy(frame) for frame in frames]) | |
frames = frames.to(device).to(torch.float32) | |
return rearrange(frames[0], "F W H C -> F C W H") | |
def ad_short_gen(prompt, ad_model, inference_generator, t=25, device="cuda"): | |
frames = ad_model(prompt, | |
negative_prompt="bad quality, worse quality", | |
num_frames=16, | |
num_inference_steps=t, | |
generator=inference_generator, | |
guidance_scale=7.5).frames[0] | |
frames = torch.stack([transform(frame) for frame in frames]) | |
frames = frames.to(device).to(torch.float32) | |
frames = F.interpolate(frames, size=256) | |
frames = frames/255.0 | |
return frames | |
def sdxl_image_gen(prompt, sdxl_model): | |
image = sdxl_model(prompt=prompt).images[0] | |
return image | |
def svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t=25, device="cuda"): | |
if image is None: | |
image = sdxl_image_gen(prompt, sdxl_model) | |
image = image.resize((576, 576)) | |
image = add_margin(image, 0, 224, 0, 224, (0, 0, 0)) | |
elif type(image) is str: | |
image = load_image(image) | |
image = resize_and_keep(image) | |
image = center_crop(image) | |
image = add_margin(image, 0, 224, 0, 224, (0, 0, 0)) | |
else: | |
image = Image.fromarray(np.uint8(image)) | |
image = resize_and_keep(image) | |
image = center_crop(image) | |
image = add_margin(image, 0, 224, 0, 224, (0, 0, 0)) | |
frames = svd_model(image, decode_chunk_size=4, generator=inference_generator).frames[0] | |
frames = torch.stack([transform(frame) for frame in frames]) | |
frames = frames.to(device).to(torch.float32) | |
frames = frames[:16,:,:,224:-224] | |
frames = F.interpolate(frames, size=256) | |
frames = frames/255.0 | |
return frames | |
def stream_long_gen(prompt, short_video, n_autoreg_gen, seed, t, image_guidance, result_file_stem, stream_cli, stream_model): | |
trainer = stream_cli.trainer | |
trainer.limit_predict_batches = 1 | |
trainer.predict_cfg = { | |
"predict_dir": stream_cli.config["result_fol"].as_posix(), | |
"result_file_stem": result_file_stem, | |
"prompt": prompt, | |
"video": short_video, | |
"seed": seed, | |
"num_inference_steps": t, | |
"guidance_scale": image_guidance, | |
'n_autoregressive_generations': n_autoreg_gen, | |
} | |
trainer.predict(model=stream_model, datamodule=stream_cli.datamodule) | |
def video2video(prompt, video, where_to_log, cfg_v2v, model_v2v, square=True): | |
downscale = cfg_v2v['downscale'] | |
upscale_size = cfg_v2v['upscale_size'] | |
pad = cfg_v2v['pad'] | |
now = datetime.datetime.now() | |
now = str(now.time()).replace(":", "_").replace(".", "_") | |
name = prompt[:100].replace(" ", "_") + "_" + now | |
enhanced_video_mp4 = opj(where_to_log, name+"_enhanced.mp4") | |
video_frames = imageio.mimread(video) | |
h, w, _ = video_frames[0].shape | |
# Downscale video, then resize to fit the upscale size | |
video = [Image.fromarray(frame).resize((w//downscale, h//downscale)) for frame in video_frames] | |
video = [resize_to_fit(frame, upscale_size) for frame in video] | |
if pad: | |
video = [pad_to_fit(frame, upscale_size) for frame in video] | |
# video = [np.array(frame) for frame in video] | |
imageio.mimsave(opj(where_to_log, 'temp_'+now+'.mp4'), video, fps=8) | |
p_input = { | |
'video_path': opj(where_to_log, 'temp_'+now+'.mp4'), | |
'text': prompt | |
} | |
output_video_path = model_v2v(p_input, output_video=enhanced_video_mp4)[OutputKeys.OUTPUT_VIDEO] | |
# Remove padding | |
video_frames = imageio.mimread(enhanced_video_mp4) | |
video_frames_square = [] | |
for frame in video_frames: | |
frame = frame[:, 280:-280, :] | |
video_frames_square.append(frame) | |
imageio.mimsave(enhanced_video_mp4, video_frames_square) | |
return enhanced_video_mp4 | |