SEM_dinov2_L512 / modeling_sem_dinov2.py
lavoies's picture
Upload model
e151bd8 verified
"""DLC DiT replaces class label conditioning with DLC conditioning
class labels are a single discrete token between 0 and num_embeds_ada_norm-1
DLCs are a fixed-length sequence of L discrete tokens between 0 and V-1
we replace LabelEmbedder with DLCEmbedder
- maintain the embedding matrix and drop_token
- but apply it to a DLC sequence of L tokens, instead of a single class
"""
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from transformers import Dinov2WithRegistersModel
from transformers.utils import ModelOutput
from .configuration_sem_dinov2 import SEMDinov2Config
@dataclass
class SEMOutput(ModelOutput):
dlc: Optional[torch.LongTensor] = (None,)
sem: Optional[torch.FloatTensor] = None
class SEMDinov2Model(Dinov2WithRegistersModel):
config_class = SEMDinov2Config
def __init__(self, *args, **kwargs):
# we use dlc_embed_l and dlc_embed_v instead of num_embeds_ada_norm_zero
# we still need to set num_embeds_ada_norm_zero since there's a check in DiT code
# but it will be overridden in our code with DLCEmbedding
super().__init__(*args, **kwargs)
self.L = self.config.dlc_L
self.V = self.config.dlc_V
self.temp = self.config.sem_temp
sem_in = self.config.hidden_size * (1 + self.config.num_register_tokens)
sem_out = self.L * self.V
self.sem_embed = nn.Linear(sem_in, sem_out, bias=False)
self.sem_norm = nn.LayerNorm(sem_out, eps=1e-6)
def forward(self, *args, **kwargs):
out = super().forward(*args, **kwargs, output_hidden_states=True)
out = out.hidden_states[-1][:, : self.config.num_register_tokens + 1]
out = out.view(len(out), -1)
out = self.sem_embed(out)
out = self.sem_norm(out)
out = out.view(-1, self.L, self.V)
out = torch.softmax(out / self.temp, dim=-1)
return SEMOutput(sem=out, dlc=out.argmax(-1))