File size: 1,514 Bytes
3cc4a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import functools
from typing import Mapping
import torch
import torch.nn as nn
from networks.base_model import BaseModel
import sys
from models import get_model


class Validator(BaseModel):
    def name(self):
        return 'Validator'

    def __init__(self, opt):
        super(Validator, self).__init__(opt)
        self.opt = opt  
        self.model = get_model("FeatureTransformer")
        self.clip_model = get_model("CLIP:ViT-L/14")

        # for name, p in self.clip_model.named_parameters():
        #     if  name=="fc.weight" or name=="fc.bias": 
        #         params.append(p) 
        #     else:
        #         p.requires_grad = False
        # del params

        self.model.to(self.device)


    def set_input(self, input):
        # self.input = torch.cat([self.clip_model.forward(x=video_frames, return_feature=True).unsqueeze(0) for video_frames in input[0]])
        self.clip_model.to(self.device)
        self.input = self.clip_model.forward(x=input[0].to(self.device).view(-1, 3, 224, 224), return_feature=True).view(-1, 16, 768)
        self.clip_model.to('cpu')
        self.input = self.input.to(self.device)
        self.label = input[1].to(self.device).float()


    def forward(self):
        self.output = self.model(self.input)
        self.output = self.output.view(-1).unsqueeze(1)

    def load_state_dict(self, ckpt_path):
        state_dict = torch.load(ckpt_path, map_location='cpu')
        self.model.load_state_dict(state_dict['model'])
        self.model.eval()