|
|
import copy |
|
|
import inspect |
|
|
import sys |
|
|
from abc import ABC, abstractmethod |
|
|
from collections import OrderedDict |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Any, Dict, Optional, Tuple, Union |
|
|
|
|
|
import hydra.utils |
|
|
import torch |
|
|
from hydra.errors import InstantiationException |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
DynamicCache, |
|
|
GenerationConfig, |
|
|
LogitsProcessorList, |
|
|
PretrainedConfig, |
|
|
PreTrainedModel, |
|
|
StoppingCriteriaList, |
|
|
) |
|
|
from transformers.cache_utils import Cache |
|
|
from transformers.generation.utils import GenerateOutput |
|
|
from transformers.modeling_outputs import ModelOutput |
|
|
|
|
|
|
|
|
|
|
|
from .backbone_automodel import AutoModelFromPreTrained |
|
|
from .backbone_encoder_decoder import ( |
|
|
LLMasEncoderDecoder, |
|
|
LLMasEncoderDecoderShareKV, |
|
|
) |
|
|
from .noise_schedule_noise_schedules import ( |
|
|
CosineNoise, |
|
|
ExponentialNoise, |
|
|
LinearNoise, |
|
|
LogarithmicNoise, |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DenoiserInput(OrderedDict): |
|
|
"""Input to the denoiser model.""" |
|
|
|
|
|
xt: torch.LongTensor |
|
|
x0: Optional[torch.LongTensor] = None |
|
|
attention_mask: Optional[torch.FloatTensor] = None |
|
|
past_key_values: Optional[Union[torch.FloatTensor, Cache]] = None |
|
|
context_mask: Optional[torch.FloatTensor] = None |
|
|
tokens_mask: Optional[torch.FloatTensor] = None |
|
|
t: Optional[torch.FloatTensor] = None |
|
|
alpha_t: Optional[torch.FloatTensor] = None |
|
|
alpha_t_prime: Optional[torch.FloatTensor] = None |
|
|
backbone_kwargs: dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LossAndNllOutput(OrderedDict): |
|
|
"""Loss output for denoiser models.""" |
|
|
|
|
|
loss: torch.FloatTensor |
|
|
nlls: torch.FloatTensor |
|
|
other_loss_terms: dict = field(default_factory=dict) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DenoiserOutput(ModelOutput): |
|
|
"""Output of the denoiser model.""" |
|
|
|
|
|
denoiser_output: Optional[torch.FloatTensor] = None |
|
|
logits: Optional[torch.FloatTensor] = None |
|
|
tokens_mask: Optional[torch.FloatTensor] = None |
|
|
past_key_values: Optional[Cache] = None |
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
nlls: Optional[torch.FloatTensor] = None |
|
|
other_loss_terms: Optional[dict[str, Any]] = None |
|
|
|
|
|
|
|
|
class DenoiserConfig(PretrainedConfig): |
|
|
"""Configuration class for Denoiser models. |
|
|
|
|
|
This class is used to initialize the model and contains all the necessary |
|
|
parameters for the model's architecture. |
|
|
""" |
|
|
|
|
|
model_type = "denoiser" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
length: Optional[int] = None, |
|
|
backbone_config: Optional[Dict[str, Any]] = None, |
|
|
noise_config: Optional[Dict[str, Any]] = None, |
|
|
tokenization_config: Optional[Dict[str, Any]] = None, |
|
|
time_conditioned_backbone: Optional[bool] = None, |
|
|
attn_backend: str = "sdpa", |
|
|
train_on_context: bool = False, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
for v in [ |
|
|
"vocab_size", |
|
|
"mask_token_id", |
|
|
"pad_token_id", |
|
|
"bos_token_id", |
|
|
"eos_token_id", |
|
|
"pad_vocab_size_multiple", |
|
|
]: |
|
|
if tokenization_config is not None and ( |
|
|
getattr(self, v, None) is None or v in tokenization_config |
|
|
): |
|
|
setattr(self, v, tokenization_config.get(v, None)) |
|
|
else: |
|
|
setattr(self, v, None) |
|
|
self.backbone_config = backbone_config |
|
|
self.noise_config = noise_config |
|
|
self.tokenization_config = tokenization_config |
|
|
self.length = length |
|
|
self.time_conditioned_backbone = time_conditioned_backbone |
|
|
self.attn_backend = attn_backend |
|
|
self.train_on_context = train_on_context |
|
|
|
|
|
|
|
|
class Denoiser(ABC, PreTrainedModel): |
|
|
"""Abstract base class for denoising models. |
|
|
|
|
|
This class defines the interface for AR, Diffusion, and Flow-based parametrizations. |
|
|
""" |
|
|
|
|
|
config_class = DenoiserConfig |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: DenoiserConfig, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Initialize the Denoiser with a configuration and optional dataset type. |
|
|
|
|
|
Parameters: |
|
|
config (Any): Configuration object for the model. |
|
|
""" |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.vocab_size = config.vocab_size |
|
|
self.mask_token_id = config.mask_token_id |
|
|
self.pad_token_id = config.pad_token_id |
|
|
self.bos_token_id = config.bos_token_id |
|
|
self.eos_token_id = config.eos_token_id |
|
|
try: |
|
|
self.backbone = hydra.utils.instantiate(config.backbone_config) |
|
|
except InstantiationException: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sys_modules = copy.deepcopy(list(sys.modules.keys())) |
|
|
repo_root_module = ".".join(__name__.split(".")[:-1]) |
|
|
for name in sys_modules: |
|
|
if name.startswith(repo_root_module): |
|
|
short = name.split(".")[-1] |
|
|
if short not in sys.modules: |
|
|
sys.modules[short] = sys.modules[name] |
|
|
del sys_modules |
|
|
self.backbone = hydra.utils.instantiate(config.backbone_config) |
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
config.tokenizer_name, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
self.noise_schedule = ( |
|
|
hydra.utils.instantiate(config.noise_config) |
|
|
if config.noise_config is not None |
|
|
else None |
|
|
) |
|
|
self.time_conditioned_backbone = ( |
|
|
config.time_conditioned_backbone |
|
|
if config.time_conditioned_backbone is not None |
|
|
else "noise" in inspect.getfullargspec(self.backbone.forward).args |
|
|
) |
|
|
|
|
|
|
|
|
self.skip_params_for_push = [] |
|
|
|
|
|
@abstractmethod |
|
|
def _prepare_inputs( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
context_mask: Optional[torch.FloatTensor] = None, |
|
|
t: Optional[torch.FloatTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
) -> DenoiserInput: |
|
|
""" |
|
|
Prepare inputs for the model. |
|
|
|
|
|
Parameters: |
|
|
input_ids (LongTensor): Input tensor to the model. |
|
|
attention_mask (Optional[FloatTensor]): Attention mask for the model. |
|
|
t (Optional[FloatTensor]): Time step for the model. |
|
|
past_key_values (Optional[Cache]): Past key values for the model. |
|
|
Returns: |
|
|
Denoiser inputs. |
|
|
""" |
|
|
raise NotImplementedError("Denoiser subclasses must implement _prepare_inputs") |
|
|
|
|
|
def _prepare_inputs_inference( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
context: Optional[torch.LongTensor] = None, |
|
|
context_mask: Optional[torch.FloatTensor] = None, |
|
|
cache: Optional[Dict[str, Any]] = None, |
|
|
**backbone_kwargs: Any, |
|
|
) -> Tuple[DenoiserInput, Dict[str, Any]]: |
|
|
raise NotImplementedError( |
|
|
"Denoiser subclasses must implement _prepare_inputs_inference" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod |
|
|
def _compute_loss( |
|
|
self, |
|
|
model_output: torch.FloatTensor, |
|
|
denoiser_inputs: DenoiserInput, |
|
|
**kwargs: Any, |
|
|
) -> LossAndNllOutput: |
|
|
""" |
|
|
Compute the loss for the denoising model. |
|
|
|
|
|
Parameters: |
|
|
model_output (FloatTensor): Output tensor from self.forward. |
|
|
denoiser_inputs (DenoiserInput): Inputs passed to the denoiser model. |
|
|
|
|
|
Returns: |
|
|
LossAndNllOutput: loss (FloatTensor) and nlls (FloatTensor). |
|
|
""" |
|
|
raise NotImplementedError("Denoiser subclasses must implement _compute_loss") |
|
|
|
|
|
def _forward( |
|
|
self, |
|
|
backbone_output: torch.FloatTensor, |
|
|
denoiser_inputs: DenoiserInput, |
|
|
**kwargs: Any, |
|
|
) -> torch.FloatTensor: |
|
|
""" |
|
|
Forward pass for the denoiser model returns probabilities over denoised |
|
|
sequence. |
|
|
|
|
|
Some classes may need to override this method. |
|
|
|
|
|
Parameters: |
|
|
backbone_output (FloatTensor): Output tensor from the backbone model. |
|
|
denoiser_inputs (DenoiserInput): Inputs passed to the denoiser model. |
|
|
|
|
|
Returns: |
|
|
Model outputs (FloatTensor). |
|
|
""" |
|
|
return torch.log_softmax(backbone_output, dim=-1) |
|
|
|
|
|
def _backbone_forward( |
|
|
self, |
|
|
denoiser_inputs: DenoiserInput, |
|
|
**backbone_kwargs: Any, |
|
|
) -> ModelOutput: |
|
|
"""Forward pass for the backbone model (should return logits). |
|
|
|
|
|
Some classes may need to override this method. |
|
|
|
|
|
Parameters: |
|
|
denoiser_inputs (DenoiserInput): Inputs passed to the denoiser model. |
|
|
return_updated_cache (bool): If True, return past_key_values instead of |
|
|
logits. |
|
|
|
|
|
Returns: |
|
|
Backbone output (ModelOutput instance). |
|
|
""" |
|
|
if self.time_conditioned_backbone: |
|
|
return self.backbone( |
|
|
denoiser_inputs.xt, |
|
|
attention_mask=denoiser_inputs.attention_mask, |
|
|
past_key_values=denoiser_inputs.past_key_values, |
|
|
noise=denoiser_inputs.alpha_t, |
|
|
**denoiser_inputs.backbone_kwargs, |
|
|
**backbone_kwargs, |
|
|
) |
|
|
return self.backbone( |
|
|
denoiser_inputs.xt, |
|
|
attention_mask=denoiser_inputs.attention_mask, |
|
|
past_key_values=denoiser_inputs.past_key_values, |
|
|
**denoiser_inputs.backbone_kwargs, |
|
|
**backbone_kwargs, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
context_mask: Optional[torch.FloatTensor] = None, |
|
|
t: Optional[torch.FloatTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
compute_loss: Optional[bool] = True, |
|
|
**kwargs, |
|
|
) -> DenoiserOutput: |
|
|
""" |
|
|
Perform a forward pass through the denoising model and |
|
|
(optionally) compute the loss. |
|
|
|
|
|
Parameters: |
|
|
input_ids (LongTensor): Input tensor to the model. |
|
|
attention_mask (Optional[FloatTensor]): Attention mask for the model. |
|
|
context_mask (Optional[FloatTensor]): Indicator for context tokens. |
|
|
t (Optional[FloatTensor]): Denoising time step for the model. |
|
|
past_key_values (Optional[Cache]): KV cache. |
|
|
compute_loss (Optional[bool]): Flag to compute loss. |
|
|
|
|
|
Returns: |
|
|
DenoiserOutput |
|
|
""" |
|
|
denoiser_inputs = self._prepare_inputs( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
context_mask=context_mask, |
|
|
past_key_values=past_key_values, |
|
|
t=t, |
|
|
) |
|
|
|
|
|
backbone_output = self._backbone_forward(denoiser_inputs, **kwargs) |
|
|
new_past_key_values = getattr(backbone_output, "past_key_values", None) |
|
|
backbone_output = getattr(backbone_output, "logits", backbone_output[0]) |
|
|
denoiser_output = self._forward( |
|
|
backbone_output, |
|
|
denoiser_inputs, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
if compute_loss: |
|
|
loss_and_nll = self._compute_loss( |
|
|
model_output=denoiser_output, denoiser_inputs=denoiser_inputs, **kwargs |
|
|
) |
|
|
loss = loss_and_nll.loss |
|
|
nlls = loss_and_nll.nlls |
|
|
other_loss_terms = loss_and_nll.other_loss_terms |
|
|
else: |
|
|
loss, nlls = None, None |
|
|
other_loss_terms = {} |
|
|
|
|
|
return DenoiserOutput( |
|
|
denoiser_output=denoiser_output, |
|
|
logits=backbone_output, |
|
|
past_key_values=new_past_key_values, |
|
|
tokens_mask=denoiser_inputs.tokens_mask, |
|
|
loss=loss, |
|
|
nlls=nlls, |
|
|
other_loss_terms=other_loss_terms, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _sample_categorical(categorical_probs, do_sample=True): |
|
|
"""Helper function to sample from a categorical distribution.""" |
|
|
categorical_probs = categorical_probs.to(torch.float64) |
|
|
if not do_sample: |
|
|
return categorical_probs.argmax(dim=-1) |
|
|
gumbel_norm = (1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()).to( |
|
|
categorical_probs.dtype |
|
|
) |
|
|
return (categorical_probs / gumbel_norm).argmax(dim=-1) |
|
|
|
|
|
@staticmethod |
|
|
def _preprocess_attention_mask(attention_mask, dtype): |
|
|
min_dtype = torch.finfo(dtype).min |
|
|
attention_mask = torch.where( |
|
|
(attention_mask == 0.0).bool(), |
|
|
min_dtype, |
|
|
0.0, |
|
|
).to(dtype) |
|
|
return attention_mask |
|
|
|
|
|
@staticmethod |
|
|
def _get_past_key_values_seq_length(past_key_values: DynamicCache): |
|
|
seq_length = 0 |
|
|
for i in range(len(past_key_values)): |
|
|
if past_key_values[i][0].shape[0] > 0: |
|
|
seq_length = max( |
|
|
past_key_values[i][0].shape[-2], |
|
|
seq_length, |
|
|
) |
|
|
return seq_length |
|
|
|
|
|
def update_cache( |
|
|
self, |
|
|
inputs: torch.LongTensor, |
|
|
cache: Optional[Dict[str, Any]] = None, |
|
|
**backbone_kwargs: Any, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Cache the key-value pairs for the context. |
|
|
Args: |
|
|
inputs (torch.LongTensor): The context tensor. |
|
|
cache (Dict[str, Any | None): Cache objects, e.g., past_key_values. |
|
|
Returns: |
|
|
Dict: Updated cache objects, e.g., past_key_values. |
|
|
""" |
|
|
context_input, cache = self._prepare_inputs_inference( |
|
|
input_ids=inputs, cache=cache, return_updated_cache=True, **backbone_kwargs |
|
|
) |
|
|
backbone_output = self._backbone_forward( |
|
|
context_input, |
|
|
return_updated_cache=True, |
|
|
**cache, |
|
|
) |
|
|
backbone_output = {k: v for k, v in backbone_output.items()} |
|
|
backbone_output.pop("logits", None) |
|
|
cache = cache | backbone_output |
|
|
return cache |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
inputs: Optional[torch.LongTensor] = None, |
|
|
generation_config: Optional[GenerationConfig] = None, |
|
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
|
max_length: Optional[int] = None, |
|
|
max_new_tokens: Optional[int] = None, |
|
|
batch_size: Optional[int] = None, |
|
|
device: Optional[str] = None, |
|
|
**kwargs: Any, |
|
|
) -> Union[GenerateOutput, torch.LongTensor]: |
|
|
"""Generates sample from denoising model. |
|
|
Follows signature of transformers.GenerationMixin. |
|
|
""" |
|
|
raise NotImplementedError("Denoiser subclasses must implement generate") |
|
|
|