import sys sys.path.append(".") from PIL import Image import torch from torchvision.transforms import ToTensor, Compose, Resize, Normalize from torch.nn import functional as F from opensora.models.ae.videobase import CausalVAEModel import argparse import numpy as np def preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor: transform = Compose( [ ToTensor(), Normalize((0.5), (0.5)), Resize(size=short_size), ] ) outputs = transform(video_data) outputs = outputs.unsqueeze(0).unsqueeze(2) return outputs def main(args: argparse.Namespace): image_path = args.image_path resolution = args.resolution device = args.device vqvae = CausalVAEModel.load_from_checkpoint(args.ckpt) vqvae.eval() vqvae = vqvae.to(device) with torch.no_grad(): x_vae = preprocess(Image.open(image_path), resolution) x_vae = x_vae.to(device) latents = vqvae.encode(x_vae) recon = vqvae.decode(latents.sample()) x = recon[0, :, 0, :, :] x = x.squeeze() x = x.detach().cpu().numpy() x = np.clip(x, -1, 1) x = (x + 1) / 2 x = (255*x).astype(np.uint8) x = x.transpose(1,2,0) image = Image.fromarray(x) image.save(args.rec_path) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--image-path', type=str, default='') parser.add_argument('--rec-path', type=str, default='') parser.add_argument('--ckpt', type=str, default='') parser.add_argument('--resolution', type=int, default=336) parser.add_argument('--device', type=str, default='cuda') args = parser.parse_args() main(args)