gestsync / app.py
sindhuhegde's picture
Update app
4b11292
raw
history blame
No virus
49.1 kB
import gradio as gr
import os
import torch
from shutil import rmtree
from torch import nn
from torch.nn import functional as F
import numpy as np
import subprocess
import cv2
import pickle
import librosa
from decord import VideoReader
from decord import cpu, gpu
from utils.audio_utils import *
from utils.inference_utils import *
from sync_models.gestsync_models import *
from tqdm import tqdm
from glob import glob
from scipy.io.wavfile import write
import mediapipe as mp
from protobuf_to_dict import protobuf_to_dict
import warnings
import spaces
mp_holistic = mp.solutions.holistic
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# Initialize global variables
CHECKPOINT_PATH = "model_rgb.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_cuda = torch.cuda.is_available()
batch_size = 12
fps = 25
n_negative_samples = 100
# Initialize the mediapipe holistic keypoint detection model
holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5)
@spaces.GPU(duration=300)
def preprocess_video(path, result_folder, apply_preprocess, padding=20):
'''
This function preprocesses the input video to extract the audio and crop the frames using YOLO model
Args:
- path (string) : Path of the input video file
- result_folder (string) : Path of the folder to save the extracted audio and cropped video
- padding (int) : Padding to add to the bounding box
Returns:
- wav_file (string) : Path of the extracted audio file
- fps (int) : FPS of the input video
- video_output (string) : Path of the cropped video file
- msg (string) : Message to be returned
'''
# Load all video frames
try:
vr = VideoReader(path, ctx=cpu(0))
fps = vr.get_avg_fps()
frame_count = len(vr)
except:
msg = "Oops! Could not load the video. Please check the input video and try again."
return None, None, None, msg
if frame_count < 25:
msg = "Not enough frames to process! Please give a longer video as input"
return None, None, None, msg
# Extract the audio from the input video file using ffmpeg
wav_file = os.path.join(result_folder, "audio.wav")
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -async 1 -ac 1 -vn \
-acodec pcm_s16le -ar 16000 %s -y' % (path, wav_file), shell=True)
if status != 0:
msg = "Oops! Could not load the audio file. Please check the input video and try again."
return None, None, None, msg
print("Extracted the audio from the video")
if apply_preprocess=="True":
all_frames = []
for k in range(len(vr)):
all_frames.append(vr[k].asnumpy())
all_frames = np.asarray(all_frames)
print("Extracted the frames for pre-processing")
# Load YOLOv9 model (pre-trained on COCO dataset)
yolo_model = YOLO("yolov9s.pt")
print("Loaded the YOLO model")
person_videos = {}
person_tracks = {}
print("Processing the frames...")
for frame_idx in tqdm(range(frame_count)):
frame = all_frames[frame_idx]
# Perform person detection
results = yolo_model(frame, verbose=False)
detections = results[0].boxes
for i, det in enumerate(detections):
x1, y1, x2, y2 = det.xyxy[0]
cls = det.cls[0]
if int(cls) == 0: # Class 0 is 'person' in COCO dataset
x1 = max(0, int(x1) - padding)
y1 = max(0, int(y1) - padding)
x2 = min(frame.shape[1], int(x2) + padding)
y2 = min(frame.shape[0], int(y2) + padding)
if i not in person_videos:
person_videos[i] = []
person_tracks[i] = []
person_videos[i].append(frame)
person_tracks[i].append([x1,y1,x2,y2])
num_persons = 0
for i in person_videos.keys():
if len(person_videos[i]) >= frame_count//2:
num_persons+=1
if num_persons==0:
msg = "No person detected in the video! Please give a video with one person as input"
return None, None, None, msg
if num_persons>1:
msg = "More than one person detected in the video! Please give a video with only one person as input"
return None, None, None, msg
# For the person detected, crop the frame based on the bounding box
if len(person_videos[0]) > frame_count-10:
crop_filename = os.path.join(result_folder, "preprocessed_video.avi")
fourcc = cv2.VideoWriter_fourcc(*'DIVX')
# Get bounding box coordinates based on person_tracks[i]
max_x1 = min([track[0] for track in person_tracks[0]])
max_y1 = min([track[1] for track in person_tracks[0]])
max_x2 = max([track[2] for track in person_tracks[0]])
max_y2 = max([track[3] for track in person_tracks[0]])
max_width = max_x2 - max_x1
max_height = max_y2 - max_y1
out = cv2.VideoWriter(crop_filename, fourcc, fps, (max_width, max_height))
for frame in person_videos[0]:
crop = frame[max_y1:max_y2, max_x1:max_x2]
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
out.write(crop)
out.release()
no_sound_video = crop_filename.split('.')[0] + '_nosound.mp4'
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (crop_filename, no_sound_video), shell=True)
if status != 0:
msg = "Oops! Could not preprocess the video. Please check the input video and try again."
return None, None, None, msg
video_output = crop_filename.split('.')[0] + '.mp4'
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 %s' %
(wav_file , no_sound_video, video_output), shell=True)
if status != 0:
msg = "Oops! Could not preprocess the video. Please check the input video and try again."
return None, None, None, msg
os.remove(crop_filename)
os.remove(no_sound_video)
print("Successfully saved the pre-processed video: ", video_output)
else:
msg = "Could not track the person in the full video! Please give a single-speaker video as input"
return None, None, None, msg
else:
video_output = path
return wav_file, fps, video_output, "success"
def resample_video(video_file, video_fname, result_folder):
'''
This function resamples the video to 25 fps
Args:
- video_file (string) : Path of the input video file
- video_fname (string) : Name of the input video file
- result_folder (string) : Path of the folder to save the resampled video
Returns:
- video_file_25fps (string) : Path of the resampled video file
'''
video_file_25fps = os.path.join(result_folder, '{}.mp4'.format(video_fname))
# Resample the video to 25 fps
status = subprocess.call("ffmpeg -hide_banner -loglevel panic -y -i {} -c:v libx264 -preset veryslow -crf 0 -filter:v fps=25 -pix_fmt yuv420p {}".format(video_file, video_file_25fps), shell=True)
if status != 0:
msg = "Oops! Could not resample the video to 25 FPS. Please check the input video and try again."
return None, msg
print('Resampled the video to 25 fps: {}'.format(video_file_25fps))
return video_file_25fps, "success"
def load_checkpoint(path, model):
'''
This function loads the trained model from the checkpoint
Args:
- path (string) : Path of the checkpoint file
- model (object) : Model object
Returns:
- model (object) : Model object with the weights loaded from the checkpoint
'''
# Load the checkpoint
if use_cuda:
checkpoint = torch.load(path)
else:
checkpoint = torch.load(path, map_location="cpu")
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
if use_cuda:
model.cuda()
print("Loaded checkpoint from: {}".format(path))
return model.eval()
def load_video_frames(video_file):
'''
This function extracts the frames from the video
Args:
- video_file (string) : Path of the video file
Returns:
- frames (list) : List of frames extracted from the video
- msg (string) : Message to be returned
'''
# Read the video
try:
vr = VideoReader(video_file, ctx=cpu(0))
except:
msg = "Oops! Could not load the input video file"
return None, msg
# Extract the frames
frames = []
for k in range(len(vr)):
frames.append(vr[k].asnumpy())
frames = np.asarray(frames)
return frames, "success"
def get_keypoints(frames):
'''
This function extracts the keypoints from the frames using MediaPipe Holistic pipeline
Args:
- frames (list) : List of frames extracted from the video
Returns:
- kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
- msg (string) : Message to be returned
'''
try:
holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5)
resolution = frames[0].shape
all_frame_kps = []
for frame in frames:
results = holistic.process(frame)
pose, left_hand, right_hand, face = None, None, None, None
if results.pose_landmarks is not None:
pose = protobuf_to_dict(results.pose_landmarks)['landmark']
if results.left_hand_landmarks is not None:
left_hand = protobuf_to_dict(results.left_hand_landmarks)['landmark']
if results.right_hand_landmarks is not None:
right_hand = protobuf_to_dict(results.right_hand_landmarks)['landmark']
if results.face_landmarks is not None:
face = protobuf_to_dict(results.face_landmarks)['landmark']
frame_dict = {"pose":pose, "left_hand":left_hand, "right_hand":right_hand, "face":face}
all_frame_kps.append(frame_dict)
kp_dict = {"kps":all_frame_kps, "resolution":resolution}
except Exception as e:
print("Error: ", e)
return None, "Error: Could not extract keypoints from the frames"
return kp_dict, "success"
def check_visible_gestures(kp_dict):
'''
This function checks if the gestures in the video are visible
Args:
- kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
Returns:
- msg (string) : Message to be returned
'''
keypoints = kp_dict['kps']
keypoints = np.array(keypoints)
if len(keypoints)<25:
msg = "Not enough keypoints to process! Please give a longer video as input"
return msg
pose_count, hand_count = 0, 0
for frame_kp_dict in keypoints:
pose = frame_kp_dict["pose"]
left_hand = frame_kp_dict["left_hand"]
right_hand = frame_kp_dict["right_hand"]
if pose is None:
pose_count += 1
if left_hand is None and right_hand is None:
hand_count += 1
if hand_count/len(keypoints) > 0.6 or pose_count/len(keypoints) > 0.6:
msg = "The gestures in the input video are not visible! Please give a video with visible gestures as input."
return msg
print("Successfully verified the input video - Gestures are visible!")
return "success"
def load_rgb_masked_frames(input_frames, kp_dict, asd=False, stride=1, window_frames=25, width=480, height=270):
'''
This function masks the faces using the keypoints extracted from the frames
Args:
- input_frames (list) : List of frames extracted from the video
- kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
- stride (int) : Stride to extract the frames
- window_frames (int) : Number of frames in each window that is given as input to the model
- width (int) : Width of the frames
- height (int) : Height of the frames
Returns:
- input_frames (array) : Frame window to be given as input to the model
- num_frames (int) : Number of frames to extract
- orig_masked_frames (array) : Masked frames extracted from the video
- msg (string) : Message to be returned
'''
print("Creating masked input frames...")
input_frames_masked = []
if kp_dict is None:
for img in tqdm(input_frames):
img = cv2.resize(img, (width, height))
masked_img = cv2.rectangle(img, (0,0), (width,110), (0,0,0), -1)
input_frames_masked.append(masked_img)
else:
# Face indices to extract the face-coordinates needed for masking
face_oval_idx = [10, 21, 54, 58, 67, 93, 103, 109, 127, 132, 136, 148, 149, 150, 152, 162, 172,
176, 234, 251, 284, 288, 297, 323, 332, 338, 356, 361, 365, 377, 378, 379, 389, 397, 400, 454]
input_keypoints, resolution = kp_dict['kps'], kp_dict['resolution']
print("Input keypoints: ", len(input_keypoints))
for i, frame_kp_dict in tqdm(enumerate(input_keypoints)):
img = input_frames[i]
face = frame_kp_dict["face"]
if face is None:
img = cv2.resize(img, (width, height))
masked_img = cv2.rectangle(img, (0,0), (width,110), (0,0,0), -1)
else:
face_kps = []
for idx in range(len(face)):
if idx in face_oval_idx:
x, y = int(face[idx]["x"]*resolution[1]), int(face[idx]["y"]*resolution[0])
face_kps.append((x,y))
face_kps = np.array(face_kps)
x1, y1 = min(face_kps[:,0]), min(face_kps[:,1])
x2, y2 = max(face_kps[:,0]), max(face_kps[:,1])
masked_img = cv2.rectangle(img, (0,0), (resolution[1],y2+15), (0,0,0), -1)
if masked_img.shape[0] != width or masked_img.shape[1] != height:
masked_img = cv2.resize(masked_img, (width, height))
input_frames_masked.append(masked_img)
orig_masked_frames = np.array(input_frames_masked)
input_frames = np.array(input_frames_masked) / 255.
if asd:
input_frames = np.pad(input_frames, ((12, 12), (0,0), (0,0), (0,0)), 'edge')
# print("Input images full: ", input_frames.shape) # num_framesx270x480x3
input_frames = np.array([input_frames[i:i+window_frames, :, :] for i in range(0,input_frames.shape[0], stride) if (i+window_frames <= input_frames.shape[0])])
# print("Input images window: ", input_frames.shape) # Tx25x270x480x3
num_frames = input_frames.shape[0]
if num_frames<10:
msg = "Not enough frames to process! Please give a longer video as input."
return None, None, None, msg
return input_frames, num_frames, orig_masked_frames, "success"
def load_spectrograms(wav_file, asd=False, num_frames=None, window_frames=25, stride=4):
'''
This function extracts the spectrogram from the audio file
Args:
- wav_file (string) : Path of the extracted audio file
- num_frames (int) : Number of frames to extract
- window_frames (int) : Number of frames in each window that is given as input to the model
- stride (int) : Stride to extract the audio frames
Returns:
- spec (array) : Spectrogram array window to be used as input to the model
- orig_spec (array) : Spectrogram array extracted from the audio file
- msg (string) : Message to be returned
'''
# Extract the audio from the input video file using ffmpeg
try:
wav = librosa.load(wav_file, sr=16000)[0]
except:
msg = "Oops! Could extract the spectrograms from the audio file. Please check the input and try again."
return None, None, msg
# Convert to tensor
wav = torch.FloatTensor(wav).unsqueeze(0)
mel, _, _, _ = wav2filterbanks(wav.to(device))
spec = mel.squeeze(0).cpu().numpy()
orig_spec = spec
spec = np.array([spec[i:i+(window_frames*stride), :] for i in range(0, spec.shape[0], stride) if (i+(window_frames*stride) <= spec.shape[0])])
if num_frames is not None:
if len(spec) != num_frames:
spec = spec[:num_frames]
frame_diff = np.abs(len(spec) - num_frames)
if frame_diff > 60:
print("The input video and audio length do not match - The results can be unreliable! Please check the input video.")
if asd:
pad_frames = (window_frames//2)
spec = np.pad(spec, ((pad_frames, pad_frames), (0,0), (0,0)), 'edge')
return spec, orig_spec, "success"
def calc_optimal_av_offset(vid_emb, aud_emb, num_avg_frames, model):
'''
This function calculates the audio-visual offset between the video and audio
Args:
- vid_emb (array) : Video embedding array
- aud_emb (array) : Audio embedding array
- num_avg_frames (int) : Number of frames to average the scores
- model (object) : Model object
Returns:
- offset (int) : Optimal audio-visual offset
- msg (string) : Message to be returned
'''
pos_vid_emb, all_aud_emb, pos_idx, stride, status = create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames)
if status != "success":
return None, status
scores, _ = calc_av_scores(pos_vid_emb, all_aud_emb, model)
offset = scores.argmax()*stride - pos_idx
return offset.item(), "success"
def create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames, stride=5):
'''
This function creates all possible positive and negative audio embeddings to compare and obtain the sync offset
Args:
- vid_emb (array) : Video embedding array
- aud_emb (array) : Audio embedding array
- num_avg_frames (int) : Number of frames to average the scores
- stride (int) : Stride to extract the negative windows
Returns:
- vid_emb_pos (array) : Positive video embedding array
- aud_emb_posneg (array) : All possible combinations of audio embedding array
- pos_idx_frame (int) : Positive video embedding array frame
- stride (int) : Stride used to extract the negative windows
- msg (string) : Message to be returned
'''
slice_size = num_avg_frames
aud_emb_posneg = aud_emb.squeeze(1).unfold(-1, slice_size, stride)
aud_emb_posneg = aud_emb_posneg.permute([0, 2, 1, 3])
aud_emb_posneg = aud_emb_posneg[:, :int(n_negative_samples/stride)+1]
pos_idx = (aud_emb_posneg.shape[1]//2)
pos_idx_frame = pos_idx*stride
min_offset_frames = -(pos_idx)*stride
max_offset_frames = (aud_emb_posneg.shape[1] - pos_idx - 1)*stride
print("With the current video length and the number of average frames, the model can predict the offsets in the range: [{}, {}]".format(min_offset_frames, max_offset_frames))
vid_emb_pos = vid_emb[:, :, pos_idx_frame:pos_idx_frame+slice_size]
if vid_emb_pos.shape[2] != slice_size:
msg = "Video is too short to use {} frames to average the scores. Please use a longer input video or reduce the number of average frames".format(slice_size)
return None, None, None, None, msg
return vid_emb_pos, aud_emb_posneg, pos_idx_frame, stride, "success"
def calc_av_scores(vid_emb, aud_emb, model):
'''
This function calls functions to calculate the audio-visual similarity and attention map between the video and audio embeddings
Args:
- vid_emb (array) : Video embedding array
- aud_emb (array) : Audio embedding array
- model (object) : Model object
Returns:
- scores (array) : Audio-visual similarity scores
- att_map (array) : Attention map
'''
scores = calc_att_map(vid_emb, aud_emb, model)
att_map = logsoftmax_2d(scores)
scores = scores.mean(-1)
return scores, att_map
def calc_att_map(vid_emb, aud_emb, model):
'''
This function calculates the similarity between the video and audio embeddings
Args:
- vid_emb (array) : Video embedding array
- aud_emb (array) : Audio embedding array
- model (object) : Model object
Returns:
- scores (array) : Audio-visual similarity scores
'''
vid_emb = vid_emb[:, :, None]
aud_emb = aud_emb.transpose(1, 2)
scores = run_func_in_parts(lambda x, y: (x * y).sum(1),
vid_emb,
aud_emb,
part_len=10,
dim=3,
device=device)
scores = model.logits_scale(scores[..., None]).squeeze(-1)
return scores
def generate_video(frames, audio_file, video_fname):
'''
This function generates the video from the frames and audio file
Args:
- frames (array) : Frames to be used to generate the video
- audio_file (string) : Path of the audio file
- video_fname (string) : Path of the video file
Returns:
- video_output (string) : Path of the video file
'''
fname = 'inference.avi'
video = cv2.VideoWriter(fname, cv2.VideoWriter_fourcc(*'DIVX'), 25, (frames[0].shape[1], frames[0].shape[0]))
for i in range(len(frames)):
video.write(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB))
video.release()
no_sound_video = video_fname + '_nosound.mp4'
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (fname, no_sound_video), shell=True)
if status != 0:
msg = "Oops! Could not generate the video. Please check the input video and try again."
return None, msg
video_output = video_fname + '.mp4'
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -c:v libx264 -preset veryslow -crf 18 -pix_fmt yuv420p -strict -2 -q:v 1 -shortest %s' %
(audio_file, no_sound_video, video_output), shell=True)
if status != 0:
msg = "Oops! Could not generate the video. Please check the input video and try again."
return None, msg
os.remove(fname)
os.remove(no_sound_video)
return video_output, "success"
def sync_correct_video(video_path, frames, wav_file, offset, result_folder, sample_rate=16000, fps=25):
'''
This function corrects the video and audio to sync with each other
Args:
- video_path (string) : Path of the video file
- frames (array) : Frames to be used to generate the video
- wav_file (string) : Path of the audio file
- offset (int) : Predicted sync-offset to be used to correct the video
- result_folder (string) : Path of the result folder to save the output sync-corrected video
- sample_rate (int) : Sample rate of the audio
- fps (int) : Frames per second of the video
Returns:
- video_output (string) : Path of the video file
'''
if offset == 0:
print("The input audio and video are in-sync! No need to perform sync correction.")
return video_path, "success"
print("Performing Sync Correction...")
corrected_frames = np.zeros_like(frames)
if offset > 0:
audio_offset = int(offset*(sample_rate/fps))
wav = librosa.core.load(wav_file, sr=sample_rate)[0]
corrected_wav = wav[audio_offset:]
corrected_wav_file = os.path.join(result_folder, "audio_sync_corrected.wav")
write(corrected_wav_file, sample_rate, corrected_wav)
wav_file = corrected_wav_file
corrected_frames = frames
elif offset < 0:
corrected_frames[0:len(frames)+offset] = frames[np.abs(offset):]
corrected_frames = corrected_frames[:len(frames)-np.abs(offset)]
corrected_video_path = os.path.join(result_folder, "result_sync_corrected")
video_output, status = generate_video(corrected_frames, wav_file, corrected_video_path)
if status != "success":
return None, status
return video_output, "success"
def load_masked_input_frames(test_videos, spec, wav_file, scene_num, result_folder):
'''
This function loads the masked input frames from the video
Args:
- test_videos (list) : List of videos to be processed (speaker-specific tracks)
- spec (array) : Spectrogram of the audio
- wav_file (string) : Path of the audio file
- scene_num (int) : Scene number to be used to save the input masked video
- result_folder (string) : Path of the folder to save the input masked video
Returns:
- all_frames (list) : List of masked input frames window to be used as input to the model
- all_orig_frames (list) : List of original masked input frames
'''
all_frames, all_orig_frames = [], []
for video_num, video in enumerate(test_videos):
print("Processing video: ", video)
# Load the video frames
frames, status = load_video_frames(video)
if status != "success":
return None, None, status
print("Successfully loaded the video frames")
# Extract the keypoints from the frames
kp_dict, status = get_keypoints(frames)
if status != "success":
return None, None, status
print("Successfully extracted the keypoints")
# Mask the frames using the keypoints extracted from the frames and prepare the input to the model
masked_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict, asd=True)
if status != "success":
return None, None, status
print("Successfully loaded the masked frames")
# Check if the length of the input frames is equal to the length of the spectrogram
if spec.shape[2]!=masked_frames.shape[0]:
num_frames = spec.shape[2]
masked_frames = masked_frames[:num_frames]
orig_masked_frames = orig_masked_frames[:num_frames]
frame_diff = np.abs(spec.shape[2] - num_frames)
if frame_diff > 60:
print("The input video and audio length do not match - The results can be unreliable! Please check the input video.")
# Transpose the frames to the correct format
frames = np.transpose(masked_frames, (4, 0, 1, 2, 3))
frames = torch.FloatTensor(np.array(frames)).unsqueeze(0)
print("Successfully converted the frames to tensor")
all_frames.append(frames)
all_orig_frames.append(orig_masked_frames)
return all_frames, all_orig_frames, "success"
def extract_audio(video, result_folder):
'''
This function extracts the audio from the video file
Args:
- video (string) : Path of the video file
- result_folder (string) : Path of the folder to save the extracted audio file
Returns:
- wav_file (string) : Path of the extracted audio file
'''
wav_file = os.path.join(result_folder, "audio.wav")
status = subprocess.call('ffmpeg -hide_banner -loglevel panic -threads 1 -y -i %s -async 1 -ac 1 -vn \
-acodec pcm_s16le -ar 16000 %s' % (video, wav_file), shell=True)
if status != 0:
msg = "Oops! Could not load the audio file in the given input video. Please check the input and try again"
return None, msg
return wav_file, "success"
@spaces.GPU(duration=200)
def get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=True):
'''
This function extracts the video and audio embeddings from the input frames and audio sequences
Args:
- video_sequences (array) : Array of video frames to be used as input to the model
- audio_sequences (array) : Array of audio frames to be used as input to the model
- model (object) : Model object
- calc_aud_emb (bool) : Flag to calculate the audio embedding
Returns:
- video_emb (array) : Video embedding
- audio_emb (array) : Audio embedding
'''
batch_size = 12
video_emb = []
audio_emb = []
for i in range(0, len(video_sequences), batch_size):
video_inp = video_sequences[i:i+batch_size, ]
vid_emb = model.forward_vid(video_inp.to(device), return_feats=False)
vid_emb = torch.mean(vid_emb, axis=-1)
video_emb.append(vid_emb.detach())
if calc_aud_emb:
audio_inp = audio_sequences[i:i+batch_size, ]
aud_emb = model.forward_aud(audio_inp.to(device))
audio_emb.append(aud_emb.detach())
torch.cuda.empty_cache()
video_emb = torch.cat(video_emb, dim=0)
if calc_aud_emb:
audio_emb = torch.cat(audio_emb, dim=0)
return video_emb, audio_emb
return video_emb
def predict_active_speaker(all_video_embeddings, audio_embedding, global_score, num_avg_frames, model):
'''
This function predicts the active speaker in each frame
Args:
- all_video_embeddings (array) : Array of video embeddings of all speakers
- audio_embedding (array) : Audio embedding
- global_score (bool) : Flag to calculate the global score
Returns:
- pred_speaker (list) : List of active speakers in each frame
'''
cos = nn.CosineSimilarity(dim=1)
audio_embedding = audio_embedding.squeeze(2)
scores = []
for i in range(len(all_video_embeddings)):
video_embedding = all_video_embeddings[i]
# Compute the similarity of each speaker's video embeddings with the audio embedding
sim = cos(video_embedding, audio_embedding)
# Apply the logits scale to the similarity scores (scaling the scores)
output = model.logits_scale(sim.unsqueeze(-1)).squeeze(-1)
if global_score=="True":
score = output.mean(0)
else:
if output.shape[0]<num_avg_frames:
num_avg_frames = output.shape[0]
output_batch = output.unfold(0, num_avg_frames, 1)
score = torch.mean(output_batch, axis=-1)
scores.append(score.detach().cpu().numpy())
if global_score=="True":
print("Using global predictions")
pred_speaker = np.argmax(scores)
else:
print("Using per-frame predictions")
pred_speaker = []
num_negs = list(range(0, len(all_video_embeddings)))
for frame_idx in range(len(scores[0])):
score = [scores[i][frame_idx] for i in num_negs]
pred_idx = np.argmax(score)
pred_speaker.append(pred_idx)
return pred_speaker, num_avg_frames
def save_video(output_tracks, input_frames, wav_file, result_folder):
'''
This function saves the output video with the active speaker detections
Args:
- output_tracks (list) : List of active speakers in each frame
- input_frames (array) : Frames to be used to generate the video
- wav_file (string) : Path of the audio file
- result_folder (string) : Path of the result folder to save the output video
Returns:
- video_output (string) : Path of the output video
'''
try:
output_frames = []
for i in range(len(input_frames)):
# If the active speaker is found, draw a bounding box around the active speaker
if i in output_tracks:
bbox = output_tracks[i]
x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
out = cv2.rectangle(input_frames[i].copy(), (x1, y1), (x2, y2), color=[0, 255, 0], thickness=3)
else:
out = input_frames[i]
output_frames.append(out)
# Generate the output video
output_video_fname = os.path.join(result_folder, "result_active_speaker_det")
video_output, status = generate_video(output_frames, wav_file, output_video_fname)
if status != "success":
return None, status
except Exception as e:
return None, f"Error: {str(e)}"
return video_output, "success"
@spaces.GPU(duration=200)
def process_video_syncoffset(video_path, num_avg_frames, apply_preprocess):
try:
# Extract the video filename
video_fname = os.path.basename(video_path.split(".")[0])
# Create folders to save the inputs and results
result_folder = os.path.join("results", video_fname)
result_folder_input = os.path.join(result_folder, "input")
result_folder_output = os.path.join(result_folder, "output")
if os.path.exists(result_folder):
rmtree(result_folder)
os.makedirs(result_folder)
os.makedirs(result_folder_input)
os.makedirs(result_folder_output)
# Preprocess the video
print("Applying preprocessing: ", apply_preprocess)
wav_file, fps, vid_path_processed, status = preprocess_video(video_path, result_folder_input, apply_preprocess)
if status != "success":
return None, status
print("Successfully preprocessed the video")
# Resample the video to 25 fps if it is not already 25 fps
print("FPS of video: ", fps)
if fps!=25:
vid_path, status = resample_video(vid_path_processed, "preprocessed_video_25fps", result_folder_input)
if status != "success":
return None, status
orig_vid_path_25fps, status = resample_video(video_path, "input_video_25fps", result_folder_input)
if status != "success":
return None, status
else:
vid_path = vid_path_processed
orig_vid_path_25fps = video_path
# Load the original video frames (before pre-processing) - Needed for the final sync-correction
orig_frames, status = load_video_frames(orig_vid_path_25fps)
if status != "success":
return None, status
# Load the pre-processed video frames
frames, status = load_video_frames(vid_path)
if status != "success":
return None, status
print("Successfully extracted the video frames")
if len(frames) < num_avg_frames:
msg = "Error: The input video is too short. Please use a longer input video."
return None, msg
# Load keypoints and check if gestures are visible
kp_dict, status = get_keypoints(frames)
if status != "success":
return None, status
print("Successfully extracted the keypoints: ", len(kp_dict), len(kp_dict["kps"]))
status = check_visible_gestures(kp_dict)
if status != "success":
return None, status
# Load RGB frames
rgb_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict, asd=False, window_frames=25, width=480, height=270)
if status != "success":
return None, status
print("Successfully loaded the RGB frames")
# Convert frames to tensor
rgb_frames = np.transpose(rgb_frames, (4, 0, 1, 2, 3))
rgb_frames = torch.FloatTensor(rgb_frames).unsqueeze(0)
B = rgb_frames.size(0)
print("Successfully converted the frames to tensor")
# Load spectrograms
spec, orig_spec, status = load_spectrograms(wav_file, asd=False, num_frames=num_frames)
if status != "success":
return None, status
spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0, 1, 2, 4, 3)
print("Successfully loaded the spectrograms")
# Create input windows
video_sequences = torch.cat([rgb_frames[:, :, i] for i in range(rgb_frames.size(2))], dim=0)
audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0)
# Load the trained model
model = Transformer_RGB()
model = load_checkpoint(CHECKPOINT_PATH, model)
print("Successfully loaded the model")
# Process in batches
batch_size = 12
video_emb = []
audio_emb = []
for i in tqdm(range(0, len(video_sequences), batch_size)):
video_inp = video_sequences[i:i+batch_size, ]
audio_inp = audio_sequences[i:i+batch_size, ]
vid_emb = model.forward_vid(video_inp.to(device))
vid_emb = torch.mean(vid_emb, axis=-1).unsqueeze(-1)
aud_emb = model.forward_aud(audio_inp.to(device))
video_emb.append(vid_emb.detach())
audio_emb.append(aud_emb.detach())
torch.cuda.empty_cache()
audio_emb = torch.cat(audio_emb, dim=0)
video_emb = torch.cat(video_emb, dim=0)
# L2 normalize embeddings
video_emb = torch.nn.functional.normalize(video_emb, p=2, dim=1)
audio_emb = torch.nn.functional.normalize(audio_emb, p=2, dim=1)
audio_emb = torch.split(audio_emb, B, dim=0)
audio_emb = torch.stack(audio_emb, dim=2)
audio_emb = audio_emb.squeeze(3)
audio_emb = audio_emb[:, None]
video_emb = torch.split(video_emb, B, dim=0)
video_emb = torch.stack(video_emb, dim=2)
video_emb = video_emb.squeeze(3)
print("Successfully extracted GestSync embeddings")
# Calculate sync offset
pred_offset, status = calc_optimal_av_offset(video_emb, audio_emb, num_avg_frames, model)
if status != "success":
return None, status
print("Predicted offset: ", pred_offset)
# Generate sync-corrected video
video_output, status = sync_correct_video(video_path, orig_frames, wav_file, pred_offset, result_folder_output, sample_rate=16000, fps=fps)
if status != "success":
return None, status
print("Successfully generated the video:", video_output)
return video_output, f"Predicted offset: {pred_offset}"
except Exception as e:
return None, f"Error: {str(e)}"
def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
try:
# Extract the video filename
video_fname = os.path.basename(video_path.split(".")[0])
# Create folders to save the inputs and results
result_folder = os.path.join("results", video_fname)
result_folder_input = os.path.join(result_folder, "input")
result_folder_output = os.path.join(result_folder, "output")
if os.path.exists(result_folder):
rmtree(result_folder)
os.makedirs(result_folder)
os.makedirs(result_folder_input)
os.makedirs(result_folder_output)
if global_speaker=="per-frame-prediction" and num_avg_frames<25:
msg = "Number of frames to average need to be set to a minimum of 25 frames. Atleast 1-second context is needed for the model. Please change the num_avg_frames and try again..."
return None, msg
# Read the video
try:
vr = VideoReader(video_path, ctx=cpu(0))
except:
msg = "Oops! Could not load the input video file"
return None, msg
# Get the FPS of the video
fps = vr.get_avg_fps()
print("FPS of video: ", fps)
# Resample the video to 25 FPS if the original video is of a different frame-rate
if fps!=25:
test_video_25fps, status = resample_video(video_path, video_fname, result_folder_input)
if status != "success":
return None, status
else:
test_video_25fps = video_path
# Load the video frames
orig_frames, status = load_video_frames(test_video_25fps)
if status != "success":
return None, status
# Extract and save the audio file
orig_wav_file, status = extract_audio(video_path, result_folder)
if status != "success":
return None, status
# Pre-process and extract per-speaker tracks in each scene
print("Pre-processing the input video...")
status = subprocess.call("python preprocess/inference_preprocess.py --data_dir={}/temp --sd_root={}/crops --work_root={}/metadata --data_root={}".format(result_folder_input, result_folder_input, result_folder_input, video_path), shell=True)
if status != 0:
msg = "Error in pre-processing the input video, please check the input video and try again..."
return None, msg
# Load the tracks file saved during pre-processing
with open('{}/metadata/tracks.pckl'.format(result_folder_input), 'rb') as file:
tracks = pickle.load(file)
# Create a dictionary of all tracks found along with the bounding-boxes
track_dict = {}
for scene_num in range(len(tracks)):
track_dict[scene_num] = {}
for i in range(len(tracks[scene_num])):
track_dict[scene_num][i] = {}
for frame_num, bbox in zip(tracks[scene_num][i]['track']['frame'], tracks[scene_num][i]['track']['bbox']):
track_dict[scene_num][i][frame_num] = bbox
# Get the total number of scenes
test_scenes = os.listdir("{}/crops".format(result_folder_input))
print("Total scenes found in the input video = ", len(test_scenes))
# Load the trained model
model = Transformer_RGB()
model = load_checkpoint(CHECKPOINT_PATH, model)
# Compute the active speaker in each scene
output_tracks = {}
for scene_num in tqdm(range(len(test_scenes))):
test_videos = glob(os.path.join("{}/crops".format(result_folder_input), "scene_{}".format(str(scene_num)), "*.avi"))
test_videos.sort(key=lambda x: int(os.path.basename(x).split('.')[0]))
print("Scene {} -> Total video files found (speaker-specific tracks) = {}".format(scene_num, len(test_videos)))
if len(test_videos)<=1:
msg = "To detect the active speaker, at least 2 visible speakers are required for each scene! Please check the input video and try again..."
return None, msg
# Load the audio file
audio_file = glob(os.path.join("{}/crops".format(result_folder_input), "scene_{}".format(str(scene_num)), "*.wav"))[0]
spec, _, status = load_spectrograms(audio_file, asd=True)
if status != "success":
return None, status
spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0,1,2,4,3)
print("Successfully loaded the spectrograms")
# Load the masked input frames
all_masked_frames, all_orig_masked_frames, status = load_masked_input_frames(test_videos, spec, audio_file, scene_num, result_folder_input)
if status != "success":
return None, status
print("Successfully loaded the masked input frames")
# Prepare the audio and video sequences for the model
audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0)
print("Obtaining audio and video embeddings...")
all_video_embs = []
for idx in tqdm(range(len(all_masked_frames))):
with torch.no_grad():
video_sequences = torch.cat([all_masked_frames[idx][:, :, i] for i in range(all_masked_frames[idx].size(2))], dim=0)
if idx==0:
video_emb, audio_emb = get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=True)
else:
video_emb = get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=False)
all_video_embs.append(video_emb)
print("Successfully extracted GestSync embeddings")
# Predict the active speaker in each scene
if global_speaker=="per-frame-prediction":
predictions, num_avg_frames = predict_active_speaker(all_video_embs, audio_emb, "False", num_avg_frames, model)
else:
predictions, _ = predict_active_speaker(all_video_embs, audio_emb, "True", num_avg_frames, model)
# Get the frames present in the scene
frames_scene = tracks[scene_num][0]['track']['frame']
# Prepare the active speakers list to draw the bounding boxes
if global_speaker=="global-prediction":
print("Aggregating scores using global predictoins")
active_speakers = [predictions]*len(frames_scene)
start, end = 0, len(frames_scene)
else:
print("Aggregating scores using per-frame predictions")
active_speakers = [0]*len(frames_scene)
mid = num_avg_frames//2
if num_avg_frames%2==0:
frame_pred = len(frames_scene)-(mid*2)+1
start, end = mid, len(frames_scene)-mid+1
else:
frame_pred = len(frames_scene)-(mid*2)
start, end = mid, len(frames_scene)-mid
print("Frame scene: {} | Avg frames: {} | Frame predictions: {}".format(len(frames_scene), num_avg_frames, frame_pred))
if len(predictions) != frame_pred:
msg = "Predicted frames {} and input video frames {} do not match!!".format(len(predictions), frame_pred)
return None, msg
active_speakers[start:end] = predictions[0:]
# Depending on the num_avg_frames, interpolate the intial and final frame predictions to get a full video output
initial_preds = max(set(predictions[:num_avg_frames]), key=predictions[:num_avg_frames].count)
active_speakers[0:start] = [initial_preds] * start
final_preds = max(set(predictions[-num_avg_frames:]), key=predictions[-num_avg_frames:].count)
active_speakers[end:] = [final_preds] * (len(frames_scene) - end)
start, end = 0, len(active_speakers)
# Get the output tracks for each frame
pred_idx = 0
for frame in frames_scene[start:end]:
label = active_speakers[pred_idx]
pred_idx += 1
output_tracks[frame] = track_dict[scene_num][label][frame]
# Save the output video
video_output, status = save_video(output_tracks, orig_frames.copy(), orig_wav_file, result_folder_output)
if status != "success":
return None, status
print("Successfully saved the output video: ", video_output)
return video_output, "success"
except Exception as e:
return None, f"Error: {str(e)}"
if __name__ == "__main__":
# Custom CSS and HTML
custom_css = """
<style>
body {
background-color: #ffffff;
color: #333333; /* Default text color */
}
.container {
max-width: 100% !important;
padding-left: 0 !important;
padding-right: 0 !important;
}
.header {
background-color: #f0f0f0;
color: #333333;
padding: 30px;
margin-bottom: 30px;
text-align: center;
font-family: 'Helvetica Neue', Arial, sans-serif;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.header h1 {
font-size: 36px;
margin-bottom: 15px;
font-weight: bold;
color: #333333; /* Explicitly set heading color */
}
.header h2 {
font-size: 24px;
margin-bottom: 10px;
color: #333333; /* Explicitly set subheading color */
}
.header p {
font-size: 18px;
margin: 5px 0;
color: #666666;
}
.blue-text {
color: #4a90e2;
}
/* Custom styles for slider container */
.slider-container {
background-color: white !important;
padding-top: 0.9em;
padding-bottom: 0.9em;
}
/* Add gap before examples */
.examples-holder {
margin-top: 2em;
}
/* Set fixed size for example videos */
.gradio-container .gradio-examples .gr-sample {
width: 240px !important;
height: 135px !important;
object-fit: cover;
display: inline-block;
margin-right: 10px;
}
.gradio-container .gradio-examples {
display: flex;
flex-wrap: wrap;
gap: 10px;
}
/* Ensure the parent container does not stretch */
.gradio-container .gradio-examples {
max-width: 100%;
overflow: hidden;
}
/* Additional styles to ensure proper sizing in Safari */
.gradio-container .gradio-examples .gr-sample img {
width: 240px !important;
height: 135px !important;
object-fit: cover;
}
</style>
"""
custom_html = custom_css + """
<div class="header">
<h1><span class="blue-text">GestSync:</span> Determining who is speaking without a talking head</h1>
<h2>Synchronization and Active Speaker Detection Demo</h2>
<p><a href='https://www.robots.ox.ac.uk/~vgg/research/gestsync/'>Project Page</a> | <a href='https://github.com/Sindhu-Hegde/gestsync'>Github</a> | <a href='https://arxiv.org/abs/2310.05304'>Paper</a></p>
</div>
"""
tips = """
<div>
<br><br>
Please give us a 🌟 on <a href='https://github.com/Sindhu-Hegde/gestsync'>Github</a> if you like our work!
Tips to get better results:
<ul>
<li>Number of Average Frames: Higher the number, better the results.</li>
<li>Clicking on "apply pre-processing" will give better results for synchornization, but this is an expensive operation and might take a while.</li>
<li>Input videos with clearly visible gestures work better.</li>
</ul>
</div>
"""
# Define functions
def toggle_slider(global_speaker):
if global_speaker == "per-frame-prediction":
return gr.update(visible=True)
else:
return gr.update(visible=False)
def toggle_demo(demo_choice):
if demo_choice == "Synchronization-correction":
return (
gr.update(value=None, visible=True), # video_input
gr.update(value=75, visible=True), # num_avg_frames
gr.update(value=None, visible=True), # apply_preprocess
gr.update(value="global-prediction", visible=False), # global_speaker
gr.update(value=None, visible=True), # output_video
gr.update(value="", visible=True), # result_text
gr.update(visible=True), # submit_button
gr.update(visible=True), # clear_button
gr.update(visible=True), # sync_examples
gr.update(visible=False), # asd_examples
gr.update(visible=True) # tips
)
else:
return (
gr.update(value=None, visible=True), # video_input
gr.update(value=75, visible=True), # num_avg_frames
gr.update(value=None, visible=False), # apply_preprocess
gr.update(value="global-prediction", visible=True), # global_speaker
gr.update(value=None, visible=True), # output_video
gr.update(value="", visible=True), # result_text
gr.update(visible=True), # submit_button
gr.update(visible=True), # clear_button
gr.update(visible=False), # sync_examples
gr.update(visible=True), # asd_examples
gr.update(visible=True) # tips
)
def clear_inputs():
return None, None, "global-prediction", 75, None, "", None
def process_video(video_input, demo_choice, global_speaker, num_avg_frames, apply_preprocess):
if demo_choice == "Synchronization-correction":
return process_video_syncoffset(video_input, num_avg_frames, apply_preprocess)
else:
return process_video_activespeaker(video_input, global_speaker, num_avg_frames)
# Define paths to sample videos
sync_sample_videos = [
["samples/sync_sample_1.mp4"],
["samples/sync_sample_2.mp4"]
]
asd_sample_videos = [
["samples/asd_sample_1.mp4"],
["samples/asd_sample_2.mp4"]
]
# Define Gradio interface
with gr.Blocks(css=custom_css, theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.pink)) as demo:
gr.HTML(custom_html)
demo_choice = gr.Radio(
choices=["Synchronization-correction", "Active-speaker-detection"],
label="Please select the task you want to perform"
)
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Upload Video", height=400, visible=False)
num_avg_frames = gr.Slider(
minimum=50,
maximum=150,
step=5,
value=75,
label="Number of Average Frames",
visible=False
)
apply_preprocess = gr.Checkbox(label="Apply Preprocessing", value=False, visible=False)
global_speaker = gr.Radio(
choices=["global-prediction", "per-frame-prediction"],
value="global-prediction",
label="Global Speaker Prediction",
visible=False
)
global_speaker.change(
fn=toggle_slider,
inputs=global_speaker,
outputs=num_avg_frames
)
with gr.Column():
output_video = gr.Video(label="Output Video", height=400, visible=False)
result_text = gr.Textbox(label="Result", visible=False)
with gr.Row():
submit_button = gr.Button("Submit", variant="primary", visible=False)
clear_button = gr.Button("Clear", visible=False)
# Add a gap before examples
gr.HTML('<div class="examples-holder"></div>')
# Add examples that only populate the video input
sync_examples = gr.Dataset(
samples=sync_sample_videos,
components=[video_input],
type="values",
visible=False
)
asd_examples = gr.Dataset(
samples=asd_sample_videos,
components=[video_input],
type="values",
visible=False
)
tips = gr.Markdown(tips, visible=False)
demo_choice.change(
fn=toggle_demo,
inputs=demo_choice,
outputs=[video_input, num_avg_frames, apply_preprocess, global_speaker, output_video, result_text, submit_button, clear_button, sync_examples, asd_examples, tips]
)
sync_examples.select(
fn=lambda x: gr.update(value=x[0], visible=True),
inputs=sync_examples,
outputs=video_input
)
asd_examples.select(
fn=lambda x: gr.update(value=x[0], visible=True),
inputs=asd_examples,
outputs=video_input
)
submit_button.click(
fn=process_video,
inputs=[video_input, demo_choice, global_speaker, num_avg_frames, apply_preprocess],
outputs=[output_video, result_text]
)
clear_button.click(
fn=clear_inputs,
inputs=[],
outputs=[demo_choice, video_input, global_speaker, num_avg_frames, apply_preprocess, result_text, output_video]
)
# Launch the interface
demo.launch(allowed_paths=["."], share=True)