File size: 2,729 Bytes
3cc4a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import functools
import torch
import torch.nn as nn
from networks.base_model import BaseModel
import sys
from models import get_model


class Trainer(BaseModel):
    def name(self):
        return 'Trainer'

    def __init__(self, opt):
        super(Trainer, self).__init__(opt)
        self.opt = opt  
        self.model = get_model("FeatureTransformer")
        self.clip_model = get_model("CLIP:ViT-L/14")
        # torch.nn.init.normal_(self.model.fc.weight.data, 0.0, opt.init_gain)

        # if opt.fix_backbone:
        params = []
        for name, p in self.clip_model.named_parameters():
            if  name=="fc.weight" or name=="fc.bias": 
                params.append(p) 
            else:
                p.requires_grad = False
        del params
        # else:
        #     print("Your backbone is not fixed. Are you sure you want to proceed? If this is a mistake, enable the --fix_backbone command during training and rerun")
        #     import time 
        #     time.sleep(3)
        #     params = self.clip_model.parameters()

        

        if opt.optim == 'adam':
            self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay)
        elif opt.optim == 'sgd':
            self.optimizer = torch.optim.SGD(self.model.parameters(), lr=opt.lr, momentum=0.0, weight_decay=opt.weight_decay)
        else:
            raise ValueError("optim should be [adam, sgd]")

        self.loss_fn = nn.BCEWithLogitsLoss()

        self.model.to(self.device)


    def adjust_learning_rate(self, min_lr=1e-6):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] /= 10.
            if param_group['lr'] < min_lr:
                return False
        return True


    def set_input(self, input):
        # self.input = torch.cat([self.clip_model.forward(x=video_frames, return_feature=True).unsqueeze(0) for video_frames in input[0]])
        self.clip_model.to(self.device)
        self.input = self.clip_model.forward(x=input[0].to(self.device).view(-1, 3, 224, 224), return_feature=True).view(-1, input[0].shape[1], 768)
        self.clip_model.to('cpu')
        self.input = self.input.to(self.device)
        self.label = input[1].to(self.device).float()


    def forward(self):
        self.output = self.model(self.input)
        self.output = self.output.view(-1).unsqueeze(1)


    def get_loss(self):
        return self.loss_fn(self.output.squeeze(1), self.label)

    def optimize_parameters(self):
        self.forward()
        self.loss = self.loss_fn(self.output.squeeze(1), self.label) 
        self.optimizer.zero_grad()
        self.loss.backward()
        self.optimizer.step()