File size: 1,963 Bytes
e151bd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""DLC DiT replaces class label conditioning with DLC conditioning

class labels are a single discrete token between 0 and num_embeds_ada_norm-1
DLCs are a fixed-length sequence of L discrete tokens between 0 and V-1

we replace LabelEmbedder with DLCEmbedder
- maintain the embedding matrix and drop_token
- but apply it to a DLC sequence of L tokens, instead of a single class
"""

from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
from transformers import Dinov2WithRegistersModel
from transformers.utils import ModelOutput

from .configuration_sem_dinov2 import SEMDinov2Config


@dataclass
class SEMOutput(ModelOutput):
    dlc: Optional[torch.LongTensor] = (None,)
    sem: Optional[torch.FloatTensor] = None


class SEMDinov2Model(Dinov2WithRegistersModel):
    config_class = SEMDinov2Config

    def __init__(self, *args, **kwargs):
        # we use dlc_embed_l and dlc_embed_v instead of num_embeds_ada_norm_zero
        # we still need to set num_embeds_ada_norm_zero since there's a check in DiT code
        # but it will be overridden in our code with DLCEmbedding
        super().__init__(*args, **kwargs)

        self.L = self.config.dlc_L
        self.V = self.config.dlc_V
        self.temp = self.config.sem_temp
        sem_in = self.config.hidden_size * (1 + self.config.num_register_tokens)
        sem_out = self.L * self.V
        self.sem_embed = nn.Linear(sem_in, sem_out, bias=False)
        self.sem_norm = nn.LayerNorm(sem_out, eps=1e-6)

    def forward(self, *args, **kwargs):
        out = super().forward(*args, **kwargs, output_hidden_states=True)
        out = out.hidden_states[-1][:, : self.config.num_register_tokens + 1]
        out = out.view(len(out), -1)
        out = self.sem_embed(out)
        out = self.sem_norm(out)
        out = out.view(-1, self.L, self.V)
        out = torch.softmax(out / self.temp, dim=-1)
        return SEMOutput(sem=out, dlc=out.argmax(-1))