Spaces:
Build error
Build error
| import functools | |
| from typing import Mapping | |
| import torch | |
| import torch.nn as nn | |
| from networks.base_model import BaseModel | |
| import sys | |
| from models import get_model | |
| class Validator(BaseModel): | |
| def name(self): | |
| return 'Validator' | |
| def __init__(self, opt): | |
| super(Validator, self).__init__(opt) | |
| self.opt = opt | |
| self.model = get_model("FeatureTransformer") | |
| self.clip_model = get_model("CLIP:ViT-L/14") | |
| # for name, p in self.clip_model.named_parameters(): | |
| # if name=="fc.weight" or name=="fc.bias": | |
| # params.append(p) | |
| # else: | |
| # p.requires_grad = False | |
| # del params | |
| self.model.to(self.device) | |
| def set_input(self, input): | |
| # self.input = torch.cat([self.clip_model.forward(x=video_frames, return_feature=True).unsqueeze(0) for video_frames in input[0]]) | |
| self.clip_model.to(self.device) | |
| self.input = self.clip_model.forward(x=input[0].to(self.device).view(-1, 3, 224, 224), return_feature=True).view(-1, 16, 768) | |
| self.clip_model.to('cpu') | |
| self.input = self.input.to(self.device) | |
| self.label = input[1].to(self.device).float() | |
| def forward(self): | |
| self.output = self.model(self.input) | |
| self.output = self.output.view(-1).unsqueeze(1) | |
| def load_state_dict(self, ckpt_path): | |
| state_dict = torch.load(ckpt_path, map_location='cpu') | |
| self.model.load_state_dict(state_dict['model']) | |
| self.model.eval() |