Spaces:
Runtime error
Runtime error
File size: 5,860 Bytes
bab971b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import math
import random
import argparse
from typing import Optional
import cv2
import numpy as np
import numpy.typing as npt
import torch
from PIL import Image
from decord import VideoReader, cpu
from torch.nn import functional as F
from pytorchvideo.transforms import ShortSideScale
from torchvision.transforms import Lambda, Compose
import sys
sys.path.append(".")
from opensora.models.ae import getae_wrapper
from opensora.dataset.transform import CenterCropVideo, resize
from opensora.models.ae.videobase import CausalVAEModel
def array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None:
height, width, channels = image_array[0].shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height))
for image in image_array:
image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
video_writer.write(image_rgb)
video_writer.release()
def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None:
x = x.detach().cpu()
x = torch.clamp(x, -1, 1)
x = (x + 1) / 2
x = x.permute(0, 2, 3, 1).numpy()
x = (255 * x).astype(np.uint8)
array_to_video(x, fps=fps, output_file=output_file)
return
def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor:
decord_vr = VideoReader(video_path, ctx=cpu(0))
total_frames = len(decord_vr)
sample_frames_len = sample_rate * num_frames
if total_frames > sample_frames_len:
s = random.randint(0, total_frames - sample_frames_len - 1)
s = 0
e = s + sample_frames_len
num_frames = num_frames
else:
s = 0
e = total_frames
num_frames = int(total_frames / sample_frames_len * num_frames)
print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path,
total_frames)
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
video_data = decord_vr.get_batch(frame_id_list).asnumpy()
video_data = torch.from_numpy(video_data)
video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
return video_data
class ResizeVideo:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
self.size = size
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
_, _, h, w = clip.shape
if w < h:
new_h = int(math.floor((float(h) / w) * self.size))
new_w = self.size
else:
new_h = self.size
new_w = int(math.floor((float(w) / h) * self.size))
return torch.nn.functional.interpolate(
clip, size=(new_h, new_w), mode=self.interpolation_mode, align_corners=False, antialias=True
)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
def preprocess(video_data: torch.Tensor, short_size: int = 128, crop_size: Optional[int] = None) -> torch.Tensor:
transform = Compose(
[
Lambda(lambda x: ((x / 255.0) * 2 - 1)),
ResizeVideo(size=short_size),
CenterCropVideo(crop_size) if crop_size is not None else Lambda(lambda x: x),
]
)
video_outputs = transform(video_data)
video_outputs = torch.unsqueeze(video_outputs, 0)
return video_outputs
def main(args: argparse.Namespace):
device = args.device
kwarg = {}
# vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir', **kwarg).to(device)
vae = getae_wrapper(args.ae)(args.ae_path, **kwarg).to(device)
if args.enable_tiling:
vae.vae.enable_tiling()
vae.vae.tile_overlap_factor = args.tile_overlap_factor
vae.eval()
vae = vae.to(device)
vae = vae.half()
with torch.no_grad():
x_vae = preprocess(read_video(args.video_path, args.num_frames, args.sample_rate), args.resolution,
args.crop_size)
x_vae = x_vae.to(device, dtype=torch.float16) # b c t h w
# from tqdm import tqdm
# for i in tqdm(range(10000000)):
latents = vae.encode(x_vae)
latents = latents.to(torch.float16)
video_recon = vae.decode(latents) # b t c h w
if video_recon.shape[2] == 1:
x = video_recon[0, 0, :, :, :]
x = x.squeeze()
x = x.detach().cpu().numpy()
x = np.clip(x, -1, 1)
x = (x + 1) / 2
x = (255 * x).astype(np.uint8)
x = x.transpose(1, 2, 0)
image = Image.fromarray(x)
image.save(args.rec_path.replace('mp4', 'jpg'))
else:
custom_to_video(video_recon[0], fps=args.fps, output_file=args.rec_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--video_path', type=str, default='')
parser.add_argument('--rec_path', type=str, default='')
parser.add_argument('--ae', type=str, default='')
parser.add_argument('--ae_path', type=str, default='')
parser.add_argument('--model_path', type=str, default='results/pretrained')
parser.add_argument('--fps', type=int, default=30)
parser.add_argument('--resolution', type=int, default=336)
parser.add_argument('--crop_size', type=int, default=None)
parser.add_argument('--num_frames', type=int, default=100)
parser.add_argument('--sample_rate', type=int, default=1)
parser.add_argument('--device', type=str, default="cuda")
parser.add_argument('--tile_overlap_factor', type=float, default=0.25)
parser.add_argument('--enable_tiling', action='store_true')
parser.add_argument('--enable_time_chunk', action='store_true')
args = parser.parse_args()
main(args)
|