|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import sys, time, numpy, os, subprocess, pandas, tqdm
|
|
|
|
from loss_multi import lossAV, lossA, lossV
|
|
from model.loconet_encoder import locoencoder
|
|
|
|
import torch.distributed as dist
|
|
from xxlib.utils.distributed import all_gather, all_reduce
|
|
|
|
|
|
class Loconet(nn.Module):
|
|
|
|
def __init__(self, cfg):
|
|
super(Loconet, self).__init__()
|
|
self.cfg = cfg
|
|
self.model = locoencoder(cfg)
|
|
self.lossAV = lossAV()
|
|
self.lossA = lossA()
|
|
self.lossV = lossV()
|
|
|
|
def forward(self, audioFeature, visualFeature, labels, masks):
|
|
b, s, t = visualFeature.shape[:3]
|
|
visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:])
|
|
labels = labels.view(b * s, *labels.shape[2:])
|
|
masks = masks.view(b * s, *masks.shape[2:])
|
|
|
|
audioEmbed = self.model.forward_audio_frontend(audioFeature)
|
|
visualEmbed = self.model.forward_visual_frontend(visualFeature)
|
|
audioEmbed = audioEmbed.repeat(s, 1, 1)
|
|
|
|
audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
|
|
outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed, b, s)
|
|
outsA = self.model.forward_audio_backend(audioEmbed)
|
|
outsV = self.model.forward_visual_backend(visualEmbed)
|
|
|
|
labels = labels.reshape((-1))
|
|
masks = masks.reshape((-1))
|
|
nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels, masks)
|
|
nlossA = self.lossA.forward(outsA, labels, masks)
|
|
nlossV = self.lossV.forward(outsV, labels, masks)
|
|
|
|
nloss = nlossAV + 0.4 * nlossA + 0.4 * nlossV
|
|
|
|
num_frames = masks.sum()
|
|
return nloss, prec, num_frames
|
|
|
|
|
|
class loconet(nn.Module):
|
|
|
|
def __init__(self, cfg, rank=None, device=None):
|
|
super(loconet, self).__init__()
|
|
self.cfg = cfg
|
|
self.rank = rank
|
|
if rank != None:
|
|
self.rank = rank
|
|
self.device = device
|
|
|
|
self.model = Loconet(cfg).to(device)
|
|
self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
|
|
self.model = nn.parallel.DistributedDataParallel(self.model,
|
|
device_ids=[rank],
|
|
output_device=rank,
|
|
find_unused_parameters=False)
|
|
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.SOLVER.BASE_LR)
|
|
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optim,
|
|
step_size=1,
|
|
gamma=self.cfg.SOLVER.SCHEDULER.GAMMA)
|
|
else:
|
|
self.model = locoencoder(cfg).cuda()
|
|
self.lossAV = lossAV().cuda()
|
|
self.lossA = lossA().cuda()
|
|
self.lossV = lossV().cuda()
|
|
|
|
print(
|
|
time.strftime("%m-%d %H:%M:%S") + " Model para number = %.2f" %
|
|
(sum(param.numel() for param in self.model.parameters()) / 1024 / 1024))
|
|
|
|
def train_network(self, epoch, loader):
|
|
self.model.train()
|
|
self.scheduler.step(epoch - 1)
|
|
index, top1, loss = 0, 0, 0
|
|
lr = self.optim.param_groups[0]['lr']
|
|
loader.sampler.set_epoch(epoch)
|
|
device = self.device
|
|
|
|
pbar = enumerate(loader, start=1)
|
|
if self.rank == 0:
|
|
pbar = tqdm.tqdm(pbar, total=loader.__len__())
|
|
|
|
for num, (audioFeature, visualFeature, labels, masks) in pbar:
|
|
|
|
audioFeature = audioFeature.to(device)
|
|
visualFeature = visualFeature.to(device)
|
|
labels = labels.to(device)
|
|
masks = masks.to(device)
|
|
nloss, prec, num_frames = self.model(
|
|
audioFeature,
|
|
visualFeature,
|
|
labels,
|
|
masks,
|
|
)
|
|
|
|
self.optim.zero_grad()
|
|
nloss.backward()
|
|
self.optim.step()
|
|
|
|
[nloss, prec, num_frames] = all_reduce([nloss, prec, num_frames], average=False)
|
|
top1 += prec.detach().cpu().numpy()
|
|
loss += nloss.detach().cpu().numpy()
|
|
index += int(num_frames.detach().cpu().item())
|
|
if self.rank == 0:
|
|
pbar.set_postfix(
|
|
dict(epoch=epoch,
|
|
lr=lr,
|
|
loss=loss / (num * self.cfg.NUM_GPUS),
|
|
acc=(top1 / index)))
|
|
dist.barrier()
|
|
return loss / num, lr
|
|
|
|
def evaluate_network(self, epoch, loader):
|
|
self.eval()
|
|
predScores = []
|
|
evalCsvSave = os.path.join(self.cfg.WORKSPACE, "{}_res.csv".format(epoch))
|
|
evalOrig = self.cfg.evalOrig
|
|
for audioFeature, visualFeature, labels, masks in tqdm.tqdm(loader):
|
|
with torch.no_grad():
|
|
audioFeature = audioFeature.cuda()
|
|
visualFeature = visualFeature.cuda()
|
|
labels = labels.cuda()
|
|
masks = masks.cuda()
|
|
b, s, t = visualFeature.shape[0], visualFeature.shape[1], visualFeature.shape[2]
|
|
visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:])
|
|
labels = labels.view(b * s, *labels.shape[2:])
|
|
masks = masks.view(b * s, *masks.shape[2:])
|
|
audioEmbed = self.model.forward_audio_frontend(audioFeature)
|
|
visualEmbed = self.model.forward_visual_frontend(visualFeature)
|
|
audioEmbed = audioEmbed.repeat(s, 1, 1)
|
|
audioEmbed, visualEmbed = self.model.forward_cross_attention(
|
|
audioEmbed, visualEmbed)
|
|
outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed, b, s)
|
|
labels = labels.reshape((-1))
|
|
masks = masks.reshape((-1))
|
|
outsAV = outsAV.view(b, s, t, -1)[:, 0, :, :].view(b * t, -1)
|
|
labels = labels.view(b, s, t)[:, 0, :].view(b * t).cuda()
|
|
masks = masks.view(b, s, t)[:, 0, :].view(b * t)
|
|
_, predScore, _, _ = self.lossAV.forward(outsAV, labels, masks)
|
|
predScore = predScore[:, 1].detach().cpu().numpy()
|
|
predScores.extend(predScore)
|
|
evalLines = open(evalOrig).read().splitlines()[1:]
|
|
labels = []
|
|
labels = pandas.Series(['SPEAKING_AUDIBLE' for line in evalLines])
|
|
scores = pandas.Series(predScores)
|
|
evalRes = pandas.read_csv(evalOrig)
|
|
evalRes['score'] = scores
|
|
evalRes['label'] = labels
|
|
evalRes.drop(['label_id'], axis=1, inplace=True)
|
|
evalRes.drop(['instance_id'], axis=1, inplace=True)
|
|
evalRes.to_csv(evalCsvSave, index=False)
|
|
cmd = "python -O utils/get_ava_active_speaker_performance.py -g %s -p %s " % (evalOrig,
|
|
evalCsvSave)
|
|
mAP = float(
|
|
str(subprocess.run(cmd, shell=True, capture_output=True).stdout).split(' ')[2][:5])
|
|
return mAP
|
|
|
|
def saveParameters(self, path):
|
|
torch.save(self.state_dict(), path)
|
|
|
|
def loadParameters(self, path):
|
|
selfState = self.state_dict()
|
|
loadedState = torch.load(path, map_location='cpu')
|
|
if self.rank != None:
|
|
info = self.load_state_dict(loadedState)
|
|
else:
|
|
new_state = {}
|
|
|
|
for k, v in loadedState.items():
|
|
new_state[k.replace("model.module.", "")] = v
|
|
info = self.load_state_dict(new_state, strict=False)
|
|
print(info)
|
|
|