import numpy as np import torch import torch.nn.functional as F import torchvision.transforms.functional as TF from diffusers import DiffusionPipeline class LGMPipeline(DiffusionPipeline): def __init__(self, lgm): super().__init__() self.imagenet_default_mean = (0.485, 0.456, 0.406) self.imagenet_default_std = (0.229, 0.224, 0.225) lgm = lgm.half().cuda() self.register_modules(lgm=lgm) def save_ply(self, gaussians, path): self.lgm.gs.save_ply(gaussians, path) @torch.no_grad() def __call__(self, images): images = np.stack([images[1], images[2], images[3], images[0]], axis=0) images = torch.from_numpy(images).permute(0, 3, 1, 2).float().cuda() images = F.interpolate( images, size=(256, 256), mode="bilinear", align_corners=False, ) images = TF.normalize( images, self.imagenet_default_mean, self.imagenet_default_std ) rays_embeddings = self.lgm.prepare_default_rays("cuda", elevation=0) images = torch.cat([images, rays_embeddings], dim=1).unsqueeze(0) images = images.half().cuda() result = self.lgm(images) return result