Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from accelerate import Accelerator | |
from accelerate.logging import MultiProcessAdapter | |
from dataclasses import dataclass, field | |
from typing import Optional, Union | |
from datasets import load_dataset | |
import json | |
import abc | |
from diffusers.utils import make_image_grid | |
import numpy as np | |
import wandb | |
from custum_3d_diffusion.trainings.utils import load_config | |
from custum_3d_diffusion.custum_modules.unifield_processor import ConfigurableUNet2DConditionModel, AttnConfig | |
class BasicTrainer(torch.nn.Module, abc.ABC): | |
accelerator: Accelerator | |
logger: MultiProcessAdapter | |
unet: ConfigurableUNet2DConditionModel | |
train_dataloader: torch.utils.data.DataLoader | |
test_dataset: torch.utils.data.Dataset | |
attn_config: AttnConfig | |
class TrainerConfig: | |
trainer_name: str = "basic" | |
pretrained_model_name_or_path: str = "" | |
attn_config: dict = field(default_factory=dict) | |
dataset_name: str = "" | |
dataset_config_name: Optional[str] = None | |
resolution: str = "1024" | |
dataloader_num_workers: int = 4 | |
pair_sampler_group_size: int = 1 | |
num_views: int = 4 | |
max_train_steps: int = -1 # -1 means infinity, otherwise [0, max_train_steps) | |
training_step_interval: int = 1 # train on step i*interval, stop at max_train_steps | |
max_train_samples: Optional[int] = None | |
seed: Optional[int] = None # For dataset related operations and validation stuff | |
train_batch_size: int = 1 | |
validation_interval: int = 5000 | |
debug: bool = False | |
cfg: TrainerConfig # only enable_xxx is used | |
def __init__( | |
self, | |
accelerator: Accelerator, | |
logger: MultiProcessAdapter, | |
unet: ConfigurableUNet2DConditionModel, | |
config: Union[dict, str], | |
weight_dtype: torch.dtype, | |
index: int, | |
): | |
super().__init__() | |
self.index = index # index in all trainers | |
self.accelerator = accelerator | |
self.logger = logger | |
self.unet = unet | |
self.weight_dtype = weight_dtype | |
self.ext_logs = {} | |
self.cfg = load_config(self.TrainerConfig, config) | |
self.attn_config = load_config(AttnConfig, self.cfg.attn_config) | |
self.test_dataset = None | |
self.validate_trainer_config() | |
self.configure() | |
def get_HW(self): | |
resolution = json.loads(self.cfg.resolution) | |
if isinstance(resolution, int): | |
H = W = resolution | |
elif isinstance(resolution, list): | |
H, W = resolution | |
return H, W | |
def unet_update(self): | |
self.unet.update_config(self.attn_config) | |
def validate_trainer_config(self): | |
pass | |
def is_train_finished(self, current_step): | |
assert isinstance(self.cfg.max_train_steps, int) | |
return self.cfg.max_train_steps != -1 and current_step >= self.cfg.max_train_steps | |
def next_train_step(self, current_step): | |
if self.is_train_finished(current_step): | |
return None | |
return current_step + self.cfg.training_step_interval | |
def make_image_into_grid(cls, all_imgs, rows=2, columns=2): | |
catted = [make_image_grid(all_imgs[i:i+rows * columns], rows=rows, cols=columns) for i in range(0, len(all_imgs), rows * columns)] | |
return make_image_grid(catted, rows=1, cols=len(catted)) | |
def configure(self) -> None: | |
pass | |
def init_shared_modules(self, shared_modules: dict) -> dict: | |
pass | |
def load_dataset(self): | |
dataset = load_dataset( | |
self.cfg.dataset_name, | |
self.cfg.dataset_config_name, | |
trust_remote_code=True | |
) | |
return dataset | |
def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader: | |
"""Both init train_dataloader and test_dataset, but returns train_dataloader only""" | |
pass | |
def forward_step( | |
self, | |
*args, | |
**kwargs | |
) -> torch.Tensor: | |
""" | |
input a batch | |
return a loss | |
""" | |
self.unet_update() | |
pass | |
def construct_pipeline(self, shared_modules, unet): | |
pass | |
def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> tuple: | |
""" | |
For inference time forward. | |
""" | |
pass | |
def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple: | |
pass | |
def do_validation( | |
self, | |
shared_modules, | |
unet, | |
global_step, | |
): | |
self.unet_update() | |
self.logger.info("Running validation... ") | |
pipeline = self.construct_pipeline(shared_modules, unet) | |
pipeline.set_progress_bar_config(disable=True) | |
titles, images = self.batched_validation_forward(pipeline, guidance_scale=[1., 3.]) | |
for tracker in self.accelerator.trackers: | |
if tracker.name == "tensorboard": | |
np_images = np.stack([np.asarray(img) for img in images]) | |
tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC") | |
elif tracker.name == "wandb": | |
[image.thumbnail((512, 512)) for image, title in zip(images, titles) if 'noresize' not in title] # inplace operation | |
tracker.log({"validation": [ | |
wandb.Image(image, caption=f"{i}: {titles[i]}", file_type="jpg") | |
for i, image in enumerate(images)]}) | |
else: | |
self.logger.warn(f"image logging not implemented for {tracker.name}") | |
del pipeline | |
torch.cuda.empty_cache() | |
return images | |
def log_validation( | |
self, | |
shared_modules, | |
unet, | |
global_step, | |
force=False | |
): | |
if self.accelerator.is_main_process: | |
for tracker in self.accelerator.trackers: | |
if tracker.name == "wandb": | |
tracker.log(self.ext_logs) | |
self.ext_logs = {} | |
if (global_step % self.cfg.validation_interval == 0 and not self.is_train_finished(global_step)) or force: | |
self.unet_update() | |
if self.accelerator.is_main_process: | |
self.do_validation(shared_modules, self.accelerator.unwrap_model(unet), global_step) | |
def save_model(self, unwrap_unet, shared_modules, save_dir): | |
if self.accelerator.is_main_process: | |
pipeline = self.construct_pipeline(shared_modules, unwrap_unet) | |
pipeline.save_pretrained(save_dir) | |
self.logger.info(f"{self.cfg.trainer_name} Model saved at {save_dir}") | |
def save_debug_info(self, save_name="debug", **kwargs): | |
if self.cfg.debug: | |
to_saves = {key: value.detach().cpu() if isinstance(value, torch.Tensor) else value for key, value in kwargs.items()} | |
import pickle | |
import os | |
if os.path.exists(f"{save_name}.pkl"): | |
for i in range(100): | |
if not os.path.exists(f"{save_name}_v{i}.pkl"): | |
save_name = f"{save_name}_v{i}" | |
break | |
with open(f"{save_name}.pkl", "wb") as f: | |
pickle.dump(to_saves, f) |