|
"""DLC DiT replaces class label conditioning with DLC conditioning |
|
|
|
class labels are a single discrete token between 0 and num_embeds_ada_norm-1 |
|
DLCs are a fixed-length sequence of L discrete tokens between 0 and V-1 |
|
|
|
we replace LabelEmbedder with DLCEmbedder |
|
- maintain the embedding matrix and drop_token |
|
- but apply it to a DLC sequence of L tokens, instead of a single class |
|
""" |
|
|
|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers import Dinov2WithRegistersModel |
|
from transformers.utils import ModelOutput |
|
|
|
from .configuration_sem_dinov2 import SEMDinov2Config |
|
|
|
|
|
@dataclass |
|
class SEMOutput(ModelOutput): |
|
dlc: Optional[torch.LongTensor] = (None,) |
|
sem: Optional[torch.FloatTensor] = None |
|
|
|
|
|
class SEMDinov2Model(Dinov2WithRegistersModel): |
|
config_class = SEMDinov2Config |
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
|
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
self.L = self.config.dlc_L |
|
self.V = self.config.dlc_V |
|
self.temp = self.config.sem_temp |
|
sem_in = self.config.hidden_size * (1 + self.config.num_register_tokens) |
|
sem_out = self.L * self.V |
|
self.sem_embed = nn.Linear(sem_in, sem_out, bias=False) |
|
self.sem_norm = nn.LayerNorm(sem_out, eps=1e-6) |
|
|
|
def forward(self, *args, **kwargs): |
|
out = super().forward(*args, **kwargs, output_hidden_states=True) |
|
out = out.hidden_states[-1][:, : self.config.num_register_tokens + 1] |
|
out = out.view(len(out), -1) |
|
out = self.sem_embed(out) |
|
out = self.sem_norm(out) |
|
out = out.view(-1, self.L, self.V) |
|
out = torch.softmax(out / self.temp, dim=-1) |
|
return SEMOutput(sem=out, dlc=out.argmax(-1)) |
|
|