Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import json | |
| import copy | |
| import os | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| def dpg_save_fn(image, metadata, root_path): | |
| image_path = os.path.join(root_path, str(metadata['filename'])+"_"+str(metadata['seed'])+".png") | |
| Image.fromarray(image).save(image_path) | |
| class DPGDataset(Dataset): | |
| def __init__(self, prompt_path, num_samples_per_instance, latent_shape): | |
| self.latent_shape = latent_shape | |
| self.prompt_path = prompt_path | |
| prompt_files = os.listdir(self.prompt_path) | |
| self.prompts = [] | |
| self.filenames = [] | |
| for prompt_file in prompt_files: | |
| with open(os.path.join(self.prompt_path, prompt_file)) as fp: | |
| self.prompts.append(fp.readline().strip()) | |
| self.filenames.append(prompt_file.replace('.txt', '')) | |
| self.num_instances = len(self.prompts) | |
| self.num_samples_per_instance = num_samples_per_instance | |
| self.num_samples = self.num_instances * self.num_samples_per_instance | |
| def __len__(self): | |
| return self.num_samples | |
| def __getitem__(self, idx): | |
| instance_idx = idx // self.num_samples_per_instance | |
| sample_idx = idx % self.num_samples_per_instance | |
| generator = torch.Generator().manual_seed(sample_idx) | |
| metadata = dict( | |
| prompt=self.prompts[instance_idx], | |
| filename=self.filenames[instance_idx], | |
| seed=sample_idx, | |
| save_fn=dpg_save_fn, | |
| ) | |
| condition = metadata["prompt"] | |
| latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32) | |
| return latent, condition, metadata |