Spaces:
Sleeping
Sleeping
import functools | |
import torch | |
import torch.nn as nn | |
from networks.base_model import BaseModel | |
import sys | |
from models import get_model | |
class Trainer(BaseModel): | |
def name(self): | |
return 'Trainer' | |
def __init__(self, opt): | |
super(Trainer, self).__init__(opt) | |
self.opt = opt | |
self.model = get_model("FeatureTransformer") | |
self.clip_model = get_model("CLIP:ViT-L/14") | |
# torch.nn.init.normal_(self.model.fc.weight.data, 0.0, opt.init_gain) | |
# if opt.fix_backbone: | |
params = [] | |
for name, p in self.clip_model.named_parameters(): | |
if name=="fc.weight" or name=="fc.bias": | |
params.append(p) | |
else: | |
p.requires_grad = False | |
del params | |
# else: | |
# print("Your backbone is not fixed. Are you sure you want to proceed? If this is a mistake, enable the --fix_backbone command during training and rerun") | |
# import time | |
# time.sleep(3) | |
# params = self.clip_model.parameters() | |
if opt.optim == 'adam': | |
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay) | |
elif opt.optim == 'sgd': | |
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=opt.lr, momentum=0.0, weight_decay=opt.weight_decay) | |
else: | |
raise ValueError("optim should be [adam, sgd]") | |
self.loss_fn = nn.BCEWithLogitsLoss() | |
self.model.to(self.device) | |
def adjust_learning_rate(self, min_lr=1e-6): | |
for param_group in self.optimizer.param_groups: | |
param_group['lr'] /= 10. | |
if param_group['lr'] < min_lr: | |
return False | |
return True | |
def set_input(self, input): | |
# self.input = torch.cat([self.clip_model.forward(x=video_frames, return_feature=True).unsqueeze(0) for video_frames in input[0]]) | |
self.clip_model.to(self.device) | |
self.input = self.clip_model.forward(x=input[0].to(self.device).view(-1, 3, 224, 224), return_feature=True).view(-1, input[0].shape[1], 768) | |
self.clip_model.to('cpu') | |
self.input = self.input.to(self.device) | |
self.label = input[1].to(self.device).float() | |
def forward(self): | |
self.output = self.model(self.input) | |
self.output = self.output.view(-1).unsqueeze(1) | |
def get_loss(self): | |
return self.loss_fn(self.output.squeeze(1), self.label) | |
def optimize_parameters(self): | |
self.forward() | |
self.loss = self.loss_fn(self.output.squeeze(1), self.label) | |
self.optimizer.zero_grad() | |
self.loss.backward() | |
self.optimizer.step() | |