Wuvin's picture
init
37aeb5b
import torch
from accelerate import Accelerator
from accelerate.logging import MultiProcessAdapter
from dataclasses import dataclass, field
from typing import Optional, Union
from datasets import load_dataset
import json
import abc
from diffusers.utils import make_image_grid
import numpy as np
import wandb
from custum_3d_diffusion.trainings.utils import load_config
from custum_3d_diffusion.custum_modules.unifield_processor import ConfigurableUNet2DConditionModel, AttnConfig
class BasicTrainer(torch.nn.Module, abc.ABC):
accelerator: Accelerator
logger: MultiProcessAdapter
unet: ConfigurableUNet2DConditionModel
train_dataloader: torch.utils.data.DataLoader
test_dataset: torch.utils.data.Dataset
attn_config: AttnConfig
@dataclass
class TrainerConfig:
trainer_name: str = "basic"
pretrained_model_name_or_path: str = ""
attn_config: dict = field(default_factory=dict)
dataset_name: str = ""
dataset_config_name: Optional[str] = None
resolution: str = "1024"
dataloader_num_workers: int = 4
pair_sampler_group_size: int = 1
num_views: int = 4
max_train_steps: int = -1 # -1 means infinity, otherwise [0, max_train_steps)
training_step_interval: int = 1 # train on step i*interval, stop at max_train_steps
max_train_samples: Optional[int] = None
seed: Optional[int] = None # For dataset related operations and validation stuff
train_batch_size: int = 1
validation_interval: int = 5000
debug: bool = False
cfg: TrainerConfig # only enable_xxx is used
def __init__(
self,
accelerator: Accelerator,
logger: MultiProcessAdapter,
unet: ConfigurableUNet2DConditionModel,
config: Union[dict, str],
weight_dtype: torch.dtype,
index: int,
):
super().__init__()
self.index = index # index in all trainers
self.accelerator = accelerator
self.logger = logger
self.unet = unet
self.weight_dtype = weight_dtype
self.ext_logs = {}
self.cfg = load_config(self.TrainerConfig, config)
self.attn_config = load_config(AttnConfig, self.cfg.attn_config)
self.test_dataset = None
self.validate_trainer_config()
self.configure()
def get_HW(self):
resolution = json.loads(self.cfg.resolution)
if isinstance(resolution, int):
H = W = resolution
elif isinstance(resolution, list):
H, W = resolution
return H, W
def unet_update(self):
self.unet.update_config(self.attn_config)
def validate_trainer_config(self):
pass
def is_train_finished(self, current_step):
assert isinstance(self.cfg.max_train_steps, int)
return self.cfg.max_train_steps != -1 and current_step >= self.cfg.max_train_steps
def next_train_step(self, current_step):
if self.is_train_finished(current_step):
return None
return current_step + self.cfg.training_step_interval
@classmethod
def make_image_into_grid(cls, all_imgs, rows=2, columns=2):
catted = [make_image_grid(all_imgs[i:i+rows * columns], rows=rows, cols=columns) for i in range(0, len(all_imgs), rows * columns)]
return make_image_grid(catted, rows=1, cols=len(catted))
def configure(self) -> None:
pass
@abc.abstractmethod
def init_shared_modules(self, shared_modules: dict) -> dict:
pass
def load_dataset(self):
dataset = load_dataset(
self.cfg.dataset_name,
self.cfg.dataset_config_name,
trust_remote_code=True
)
return dataset
@abc.abstractmethod
def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
"""Both init train_dataloader and test_dataset, but returns train_dataloader only"""
pass
@abc.abstractmethod
def forward_step(
self,
*args,
**kwargs
) -> torch.Tensor:
"""
input a batch
return a loss
"""
self.unet_update()
pass
@abc.abstractmethod
def construct_pipeline(self, shared_modules, unet):
pass
@abc.abstractmethod
def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
"""
For inference time forward.
"""
pass
@abc.abstractmethod
def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
pass
def do_validation(
self,
shared_modules,
unet,
global_step,
):
self.unet_update()
self.logger.info("Running validation... ")
pipeline = self.construct_pipeline(shared_modules, unet)
pipeline.set_progress_bar_config(disable=True)
titles, images = self.batched_validation_forward(pipeline, guidance_scale=[1., 3.])
for tracker in self.accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
elif tracker.name == "wandb":
[image.thumbnail((512, 512)) for image, title in zip(images, titles) if 'noresize' not in title] # inplace operation
tracker.log({"validation": [
wandb.Image(image, caption=f"{i}: {titles[i]}", file_type="jpg")
for i, image in enumerate(images)]})
else:
self.logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
return images
@torch.no_grad()
def log_validation(
self,
shared_modules,
unet,
global_step,
force=False
):
if self.accelerator.is_main_process:
for tracker in self.accelerator.trackers:
if tracker.name == "wandb":
tracker.log(self.ext_logs)
self.ext_logs = {}
if (global_step % self.cfg.validation_interval == 0 and not self.is_train_finished(global_step)) or force:
self.unet_update()
if self.accelerator.is_main_process:
self.do_validation(shared_modules, self.accelerator.unwrap_model(unet), global_step)
def save_model(self, unwrap_unet, shared_modules, save_dir):
if self.accelerator.is_main_process:
pipeline = self.construct_pipeline(shared_modules, unwrap_unet)
pipeline.save_pretrained(save_dir)
self.logger.info(f"{self.cfg.trainer_name} Model saved at {save_dir}")
def save_debug_info(self, save_name="debug", **kwargs):
if self.cfg.debug:
to_saves = {key: value.detach().cpu() if isinstance(value, torch.Tensor) else value for key, value in kwargs.items()}
import pickle
import os
if os.path.exists(f"{save_name}.pkl"):
for i in range(100):
if not os.path.exists(f"{save_name}_v{i}.pkl"):
save_name = f"{save_name}_v{i}"
break
with open(f"{save_name}.pkl", "wb") as f:
pickle.dump(to_saves, f)