LinB203
update
bab971b
raw
history blame
No virus
1.72 kB
import sys
sys.path.append(".")
from PIL import Image
import torch
from torchvision.transforms import ToTensor, Compose, Resize, Normalize
from torch.nn import functional as F
from opensora.models.ae.videobase import CausalVAEModel
import argparse
import numpy as np
def preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor:
transform = Compose(
[
ToTensor(),
Normalize((0.5), (0.5)),
Resize(size=short_size),
]
)
outputs = transform(video_data)
outputs = outputs.unsqueeze(0).unsqueeze(2)
return outputs
def main(args: argparse.Namespace):
image_path = args.image_path
resolution = args.resolution
device = args.device
vqvae = CausalVAEModel.load_from_checkpoint(args.ckpt)
vqvae.eval()
vqvae = vqvae.to(device)
with torch.no_grad():
x_vae = preprocess(Image.open(image_path), resolution)
x_vae = x_vae.to(device)
latents = vqvae.encode(x_vae)
recon = vqvae.decode(latents.sample())
x = recon[0, :, 0, :, :]
x = x.squeeze()
x = x.detach().cpu().numpy()
x = np.clip(x, -1, 1)
x = (x + 1) / 2
x = (255*x).astype(np.uint8)
x = x.transpose(1,2,0)
image = Image.fromarray(x)
image.save(args.rec_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image-path', type=str, default='')
parser.add_argument('--rec-path', type=str, default='')
parser.add_argument('--ckpt', type=str, default='')
parser.add_argument('--resolution', type=int, default=336)
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
main(args)