import torch from torch.utils.data import DataLoader, Subset import sys sys.path.append(".") from opensora.models.ae.videobase import CausalVAEModel, CausalVAEDataset num_workers = 4 batch_size = 12 torch.manual_seed(0) torch.set_grad_enabled(False) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') pretrained_model_name_or_path = 'results/causalvae/checkpoint-26000' data_path = '/remote-home1/dataset/UCF-101' video_num_frames = 17 resolution = 128 sample_rate = 10 vae = CausalVAEModel.load_from_checkpoint(pretrained_model_name_or_path) vae.to(device) dataset = CausalVAEDataset(data_path, sequence_length=video_num_frames, resolution=resolution, sample_rate=sample_rate) subset_indices = list(range(1000)) subset_dataset = Subset(dataset, subset_indices) loader = DataLoader(subset_dataset, batch_size=8, pin_memory=True) all_latents = [] for video_data in loader: video_data = video_data['video'].to(device) latents = vae.encode(video_data).sample() all_latents.append(video_data.cpu()) all_latents_tensor = torch.cat(all_latents) std = all_latents_tensor.std().item() normalizer = 1 / std print(f'{normalizer = }')