import functools from typing import Mapping import torch import torch.nn as nn from networks.base_model import BaseModel import sys from models import get_model class Validator(BaseModel): def name(self): return 'Validator' def __init__(self, opt): super(Validator, self).__init__(opt) self.opt = opt self.model = get_model("FeatureTransformer") self.clip_model = get_model("CLIP:ViT-L/14") # for name, p in self.clip_model.named_parameters(): # if name=="fc.weight" or name=="fc.bias": # params.append(p) # else: # p.requires_grad = False # del params self.model.to(self.device) def set_input(self, input): # self.input = torch.cat([self.clip_model.forward(x=video_frames, return_feature=True).unsqueeze(0) for video_frames in input[0]]) self.clip_model.to(self.device) self.input = self.clip_model.forward(x=input[0].to(self.device).view(-1, 3, 224, 224), return_feature=True).view(-1, 16, 768) self.clip_model.to('cpu') self.input = self.input.to(self.device) self.label = input[1].to(self.device).float() def forward(self): self.output = self.model(self.input) self.output = self.output.view(-1).unsqueeze(1) def load_state_dict(self, ckpt_path): state_dict = torch.load(ckpt_path, map_location='cpu') self.model.load_state_dict(state_dict['model']) self.model.eval()