| | |
| | |
| | |
| | |
| |
|
| | from dataclasses import dataclass, field |
| | from typing import Optional, Tuple |
| |
|
| | import torch |
| | from fairseq2.config_registry import ConfigRegistry |
| | from fairseq2.logging import get_log_writer |
| | from fairseq2.nn.padding import PaddingMask, get_seq_lens |
| | from fairseq2.nn.transformer import CausalAttentionMaskFactory |
| | from fairseq2.typing import DataType, Device |
| | from torch import Tensor |
| |
|
| | from lcm.datasets.batch import EmbeddingsBatch |
| | from lcm.models.abstract_lcm import ( |
| | AbstractLCModel, |
| | AbstractLCModelBuilder, |
| | AbstractLCModelConfig, |
| | ) |
| | from lcm.models.sonar_normalizer.builder import SonarNormalizer |
| | from lcm.models.two_tower_diffusion_lcm.frontend import ( |
| | EncoderFrontend, |
| | EncoderFrontendConfig, |
| | ) |
| | from lcm.nn.denoisers import ( |
| | DenoiserConfig, |
| | LCMDenoiser, |
| | LCMDenoiserTransformerFactory, |
| | ) |
| | from lcm.nn.incremental_state import LCMIncrementalStateBag |
| | from lcm.nn.initialization import parse_norm_order |
| | from lcm.nn.normalization import parse_layer_norm_factory |
| | from lcm.nn.schedulers import DDIMScheduler, DDIMSchedulerConfig |
| | from lcm.nn.transformer import ( |
| | LCMTransformerDecoder, |
| | TransformerConfig, |
| | TransformerFactory, |
| | ) |
| |
|
| | logger = get_log_writer(__name__) |
| |
|
| |
|
| | TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE = "two_tower_diffusion_lcm" |
| |
|
| |
|
| | @dataclass |
| | class TwoTowerDiffusionLCModelConfig(AbstractLCModelConfig): |
| | model_type: str = TWO_TOWER_DIFFUSION_LCM_MODEL_TYPE |
| |
|
| | max_seq_len: int = 2048 |
| |
|
| | model_dim: int = 1024 |
| |
|
| | frontend: EncoderFrontendConfig = field( |
| | default_factory=lambda: EncoderFrontendConfig() |
| | ) |
| | """ The fronted config. This module maps from `sonar_embed_dim` to `model_dim` |
| | and potentially adds positional embeddings""" |
| |
|
| | context_encoder: TransformerConfig = field( |
| | default_factory=lambda: TransformerConfig() |
| | ) |
| | """The context encoder config. This is causal Transformer decoder""" |
| |
|
| | noise_scheduler: DDIMSchedulerConfig = field( |
| | default_factory=lambda: DDIMSchedulerConfig() |
| | ) |
| | """The config of the noise scheduler. |
| | See lcm/diffusion_schedulers/ddim for more""" |
| |
|
| | denoiser: DenoiserConfig = field(default_factory=lambda: DenoiserConfig()) |
| | """the config of the denoiser""" |
| |
|
| | trained_with_cf_guidance: bool = False |
| | """If `True`, the model will be trained with classifier-free guidance i.e., |
| | unconditional embedding generation. |
| | The CF-guidance probability is set in |
| | DiffusionLCMCriterionConfig.cf_guidance_probability""" |
| |
|
| |
|
| | lcm_archs = ConfigRegistry[TwoTowerDiffusionLCModelConfig]() |
| | lcm_arch = lcm_archs.decorator |
| |
|
| |
|
| | class TwoTowerDiffusionLCModel(AbstractLCModel): |
| | """Class for a diffusion-based LCM model""" |
| |
|
| | config: TwoTowerDiffusionLCModelConfig |
| |
|
| | def __init__( |
| | self, |
| | config: TwoTowerDiffusionLCModelConfig, |
| | sonar_normalizer: SonarNormalizer, |
| | encoder_frontend: EncoderFrontend, |
| | context_encoder: LCMTransformerDecoder, |
| | denoiser: LCMDenoiser, |
| | noise_scheduler: DDIMScheduler, |
| | ) -> None: |
| | super().__init__(config) |
| |
|
| | self.model_dim = context_encoder.model_dim |
| |
|
| | self.sonar_embed_dim = config.sonar_embed_dim |
| |
|
| | self.sonar_normalizer = sonar_normalizer |
| |
|
| | self.encoder_frontend = encoder_frontend |
| | """The frontend of the context encoder. |
| | This frontend simply applies a pre-linear projection |
| | (to increase dimensionality) then adds positional embeddings""" |
| |
|
| | self.context_encoder = context_encoder |
| | """A causal Transformer decoder""" |
| |
|
| | self.noise_scheduler = noise_scheduler |
| | """The diffusion noise scheduler""" |
| |
|
| | self.denoiser = denoiser |
| |
|
| | def extra_repr(self) -> str: |
| | """:meta private:""" |
| | s = super().extra_repr() |
| | return f"{s}, dtype={self.dtype}" |
| |
|
| | def forward( |
| | self, |
| | batch: EmbeddingsBatch, |
| | noisy_batch: EmbeddingsBatch, |
| | cf_guidance_prob: float = 0.0, |
| | ) -> EmbeddingsBatch: |
| | """ |
| | Arguments: |
| | - batch (`EmbeddingsBatch`): The clean batch of embeddings to encode the context. |
| | If `unsupervised` this is the source embeddings. |
| | If `supervised` this is the source+target embeddings. |
| | |
| | - noisy_batch (`EmbeddingsBatch`): the embeddings noised by the noise scheduler |
| | If `unsupervised` this is noised source embeddings. |
| | If `supervised` this is noised target-only embeddings. |
| | |
| | - cf_guidance_prob: probability of training without any guiding context |
| | """ |
| | |
| | source_lengths = batch.source_lengths |
| |
|
| | |
| | context = self.encode(batch) |
| |
|
| | |
| | output_batch = self.denoise( |
| | noisy_batch=noisy_batch, |
| | context=context, |
| | source_lengths=source_lengths, |
| | cf_guidance_prob=cf_guidance_prob, |
| | ) |
| | return output_batch |
| |
|
| | def encode( |
| | self, |
| | batch: EmbeddingsBatch, |
| | state_bag: Optional[LCMIncrementalStateBag] = None, |
| | **kwargs, |
| | ) -> EmbeddingsBatch: |
| | """ |
| | The main context encoder that takes in a sequence of sonar embeddings in B, T, D |
| | and returns a sequence of the same shape after causal contextualization. |
| | |
| | Main modules: |
| | `frontend`: linear projection to model_dim + optional positional embeddings, |
| | `context_encoder`: Causal Transformer decoder to causally encode the context |
| | """ |
| | |
| | seqs, padding_mask = self.encoder_frontend( |
| | batch.seqs, |
| | batch.padding_mask, |
| | state_bag=state_bag, |
| | **kwargs, |
| | ) |
| |
|
| | |
| | seqs, padding_mask = self.context_encoder( |
| | seqs, |
| | padding_mask, |
| | state_bag=state_bag, |
| | **kwargs, |
| | ) |
| |
|
| | return EmbeddingsBatch(seqs=seqs, padding_mask=padding_mask) |
| |
|
| | def denoise( |
| | self, |
| | noisy_batch: EmbeddingsBatch, |
| | context: EmbeddingsBatch, |
| | source_lengths: Optional[Tensor] = None, |
| | cf_guidance_prob: float = 0.0, |
| | state_bag: Optional[LCMIncrementalStateBag] = None, |
| | inference: bool = False, |
| | ) -> EmbeddingsBatch: |
| | """Diffuse a noised sonar embedding conditioned on the encoded context""" |
| | seqs, padding_mask = self.denoiser( |
| | seqs=noisy_batch.seqs, |
| | diffusion_timesteps=noisy_batch.diffusion_timesteps, |
| | padding_mask=noisy_batch.padding_mask, |
| | conditioning_variables=context.seqs, |
| | conditioning_variables_padding_mask=context.padding_mask, |
| | source_lengths=source_lengths, |
| | cf_guidance_prob=cf_guidance_prob, |
| | inference=inference, |
| | ) |
| | return EmbeddingsBatch(seqs=seqs, padding_mask=padding_mask) |
| |
|
| | def prep_for_denoising(self, decoding_options): |
| | """This setup is done once when we initialize the generator""" |
| | self.guidance_scale = decoding_options.guidance_scale |
| | self.guidance_rescale = decoding_options.guidance_rescale |
| | self.initial_noise_scale = decoding_options.initial_noise_scale |
| | self.timesteps = decoding_options.inference_timesteps |
| | self.clip_noise = decoding_options.clip_noise |
| | self.ddim_eta = decoding_options.ddim_eta |
| | self.epsilon_scaling = decoding_options.epsilon_scaling |
| |
|
| | |
| | self.do_classifier_free_guidance = self.guidance_scale != 1.0 |
| |
|
| | |
| | |
| | self.noise_scheduler.set_timesteps(self.timesteps, device=self.device) |
| |
|
| | |
| | self.noise_scheduler.init_noise_sigma = self.initial_noise_scale |
| | |
| | if decoding_options.thresholding: |
| | self.noise_scheduler.config.thresholding = decoding_options.thresholding |
| | self.noise_scheduler.config.dynamic_thresholding_ratio = ( |
| | decoding_options.dynamic_thresholding_ratio |
| | ) |
| | self.noise_scheduler.config.sample_max_value = ( |
| | decoding_options.sample_max_value |
| | ) |
| |
|
| | def sample_initial_noise_vectors(self, batch_size: int): |
| | |
| | assert hasattr(self, "clip_noise"), ( |
| | "The model is not properly set for decoding, make sure to call `model.prep_for_denoising()`" |
| | ) |
| |
|
| | |
| | latents = torch.randn( |
| | batch_size, 1, self.config.sonar_embed_dim, device=self.device |
| | ) |
| |
|
| | |
| | latents = latents * self.noise_scheduler.init_noise_sigma |
| |
|
| | |
| | latents = latents.clip(-self.clip_noise, self.clip_noise) |
| | return latents |
| |
|
| | @torch.inference_mode() |
| | def predict_next_sentence( |
| | self, |
| | batch: EmbeddingsBatch, |
| | context: EmbeddingsBatch, |
| | temperature: float = 1.0, |
| | state_bag: Optional[LCMIncrementalStateBag] = None, |
| | context_state_bag: Optional[LCMIncrementalStateBag] = None, |
| | **kwargs, |
| | ) -> Tuple[EmbeddingsBatch, EmbeddingsBatch]: |
| | assert context_state_bag is not None, ( |
| | "Expected a state_bag to incrementally encode the context" |
| | ) |
| |
|
| | if self.do_classifier_free_guidance: |
| | logger.debug("Running inference with CF-guidance...") |
| | return self.predict_next_sentence_with_cf_guidance( |
| | batch=batch, |
| | context=context, |
| | temperature=temperature, |
| | state_bag=state_bag, |
| | context_state_bag=context_state_bag, |
| | **kwargs, |
| | ) |
| |
|
| | |
| | |
| | if self.sonar_normalizer is not None: |
| | batch = batch.normalize_seqs(self.sonar_normalizer) |
| |
|
| | |
| | new_context = self.encode(batch, context_state_bag) |
| | context_state_bag.increment_step_nr(1) |
| |
|
| | |
| | context = EmbeddingsBatch(torch.cat((context.seqs, new_context.seqs), dim=1)) |
| |
|
| | |
| | latents = self.sample_initial_noise_vectors(batch_size=batch.seqs.size(0)) |
| |
|
| | |
| | diffusion_timesteps_schedule = self.noise_scheduler.timesteps |
| |
|
| | for diffusion_timestep in diffusion_timesteps_schedule: |
| | input_batch = EmbeddingsBatch( |
| | seqs=latents, |
| | diffusion_timesteps=diffusion_timestep.long().repeat( |
| | (latents.shape[0], 1) |
| | ), |
| | ) |
| | |
| | model_prediction = self.denoise( |
| | noisy_batch=input_batch, |
| | context=context, |
| | state_bag=None, |
| | inference=True, |
| | ) |
| |
|
| | scheduler_outputs = self.noise_scheduler.step( |
| | model_output=model_prediction.seqs, |
| | timestep=diffusion_timestep, |
| | sample=latents, |
| | eta=self.ddim_eta, |
| | epsilon_scaling=self.epsilon_scaling, |
| | ) |
| |
|
| | |
| | latents = scheduler_outputs.prev_sample |
| | |
| | latents = latents.clip(-self.clip_noise, self.clip_noise) |
| |
|
| | |
| | final_seqs = scheduler_outputs.pred_original_sample |
| |
|
| | final_seqs = self.sonar_normalizer.denormalize(final_seqs) |
| |
|
| | return EmbeddingsBatch(final_seqs, None), context |
| |
|
| | @torch.inference_mode() |
| | def predict_next_sentence_with_cf_guidance( |
| | self, |
| | batch: EmbeddingsBatch, |
| | context: EmbeddingsBatch, |
| | temperature: float = 1.0, |
| | state_bag: Optional[LCMIncrementalStateBag] = None, |
| | context_state_bag: Optional[LCMIncrementalStateBag] = None, |
| | **kwargs, |
| | ) -> Tuple[EmbeddingsBatch, EmbeddingsBatch]: |
| | assert context_state_bag is not None, ( |
| | "Expected a state_bag to incrementally encode the context" |
| | ) |
| |
|
| | |
| | |
| | if self.sonar_normalizer is not None: |
| | batch = batch.normalize_seqs(self.sonar_normalizer) |
| |
|
| | |
| | new_context = self.encode(batch, context_state_bag) |
| | context_state_bag.increment_step_nr(1) |
| |
|
| | |
| | context = EmbeddingsBatch(torch.cat((context.seqs, new_context.seqs), dim=1)) |
| |
|
| | |
| | latents = self.sample_initial_noise_vectors(batch_size=batch.seqs.size(0)) |
| |
|
| | |
| | diffusion_timesteps_schedule = self.noise_scheduler.timesteps |
| |
|
| | |
| | _seq_lens = get_seq_lens(context.seqs, context.padding_mask) |
| |
|
| | |
| | _seq_lens = torch.concat((_seq_lens, torch.zeros_like(_seq_lens)), dim=0) |
| |
|
| | context = EmbeddingsBatch( |
| | torch.concat((context.seqs, torch.zeros_like(context.seqs)), dim=0), |
| | PaddingMask(_seq_lens, batch_seq_len=context.seqs.size(1)), |
| | ) |
| |
|
| | batch_multiplier = 2 |
| | for diffusion_timestep in diffusion_timesteps_schedule: |
| | is_max_diffusion_step = ( |
| | diffusion_timestep == self.noise_scheduler.num_diffusion_train_steps - 1 |
| | ) |
| |
|
| | input_batch = EmbeddingsBatch( |
| | torch.concat(batch_multiplier * [latents], dim=0), |
| | diffusion_timesteps=diffusion_timestep.long().repeat( |
| | (latents.shape[0] * batch_multiplier, 1) |
| | ), |
| | ) |
| |
|
| | model_prediction = self.denoise( |
| | noisy_batch=input_batch, |
| | context=context, |
| | state_bag=None, |
| | inference=True, |
| | ) |
| |
|
| | |
| | if is_max_diffusion_step: |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | scheduler_outputs = self.noise_scheduler.step( |
| | model_output=model_prediction.seqs.chunk(batch_multiplier)[0], |
| | timestep=diffusion_timestep, |
| | sample=latents, |
| | eta=self.ddim_eta, |
| | epsilon_scaling=self.epsilon_scaling, |
| | ) |
| | else: |
| | |
| | predicted_noise = self.noise_scheduler.get_epsilon( |
| | model_output=model_prediction.seqs, |
| | sample=input_batch.seqs, |
| | timestep=diffusion_timestep, |
| | ) |
| |
|
| | if self.do_classifier_free_guidance: |
| | |
| | |
| | |
| | |
| | predicted_noise = self.apply_classifier_free_guidance( |
| | predicted_noise |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | scheduler_outputs = self.noise_scheduler.step( |
| | model_output=predicted_noise, |
| | timestep=diffusion_timestep, |
| | sample=latents, |
| | eta=self.ddim_eta, |
| | epsilon_scaling=self.epsilon_scaling, |
| | prediction_type="epsilon", |
| | ) |
| |
|
| | |
| | latents = scheduler_outputs.prev_sample |
| | |
| | latents = latents.clip(-self.clip_noise, self.clip_noise) |
| |
|
| | |
| | final_seqs = scheduler_outputs.pred_original_sample |
| |
|
| | final_seqs = self.sonar_normalizer.denormalize(final_seqs) |
| |
|
| | return EmbeddingsBatch(final_seqs, None), context |
| |
|
| | def apply_classifier_free_guidance(self, predicted_noise: Tensor) -> Tensor: |
| | """ " |
| | Apply Classifier-Free Guidance with Rescale as introduced in Algorithm 2 of https://arxiv.org/pdf/2305.08891 |
| | `pos` would be the conditional prediction `cond_prediction` |
| | and `neg` the unconditional prediction `uncond_prediction`: |
| | The batch during prefilling is prepared with the conditioning prefix in |
| | the first half |
| | """ |
| | |
| | cond_prediction, uncond_prediction = predicted_noise.chunk(2) |
| |
|
| | |
| | guided_noise_prediction = uncond_prediction + self.guidance_scale * ( |
| | cond_prediction - uncond_prediction |
| | ) |
| |
|
| | |
| | |
| | std_pos = cond_prediction.std(dim=-1, keepdim=True) |
| | std_cfg = guided_noise_prediction.std(dim=-1, keepdim=True) |
| |
|
| | |
| | factor = std_pos / std_cfg |
| | factor = self.guidance_rescale * factor + (1 - self.guidance_rescale) |
| |
|
| | return factor * guided_noise_prediction |
| |
|
| |
|
| | class TwoTowerDiffusionLCModelBuilder(AbstractLCModelBuilder): |
| | """Builds modules of a diffusion-based LCM""" |
| |
|
| | config: TwoTowerDiffusionLCModelConfig |
| | denoiser_factory: LCMDenoiserTransformerFactory |
| |
|
| | def __init__( |
| | self, |
| | config: TwoTowerDiffusionLCModelConfig, |
| | *, |
| | device: Optional[Device] = None, |
| | dtype: Optional[DataType] = None, |
| | ) -> None: |
| | """ |
| | :param config: |
| | The configuration. |
| | :param device: |
| | The device on which to initialize modules. |
| | :param dtype: |
| | The data type of module parameters and buffers. |
| | """ |
| | super().__init__(config=config, device=device, dtype=dtype) |
| |
|
| | self.context_encoder_factory = TransformerFactory( |
| | model_dim=self.config.model_dim, |
| | max_seq_len=self.config.max_seq_len, |
| | config=self.config.context_encoder, |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | self.denoiser_factory = LCMDenoiserTransformerFactory( |
| | model_dim=self.config.model_dim, |
| | num_diffusion_train_timesteps=self.config.noise_scheduler.num_diffusion_train_steps, |
| | max_seq_len=self.config.max_seq_len, |
| | config=self.config.denoiser, |
| | input_dim=self.config.sonar_embed_dim, |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | def build_model(self) -> TwoTowerDiffusionLCModel: |
| | """Build a model.""" |
| |
|
| | sonar_normalizer = self.build_sonar_normalizer() |
| | assert sonar_normalizer is not None, ( |
| | "TwoTowerDiffusionLCModel expects a `sonar_normalizer`" |
| | ) |
| |
|
| | |
| | encoder_frontend = self.build_frontend() |
| |
|
| | context_encoder = self.build_context_encoder() |
| |
|
| | |
| | denoiser = self.build_denoiser() |
| |
|
| | noise_scheduler = self.build_noise_scheduler() |
| |
|
| | return TwoTowerDiffusionLCModel( |
| | config=self.config, |
| | sonar_normalizer=sonar_normalizer, |
| | context_encoder=context_encoder, |
| | encoder_frontend=encoder_frontend, |
| | denoiser=denoiser, |
| | noise_scheduler=noise_scheduler, |
| | ) |
| |
|
| | def build_frontend(self) -> EncoderFrontend: |
| | """Build the context encoder front-end.""" |
| |
|
| | return EncoderFrontend( |
| | sonar_embed_dim=self.config.sonar_embed_dim, |
| | model_dim=self.config.model_dim, |
| | config=self.config.frontend, |
| | pos_encoder=self.context_encoder_factory.build_pos_encoder(), |
| | device=self.device, |
| | dtype=self.dtype, |
| | ) |
| |
|
| | def build_context_encoder(self) -> LCMTransformerDecoder: |
| | """Build the context encoder.""" |
| |
|
| | config = self.config.context_encoder |
| |
|
| | num_layers = config.num_layers |
| | assert num_layers > 0, "The context encoder needs a non-zero number of layers" |
| |
|
| | layers = [self.context_encoder_factory.build_layer() for _ in range(num_layers)] |
| |
|
| | self_attn_mask_factory = CausalAttentionMaskFactory() |
| |
|
| | if config.final_norm_order_style is None: |
| | |
| | |
| | final_norm_order = parse_norm_order(config.norm_order_style) |
| | else: |
| | final_norm_order = parse_norm_order(config.final_norm_order_style) |
| |
|
| | layer_norm_factory = parse_layer_norm_factory(config.layer_normalization_style) |
| |
|
| | return LCMTransformerDecoder( |
| | layers, |
| | self_attn_mask_factory=self_attn_mask_factory, |
| | norm_order=final_norm_order, |
| | layer_norm_factory=layer_norm_factory, |
| | dropout_p=config.final_dropout_p, |
| | device=self.device, |
| | dtype=self.dtype, |
| | ) |
| |
|
| | def build_noise_scheduler(self) -> DDIMScheduler: |
| | return DDIMScheduler(self.config.noise_scheduler) |
| |
|
| | def build_denoiser(self) -> LCMDenoiser: |
| | """Build a Transformer for diffusing noised latents.""" |
| | return self.denoiser_factory.build_model() |
| |
|
| |
|
| | def create_two_tower_diffusion_lcm_model( |
| | config: TwoTowerDiffusionLCModelConfig, |
| | *, |
| | device: Optional[Device] = None, |
| | dtype: Optional[DataType] = None, |
| | ) -> TwoTowerDiffusionLCModel: |
| | """Create a DiffusionLCM model. |
| | :param config: |
| | The configuration. |
| | :param device: |
| | The device on which to initialize modules. |
| | :param dtype: |
| | The data type of module parameters and buffers. |
| | """ |
| | return TwoTowerDiffusionLCModelBuilder( |
| | config, |
| | device=device, |
| | dtype=dtype, |
| | ).build_model() |
| |
|