point_navigation / pipeline.py
hvent90's picture
Upload pipeline.py with huggingface_hub
85436ff verified
from typing import List, Optional, Tuple, Union
import torch
from diffusers import DiffusionPipeline, DDPMScheduler
from diffusers.utils import BaseOutput
from src.pipelines.point_navigation.components.model import TrajectoryDiffusionModel
class TrajectoryPipelineOutput(BaseOutput):
trajectories: torch.Tensor
class PointNavigationPipeline(DiffusionPipeline):
model: TrajectoryDiffusionModel
scheduler: DDPMScheduler
def __init__(self, model: TrajectoryDiffusionModel, scheduler: DDPMScheduler):
super().__init__()
self.register_modules(model=model, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
start: Union[List[float], Tuple[float, float]],
target: Union[List[float], Tuple[float, float]],
batch_size: int = 1,
num_inference_steps: int = 1000,
generator: Optional[torch.Generator] = None,
) -> TrajectoryPipelineOutput:
device = self.device
observation = torch.tensor(
[[start[0], start[1], target[0], target[1]]] * batch_size,
device=device,
dtype=torch.float32
)
trajectory = torch.randn(
(batch_size, 32, 2),
device=device,
generator=generator,
)
self.scheduler.set_timesteps(num_inference_steps)
for t in self.scheduler.timesteps:
noise_pred = self.model(
trajectory,
torch.tensor([t] * batch_size, device=device),
observation,
)
trajectory = self.scheduler.step(noise_pred, t, trajectory).prev_sample
return TrajectoryPipelineOutput(trajectories=trajectory)