File size: 6,801 Bytes
170cd5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
#!/usr/bin/python
#-*- coding: utf-8 -*-
# Video 25 FPS, Audio 16000HZ
import torch
import numpy
import time, pdb, argparse, subprocess, os, math, glob
import cv2
import python_speech_features
from scipy import signal
from scipy.io import wavfile
from SyncNetModel import *
from shutil import rmtree
# ==================== Get OFFSET ====================
def calc_pdist(feat1, feat2, vshift=10):
win_size = vshift*2+1
feat2p = torch.nn.functional.pad(feat2,(0,0,vshift,vshift))
dists = []
for i in range(0,len(feat1)):
dists.append(torch.nn.functional.pairwise_distance(feat1[[i],:].repeat(win_size, 1), feat2p[i:i+win_size,:]))
return dists
# ==================== MAIN DEF ====================
class SyncNetInstance(torch.nn.Module):
def __init__(self, dropout = 0, num_layers_in_fc_layers = 1024):
super(SyncNetInstance, self).__init__();
self.__S__ = S(num_layers_in_fc_layers = num_layers_in_fc_layers).cuda();
def evaluate(self, opt, videofile):
self.__S__.eval();
# ========== ==========
# Convert files
# ========== ==========
if os.path.exists(os.path.join(opt.tmp_dir,opt.reference)):
rmtree(os.path.join(opt.tmp_dir,opt.reference))
os.makedirs(os.path.join(opt.tmp_dir,opt.reference))
command = ("ffmpeg -loglevel error -y -i %s -threads 1 -f image2 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'%06d.jpg')))
output = subprocess.call(command, shell=True, stdout=None)
command = ("ffmpeg -loglevel error -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'audio.wav')))
output = subprocess.call(command, shell=True, stdout=None)
# ========== ==========
# Load video
# ========== ==========
images = []
flist = glob.glob(os.path.join(opt.tmp_dir,opt.reference,'*.jpg'))
flist.sort()
for fname in flist:
img_input = cv2.imread(fname)
img_input = cv2.resize(img_input, (224,224)) #HARD CODED, CHANGE BEFORE RELEASE
images.append(img_input)
im = numpy.stack(images,axis=3)
im = numpy.expand_dims(im,axis=0)
im = numpy.transpose(im,(0,3,4,1,2))
imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
# ========== ==========
# Load audio
# ========== ==========
sample_rate, audio = wavfile.read(os.path.join(opt.tmp_dir,opt.reference,'audio.wav'))
mfcc = zip(*python_speech_features.mfcc(audio,sample_rate))
mfcc = numpy.stack([numpy.array(i) for i in mfcc])
cc = numpy.expand_dims(numpy.expand_dims(mfcc,axis=0),axis=0)
cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float())
# ========== ==========
# Check audio and video input length
# ========== ==========
#if (float(len(audio))/16000) != (float(len(images))/25) :
# print("WARNING: Audio (%.4fs) and video (%.4fs) lengths are different."%(float(len(audio))/16000,float(len(images))/25))
min_length = min(len(images),math.floor(len(audio)/640))
# ========== ==========
# Generate video and audio feats
# ========== ==========
lastframe = min_length-5
im_feat = []
cc_feat = []
tS = time.time()
for i in range(0,lastframe,opt.batch_size):
im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
im_in = torch.cat(im_batch,0)
im_out = self.__S__.forward_lip(im_in.cuda());
im_feat.append(im_out.data.cpu())
cc_batch = [ cct[:,:,:,vframe*4:vframe*4+20] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
cc_in = torch.cat(cc_batch,0)
cc_out = self.__S__.forward_aud(cc_in.cuda())
cc_feat.append(cc_out.data.cpu())
im_feat = torch.cat(im_feat,0)
cc_feat = torch.cat(cc_feat,0)
# ========== ==========
# Compute offset
# ========== ==========
#print('Compute time %.3f sec.' % (time.time()-tS))
dists = calc_pdist(im_feat,cc_feat,vshift=opt.vshift)
mdist = torch.mean(torch.stack(dists,1),1)
minval, minidx = torch.min(mdist,0)
offset = opt.vshift-minidx
conf = torch.median(mdist) - minval
fdist = numpy.stack([dist[minidx].numpy() for dist in dists])
# fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=15)
fconf = torch.median(mdist).numpy() - fdist
fconfm = signal.medfilt(fconf,kernel_size=9)
numpy.set_printoptions(formatter={'float': '{: 0.3f}'.format})
#print('Framewise conf: ')
#print(fconfm)
#print('AV offset: \t%d \nMin dist: \t%.3f\nConfidence: \t%.3f' % (offset,minval,conf))
dists_npy = numpy.array([ dist.numpy() for dist in dists ])
return offset.numpy(), conf.numpy(), minval.numpy()
def extract_feature(self, opt, videofile):
self.__S__.eval();
# ========== ==========
# Load video
# ========== ==========
cap = cv2.VideoCapture(videofile)
frame_num = 1;
images = []
while frame_num:
frame_num += 1
ret, image = cap.read()
if ret == 0:
break
images.append(image)
im = numpy.stack(images,axis=3)
im = numpy.expand_dims(im,axis=0)
im = numpy.transpose(im,(0,3,4,1,2))
imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
# ========== ==========
# Generate video feats
# ========== ==========
lastframe = len(images)-4
im_feat = []
tS = time.time()
for i in range(0,lastframe,opt.batch_size):
im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
im_in = torch.cat(im_batch,0)
im_out = self.__S__.forward_lipfeat(im_in.cuda());
im_feat.append(im_out.data.cpu())
im_feat = torch.cat(im_feat,0)
# ========== ==========
# Compute offset
# ========== ==========
print('Compute time %.3f sec.' % (time.time()-tS))
return im_feat
def loadParameters(self, path):
loaded_state = torch.load(path, map_location=lambda storage, loc: storage);
self_state = self.__S__.state_dict();
for name, param in loaded_state.items():
self_state[name].copy_(param);
|