Spaces:
Runtime error
Runtime error
| from typing import * | |
| import copy | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| import numpy as np | |
| from easydict import EasyDict as edict | |
| from ..basic import BasicTrainer | |
| from ...pipelines import samplers | |
| from ...utils.general_utils import dict_reduce | |
| from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin | |
| from .mixins.text_conditioned import TextConditionedMixin | |
| from .mixins.image_conditioned import ImageConditionedMixin | |
| class FlowMatchingTrainer(BasicTrainer): | |
| """ | |
| Trainer for diffusion model with flow matching objective. | |
| Args: | |
| models (dict[str, nn.Module]): Models to train. | |
| dataset (torch.utils.data.Dataset): Dataset. | |
| output_dir (str): Output directory. | |
| load_dir (str): Load directory. | |
| step (int): Step to load. | |
| batch_size (int): Batch size. | |
| batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. | |
| batch_split (int): Split batch with gradient accumulation. | |
| max_steps (int): Max steps. | |
| optimizer (dict): Optimizer config. | |
| lr_scheduler (dict): Learning rate scheduler config. | |
| elastic (dict): Elastic memory management config. | |
| grad_clip (float or dict): Gradient clip config. | |
| ema_rate (float or list): Exponential moving average rates. | |
| fp16_mode (str): FP16 mode. | |
| - None: No FP16. | |
| - 'inflat_all': Hold a inflated fp32 master param for all params. | |
| - 'amp': Automatic mixed precision. | |
| fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. | |
| finetune_ckpt (dict): Finetune checkpoint. | |
| log_param_stats (bool): Log parameter stats. | |
| i_print (int): Print interval. | |
| i_log (int): Log interval. | |
| i_sample (int): Sample interval. | |
| i_save (int): Save interval. | |
| i_ddpcheck (int): DDP check interval. | |
| t_schedule (dict): Time schedule for flow matching. | |
| sigma_min (float): Minimum noise level. | |
| """ | |
| def __init__( | |
| self, | |
| *args, | |
| t_schedule: dict = { | |
| 'name': 'logitNormal', | |
| 'args': { | |
| 'mean': 0.0, | |
| 'std': 1.0, | |
| } | |
| }, | |
| sigma_min: float = 1e-5, | |
| **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.t_schedule = t_schedule | |
| self.sigma_min = sigma_min | |
| def diffuse(self, x_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| """ | |
| Diffuse the data for a given number of diffusion steps. | |
| In other words, sample from q(x_t | x_0). | |
| Args: | |
| x_0: The [N x C x ...] tensor of noiseless inputs. | |
| t: The [N] tensor of diffusion steps [0-1]. | |
| noise: If specified, use this noise instead of generating new noise. | |
| Returns: | |
| x_t, the noisy version of x_0 under timestep t. | |
| """ | |
| if noise is None: | |
| noise = torch.randn_like(x_0) | |
| assert noise.shape == x_0.shape, "noise must have same shape as x_0" | |
| t = t.view(-1, *[1 for _ in range(len(x_0.shape) - 1)]) | |
| x_t = (1 - t) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t) * noise | |
| return x_t | |
| def reverse_diffuse(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Get original image from noisy version under timestep t. | |
| """ | |
| assert noise.shape == x_t.shape, "noise must have same shape as x_t" | |
| t = t.view(-1, *[1 for _ in range(len(x_t.shape) - 1)]) | |
| x_0 = (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * noise) / (1 - t) | |
| return x_0 | |
| def get_v(self, x_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Compute the velocity of the diffusion process at time t. | |
| """ | |
| return (1 - self.sigma_min) * noise - x_0 | |
| def get_cond(self, cond, **kwargs): | |
| """ | |
| Get the conditioning data. | |
| """ | |
| return cond | |
| def get_inference_cond(self, cond, **kwargs): | |
| """ | |
| Get the conditioning data for inference. | |
| """ | |
| return {'cond': cond, **kwargs} | |
| def get_sampler(self, **kwargs) -> samplers.FlowEulerSampler: | |
| """ | |
| Get the sampler for the diffusion process. | |
| """ | |
| return samplers.FlowEulerSampler(self.sigma_min) | |
| def vis_cond(self, **kwargs): | |
| """ | |
| Visualize the conditioning data. | |
| """ | |
| return {} | |
| def sample_t(self, batch_size: int) -> torch.Tensor: | |
| """ | |
| Sample timesteps. | |
| """ | |
| if self.t_schedule['name'] == 'uniform': | |
| t = torch.rand(batch_size) | |
| elif self.t_schedule['name'] == 'logitNormal': | |
| mean = self.t_schedule['args']['mean'] | |
| std = self.t_schedule['args']['std'] | |
| t = torch.sigmoid(torch.randn(batch_size) * std + mean) | |
| else: | |
| raise ValueError(f"Unknown t_schedule: {self.t_schedule['name']}") | |
| return t | |
| def training_losses( | |
| self, | |
| x_0: torch.Tensor, | |
| cond=None, | |
| **kwargs | |
| ) -> Tuple[Dict, Dict]: | |
| """ | |
| Compute training losses for a single timestep. | |
| Args: | |
| x_0: The [N x C x ...] tensor of noiseless inputs. | |
| cond: The [N x ...] tensor of additional conditions. | |
| kwargs: Additional arguments to pass to the backbone. | |
| Returns: | |
| a dict with the key "loss" containing a tensor of shape [N]. | |
| may also contain other keys for different terms. | |
| """ | |
| noise = torch.randn_like(x_0) | |
| t = self.sample_t(x_0.shape[0]).to(x_0.device).float() | |
| x_t = self.diffuse(x_0, t, noise=noise) | |
| cond = self.get_cond(cond, **kwargs) | |
| pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs) | |
| assert pred.shape == noise.shape == x_0.shape | |
| target = self.get_v(x_0, noise, t) | |
| terms = edict() | |
| terms["mse"] = F.mse_loss(pred, target) | |
| terms["loss"] = terms["mse"] | |
| # log loss with time bins | |
| mse_per_instance = np.array([ | |
| F.mse_loss(pred[i], target[i]).item() | |
| for i in range(x_0.shape[0]) | |
| ]) | |
| time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1 | |
| for i in range(10): | |
| if (time_bin == i).sum() != 0: | |
| terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()} | |
| return terms, {} | |
| def run_snapshot( | |
| self, | |
| num_samples: int, | |
| batch_size: int, | |
| verbose: bool = False, | |
| ) -> Dict: | |
| dataloader = DataLoader( | |
| copy.deepcopy(self.dataset), | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=0, | |
| collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, | |
| ) | |
| # inference | |
| sampler = self.get_sampler() | |
| sample_gt = [] | |
| sample = [] | |
| cond_vis = [] | |
| for i in range(0, num_samples, batch_size): | |
| batch = min(batch_size, num_samples - i) | |
| data = next(iter(dataloader)) | |
| data = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()} | |
| noise = torch.randn_like(data['x_0']) | |
| sample_gt.append(data['x_0']) | |
| cond_vis.append(self.vis_cond(**data)) | |
| del data['x_0'] | |
| args = self.get_inference_cond(**data) | |
| res = sampler.sample( | |
| self.models['denoiser'], | |
| noise=noise, | |
| **args, | |
| steps=50, cfg_strength=3.0, verbose=verbose, | |
| ) | |
| sample.append(res.samples) | |
| sample_gt = torch.cat(sample_gt, dim=0) | |
| sample = torch.cat(sample, dim=0) | |
| sample_dict = { | |
| 'sample_gt': {'value': sample_gt, 'type': 'sample'}, | |
| 'sample': {'value': sample, 'type': 'sample'}, | |
| } | |
| sample_dict.update(dict_reduce(cond_vis, None, { | |
| 'value': lambda x: torch.cat(x, dim=0), | |
| 'type': lambda x: x[0], | |
| })) | |
| return sample_dict | |
| class FlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, FlowMatchingTrainer): | |
| """ | |
| Trainer for diffusion model with flow matching objective and classifier-free guidance. | |
| Args: | |
| models (dict[str, nn.Module]): Models to train. | |
| dataset (torch.utils.data.Dataset): Dataset. | |
| output_dir (str): Output directory. | |
| load_dir (str): Load directory. | |
| step (int): Step to load. | |
| batch_size (int): Batch size. | |
| batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. | |
| batch_split (int): Split batch with gradient accumulation. | |
| max_steps (int): Max steps. | |
| optimizer (dict): Optimizer config. | |
| lr_scheduler (dict): Learning rate scheduler config. | |
| elastic (dict): Elastic memory management config. | |
| grad_clip (float or dict): Gradient clip config. | |
| ema_rate (float or list): Exponential moving average rates. | |
| fp16_mode (str): FP16 mode. | |
| - None: No FP16. | |
| - 'inflat_all': Hold a inflated fp32 master param for all params. | |
| - 'amp': Automatic mixed precision. | |
| fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. | |
| finetune_ckpt (dict): Finetune checkpoint. | |
| log_param_stats (bool): Log parameter stats. | |
| i_print (int): Print interval. | |
| i_log (int): Log interval. | |
| i_sample (int): Sample interval. | |
| i_save (int): Save interval. | |
| i_ddpcheck (int): DDP check interval. | |
| t_schedule (dict): Time schedule for flow matching. | |
| sigma_min (float): Minimum noise level. | |
| p_uncond (float): Probability of dropping conditions. | |
| """ | |
| pass | |
| class TextConditionedFlowMatchingCFGTrainer(TextConditionedMixin, FlowMatchingCFGTrainer): | |
| """ | |
| Trainer for text-conditioned diffusion model with flow matching objective and classifier-free guidance. | |
| Args: | |
| models (dict[str, nn.Module]): Models to train. | |
| dataset (torch.utils.data.Dataset): Dataset. | |
| output_dir (str): Output directory. | |
| load_dir (str): Load directory. | |
| step (int): Step to load. | |
| batch_size (int): Batch size. | |
| batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. | |
| batch_split (int): Split batch with gradient accumulation. | |
| max_steps (int): Max steps. | |
| optimizer (dict): Optimizer config. | |
| lr_scheduler (dict): Learning rate scheduler config. | |
| elastic (dict): Elastic memory management config. | |
| grad_clip (float or dict): Gradient clip config. | |
| ema_rate (float or list): Exponential moving average rates. | |
| fp16_mode (str): FP16 mode. | |
| - None: No FP16. | |
| - 'inflat_all': Hold a inflated fp32 master param for all params. | |
| - 'amp': Automatic mixed precision. | |
| fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. | |
| finetune_ckpt (dict): Finetune checkpoint. | |
| log_param_stats (bool): Log parameter stats. | |
| i_print (int): Print interval. | |
| i_log (int): Log interval. | |
| i_sample (int): Sample interval. | |
| i_save (int): Save interval. | |
| i_ddpcheck (int): DDP check interval. | |
| t_schedule (dict): Time schedule for flow matching. | |
| sigma_min (float): Minimum noise level. | |
| p_uncond (float): Probability of dropping conditions. | |
| text_cond_model(str): Text conditioning model. | |
| """ | |
| pass | |
| class ImageConditionedFlowMatchingCFGTrainer(ImageConditionedMixin, FlowMatchingCFGTrainer): | |
| """ | |
| Trainer for image-conditioned diffusion model with flow matching objective and classifier-free guidance. | |
| Args: | |
| models (dict[str, nn.Module]): Models to train. | |
| dataset (torch.utils.data.Dataset): Dataset. | |
| output_dir (str): Output directory. | |
| load_dir (str): Load directory. | |
| step (int): Step to load. | |
| batch_size (int): Batch size. | |
| batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. | |
| batch_split (int): Split batch with gradient accumulation. | |
| max_steps (int): Max steps. | |
| optimizer (dict): Optimizer config. | |
| lr_scheduler (dict): Learning rate scheduler config. | |
| elastic (dict): Elastic memory management config. | |
| grad_clip (float or dict): Gradient clip config. | |
| ema_rate (float or list): Exponential moving average rates. | |
| fp16_mode (str): FP16 mode. | |
| - None: No FP16. | |
| - 'inflat_all': Hold a inflated fp32 master param for all params. | |
| - 'amp': Automatic mixed precision. | |
| fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. | |
| finetune_ckpt (dict): Finetune checkpoint. | |
| log_param_stats (bool): Log parameter stats. | |
| i_print (int): Print interval. | |
| i_log (int): Log interval. | |
| i_sample (int): Sample interval. | |
| i_save (int): Save interval. | |
| i_ddpcheck (int): DDP check interval. | |
| t_schedule (dict): Time schedule for flow matching. | |
| sigma_min (float): Minimum noise level. | |
| p_uncond (float): Probability of dropping conditions. | |
| image_cond_model (str): Image conditioning model. | |
| """ | |
| pass | |