|
import torch |
|
from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, DDIMScheduler |
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, BatchFeature |
|
|
|
import json |
|
from dataclasses import dataclass |
|
from typing import List, Optional |
|
|
|
from custum_3d_diffusion.modules import register |
|
from custum_3d_diffusion.trainings.base import BasicTrainer |
|
from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2mvimg import StableDiffusionImage2MVCustomPipeline |
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput |
|
|
|
def get_HW(resolution): |
|
if isinstance(resolution, str): |
|
resolution = json.loads(resolution) |
|
if isinstance(resolution, int): |
|
H = W = resolution |
|
elif isinstance(resolution, list): |
|
H, W = resolution |
|
return H, W |
|
|
|
@register("image2mvimage_trainer") |
|
class Image2MVImageTrainer(BasicTrainer): |
|
""" |
|
Trainer for simple image to multiview images. |
|
""" |
|
@dataclass |
|
class TrainerConfig(BasicTrainer.TrainerConfig): |
|
trainer_name: str = "image2mvimage" |
|
condition_image_column_name: str = "conditioning_image" |
|
image_column_name: str = "image" |
|
condition_dropout: float = 0. |
|
condition_image_resolution: str = "512" |
|
validation_images: Optional[List[str]] = None |
|
noise_offset: float = 0.1 |
|
max_loss_drop: float = 0. |
|
snr_gamma: float = 5.0 |
|
log_distribution: bool = False |
|
latents_offset: Optional[List[float]] = None |
|
input_perturbation: float = 0. |
|
noisy_condition_input: bool = False |
|
normal_cls_offset: int = 0 |
|
condition_offset: bool = True |
|
zero_snr: bool = False |
|
linear_beta_schedule: bool = False |
|
|
|
cfg: TrainerConfig |
|
|
|
def configure(self) -> None: |
|
return super().configure() |
|
|
|
def init_shared_modules(self, shared_modules: dict) -> dict: |
|
if 'vae' not in shared_modules: |
|
vae = AutoencoderKL.from_pretrained( |
|
self.cfg.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.weight_dtype |
|
) |
|
vae.requires_grad_(False) |
|
vae.to(self.accelerator.device, dtype=self.weight_dtype) |
|
shared_modules['vae'] = vae |
|
if 'image_encoder' not in shared_modules: |
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained( |
|
self.cfg.pretrained_model_name_or_path, subfolder="image_encoder" |
|
) |
|
image_encoder.requires_grad_(False) |
|
image_encoder.to(self.accelerator.device, dtype=self.weight_dtype) |
|
shared_modules['image_encoder'] = image_encoder |
|
if 'feature_extractor' not in shared_modules: |
|
feature_extractor = CLIPImageProcessor.from_pretrained( |
|
self.cfg.pretrained_model_name_or_path, subfolder="feature_extractor" |
|
) |
|
shared_modules['feature_extractor'] = feature_extractor |
|
return shared_modules |
|
|
|
def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader: |
|
raise NotImplementedError() |
|
|
|
def loss_rescale(self, loss, timesteps=None): |
|
raise NotImplementedError() |
|
|
|
def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor: |
|
raise NotImplementedError() |
|
|
|
def construct_pipeline(self, shared_modules, unet, old_version=False): |
|
MyPipeline = StableDiffusionImage2MVCustomPipeline |
|
pipeline = MyPipeline.from_pretrained( |
|
self.cfg.pretrained_model_name_or_path, |
|
vae=shared_modules['vae'], |
|
image_encoder=shared_modules['image_encoder'], |
|
feature_extractor=shared_modules['feature_extractor'], |
|
unet=unet, |
|
safety_checker=None, |
|
torch_dtype=self.weight_dtype, |
|
latents_offset=self.cfg.latents_offset, |
|
noisy_cond_latents=self.cfg.noisy_condition_input, |
|
condition_offset=self.cfg.condition_offset, |
|
) |
|
pipeline.set_progress_bar_config(disable=True) |
|
scheduler_dict = {} |
|
if self.cfg.zero_snr: |
|
scheduler_dict.update(rescale_betas_zero_snr=True) |
|
if self.cfg.linear_beta_schedule: |
|
scheduler_dict.update(beta_schedule='linear') |
|
|
|
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict) |
|
return pipeline |
|
|
|
def get_forward_args(self): |
|
if self.cfg.seed is None: |
|
generator = None |
|
else: |
|
generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed) |
|
|
|
H, W = get_HW(self.cfg.resolution) |
|
H_cond, W_cond = get_HW(self.cfg.condition_image_resolution) |
|
|
|
sub_img_H = H // 2 |
|
num_imgs = H // sub_img_H * W // sub_img_H |
|
|
|
forward_args = dict( |
|
num_images_per_prompt=num_imgs, |
|
num_inference_steps=50, |
|
height=sub_img_H, |
|
width=sub_img_H, |
|
height_cond=H_cond, |
|
width_cond=W_cond, |
|
generator=generator, |
|
) |
|
if self.cfg.zero_snr: |
|
forward_args.update(guidance_rescale=0.7) |
|
return forward_args |
|
|
|
def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput: |
|
forward_args = self.get_forward_args() |
|
forward_args.update(pipeline_call_kwargs) |
|
return pipeline(**forward_args) |
|
|
|
def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple: |
|
raise NotImplementedError() |