| from typing import List, Optional, Tuple |
| import os |
| import torch |
| from torch.utils.data import Dataset |
|
|
|
|
| def load_data_from_dir( |
| data_folder: str, limit: int = 200 |
| ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[Optional[torch.Tensor]], List[Optional[torch.Tensor]]]: |
| latents, targets, conditions, unconditions = [], [], [], [] |
| pt_files = [f for f in os.listdir(data_folder) if f.endswith('pt')] |
| for file_name in sorted(pt_files)[:limit]: |
| file_path = os.path.join(data_folder, file_name) |
| data = torch.load(file_path) |
| latents.append(data["latent"]) |
| targets.append(data["img"]) |
| conditions.append(data.get("c", None)) |
| unconditions.append(data.get("uc", None)) |
| return latents, targets, conditions, unconditions |
|
|
|
|
| class LD3Dataset(Dataset): |
| def __init__( |
| self, |
| ori_latent: List[torch.Tensor], |
| latent: List[torch.Tensor], |
| target: List[torch.Tensor], |
| condition: List[Optional[torch.Tensor]], |
| uncondition: List[Optional[torch.Tensor]], |
| ): |
| self.ori_latent = ori_latent |
| self.latent = latent |
| self.target = target |
| self.condition = condition |
| self.uncondition = uncondition |
|
|
| def __len__(self) -> int: |
| return len(self.ori_latent) |
|
|
| def __getitem__(self, idx: int): |
| img = self.target[idx] |
| latent = self.latent[idx] |
| ori_latent = self.ori_latent[idx] |
| condition = self.condition[idx] |
| uncondition = self.uncondition[idx] |
| return img, latent, ori_latent, condition, uncondition |