LGM / pipeline.py
dylanebert's picture
dylanebert HF staff
Correct final rotation
1bad10f
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from diffusers import DiffusionPipeline
class LGMPipeline(DiffusionPipeline):
def __init__(self, lgm):
super().__init__()
self.imagenet_default_mean = (0.485, 0.456, 0.406)
self.imagenet_default_std = (0.229, 0.224, 0.225)
lgm = lgm.half().cuda()
self.register_modules(lgm=lgm)
def save_ply(self, gaussians, path):
self.lgm.gs.save_ply(gaussians, path)
@torch.no_grad()
def __call__(self, images):
images = np.stack([images[1], images[2], images[3], images[0]], axis=0)
images = torch.from_numpy(images).permute(0, 3, 1, 2).float().cuda()
images = F.interpolate(
images,
size=(256, 256),
mode="bilinear",
align_corners=False,
)
images = TF.normalize(
images, self.imagenet_default_mean, self.imagenet_default_std
)
rays_embeddings = self.lgm.prepare_default_rays("cuda", elevation=0)
images = torch.cat([images, rays_embeddings], dim=1).unsqueeze(0)
images = images.half().cuda()
result = self.lgm(images)
return result