File size: 1,704 Bytes
3e0e9f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from src.dinov2.models.vision_transformer import vit_base
from src.options import opts

def freeze_model(m):
    m.requires_grad_(False)

def freeze_all_but_bn(m):
    if not isinstance(m, torch.nn.LayerNorm):
        if hasattr(m, 'weight') and m.weight is not None:
            m.weight.requires_grad_(False)
        if hasattr(m, 'bias') and m.bias is not None:
            m.bias.requires_grad_(False)
    else:
        print("LayerNorm")

class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.opts = opts

        self.dino = vit_base(patch_size=14, block_chunks=0, init_values=1.0) 

        # Prompt Engineering
        self.sk_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))
        self.img_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))


    def configure_optimizers(self):
        if self.opts.model_type == 'one_encoder':
            model_params = list(self.dino.parameters())
        else:
            model_params = list(self.dino.parameters()) + list(self.clip_sk.parameters())

        optimizer = torch.optim.Adam([
            {'params': model_params, 'lr': self.opts.clip_LN_lr},
            {'params': [self.sk_prompt] + [self.img_prompt], 'lr': self.opts.prompt_lr}])
        return optimizer

    def forward(self, data, dtype='image'):
        if dtype == 'image':
            feat = self.dino(data, prompt=self.img_prompt.expand(data.shape[0], -1, -1))
        else:
            feat = self.dino(data, prompt=self.sk_prompt.expand(data.shape[0], -1, -1))
        return feat