fffiloni's picture
Upload 244 files
b3f324b verified
raw
history blame
No virus
1.72 kB
import sys
sys.path.append(".")
from PIL import Image
import torch
from torchvision.transforms import ToTensor, Compose, Resize, Normalize
from torch.nn import functional as F
from opensora.models.ae.videobase import CausalVAEModel
import argparse
import numpy as np
def preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor:
transform = Compose(
[
ToTensor(),
Normalize((0.5), (0.5)),
Resize(size=short_size),
]
)
outputs = transform(video_data)
outputs = outputs.unsqueeze(0).unsqueeze(2)
return outputs
def main(args: argparse.Namespace):
image_path = args.image_path
resolution = args.resolution
device = args.device
vqvae = CausalVAEModel.load_from_checkpoint(args.ckpt)
vqvae.eval()
vqvae = vqvae.to(device)
with torch.no_grad():
x_vae = preprocess(Image.open(image_path), resolution)
x_vae = x_vae.to(device)
latents = vqvae.encode(x_vae)
recon = vqvae.decode(latents.sample())
x = recon[0, :, 0, :, :]
x = x.squeeze()
x = x.detach().cpu().numpy()
x = np.clip(x, -1, 1)
x = (x + 1) / 2
x = (255*x).astype(np.uint8)
x = x.transpose(1,2,0)
image = Image.fromarray(x)
image.save(args.rec_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image-path', type=str, default='')
parser.add_argument('--rec-path', type=str, default='')
parser.add_argument('--ckpt', type=str, default='')
parser.add_argument('--resolution', type=int, default=336)
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
main(args)