File size: 11,296 Bytes
19ed37d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
from __future__ import annotations

"""sedd_wrapper.py
=========================================
This module provides a minimal HuggingFace-compatible wrapper around the
`SEDD` architecture that is implemented in :pyfile:`model/transformer.py`.

The wrapper closely follows the design used in the Aero implementation that
lives in this code-base (see :pyfile:`configuration_aero.py` and
:pyfile:`modeling_aero.py`).  Concretely we expose three public objects:

* ``SEDDConfig`` A :class:`transformers.PretrainedConfig` subclass that
  stores the hyper-parameters needed to instantiate a ``SEDD`` model.
* ``SEDDModel`` A :class:`transformers.PreTrainedModel` subclass that
  internally contains an instance of the original ``SEDD`` network and maps
  from  ``input_ids`` + ``sigma`` to the vocabulary logits.
* ``SEDDOutput`` A thin :class:`transformers.modeling_outputs.ModelOutput`
  dataclass that mirrors the usual "logits / loss" structure.

With this wrapper a trained model checkpoint can be pushed to / loaded from
πŸ€— Hub via ``SEDDModel.push_to_hub`` / ``SEDDModel.from_pretrained`` the same
way as any other ``transformers`` model.
"""

from dataclasses import dataclass
from typing import Optional, Tuple, List, Dict, Any, Union

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import ModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging

# Original SEDD implementation
from model.transformer import SEDD as _OrigSEDD

try:
    from omegaconf import OmegaConf
except ImportError:  # pragma: no cover – omegaconf is an explicit dependency of SEDD
    OmegaConf = None  # type: ignore

logger = logging.get_logger(__name__)

###############################################################################
# Configuration                                                               #
###############################################################################


class SEDDConfig(PretrainedConfig):
    """Configuration class for the SEDD architecture.

    The defaults reproduce *roughly* the "small" configuration shipped in
    ``configs/model/small.yaml``.  Additional keys that are present in the
    original Hydra config but not required for instantiation (e.g. *training*
    hyper-parameters) are deliberately omitted here – they can still be stored
    as *extra* fields in the underlying JSON if a user wishes to preserve them.
    """

    model_type: str = "sedd"

    def __init__(
        self,
        *,
        tokens: int = 50257,
        # graph section
        graph_type: str = "absorb",
        # model section (mirrors configs/model/*.yaml)
        model_hidden_size: int = 768,
        model_cond_dim: int = 128,
        model_length: int = 1024,
        model_n_blocks: int = 12,
        model_n_heads: int = 12,
        model_scale_by_sigma: bool = True,
        model_dropout: float = 0.10,
        # miscellaneous
        tie_word_embeddings: bool = False,
        **kwargs,
    ) -> None:
        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

        # Top-level attributes (kept flat for simplicity)
        self.tokens = tokens
        self.graph_type = graph_type

        # Model hyper-parameters
        self.model_hidden_size = model_hidden_size
        self.model_cond_dim = model_cond_dim
        self.model_length = model_length
        self.model_n_blocks = model_n_blocks
        self.model_n_heads = model_n_heads
        self.model_scale_by_sigma = model_scale_by_sigma
        self.model_dropout = model_dropout

    # ---------------------------------------------------------------------
    # Serialization helpers – these optionally bridge to the original Hydra
    # config structure that the reference implementation expects.
    # ---------------------------------------------------------------------

    def to_hydra(self):
        """Convert this *flat* config to the nested OmegaConf structure that
        the reference ``SEDD`` implementation expects.
        """

        if OmegaConf is None:
            raise RuntimeError("`omegaconf` is required to build a Hydra config")

        nested: Dict[str, Any] = {
            "tokens": self.tokens,
            "graph": {
                "type": self.graph_type,
            },
            "model": {
                "hidden_size": self.model_hidden_size,
                "cond_dim": self.model_cond_dim,
                "length": self.model_length,
                "n_blocks": self.model_n_blocks,
                "n_heads": self.model_n_heads,
                "scale_by_sigma": self.model_scale_by_sigma,
                "dropout": self.model_dropout,
            },
        }
        return OmegaConf.create(nested)

###############################################################################
# Output container                                                            #
###############################################################################


@dataclass
class SEDDOutput(ModelOutput):
    """Standard output for :class:`SEDDModel`.

    Attributes
    ----------
    loss:
        *Optional* scalar returned when ``labels`` are provided.
    logits:
        The raw vocabulary logits computed by the model of shape
        ``(batch_size, sequence_length, vocab_size)``.
    """

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor | None = None

###############################################################################
# Model                                                                       #
###############################################################################


class SEDDModel(PreTrainedModel):
    """HuggingFace *Transformers* wrapper around the original ``SEDD`` model."""

    config_class = SEDDConfig
    base_model_prefix = "score_model"
    _no_split_modules: List[str] = [
        "DDiTBlock",  # ensure these blocks are not split when using FSDP/TP
    ]

    def __init__(self, config: SEDDConfig):
        super().__init__(config)

        # ------------------------------------------------------------------
        # Instantiate the original SEDD architecture using the Hydra cfg that
        # the implementation expects.
        # ------------------------------------------------------------------
        if OmegaConf is None:
            raise RuntimeError("`omegaconf` is required to instantiate SEDD")

        hydra_cfg = config.to_hydra()
        self.score_model = _OrigSEDD(hydra_cfg)

        # Make sure parameters are created on the right device / dtype.
        self.post_init()

    # ------------------------------------------------------------------
    # Forward pass
    # ------------------------------------------------------------------

    def forward(
        self,
        input_ids: torch.LongTensor,
        sigma: torch.FloatTensor,
        labels: Optional[torch.LongTensor] = None,
        **kwargs: Any,
    ) -> Union[SEDDOutput, Tuple]:
        """Run a forward pass.

        Parameters
        ----------
        input_ids:
            Token indices of shape ``(batch_size, seq_len)``.
        sigma:
            Noise level ("time-step") of shape ``(batch_size,)``.
        labels:
            *Optional* label tensor used to compute a cross-entropy training
            loss.  If provided the returned :class:`SEDDOutput` will contain a
            ``loss`` field.
        """

        logits = self.score_model(indices=input_ids, sigma=sigma)

        loss: Optional[torch.Tensor] = None
        if labels is not None:
            # Standard CE loss over the last dimension (vocab)
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))

        if not self.config.return_dict:
            output: Tuple[Any, ...] = (logits,)
            return ((loss,) + output) if loss is not None else output

        return SEDDOutput(loss=loss, logits=logits)

    # ------------------------------------------------------------------
    # Weight loading helpers – we delegate to the *original* SEDD mixin so that
    # checkpoints trained with the previous implementation can be re-used.
    # ------------------------------------------------------------------

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: str,
        *model_args: Any,
        **kwargs: Any,
    ) -> "SEDDModel":
        """Overrides the default method to allow loading legacy SEDD checkpoints
        whose weights are saved via ``torch.save({'model': state_dict, ...})``.
        """

        try:
            # First try the regular *transformers* loading routine – this will
            # succeed if the repository follows the standard file-naming
            # conventions (i.e. contains a ``pytorch_model.bin`` / safetensors).
            return super().from_pretrained(
                pretrained_model_name_or_path, *model_args, **kwargs
            )
        except (EnvironmentError, RuntimeError) as e:
            logger.info(
                "Falling back to legacy SEDD checkpoint format because standard "
                "loading raised: %s", e,
            )

            # ----------------------------------------------------------
            # 1. Load config the usual way so we get a `SEDDConfig` instance.
            # ----------------------------------------------------------
            config = kwargs.pop("config", None) or SEDDConfig.from_pretrained(
                pretrained_model_name_or_path
            )
            model = cls(config, *model_args, **kwargs)

            # ----------------------------------------------------------
            # 2. Attempt to locate the legacy *.pth* checkpoint and load it.
            # ----------------------------------------------------------
            import os
            import torch as _torch

            checkpoint_path = os.path.join(
                pretrained_model_name_or_path, "checkpoints-meta", "checkpoint.pth"
            )
            if not os.path.isfile(checkpoint_path):
                raise FileNotFoundError(
                    "Could not find legacy SEDD checkpoint at " f"{checkpoint_path}"
                )

            ckpt = _torch.load(checkpoint_path, map_location="cpu")
            state_dict = ckpt.get("model", ckpt)
            # Strip prefix if present (sometimes stored under "module.")
            state_dict = {
                k.replace("module.", ""): v for k, v in state_dict.items()
            }
            missing, unexpected = model.load_state_dict(state_dict, strict=False)
            if missing:
                logger.warning("Missing keys when loading SEDD weights: %s", missing)
            if unexpected:
                logger.warning(
                    "Unexpected keys when loading SEDD weights: %s", unexpected
                )
            return model

###############################################################################
# Public API                                                                  #
###############################################################################

__all__ = [
    "SEDDConfig",
    "SEDDModel",
    "SEDDOutput",
]