Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import os | |
import torch | |
from shutil import rmtree | |
from torch import nn | |
from torch.nn import functional as F | |
import numpy as np | |
import subprocess | |
import cv2 | |
import pickle | |
import librosa | |
from ultralytics import YOLO | |
from decord import VideoReader | |
from decord import cpu, gpu | |
from utils.audio_utils import * | |
from utils.inference_utils import * | |
from sync_models.gestsync_models import * | |
from shutil import rmtree, copy, copytree | |
import scenedetect | |
from scenedetect.video_manager import VideoManager | |
from scenedetect.scene_manager import SceneManager | |
from scenedetect.stats_manager import StatsManager | |
from scenedetect.detectors import ContentDetector | |
from scipy.interpolate import interp1d | |
from scipy import signal | |
from tqdm import tqdm | |
from glob import glob | |
from scipy.io.wavfile import write | |
import mediapipe as mp | |
from protobuf_to_dict import protobuf_to_dict | |
import warnings | |
import spaces | |
mp_holistic = mp.solutions.holistic | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
warnings.filterwarnings("ignore", category=UserWarning) | |
# Initialize global variables | |
CHECKPOINT_PATH = "model_rgb.pth" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
use_cuda = torch.cuda.is_available() | |
print("Use cuda status: ", use_cuda) | |
batch_size = 24 | |
fps = 25 | |
n_negative_samples = 100 | |
facedet_scale=0.25 | |
crop_scale=0 | |
min_track=50 | |
frame_rate=25 | |
num_failed_det=25 | |
min_frame_size=64 | |
print("Device: ", device) | |
# Initialize the mediapipe holistic keypoint detection model | |
holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) | |
def bb_intersection_over_union(boxA, boxB): | |
xA = max(boxA[0], boxB[0]) | |
yA = max(boxA[1], boxB[1]) | |
xB = min(boxA[2], boxB[2]) | |
yB = min(boxB[3], boxB[3]) | |
interArea = max(0, xB - xA) * max(0, yB - yA) | |
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]) | |
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]) | |
iou = interArea / float(boxAArea + boxBArea - interArea) | |
return iou | |
def track_shot(scenefaces): | |
iouThres = 0.5 # Minimum IOU between consecutive face detections | |
tracks = [] | |
while True: | |
track = [] | |
for framefaces in scenefaces: | |
for face in framefaces: | |
if track == []: | |
track.append(face) | |
framefaces.remove(face) | |
elif face['frame'] - track[-1]['frame'] <= num_failed_det: | |
iou = bb_intersection_over_union(face['bbox'], track[-1]['bbox']) | |
if iou > iouThres: | |
track.append(face) | |
framefaces.remove(face) | |
continue | |
else: | |
break | |
if track == []: | |
break | |
elif len(track) > min_track: | |
framenum = np.array([f['frame'] for f in track]) | |
bboxes = np.array([np.array(f['bbox']) for f in track]) | |
frame_i = np.arange(framenum[0], framenum[-1] + 1) | |
bboxes_i = [] | |
for ij in range(0, 4): | |
interpfn = interp1d(framenum, bboxes[:, ij]) | |
bboxes_i.append(interpfn(frame_i)) | |
bboxes_i = np.stack(bboxes_i, axis=1) | |
if max(np.mean(bboxes_i[:, 2] - bboxes_i[:, 0]), np.mean(bboxes_i[:, 3] - bboxes_i[:, 1])) > min_frame_size: | |
tracks.append({'frame': frame_i, 'bbox': bboxes_i}) | |
return tracks | |
def check_folder(folder): | |
if os.path.exists(folder): | |
return True | |
return False | |
def del_folder(folder): | |
if os.path.exists(folder): | |
rmtree(folder) | |
def read_video(o, start_idx): | |
with open(o, 'rb') as o: | |
video_stream = VideoReader(o) | |
if start_idx > 0: | |
video_stream.skip_frames(start_idx) | |
return video_stream | |
def crop_video(avi_dir, tmp_dir, track, cropfile, tight_scale=1): | |
fourcc = cv2.VideoWriter_fourcc(*'XVID') | |
vOut = cv2.VideoWriter(cropfile + '.avi', fourcc, frame_rate, (480, 270)) | |
dets = {'x': [], 'y': [], 's': [], 'bbox': track['bbox'], 'frame': track['frame']} | |
for det in track['bbox']: | |
# Reduce the size of the bounding box by a small factor if tighter crops are needed (default -> no reduction in size) | |
width = (det[2] - det[0]) * tight_scale | |
height = (det[3] - det[1]) * tight_scale | |
center_x = (det[0] + det[2]) / 2 | |
center_y = (det[1] + det[3]) / 2 | |
dets['s'].append(max(height, width) / 2) | |
dets['y'].append(center_y) # crop center y | |
dets['x'].append(center_x) # crop center x | |
# Smooth detections | |
dets['s'] = signal.medfilt(dets['s'], kernel_size=13) | |
dets['x'] = signal.medfilt(dets['x'], kernel_size=13) | |
dets['y'] = signal.medfilt(dets['y'], kernel_size=13) | |
videofile = os.path.join(avi_dir, 'video.avi') | |
frame_no_to_start = track['frame'][0] | |
video_stream = cv2.VideoCapture(videofile) | |
video_stream.set(cv2.CAP_PROP_POS_FRAMES, frame_no_to_start) | |
for fidx, frame in enumerate(track['frame']): | |
cs = crop_scale | |
bs = dets['s'][fidx] # Detection box size | |
bsi = int(bs * (1 + 2 * cs)) # Pad videos by this amount | |
image = video_stream.read()[1] | |
frame = np.pad(image, ((bsi, bsi), (bsi, bsi), (0, 0)), 'constant', constant_values=(110, 110)) | |
my = dets['y'][fidx] + bsi # BBox center Y | |
mx = dets['x'][fidx] + bsi # BBox center X | |
face = frame[int(my - bs):int(my + bs * (1 + 2 * cs)), int(mx - bs * (1 + cs)):int(mx + bs * (1 + cs))] | |
vOut.write(cv2.resize(face, (480, 270))) | |
video_stream.release() | |
audiotmp = os.path.join(tmp_dir, 'audio.wav') | |
audiostart = (track['frame'][0]) / frame_rate | |
audioend = (track['frame'][-1] + 1) / frame_rate | |
vOut.release() | |
# ========== CROP AUDIO FILE ========== | |
command = ("ffmpeg -hide_banner -loglevel panic -y -i %s -ss %.3f -to %.3f %s" % (os.path.join(avi_dir, 'audio.wav'), audiostart, audioend, audiotmp)) | |
output = subprocess.call(command, shell=True, stdout=None) | |
copy(audiotmp, cropfile + '.wav') | |
# print('Written %s' % cropfile) | |
# print('Mean pos: x %.2f y %.2f s %.2f' % (np.mean(dets['x']), np.mean(dets['y']), np.mean(dets['s']))) | |
return {'track': track, 'proc_track': dets} | |
def inference_video(avi_dir, work_dir, padding=0): | |
videofile = os.path.join(avi_dir, 'video.avi') | |
vidObj = cv2.VideoCapture(videofile) | |
yolo_model = YOLO("yolov9m.pt") | |
global dets, fidx | |
dets = [] | |
fidx = 0 | |
print("Detecting people in the video using YOLO...") | |
def generate_detections(): | |
global dets, fidx | |
while True: | |
success, image = vidObj.read() | |
if not success: | |
break | |
image_np = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# Perform person detection | |
results = yolo_model(image_np, verbose=False) | |
detections = results[0].boxes | |
dets.append([]) | |
for i, det in enumerate(detections): | |
x1, y1, x2, y2 = det.xyxy[0].detach().cpu().numpy() | |
cls = det.cls[0].detach().cpu().numpy() | |
conf = det.conf[0].detach().cpu().numpy() | |
if int(cls) == 0 and conf>0.7: # Class 0 is 'person' in COCO dataset | |
x1 = max(0, int(x1) - padding) | |
y1 = max(0, int(y1) - padding) | |
x2 = min(image_np.shape[1], int(x2) + padding) | |
y2 = min(image_np.shape[0], int(y2) + padding) | |
dets[-1].append({'frame': fidx, 'bbox': [x1, y1, x2, y2], 'conf': conf}) | |
fidx += 1 | |
yield | |
return dets | |
for _ in tqdm(generate_detections()): | |
pass | |
print("Successfully detected people in the video") | |
savepath = os.path.join(work_dir, 'faces.pckl') | |
with open(savepath, 'wb') as fil: | |
pickle.dump(dets, fil) | |
return dets | |
def scene_detect(avi_dir, work_dir): | |
video_manager = VideoManager([os.path.join(avi_dir, 'video.avi')]) | |
stats_manager = StatsManager() | |
scene_manager = SceneManager(stats_manager) | |
scene_manager.add_detector(ContentDetector()) | |
base_timecode = video_manager.get_base_timecode() | |
video_manager.set_downscale_factor() | |
video_manager.start() | |
scene_manager.detect_scenes(frame_source=video_manager) | |
scene_list = scene_manager.get_scene_list(base_timecode) | |
savepath = os.path.join(work_dir, 'scene.pckl') | |
if scene_list == []: | |
scene_list = [(video_manager.get_base_timecode(), video_manager.get_current_timecode())] | |
with open(savepath, 'wb') as fil: | |
pickle.dump(scene_list, fil) | |
print('%s - scenes detected %d' % (os.path.join(avi_dir, 'video.avi'), len(scene_list))) | |
return scene_list | |
def process_video_asd(file, sd_root, work_root, data_root, avi_dir, tmp_dir, work_dir, crop_dir, frames_dir): | |
video_file_name = os.path.basename(file.strip()) | |
sd_dest_folder = sd_root | |
work_dest_folder = work_root | |
del_folder(sd_dest_folder) | |
del_folder(work_dest_folder) | |
videofile = file | |
if os.path.exists(work_dir): | |
rmtree(work_dir) | |
if os.path.exists(crop_dir): | |
rmtree(crop_dir) | |
if os.path.exists(avi_dir): | |
rmtree(avi_dir) | |
if os.path.exists(frames_dir): | |
rmtree(frames_dir) | |
if os.path.exists(tmp_dir): | |
rmtree(tmp_dir) | |
os.makedirs(work_dir) | |
os.makedirs(crop_dir) | |
os.makedirs(avi_dir) | |
os.makedirs(frames_dir) | |
os.makedirs(tmp_dir) | |
command = ("ffmpeg -hide_banner -loglevel panic -y -i %s -qscale:v 2 -async 1 -r 25 %s" % (videofile, | |
os.path.join(avi_dir, | |
'video.avi'))) | |
status = subprocess.call(command, shell=True, stdout=None) | |
if status != 0: | |
msg = "Error in pre-processing the video, please check the input video and try again" | |
return msg | |
command = ("ffmpeg -hide_banner -loglevel panic -y -i %s -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (os.path.join(avi_dir, | |
'video.avi'), | |
os.path.join(avi_dir, | |
'audio.wav'))) | |
status = subprocess.call(command, shell=True, stdout=None) | |
if status != 0: | |
msg = "Error in pre-processing the video, please check the input video and try again" | |
return msg | |
try: | |
faces = inference_video(avi_dir, work_dir) | |
except: | |
msg = "Error in pre-processing the video, please check the input video and try again" | |
return msg | |
print("YOLO done") | |
print("Detecting scenes in the video...") | |
try: | |
scene = scene_detect(avi_dir, work_dir) | |
except: | |
msg = "Error in detecting the scenes in the video, please check the input video and try again" | |
return msg | |
print("Scene detect done") | |
print("Tracking video...") | |
allscenes = [] | |
for shot in scene: | |
if shot[1].frame_num - shot[0].frame_num >= min_track: | |
allscenes.append(track_shot(faces[shot[0].frame_num:shot[1].frame_num])) | |
print("Cropping video...") | |
alltracks = [] | |
for sc_num in range(len(allscenes)): | |
vidtracks = [] | |
for ii, track in enumerate(allscenes[sc_num]): | |
os.makedirs(os.path.join(crop_dir, 'scene_'+str(sc_num)), exist_ok=True) | |
vidtracks.append(crop_video(avi_dir, tmp_dir, track, os.path.join(crop_dir, 'scene_'+str(sc_num), '%05d' % ii))) | |
alltracks.append(vidtracks) | |
savepath = os.path.join(work_dir, 'tracks.pckl') | |
with open(savepath, 'wb') as fil: | |
pickle.dump(alltracks, fil) | |
rmtree(tmp_dir) | |
rmtree(avi_dir) | |
rmtree(frames_dir) | |
copytree(crop_dir, sd_dest_folder) | |
copytree(work_dir, work_dest_folder) | |
return "success" | |
def get_person_detection(all_frames, frame_count, padding=20): | |
try: | |
# Load YOLOv9 model (pre-trained on COCO dataset) | |
yolo_model = YOLO("yolov9s.pt") | |
print("Loaded the YOLO model") | |
person_videos = {} | |
person_tracks = {} | |
print("Processing the frames...") | |
for frame_idx in tqdm(range(frame_count)): | |
frame = all_frames[frame_idx] | |
# Perform person detection | |
results = yolo_model(frame, verbose=False) | |
detections = results[0].boxes | |
for i, det in enumerate(detections): | |
x1, y1, x2, y2 = det.xyxy[0] | |
cls = det.cls[0] | |
if int(cls) == 0: # Class 0 is 'person' in COCO dataset | |
x1 = max(0, int(x1) - padding) | |
y1 = max(0, int(y1) - padding) | |
x2 = min(frame.shape[1], int(x2) + padding) | |
y2 = min(frame.shape[0], int(y2) + padding) | |
if i not in person_videos: | |
person_videos[i] = [] | |
person_tracks[i] = [] | |
person_videos[i].append(frame) | |
person_tracks[i].append([x1,y1,x2,y2]) | |
num_persons = 0 | |
for i in person_videos.keys(): | |
if len(person_videos[i]) >= frame_count//2: | |
num_persons+=1 | |
if num_persons==0: | |
msg = "No person detected in the video! Please give a video with one person as input" | |
return None, None, msg | |
if num_persons>1: | |
msg = "More than one person detected in the video! Please give a video with only one person as input" | |
return None, None, msg | |
except: | |
msg = "Error in detecting person in the video, please check the input video and try again" | |
return None, None, msg | |
return person_videos, person_tracks, "success" | |
def preprocess_video(path, result_folder, apply_preprocess, padding=20): | |
''' | |
This function preprocesses the input video to extract the audio and crop the frames using YOLO model | |
Args: | |
- path (string) : Path of the input video file | |
- result_folder (string) : Path of the folder to save the extracted audio and cropped video | |
- padding (int) : Padding to add to the bounding box | |
Returns: | |
- wav_file (string) : Path of the extracted audio file | |
- fps (int) : FPS of the input video | |
- video_output (string) : Path of the cropped video file | |
- msg (string) : Message to be returned | |
''' | |
# Load all video frames | |
try: | |
vr = VideoReader(path, ctx=cpu(0)) | |
fps = vr.get_avg_fps() | |
frame_count = len(vr) | |
except: | |
msg = "Oops! Could not load the video. Please check the input video and try again." | |
return None, None, None, msg | |
if frame_count < 25: | |
msg = "Not enough frames to process! Please give a longer video as input" | |
return None, None, None, msg | |
# Extract the audio from the input video file using ffmpeg | |
wav_file = os.path.join(result_folder, "audio.wav") | |
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -async 1 -ac 1 -vn \ | |
-acodec pcm_s16le -ar 16000 %s -y' % (path, wav_file), shell=True) | |
if status != 0: | |
msg = "Oops! Could not load the audio file. Please check the input video and try again." | |
return None, None, None, msg | |
print("Extracted the audio from the video") | |
if apply_preprocess=="True": | |
all_frames = [] | |
for k in range(len(vr)): | |
all_frames.append(vr[k].asnumpy()) | |
all_frames = np.asarray(all_frames) | |
print("Extracted the frames for pre-processing") | |
person_videos, person_tracks, msg = get_person_detection(all_frames, frame_count, padding) | |
if msg != "success": | |
return None, None, None, msg | |
# For the person detected, crop the frame based on the bounding box | |
if len(person_videos[0]) > frame_count-10: | |
crop_filename = os.path.join(result_folder, "preprocessed_video.avi") | |
fourcc = cv2.VideoWriter_fourcc(*'DIVX') | |
# Get bounding box coordinates based on person_tracks[i] | |
max_x1 = min([track[0] for track in person_tracks[0]]) | |
max_y1 = min([track[1] for track in person_tracks[0]]) | |
max_x2 = max([track[2] for track in person_tracks[0]]) | |
max_y2 = max([track[3] for track in person_tracks[0]]) | |
max_width = max_x2 - max_x1 | |
max_height = max_y2 - max_y1 | |
out = cv2.VideoWriter(crop_filename, fourcc, fps, (max_width, max_height)) | |
for frame in person_videos[0]: | |
crop = frame[max_y1:max_y2, max_x1:max_x2] | |
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB) | |
out.write(crop) | |
out.release() | |
no_sound_video = crop_filename.split('.')[0] + '_nosound.mp4' | |
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (crop_filename, no_sound_video), shell=True) | |
if status != 0: | |
msg = "Oops! Could not preprocess the video. Please check the input video and try again." | |
return None, None, None, msg | |
video_output = crop_filename.split('.')[0] + '.mp4' | |
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 %s' % | |
(wav_file , no_sound_video, video_output), shell=True) | |
if status != 0: | |
msg = "Oops! Could not preprocess the video. Please check the input video and try again." | |
return None, None, None, msg | |
os.remove(crop_filename) | |
os.remove(no_sound_video) | |
print("Successfully saved the pre-processed video: ", video_output) | |
else: | |
msg = "Could not track the person in the full video! Please give a single-speaker video as input" | |
return None, None, None, msg | |
else: | |
video_output = path | |
return wav_file, fps, video_output, "success" | |
def resample_video(video_file, video_fname, result_folder): | |
''' | |
This function resamples the video to 25 fps | |
Args: | |
- video_file (string) : Path of the input video file | |
- video_fname (string) : Name of the input video file | |
- result_folder (string) : Path of the folder to save the resampled video | |
Returns: | |
- video_file_25fps (string) : Path of the resampled video file | |
- msg (string) : Message to be returned | |
''' | |
video_file_25fps = os.path.join(result_folder, '{}.mp4'.format(video_fname)) | |
# Resample the video to 25 fps | |
status = subprocess.call("ffmpeg -hide_banner -loglevel panic -y -i {} -c:v libx264 -preset veryslow -crf 0 -filter:v fps=25 -pix_fmt yuv420p {}".format(video_file, video_file_25fps), shell=True) | |
if status != 0: | |
msg = "Oops! Could not resample the video to 25 FPS. Please check the input video and try again." | |
return None, msg | |
print('Resampled the video to 25 fps: {}'.format(video_file_25fps)) | |
return video_file_25fps, "success" | |
def load_checkpoint(path, model): | |
''' | |
This function loads the trained model from the checkpoint | |
Args: | |
- path (string) : Path of the checkpoint file | |
- model (object) : Model object | |
Returns: | |
- model (object) : Model object with the weights loaded from the checkpoint | |
''' | |
# Load the checkpoint | |
checkpoint = torch.load(path, map_location="cpu") | |
s = checkpoint["state_dict"] | |
new_s = {} | |
for k, v in s.items(): | |
new_s[k.replace('module.', '')] = v | |
model.load_state_dict(new_s) | |
print("Loaded checkpoint from: {}".format(path)) | |
return model.eval() | |
def load_video_frames(video_file): | |
''' | |
This function extracts the frames from the video | |
Args: | |
- video_file (string) : Path of the video file | |
Returns: | |
- frames (list) : List of frames extracted from the video | |
- msg (string) : Message to be returned | |
''' | |
# Read the video | |
try: | |
vr = VideoReader(video_file, ctx=cpu(0)) | |
except: | |
msg = "Oops! Could not load the input video file" | |
return None, msg | |
# Extract the frames | |
frames = [] | |
for k in range(len(vr)): | |
frames.append(vr[k].asnumpy()) | |
frames = np.asarray(frames) | |
return frames, "success" | |
def get_keypoints(frames): | |
''' | |
This function extracts the keypoints from the frames using MediaPipe Holistic pipeline | |
Args: | |
- frames (list) : List of frames extracted from the video | |
Returns: | |
- kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames | |
- msg (string) : Message to be returned | |
''' | |
try: | |
holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) | |
resolution = frames[0].shape | |
all_frame_kps = [] | |
for frame in frames: | |
results = holistic.process(frame) | |
pose, left_hand, right_hand, face = None, None, None, None | |
if results.pose_landmarks is not None: | |
pose = protobuf_to_dict(results.pose_landmarks)['landmark'] | |
if results.left_hand_landmarks is not None: | |
left_hand = protobuf_to_dict(results.left_hand_landmarks)['landmark'] | |
if results.right_hand_landmarks is not None: | |
right_hand = protobuf_to_dict(results.right_hand_landmarks)['landmark'] | |
if results.face_landmarks is not None: | |
face = protobuf_to_dict(results.face_landmarks)['landmark'] | |
frame_dict = {"pose":pose, "left_hand":left_hand, "right_hand":right_hand, "face":face} | |
all_frame_kps.append(frame_dict) | |
kp_dict = {"kps":all_frame_kps, "resolution":resolution} | |
except Exception as e: | |
print("Error: ", e) | |
return None, "Error: Could not extract keypoints from the frames" | |
return kp_dict, "success" | |
def check_visible_gestures(kp_dict): | |
''' | |
This function checks if the gestures in the video are visible | |
Args: | |
- kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames | |
Returns: | |
- msg (string) : Message to be returned | |
''' | |
keypoints = kp_dict['kps'] | |
keypoints = np.array(keypoints) | |
if len(keypoints)<25: | |
msg = "Not enough keypoints to process! Please give a longer video as input" | |
return msg | |
pose_count, hand_count = 0, 0 | |
for frame_kp_dict in keypoints: | |
pose = frame_kp_dict["pose"] | |
left_hand = frame_kp_dict["left_hand"] | |
right_hand = frame_kp_dict["right_hand"] | |
if pose is None: | |
pose_count += 1 | |
if left_hand is None and right_hand is None: | |
hand_count += 1 | |
if hand_count/len(keypoints) > 0.6 or pose_count/len(keypoints) > 0.6: | |
msg = "The gestures in the input video are not visible! Please give a video with visible gestures as input." | |
return msg | |
print("Successfully verified the input video - Gestures are visible!") | |
return "success" | |
def load_rgb_masked_frames(input_frames, kp_dict, asd=False, stride=1, window_frames=25, width=480, height=270): | |
''' | |
This function masks the faces using the keypoints extracted from the frames | |
Args: | |
- input_frames (list) : List of frames extracted from the video | |
- kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames | |
- asd (bool) : Whether to use padding (needed for active speaker detection task) or not | |
- stride (int) : Stride to extract the frames | |
- window_frames (int) : Number of frames in each window that is given as input to the model | |
- width (int) : Width of the frames | |
- height (int) : Height of the frames | |
Returns: | |
- input_frames (array) : Frame window to be given as input to the model | |
- num_frames (int) : Number of frames to extract | |
- orig_masked_frames (array) : Masked frames extracted from the video | |
- msg (string) : Message to be returned | |
''' | |
print("Creating masked input frames...") | |
input_frames_masked = [] | |
if kp_dict is None: | |
for img in tqdm(input_frames): | |
img = cv2.resize(img, (width, height)) | |
masked_img = cv2.rectangle(img, (0,0), (width,110), (0,0,0), -1) | |
input_frames_masked.append(masked_img) | |
else: | |
# Face indices to extract the face-coordinates needed for masking | |
face_oval_idx = [10, 21, 54, 58, 67, 93, 103, 109, 127, 132, 136, 148, 149, 150, 152, 162, 172, | |
176, 234, 251, 284, 288, 297, 323, 332, 338, 356, 361, 365, 377, 378, 379, 389, 397, 400, 454] | |
input_keypoints, resolution = kp_dict['kps'], kp_dict['resolution'] | |
print("Input keypoints: ", len(input_keypoints)) | |
for i, frame_kp_dict in tqdm(enumerate(input_keypoints)): | |
img = input_frames[i] | |
face = frame_kp_dict["face"] | |
if face is None: | |
img = cv2.resize(img, (width, height)) | |
masked_img = cv2.rectangle(img, (0,0), (width,110), (0,0,0), -1) | |
else: | |
face_kps = [] | |
for idx in range(len(face)): | |
if idx in face_oval_idx: | |
x, y = int(face[idx]["x"]*resolution[1]), int(face[idx]["y"]*resolution[0]) | |
face_kps.append((x,y)) | |
face_kps = np.array(face_kps) | |
x1, y1 = min(face_kps[:,0]), min(face_kps[:,1]) | |
x2, y2 = max(face_kps[:,0]), max(face_kps[:,1]) | |
masked_img = cv2.rectangle(img, (0,0), (resolution[1],y2+15), (0,0,0), -1) | |
if masked_img.shape[0] != width or masked_img.shape[1] != height: | |
masked_img = cv2.resize(masked_img, (width, height)) | |
input_frames_masked.append(masked_img) | |
orig_masked_frames = np.array(input_frames_masked) | |
input_frames = np.array(input_frames_masked) / 255. | |
if asd: | |
input_frames = np.pad(input_frames, ((12, 12), (0,0), (0,0), (0,0)), 'edge') | |
input_frames = np.array([input_frames[i:i+window_frames, :, :] for i in range(0,input_frames.shape[0], stride) if (i+window_frames <= input_frames.shape[0])]) | |
print("Successfully created masked input frames") | |
num_frames = input_frames.shape[0] | |
if num_frames<10: | |
msg = "Not enough frames to process! Please give a longer video as input." | |
return None, None, None, msg | |
return input_frames, num_frames, orig_masked_frames, "success" | |
def load_spectrograms(wav_file, asd=False, num_frames=None, window_frames=25, stride=4): | |
''' | |
This function extracts the spectrogram from the audio file | |
Args: | |
- wav_file (string) : Path of the extracted audio file | |
- asd (bool) : Whether to use padding (needed for active speaker detection task) or not | |
- num_frames (int) : Number of frames to extract | |
- window_frames (int) : Number of frames in each window that is given as input to the model | |
- stride (int) : Stride to extract the audio frames | |
Returns: | |
- spec (array) : Spectrogram array window to be used as input to the model | |
- orig_spec (array) : Spectrogram array extracted from the audio file | |
- msg (string) : Message to be returned | |
''' | |
# Extract the audio from the input video file using ffmpeg | |
try: | |
wav = librosa.load(wav_file, sr=16000)[0] | |
except: | |
msg = "Oops! Could extract the spectrograms from the audio file. Please check the input and try again." | |
return None, None, msg | |
# Convert to tensor | |
wav = torch.FloatTensor(wav).unsqueeze(0) | |
mel, _, _, _ = wav2filterbanks(wav) | |
spec = mel.squeeze(0).cpu().numpy() | |
orig_spec = spec | |
spec = np.array([spec[i:i+(window_frames*stride), :] for i in range(0, spec.shape[0], stride) if (i+(window_frames*stride) <= spec.shape[0])]) | |
if num_frames is not None: | |
if len(spec) != num_frames: | |
spec = spec[:num_frames] | |
frame_diff = np.abs(len(spec) - num_frames) | |
if frame_diff > 60: | |
print("The input video and audio length do not match - The results can be unreliable! Please check the input video.") | |
if asd: | |
pad_frames = (window_frames//2) | |
spec = np.pad(spec, ((pad_frames, pad_frames), (0,0), (0,0)), 'edge') | |
return spec, orig_spec, "success" | |
def calc_optimal_av_offset(vid_emb, aud_emb, num_avg_frames, model): | |
''' | |
This function calculates the audio-visual offset between the video and audio | |
Args: | |
- vid_emb (array) : Video embedding array | |
- aud_emb (array) : Audio embedding array | |
- num_avg_frames (int) : Number of frames to average the scores | |
- model (object) : Model object | |
Returns: | |
- offset (int) : Optimal audio-visual offset | |
- msg (string) : Message to be returned | |
''' | |
pos_vid_emb, all_aud_emb, pos_idx, stride, status = create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames) | |
if status != "success": | |
return None, status | |
scores, _ = calc_av_scores(pos_vid_emb, all_aud_emb, model) | |
offset = scores.argmax()*stride - pos_idx | |
return offset.item(), "success" | |
def create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames, stride=5): | |
''' | |
This function creates all possible positive and negative audio embeddings to compare and obtain the sync offset | |
Args: | |
- vid_emb (array) : Video embedding array | |
- aud_emb (array) : Audio embedding array | |
- num_avg_frames (int) : Number of frames to average the scores | |
- stride (int) : Stride to extract the negative windows | |
Returns: | |
- vid_emb_pos (array) : Positive video embedding array | |
- aud_emb_posneg (array) : All possible combinations of audio embedding array | |
- pos_idx_frame (int) : Positive video embedding array frame | |
- stride (int) : Stride used to extract the negative windows | |
- msg (string) : Message to be returned | |
''' | |
slice_size = num_avg_frames | |
aud_emb_posneg = aud_emb.squeeze(1).unfold(-1, slice_size, stride) | |
aud_emb_posneg = aud_emb_posneg.permute([0, 2, 1, 3]) | |
aud_emb_posneg = aud_emb_posneg[:, :int(n_negative_samples/stride)+1] | |
pos_idx = (aud_emb_posneg.shape[1]//2) | |
pos_idx_frame = pos_idx*stride | |
min_offset_frames = -(pos_idx)*stride | |
max_offset_frames = (aud_emb_posneg.shape[1] - pos_idx - 1)*stride | |
print("With the current video length and the number of average frames, the model can predict the offsets in the range: [{}, {}]".format(min_offset_frames, max_offset_frames)) | |
vid_emb_pos = vid_emb[:, :, pos_idx_frame:pos_idx_frame+slice_size] | |
if vid_emb_pos.shape[2] != slice_size: | |
msg = "Video is too short to use {} frames to average the scores. Please use a longer input video or reduce the number of average frames".format(slice_size) | |
return None, None, None, None, msg | |
return vid_emb_pos, aud_emb_posneg, pos_idx_frame, stride, "success" | |
def calc_av_scores(vid_emb, aud_emb, model): | |
''' | |
This function calls functions to calculate the audio-visual similarity and attention map between the video and audio embeddings | |
Args: | |
- vid_emb (array) : Video embedding array | |
- aud_emb (array) : Audio embedding array | |
- model (object) : Model object | |
Returns: | |
- scores (array) : Audio-visual similarity scores | |
- att_map (array) : Attention map | |
''' | |
scores = calc_att_map(vid_emb, aud_emb, model) | |
att_map = logsoftmax_2d(torch.Tensor(scores)) | |
scores = scores.mean(-1) | |
return scores, att_map | |
def calc_att_map(vid_emb, aud_emb, model): | |
''' | |
This function calculates the similarity between the video and audio embeddings | |
Args: | |
- vid_emb (array) : Video embedding array | |
- aud_emb (array) : Audio embedding array | |
- model (object) : Model object | |
Returns: | |
- scores (array) : Audio-visual similarity scores | |
''' | |
vid_emb = vid_emb[:, :, None] | |
aud_emb = aud_emb.transpose(1, 2) | |
scores = run_func_in_parts(lambda x, y: (x * y).sum(1), | |
vid_emb, | |
aud_emb, | |
part_len=10, | |
dim=3) | |
scores = model.logits_scale(scores[..., None]).squeeze(-1) | |
return scores.detach().cpu().numpy() | |
def generate_video(frames, audio_file, video_fname): | |
''' | |
This function generates the video from the frames and audio file | |
Args: | |
- frames (array) : Frames to be used to generate the video | |
- audio_file (string) : Path of the audio file | |
- video_fname (string) : Path of the video file | |
Returns: | |
- video_output (string) : Path of the video file | |
- msg (string) : Message to be returned | |
''' | |
fname = 'inference.avi' | |
video = cv2.VideoWriter(fname, cv2.VideoWriter_fourcc(*'DIVX'), 25, (frames[0].shape[1], frames[0].shape[0])) | |
for i in range(len(frames)): | |
video.write(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB)) | |
video.release() | |
no_sound_video = video_fname + '_nosound.mp4' | |
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (fname, no_sound_video), shell=True) | |
if status != 0: | |
msg = "Oops! Could not generate the video. Please check the input video and try again." | |
return None, msg | |
video_output = video_fname + '.mp4' | |
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -c:v libx264 -preset veryslow -crf 18 -pix_fmt yuv420p -strict -2 -q:v 1 -shortest %s' % | |
(audio_file, no_sound_video, video_output), shell=True) | |
if status != 0: | |
msg = "Oops! Could not generate the video. Please check the input video and try again." | |
return None, msg | |
os.remove(fname) | |
os.remove(no_sound_video) | |
return video_output, "success" | |
def sync_correct_video(video_path, frames, wav_file, offset, result_folder, sample_rate=16000, fps=25): | |
''' | |
This function corrects the video and audio to sync with each other | |
Args: | |
- video_path (string) : Path of the video file | |
- frames (array) : Frames to be used to generate the video | |
- wav_file (string) : Path of the audio file | |
- offset (int) : Predicted sync-offset to be used to correct the video | |
- result_folder (string) : Path of the result folder to save the output sync-corrected video | |
- sample_rate (int) : Sample rate of the audio | |
- fps (int) : Frames per second of the video | |
Returns: | |
- video_output (string) : Path of the video file | |
- msg (string) : Message to be returned | |
''' | |
if offset == 0: | |
print("The input audio and video are in-sync! No need to perform sync correction.") | |
return video_path, "success" | |
print("Performing Sync Correction...") | |
corrected_frames = np.zeros_like(frames) | |
if offset > 0: | |
audio_offset = int(offset*(sample_rate/fps)) | |
wav = librosa.core.load(wav_file, sr=sample_rate)[0] | |
corrected_wav = wav[audio_offset:] | |
corrected_wav_file = os.path.join(result_folder, "audio_sync_corrected.wav") | |
write(corrected_wav_file, sample_rate, corrected_wav) | |
wav_file = corrected_wav_file | |
corrected_frames = frames | |
elif offset < 0: | |
corrected_frames[0:len(frames)+offset] = frames[np.abs(offset):] | |
corrected_frames = corrected_frames[:len(frames)-np.abs(offset)] | |
corrected_video_path = os.path.join(result_folder, "result_sync_corrected") | |
video_output, status = generate_video(corrected_frames, wav_file, corrected_video_path) | |
if status != "success": | |
return None, status | |
return video_output, "success" | |
def load_masked_input_frames(test_videos, spec, wav_file, scene_num, result_folder): | |
''' | |
This function loads the masked input frames from the video | |
Args: | |
- test_videos (list) : List of videos to be processed (speaker-specific tracks) | |
- spec (array) : Spectrogram of the audio | |
- wav_file (string) : Path of the audio file | |
- scene_num (int) : Scene number to be used to save the input masked video | |
- result_folder (string) : Path of the folder to save the input masked video | |
Returns: | |
- all_frames (list) : List of masked input frames window to be used as input to the model | |
- all_orig_frames (list) : List of original masked input frames | |
''' | |
all_frames, all_orig_frames = [], [] | |
for video_num, video in enumerate(test_videos): | |
print("Processing video: ", video) | |
# Load the video frames | |
frames, status = load_video_frames(video) | |
if status != "success": | |
return None, None, status | |
print("Successfully loaded the video frames") | |
# Extract the keypoints from the frames | |
# kp_dict, status = get_keypoints(frames) | |
# if status != "success": | |
# return None, None, status | |
# print("Successfully extracted the keypoints") | |
# Mask the frames using the keypoints extracted from the frames and prepare the input to the model | |
masked_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict=None, asd=True) | |
if status != "success": | |
return None, None, status | |
print("Successfully loaded the masked frames") | |
# Check if the length of the input frames is equal to the length of the spectrogram | |
if spec.shape[2]!=masked_frames.shape[0]: | |
num_frames = spec.shape[2] | |
masked_frames = masked_frames[:num_frames] | |
orig_masked_frames = orig_masked_frames[:num_frames] | |
frame_diff = np.abs(spec.shape[2] - num_frames) | |
if frame_diff > 60: | |
print("The input video and audio length do not match - The results can be unreliable! Please check the input video.") | |
# Transpose the frames to the correct format | |
frames = np.transpose(masked_frames, (4, 0, 1, 2, 3)) | |
frames = torch.FloatTensor(np.array(frames)).unsqueeze(0) | |
print("Successfully converted the frames to tensor") | |
all_frames.append(frames) | |
all_orig_frames.append(orig_masked_frames) | |
return all_frames, all_orig_frames, "success" | |
def extract_audio(video, result_folder): | |
''' | |
This function extracts the audio from the video file | |
Args: | |
- video (string) : Path of the video file | |
- result_folder (string) : Path of the folder to save the extracted audio file | |
Returns: | |
- wav_file (string) : Path of the extracted audio file | |
- msg (string) : Message to be returned | |
''' | |
wav_file = os.path.join(result_folder, "audio.wav") | |
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -threads 1 -y -i %s -async 1 -ac 1 -vn \ | |
-acodec pcm_s16le -ar 16000 %s' % (video, wav_file), shell=True) | |
if status != 0: | |
msg = "Oops! Could not load the audio file in the given input video. Please check the input and try again" | |
return None, msg | |
return wav_file, "success" | |
def get_embeddings(video_sequences, audio_sequences, model, asd=False, calc_aud_emb=True): | |
''' | |
This function extracts the video and audio embeddings from the input frames and audio sequences | |
Args: | |
- video_sequences (array) : Array of video frames to be used as input to the model | |
- audio_sequences (array) : Array of audio frames to be used as input to the model | |
- model (object) : Model object | |
- asd (bool) : Active speaker detection task flag to return the correct dimensions for the embeddings | |
- calc_aud_emb (bool) : Flag to calculate the audio embedding | |
Returns: | |
- video_emb (array) : Video embedding | |
- audio_emb (array) : Audio embedding | |
''' | |
video_emb = [] | |
audio_emb = [] | |
for i in range(0, len(video_sequences), batch_size): | |
video_inp = video_sequences[i:i+batch_size, ] | |
vid_emb = model.forward_vid(video_inp, return_feats=False) | |
vid_emb = torch.mean(vid_emb, axis=-1) | |
if not asd: | |
vid_emb = vid_emb.unsqueeze(-1) | |
video_emb.extend(vid_emb.detach().cpu().numpy()) | |
if calc_aud_emb: | |
audio_inp = audio_sequences[i:i+batch_size, ] | |
aud_emb = model.forward_aud(audio_inp) | |
audio_emb.extend(aud_emb.detach().cpu().numpy()) | |
torch.cuda.empty_cache() | |
video_emb = np.array(video_emb) | |
print("Video Embedding Shape: ", video_emb.shape) | |
if calc_aud_emb: | |
audio_emb = np.array(audio_emb) | |
print("Audio Embedding Shape: ", audio_emb.shape) | |
return video_emb, audio_emb | |
return video_emb | |
def predict_active_speaker(all_video_embeddings, audio_embedding, global_score, num_avg_frames, model): | |
''' | |
This function predicts the active speaker in each frame | |
Args: | |
- all_video_embeddings (array) : Array of video embeddings of all speakers | |
- audio_embedding (array) : Audio embedding | |
- global_score (bool) : Flag to calculate the global score | |
- num_avg_frames (int) : Number of frames to average the scores | |
- model (object) : Model object | |
Returns: | |
- pred_speaker (list) : List of active speakers in each frame | |
- num_avg_frames (int) : Number of frames to average the scores | |
''' | |
cos = nn.CosineSimilarity(dim=1) | |
audio_embedding = torch.tensor(audio_embedding).squeeze(2) | |
scores = [] | |
for i in range(len(all_video_embeddings)): | |
video_embedding = torch.tensor(all_video_embeddings[i]) | |
# Compute the similarity of each speaker's video embeddings with the audio embedding | |
sim = cos(video_embedding, audio_embedding) | |
# Apply the logits scale to the similarity scores (scaling the scores) | |
output = model.logits_scale(sim.unsqueeze(-1)).squeeze(-1) | |
if global_score=="True": | |
score = output.mean(0) | |
else: | |
if output.shape[0]<num_avg_frames: | |
num_avg_frames = output.shape[0] | |
output_batch = output.unfold(0, num_avg_frames, 1) | |
score = torch.mean(output_batch, axis=-1) | |
scores.append(score.detach().cpu().numpy()) | |
if global_score=="True": | |
print("Using global predictions") | |
pred_speaker = np.argmax(scores) | |
else: | |
print("Using per-frame predictions") | |
pred_speaker = [] | |
num_negs = list(range(0, len(all_video_embeddings))) | |
for frame_idx in range(len(scores[0])): | |
score = [scores[i][frame_idx] for i in num_negs] | |
pred_idx = np.argmax(score) | |
pred_speaker.append(pred_idx) | |
return pred_speaker, num_avg_frames | |
def save_video(output_tracks, input_frames, wav_file, result_folder): | |
''' | |
This function saves the output video with the active speaker detections | |
Args: | |
- output_tracks (list) : List of active speakers in each frame | |
- input_frames (array) : Frames to be used to generate the video | |
- wav_file (string) : Path of the audio file | |
- result_folder (string) : Path of the result folder to save the output video | |
Returns: | |
- video_output (string) : Path of the output video | |
- msg (string) : Message to be returned | |
''' | |
try: | |
output_frames = [] | |
for i in range(len(input_frames)): | |
# If the active speaker is found, draw a bounding box around the active speaker | |
if i in output_tracks: | |
bbox = output_tracks[i] | |
x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) | |
out = cv2.rectangle(input_frames[i].copy(), (x1, y1), (x2, y2), color=[0, 255, 0], thickness=3) | |
else: | |
out = input_frames[i] | |
output_frames.append(out) | |
# Generate the output video | |
output_video_fname = os.path.join(result_folder, "result_active_speaker_det") | |
video_output, status = generate_video(output_frames, wav_file, output_video_fname) | |
if status != "success": | |
return None, status | |
except Exception as e: | |
return None, f"Error: {str(e)}" | |
return video_output, "success" | |
def preprocess_asd(video_path, result_folder_input): | |
''' | |
This function preprocesses the video for the active speaker detection task | |
Args: | |
- video_path (string) : Path of the video file | |
- result_folder_input (string) : Path of the folder to save the input video | |
Returns: | |
- msg (string) : Message to be returned | |
''' | |
file = video_path | |
data_dir = os.path.join(result_folder_input, 'temp') | |
sd_root = os.path.join(result_folder_input, 'crops') | |
work_root = os.path.join(result_folder_input, 'metadata') | |
data_root = result_folder_input | |
os.makedirs(sd_root, exist_ok=True) | |
os.makedirs(work_root, exist_ok=True) | |
avi_dir = os.path.join(data_dir, 'pyavi') | |
tmp_dir = os.path.join(data_dir, 'pytmp') | |
work_dir = os.path.join(data_dir, 'pywork') | |
crop_dir = os.path.join(data_dir, 'pycrop') | |
frames_dir = os.path.join(data_dir, 'pyframes') | |
status = process_video_asd(file, sd_root, work_root, data_root, avi_dir, tmp_dir, work_dir, crop_dir, frames_dir) | |
if status != "success": | |
return status | |
print("Successfully pre-processed the video") | |
return "success" | |
def process_video_syncoffset(video_path, num_avg_frames, apply_preprocess): | |
''' | |
This function processes the video for the sync offset prediction task | |
Args: | |
- video_path (string) : Path of the video file | |
- num_avg_frames (int) : Number of frames to average the scores | |
- apply_preprocess (bool) : Flag to apply the pre-processing steps or not | |
Returns: | |
- video_output (string) : Path of the output video | |
- msg (string) : Message to be returned | |
''' | |
try: | |
# Extract the video filename | |
video_fname = os.path.basename(video_path.split(".")[0]) | |
# Create folders to save the inputs and results | |
result_folder = os.path.join("results", video_fname) | |
result_folder_input = os.path.join(result_folder, "input") | |
result_folder_output = os.path.join(result_folder, "output") | |
if os.path.exists(result_folder): | |
rmtree(result_folder) | |
os.makedirs(result_folder) | |
os.makedirs(result_folder_input) | |
os.makedirs(result_folder_output) | |
# Preprocess the video | |
print("Applying preprocessing: ", apply_preprocess) | |
wav_file, fps, vid_path_processed, status = preprocess_video(video_path, result_folder_input, apply_preprocess) | |
if status != "success": | |
return None, status | |
print("Successfully preprocessed the video") | |
# Resample the video to 25 fps if it is not already 25 fps | |
print("FPS of video: ", fps) | |
if fps!=25: | |
vid_path, status = resample_video(vid_path_processed, "preprocessed_video_25fps", result_folder_input) | |
if status != "success": | |
return None, status | |
orig_vid_path_25fps, status = resample_video(video_path, "input_video_25fps", result_folder_input) | |
if status != "success": | |
return None, status | |
else: | |
vid_path = vid_path_processed | |
orig_vid_path_25fps = video_path | |
# Load the original video frames (before pre-processing) - Needed for the final sync-correction | |
orig_frames, status = load_video_frames(orig_vid_path_25fps) | |
if status != "success": | |
return None, status | |
# Load the pre-processed video frames | |
frames, status = load_video_frames(vid_path) | |
if status != "success": | |
return None, status | |
print("Successfully extracted the video frames") | |
if len(frames) < num_avg_frames: | |
msg = "Error: The input video is too short. Please use a longer input video." | |
return None, msg | |
# Load keypoints and check if gestures are visible | |
kp_dict, status = get_keypoints(frames) | |
if status != "success": | |
return None, status | |
print("Successfully extracted the keypoints: ", len(kp_dict), len(kp_dict["kps"])) | |
status = check_visible_gestures(kp_dict) | |
if status != "success": | |
return None, status | |
# Load RGB frames | |
rgb_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict, asd=False, window_frames=25, width=480, height=270) | |
if status != "success": | |
return None, status | |
print("Successfully loaded the RGB frames") | |
# Convert frames to tensor | |
rgb_frames = np.transpose(rgb_frames, (4, 0, 1, 2, 3)) | |
rgb_frames = torch.FloatTensor(rgb_frames).unsqueeze(0) | |
B = rgb_frames.size(0) | |
print("Successfully converted the frames to tensor") | |
# Load spectrograms | |
spec, orig_spec, status = load_spectrograms(wav_file, asd=False, num_frames=num_frames) | |
if status != "success": | |
return None, status | |
spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0, 1, 2, 4, 3) | |
print("Successfully loaded the spectrograms") | |
# Create input windows | |
video_sequences = torch.cat([rgb_frames[:, :, i] for i in range(rgb_frames.size(2))], dim=0) | |
audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0) | |
# Load the trained model | |
model = Transformer_RGB() | |
model = load_checkpoint(CHECKPOINT_PATH, model) | |
print("Successfully loaded the model") | |
# Extract embeddings | |
print("Obtaining audio and video embeddings...") | |
video_emb, audio_emb = get_embeddings(video_sequences, audio_sequences, model, asd=False, calc_aud_emb=True) | |
# L2 normalize embeddings | |
print("Normalizing embeddings") | |
video_emb = torch.tensor(video_emb) | |
video_emb = torch.nn.functional.normalize(video_emb, p=2, dim=1) | |
audio_emb = torch.tensor(audio_emb) | |
audio_emb = torch.nn.functional.normalize(audio_emb, p=2, dim=1) | |
audio_emb = torch.split(audio_emb, B, dim=0) | |
audio_emb = torch.stack(audio_emb, dim=2) | |
audio_emb = audio_emb.squeeze(3) | |
audio_emb = audio_emb[:, None] | |
video_emb = torch.split(video_emb, B, dim=0) | |
video_emb = torch.stack(video_emb, dim=2) | |
video_emb = video_emb.squeeze(3) | |
print("Successfully extracted GestSync embeddings") | |
# Calculate sync offset | |
print("Calculating sync offset...") | |
pred_offset, status = calc_optimal_av_offset(video_emb, audio_emb, num_avg_frames, model) | |
if status != "success": | |
return None, status | |
print("Predicted offset: ", pred_offset) | |
# Generate sync-corrected video | |
video_output, status = sync_correct_video(video_path, orig_frames, wav_file, pred_offset, result_folder_output, sample_rate=16000, fps=fps) | |
if status != "success": | |
return None, status | |
print("Successfully generated the video:", video_output) | |
return video_output, f"Predicted offset: {pred_offset}" | |
except Exception as e: | |
return None, f"Error: {str(e)}" | |
def process_video_activespeaker(video_path, global_speaker, num_avg_frames): | |
''' | |
This function processes the video for the active speaker detection task | |
Args: | |
- video_path (string) : Path of the video file | |
- global_speaker (string) : Flag to use global or per-frame predictions | |
- num_avg_frames (int) : Number of frames to average the scores | |
Returns: | |
- video_output (string) : Path of the output video | |
- msg (string) : Message to be returned | |
''' | |
try: | |
# Extract the video filename | |
video_fname = os.path.basename(video_path.split(".")[0]) | |
# Create folders to save the inputs and results | |
result_folder = os.path.join("results", video_fname) | |
result_folder_input = os.path.join(result_folder, "input") | |
result_folder_output = os.path.join(result_folder, "output") | |
if os.path.exists(result_folder): | |
rmtree(result_folder) | |
os.makedirs(result_folder) | |
os.makedirs(result_folder_input) | |
os.makedirs(result_folder_output) | |
if global_speaker=="per-frame-prediction" and num_avg_frames<25: | |
msg = "Number of frames to average need to be set to a minimum of 25 frames. Atleast 1-second context is needed for the model. Please change the num_avg_frames and try again..." | |
return None, msg | |
# Read the video | |
try: | |
vr = VideoReader(video_path, ctx=cpu(0)) | |
except: | |
msg = "Oops! Could not load the input video file" | |
return None, msg | |
# Get the FPS of the video | |
fps = vr.get_avg_fps() | |
print("FPS of video: ", fps) | |
# Resample the video to 25 FPS if the original video is of a different frame-rate | |
if fps!=25: | |
test_video_25fps, status = resample_video(video_path, video_fname, result_folder_input) | |
if status != "success": | |
return None, status | |
else: | |
test_video_25fps = video_path | |
# Load the video frames | |
orig_frames, status = load_video_frames(test_video_25fps) | |
if status != "success": | |
return None, status | |
# Extract and save the audio file | |
orig_wav_file, status = extract_audio(video_path, result_folder) | |
if status != "success": | |
return None, status | |
# Pre-process and extract per-speaker tracks in each scene | |
status = preprocess_asd(video_path, result_folder_input) | |
if status != "success": | |
return None, status | |
# Load the tracks file saved during pre-processing | |
with open('{}/metadata/tracks.pckl'.format(result_folder_input), 'rb') as file: | |
tracks = pickle.load(file) | |
# Create a dictionary of all tracks found along with the bounding-boxes | |
track_dict = {} | |
for scene_num in range(len(tracks)): | |
track_dict[scene_num] = {} | |
for i in range(len(tracks[scene_num])): | |
track_dict[scene_num][i] = {} | |
for frame_num, bbox in zip(tracks[scene_num][i]['track']['frame'], tracks[scene_num][i]['track']['bbox']): | |
track_dict[scene_num][i][frame_num] = bbox | |
# Get the total number of scenes | |
test_scenes = os.listdir("{}/crops".format(result_folder_input)) | |
print("Total scenes found in the input video = ", len(test_scenes)) | |
# Load the trained model | |
model = Transformer_RGB() | |
model = load_checkpoint(CHECKPOINT_PATH, model) | |
# Compute the active speaker in each scene | |
output_tracks = {} | |
for scene_num in tqdm(range(len(test_scenes))): | |
test_videos = glob(os.path.join("{}/crops".format(result_folder_input), "scene_{}".format(str(scene_num)), "*.avi")) | |
test_videos.sort(key=lambda x: int(os.path.basename(x).split('.')[0])) | |
print("Scene {} -> Total video files found (speaker-specific tracks) = {}".format(scene_num, len(test_videos))) | |
if len(test_videos)<=1: | |
msg = "To detect the active speaker, at least 2 visible speakers are required for each scene! Please check the input video and try again..." | |
return None, msg | |
# Load the audio file | |
audio_file = glob(os.path.join("{}/crops".format(result_folder_input), "scene_{}".format(str(scene_num)), "*.wav"))[0] | |
spec, _, status = load_spectrograms(audio_file, asd=True) | |
if status != "success": | |
return None, status | |
spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0,1,2,4,3) | |
print("Successfully loaded the spectrograms") | |
# Load the masked input frames | |
all_masked_frames, all_orig_masked_frames, status = load_masked_input_frames(test_videos, spec, audio_file, scene_num, result_folder_input) | |
if status != "success": | |
return None, status | |
print("Successfully loaded the masked input frames") | |
# Prepare the audio and video sequences for the model | |
audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0) | |
print("Obtaining audio and video embeddings...") | |
all_video_embs = [] | |
for idx in tqdm(range(len(all_masked_frames))): | |
with torch.no_grad(): | |
video_sequences = torch.cat([all_masked_frames[idx][:, :, i] for i in range(all_masked_frames[idx].size(2))], dim=0) | |
if idx==0: | |
video_emb, audio_emb = get_embeddings(video_sequences, audio_sequences, model, asd=True, calc_aud_emb=True) | |
else: | |
video_emb = get_embeddings(video_sequences, audio_sequences, model, asd=True, calc_aud_emb=False) | |
all_video_embs.append(video_emb) | |
print("Successfully extracted GestSync embeddings") | |
# Predict the active speaker in each scene | |
if global_speaker=="per-frame-prediction": | |
predictions, num_avg_frames = predict_active_speaker(all_video_embs, audio_emb, "False", num_avg_frames, model) | |
else: | |
predictions, _ = predict_active_speaker(all_video_embs, audio_emb, "True", num_avg_frames, model) | |
# Get the frames present in the scene | |
frames_scene = tracks[scene_num][0]['track']['frame'] | |
# Prepare the active speakers list to draw the bounding boxes | |
if global_speaker=="global-prediction": | |
print("Aggregating scores using global predictions") | |
active_speakers = [predictions]*len(frames_scene) | |
start, end = 0, len(frames_scene) | |
else: | |
print("Aggregating scores using per-frame predictions") | |
active_speakers = [0]*len(frames_scene) | |
mid = num_avg_frames//2 | |
if num_avg_frames%2==0: | |
frame_pred = len(frames_scene)-(mid*2)+1 | |
start, end = mid, len(frames_scene)-mid+1 | |
else: | |
frame_pred = len(frames_scene)-(mid*2) | |
start, end = mid, len(frames_scene)-mid | |
print("Frame scene: {} | Avg frames: {} | Frame predictions: {}".format(len(frames_scene), num_avg_frames, frame_pred)) | |
if len(predictions) != frame_pred: | |
msg = "Predicted frames {} and input video frames {} do not match!!".format(len(predictions), frame_pred) | |
return None, msg | |
active_speakers[start:end] = predictions[0:] | |
# Depending on the num_avg_frames, interpolate the intial and final frame predictions to get a full video output | |
initial_preds = max(set(predictions[:num_avg_frames]), key=predictions[:num_avg_frames].count) | |
active_speakers[0:start] = [initial_preds] * start | |
final_preds = max(set(predictions[-num_avg_frames:]), key=predictions[-num_avg_frames:].count) | |
active_speakers[end:] = [final_preds] * (len(frames_scene) - end) | |
start, end = 0, len(active_speakers) | |
# Get the output tracks for each frame | |
pred_idx = 0 | |
for frame in frames_scene[start:end]: | |
label = active_speakers[pred_idx] | |
pred_idx += 1 | |
output_tracks[frame] = track_dict[scene_num][label][frame] | |
# Save the output video | |
video_output, status = save_video(output_tracks, orig_frames.copy(), orig_wav_file, result_folder_output) | |
if status != "success": | |
return None, status | |
print("Successfully saved the output video: ", video_output) | |
return video_output, "success" | |
except Exception as e: | |
return None, f"Error: {str(e)}" | |
if __name__ == "__main__": | |
# Custom CSS and HTML | |
custom_css = """ | |
<style> | |
body { | |
background-color: #ffffff; | |
color: #333333; /* Default text color */ | |
} | |
.container { | |
max-width: 100% !important; | |
padding-left: 0 !important; | |
padding-right: 0 !important; | |
} | |
.header { | |
background-color: #f0f0f0; | |
color: #333333; | |
padding: 30px; | |
margin-bottom: 30px; | |
text-align: center; | |
font-family: 'Helvetica Neue', Arial, sans-serif; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
} | |
.header h1 { | |
font-size: 36px; | |
margin-bottom: 15px; | |
font-weight: bold; | |
color: #333333; /* Explicitly set heading color */ | |
} | |
.header h2 { | |
font-size: 24px; | |
margin-bottom: 10px; | |
color: #333333; /* Explicitly set subheading color */ | |
} | |
.header p { | |
font-size: 18px; | |
margin: 5px 0; | |
color: #666666; | |
} | |
.blue-text { | |
color: #4a90e2; | |
} | |
/* Custom styles for slider container */ | |
.slider-container { | |
background-color: white !important; | |
padding-top: 0.9em; | |
padding-bottom: 0.9em; | |
} | |
/* Add gap before examples */ | |
.examples-holder { | |
margin-top: 2em; | |
} | |
/* Set fixed size for example videos */ | |
.gradio-container .gradio-examples .gr-sample { | |
width: 240px !important; | |
height: 135px !important; | |
object-fit: cover; | |
display: inline-block; | |
margin-right: 10px; | |
} | |
.gradio-container .gradio-examples { | |
display: flex; | |
flex-wrap: wrap; | |
gap: 10px; | |
} | |
/* Ensure the parent container does not stretch */ | |
.gradio-container .gradio-examples { | |
max-width: 100%; | |
overflow: hidden; | |
} | |
/* Additional styles to ensure proper sizing in Safari */ | |
.gradio-container .gradio-examples .gr-sample img { | |
width: 240px !important; | |
height: 135px !important; | |
object-fit: cover; | |
} | |
</style> | |
""" | |
custom_html = custom_css + """ | |
<div class="header"> | |
<h1><span class="blue-text">GestSync:</span> Determining who is speaking without a talking head</h1> | |
<h2>Synchronization and Active Speaker Detection Demo</h2> | |
<p><a href='https://www.robots.ox.ac.uk/~vgg/research/gestsync/'>Project Page</a> | <a href='https://github.com/Sindhu-Hegde/gestsync'>Github</a> | <a href='https://arxiv.org/abs/2310.05304'>Paper</a></p> | |
</div> | |
""" | |
tips = """ | |
<div> | |
<br><br> | |
Please give us a 🌟 on <a href='https://github.com/Sindhu-Hegde/gestsync'>Github</a> if you like our work! | |
Tips to get better results: | |
<ul> | |
<li>Number of Average Frames: Higher the number, better the results.</li> | |
<li>Clicking on "apply pre-processing" will give better results for synchornization, but this is an expensive operation and might take a while.</li> | |
<li>Input videos with clearly visible gestures work better.</li> | |
</ul> | |
Inference time: | |
<ul> | |
<li>Synchronization-correction: ~1 minute for a 10-second video</li> | |
<li>Active-speaker-detection: ~2 minutes for a 10-second video</li> | |
</ul> | |
Note: Occasionally, there may be a delay in acquiring a GPU, as the model runs on a free community GPU from ZeroGPU. | |
</div> | |
""" | |
# Define functions | |
def toggle_slider(global_speaker): | |
if global_speaker == "per-frame-prediction": | |
return gr.update(visible=True) | |
else: | |
return gr.update(visible=False) | |
def toggle_demo(demo_choice): | |
if demo_choice == "Synchronization-correction": | |
return ( | |
gr.update(value=None, visible=True), # video_input | |
gr.update(value=75, visible=True), # num_avg_frames | |
gr.update(value=None, visible=True), # apply_preprocess | |
gr.update(value="global-prediction", visible=False), # global_speaker | |
gr.update(value=None, visible=True), # output_video | |
gr.update(value="", visible=True), # result_text | |
gr.update(visible=True), # submit_button | |
gr.update(visible=True), # clear_button | |
gr.update(visible=True), # sync_examples | |
gr.update(visible=False), # asd_examples | |
gr.update(visible=True) # tips | |
) | |
else: | |
return ( | |
gr.update(value=None, visible=True), # video_input | |
gr.update(value=75, visible=True), # num_avg_frames | |
gr.update(value=None, visible=False), # apply_preprocess | |
gr.update(value="global-prediction", visible=True), # global_speaker | |
gr.update(value=None, visible=True), # output_video | |
gr.update(value="", visible=True), # result_text | |
gr.update(visible=True), # submit_button | |
gr.update(visible=True), # clear_button | |
gr.update(visible=False), # sync_examples | |
gr.update(visible=True), # asd_examples | |
gr.update(visible=True) # tips | |
) | |
def clear_inputs(): | |
return None, None, "global-prediction", 75, None, "", None | |
def process_video(video_input, demo_choice, global_speaker, num_avg_frames, apply_preprocess): | |
if demo_choice == "Synchronization-correction": | |
return process_video_syncoffset(video_input, num_avg_frames, apply_preprocess) | |
else: | |
return process_video_activespeaker(video_input, global_speaker, num_avg_frames) | |
# Define paths to sample videos | |
sync_sample_videos = [ | |
["samples/sync_sample_1.mp4"], | |
["samples/sync_sample_2.mp4"] | |
] | |
asd_sample_videos = [ | |
["samples/asd_sample_1.mp4"], | |
["samples/asd_sample_2.mp4"] | |
] | |
# Define Gradio interface | |
with gr.Blocks(css=custom_css, theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.pink)) as demo: | |
gr.HTML(custom_html) | |
demo_choice = gr.Radio( | |
choices=["Synchronization-correction", "Active-speaker-detection"], | |
label="Please select the task you want to perform" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video(label="Upload Video", height=400, visible=False) | |
num_avg_frames = gr.Slider( | |
minimum=50, | |
maximum=150, | |
step=5, | |
value=75, | |
label="Number of Average Frames", | |
visible=False | |
) | |
apply_preprocess = gr.Checkbox(label="Apply Preprocessing", value=False, visible=False) | |
global_speaker = gr.Radio( | |
choices=["global-prediction", "per-frame-prediction"], | |
value="global-prediction", | |
label="Global Speaker Prediction", | |
visible=False | |
) | |
global_speaker.change( | |
fn=toggle_slider, | |
inputs=global_speaker, | |
outputs=num_avg_frames | |
) | |
with gr.Column(): | |
output_video = gr.Video(label="Output Video", height=400, visible=False) | |
result_text = gr.Textbox(label="Result", visible=False) | |
with gr.Row(): | |
submit_button = gr.Button("Submit", variant="primary", visible=False) | |
clear_button = gr.Button("Clear", visible=False) | |
# Add a gap before examples | |
gr.HTML('<div class="examples-holder"></div>') | |
# Add examples that only populate the video input | |
sync_examples = gr.Dataset( | |
samples=sync_sample_videos, | |
components=[video_input], | |
type="values", | |
visible=False | |
) | |
asd_examples = gr.Dataset( | |
samples=asd_sample_videos, | |
components=[video_input], | |
type="values", | |
visible=False | |
) | |
tips = gr.Markdown(tips, visible=False) | |
demo_choice.change( | |
fn=toggle_demo, | |
inputs=demo_choice, | |
outputs=[video_input, num_avg_frames, apply_preprocess, global_speaker, output_video, result_text, submit_button, clear_button, sync_examples, asd_examples, tips] | |
) | |
sync_examples.select( | |
fn=lambda x: gr.update(value=x[0], visible=True), | |
inputs=sync_examples, | |
outputs=video_input | |
) | |
asd_examples.select( | |
fn=lambda x: gr.update(value=x[0], visible=True), | |
inputs=asd_examples, | |
outputs=video_input | |
) | |
submit_button.click( | |
fn=process_video, | |
inputs=[video_input, demo_choice, global_speaker, num_avg_frames, apply_preprocess], | |
outputs=[output_video, result_text] | |
) | |
clear_button.click( | |
fn=clear_inputs, | |
inputs=[], | |
outputs=[demo_choice, video_input, global_speaker, num_avg_frames, apply_preprocess, result_text, output_video] | |
) | |
# Launch the interface | |
demo.launch(allowed_paths=["."], share=True) |