import torch import torch.nn as nn import torch.nn.functional as F import sys, time, numpy, os, subprocess, pandas, tqdm from loss_multi import lossAV, lossA, lossV from model.loconet_encoder import locoencoder import torch.distributed as dist from xxlib.utils.distributed import all_gather, all_reduce class Loconet(nn.Module): def __init__(self, cfg): super(Loconet, self).__init__() self.cfg = cfg self.model = locoencoder(cfg) self.lossAV = lossAV() self.lossA = lossA() self.lossV = lossV() def forward(self, audioFeature, visualFeature, labels, masks): b, s, t = visualFeature.shape[:3] visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:]) labels = labels.view(b * s, *labels.shape[2:]) masks = masks.view(b * s, *masks.shape[2:]) audioEmbed = self.model.forward_audio_frontend(audioFeature) # B, C, T, 4 visualEmbed = self.model.forward_visual_frontend(visualFeature) audioEmbed = audioEmbed.repeat(s, 1, 1) audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed) outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed, b, s) outsA = self.model.forward_audio_backend(audioEmbed) outsV = self.model.forward_visual_backend(visualEmbed) labels = labels.reshape((-1)) masks = masks.reshape((-1)) nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels, masks) nlossA = self.lossA.forward(outsA, labels, masks) nlossV = self.lossV.forward(outsV, labels, masks) nloss = nlossAV + 0.4 * nlossA + 0.4 * nlossV num_frames = masks.sum() return nloss, prec, num_frames class loconet(nn.Module): def __init__(self, cfg, rank=None, device=None): super(loconet, self).__init__() self.cfg = cfg self.rank = rank if rank != None: self.rank = rank self.device = device self.model = Loconet(cfg).to(device) self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[rank], output_device=rank, find_unused_parameters=False) self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.SOLVER.BASE_LR) self.scheduler = torch.optim.lr_scheduler.StepLR(self.optim, step_size=1, gamma=self.cfg.SOLVER.SCHEDULER.GAMMA) else: self.model = locoencoder(cfg).cuda() self.lossAV = lossAV().cuda() self.lossA = lossA().cuda() self.lossV = lossV().cuda() print( time.strftime("%m-%d %H:%M:%S") + " Model para number = %.2f" % (sum(param.numel() for param in self.model.parameters()) / 1024 / 1024)) def train_network(self, epoch, loader): self.model.train() self.scheduler.step(epoch - 1) index, top1, loss = 0, 0, 0 lr = self.optim.param_groups[0]['lr'] loader.sampler.set_epoch(epoch) device = self.device pbar = enumerate(loader, start=1) if self.rank == 0: pbar = tqdm.tqdm(pbar, total=loader.__len__()) for num, (audioFeature, visualFeature, labels, masks) in pbar: audioFeature = audioFeature.to(device) visualFeature = visualFeature.to(device) labels = labels.to(device) masks = masks.to(device) nloss, prec, num_frames = self.model( audioFeature, visualFeature, labels, masks, ) self.optim.zero_grad() nloss.backward() self.optim.step() [nloss, prec, num_frames] = all_reduce([nloss, prec, num_frames], average=False) top1 += prec.detach().cpu().numpy() loss += nloss.detach().cpu().numpy() index += int(num_frames.detach().cpu().item()) if self.rank == 0: pbar.set_postfix( dict(epoch=epoch, lr=lr, loss=loss / (num * self.cfg.NUM_GPUS), acc=(top1 / index))) dist.barrier() return loss / num, lr def evaluate_network(self, epoch, loader): self.eval() predScores = [] evalCsvSave = os.path.join(self.cfg.WORKSPACE, "{}_res.csv".format(epoch)) evalOrig = self.cfg.evalOrig for audioFeature, visualFeature, labels, masks in tqdm.tqdm(loader): with torch.no_grad(): audioFeature = audioFeature.cuda() visualFeature = visualFeature.cuda() labels = labels.cuda() masks = masks.cuda() b, s, t = visualFeature.shape[0], visualFeature.shape[1], visualFeature.shape[2] visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:]) labels = labels.view(b * s, *labels.shape[2:]) masks = masks.view(b * s, *masks.shape[2:]) audioEmbed = self.model.forward_audio_frontend(audioFeature) visualEmbed = self.model.forward_visual_frontend(visualFeature) audioEmbed = audioEmbed.repeat(s, 1, 1) audioEmbed, visualEmbed = self.model.forward_cross_attention( audioEmbed, visualEmbed) outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed, b, s) labels = labels.reshape((-1)) masks = masks.reshape((-1)) outsAV = outsAV.view(b, s, t, -1)[:, 0, :, :].view(b * t, -1) labels = labels.view(b, s, t)[:, 0, :].view(b * t).cuda() masks = masks.view(b, s, t)[:, 0, :].view(b * t) _, predScore, _, _ = self.lossAV.forward(outsAV, labels, masks) predScore = predScore[:, 1].detach().cpu().numpy() predScores.extend(predScore) evalLines = open(evalOrig).read().splitlines()[1:] labels = [] labels = pandas.Series(['SPEAKING_AUDIBLE' for line in evalLines]) scores = pandas.Series(predScores) evalRes = pandas.read_csv(evalOrig) evalRes['score'] = scores evalRes['label'] = labels evalRes.drop(['label_id'], axis=1, inplace=True) evalRes.drop(['instance_id'], axis=1, inplace=True) evalRes.to_csv(evalCsvSave, index=False) cmd = "python -O utils/get_ava_active_speaker_performance.py -g %s -p %s " % (evalOrig, evalCsvSave) mAP = float( str(subprocess.run(cmd, shell=True, capture_output=True).stdout).split(' ')[2][:5]) return mAP def saveParameters(self, path): torch.save(self.state_dict(), path) def loadParameters(self, path): selfState = self.state_dict() loadedState = torch.load(path, map_location='cpu') if self.rank != None: info = self.load_state_dict(loadedState) else: new_state = {} for k, v in loadedState.items(): new_state[k.replace("model.module.", "")] = v info = self.load_state_dict(new_state, strict=False) print(info)