File size: 30,090 Bytes
3860ffa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
import gradio as gr
import argparse
import os, subprocess
from shutil import rmtree

import numpy as np
import cv2
import librosa
import torch

from utils.audio_utils import *
from utils.inference_utils import *
from sync_models.gestsync_models import *

import sys
if sys.version_info > (3, 0): long, unicode, basestring = int, str, str

from tqdm import tqdm
from scipy.io.wavfile import write
import mediapipe as mp
from protobuf_to_dict import protobuf_to_dict
mp_holistic = mp.solutions.holistic
from ultralytics import YOLO
from decord import VideoReader, cpu

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 
warnings.filterwarnings("ignore", category=UserWarning) 

# Set the path to checkpoint file
CHECKPOINT_PATH = "model_rgb.pth" 

# Initialize global variables
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_cuda = torch.cuda.is_available()
n_negative_samples = 100
print("Using CUDA: ", use_cuda, device)

def preprocess_video(path, result_folder, apply_preprocess, padding=20):

	'''
	This function preprocesses the input video to extract the audio and crop the frames using YOLO model

	Args:
		- path (string) : Path of the input video file
		- result_folder (string) : Path of the folder to save the extracted audio and cropped video
		- padding (int) : Padding to add to the bounding box
	Returns:
		- wav_file (string) : Path of the extracted audio file
		- fps (int) : FPS of the input video
		- video_output (string) : Path of the cropped video file
		- msg (string) : Message to be returned
	'''
	
	# Load all video frames
	try:
		vr = VideoReader(path, ctx=cpu(0))
		fps = vr.get_avg_fps()
		frame_count = len(vr)
	except:
		msg = "Oops! Could not load the video. Please check the input video and try again."
		return None, None, None, msg

	if frame_count < 25:
		msg = "Not enough frames to process! Please give a longer video as input"
		return None, None, None, msg

	# Extract the audio from the input video file using ffmpeg
	wav_file  = os.path.join(result_folder, "audio.wav")

	status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -async 1 -ac 1 -vn \
					-acodec pcm_s16le -ar 16000 %s -y' % (path, wav_file), shell=True)

	if status != 0:
		msg = "Oops! Could not load the audio file. Please check the input video and try again."
		return None, None, None, msg
	print("Extracted the audio from the video")

	if apply_preprocess=="True":
		all_frames = []
		for k in range(len(vr)):
				all_frames.append(vr[k].asnumpy())
		all_frames = np.asarray(all_frames)
		print("Extracted the frames for pre-processing")

		# Load YOLOv9 model (pre-trained on COCO dataset)
		yolo_model = YOLO("yolov9s.pt")
		print("Loaded the YOLO model")

		

		person_videos = {}
		person_tracks = {}

		print("Processing the frames...")
		for frame_idx in tqdm(range(frame_count)):
		
			frame = all_frames[frame_idx]
		
			# Perform person detection
			results = yolo_model(frame, verbose=False)
			detections = results[0].boxes
		
			for i, det in enumerate(detections):
				x1, y1, x2, y2 = det.xyxy[0]
				cls = det.cls[0]
				if int(cls) == 0:  # Class 0 is 'person' in COCO dataset
				
					x1 = max(0, int(x1) - padding)
					y1 = max(0, int(y1) - padding)
					x2 = min(frame.shape[1], int(x2) + padding)
					y2 = min(frame.shape[0], int(y2) + padding)

					if i not in person_videos:
						person_videos[i] = []
						person_tracks[i] = []

					person_videos[i].append(frame)
					person_tracks[i].append([x1,y1,x2,y2])
			
		
		num_persons = 0
		for i in person_videos.keys():
			if len(person_videos[i]) >= frame_count//2:
				num_persons+=1

		if num_persons==0:
			msg = "No person detected in the video! Please give a video with one person as input"
			return None, None, None, msg
		if num_persons>1:
			msg = "More than one person detected in the video! Please give a video with only one person as input"
			return None, None, None, msg

		

		# For the person detected, crop the frame based on the bounding box
		if len(person_videos[0]) > frame_count-10:
			crop_filename = os.path.join(result_folder, "preprocessed_video.avi")
			fourcc = cv2.VideoWriter_fourcc(*'DIVX')

			# Get bounding box coordinates based on person_tracks[i]
			max_x1 = min([track[0] for track in person_tracks[0]])
			max_y1 = min([track[1] for track in person_tracks[0]])
			max_x2 = max([track[2] for track in person_tracks[0]])
			max_y2 = max([track[3] for track in person_tracks[0]])

			max_width = max_x2 - max_x1
			max_height = max_y2 - max_y1

			out = cv2.VideoWriter(crop_filename, fourcc, fps, (max_width, max_height))
			for frame in person_videos[0]:
				crop = frame[max_y1:max_y2, max_x1:max_x2]
				crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
				out.write(crop)
			out.release()

			no_sound_video = crop_filename.split('.')[0] + '_nosound.mp4'
			status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (crop_filename, no_sound_video), shell=True)
			if status != 0:
				msg = "Oops! Could not preprocess the video. Please check the input video and try again."
				return None, None, None, msg
			
			video_output = crop_filename.split('.')[0] + '.mp4'
			status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 %s' % 
							(wav_file , no_sound_video, video_output), shell=True)
			if status != 0:
				msg = "Oops! Could not preprocess the video. Please check the input video and try again."
				return None, None, None, msg
			
			os.remove(crop_filename)
			os.remove(no_sound_video)

			print("Successfully saved the pre-processed video: ", video_output)
		else:
			msg = "Could not track the person in the full video! Please give a single-speaker video as input"
			return None, None, None, msg

	else:
		video_output = path

	return wav_file, fps, video_output, "success"

def resample_video(video_file, video_fname, result_folder):

	'''
	This function resamples the video to 25 fps

	Args:
		- video_file (string) : Path of the input video file
		- video_fname (string) : Name of the input video file
		- result_folder (string) : Path of the folder to save the resampled video
	Returns:
		- video_file_25fps (string) : Path of the resampled video file
	'''
	video_file_25fps = os.path.join(result_folder, '{}.mp4'.format(video_fname))
	
	# Resample the video to 25 fps
	command = ("ffmpeg -hide_banner -loglevel panic -y -i {} -q:v 1 -filter:v fps=25 {}".format(video_file, video_file_25fps))
	from subprocess import call
	cmd = command.split(' ')
	print('Resampled the video to 25 fps: {}'.format(video_file_25fps))
	call(cmd)

	return video_file_25fps

def load_checkpoint(path, model):
	'''
	This function loads the trained model from the checkpoint

	Args:
		- path (string) : Path of the checkpoint file
		- model (object) : Model object
	Returns:
		- model (object) : Model object with the weights loaded from the checkpoint
	'''	

	# Load the checkpoint
	if use_cuda:
		checkpoint = torch.load(path)
	else:
		checkpoint = torch.load(path, map_location="cpu")
	
	s = checkpoint["state_dict"]
	new_s = {}
	
	for k, v in s.items():
		new_s[k.replace('module.', '')] = v
	model.load_state_dict(new_s)

	if use_cuda:
		model.cuda()

	print("Loaded checkpoint from: {}".format(path))

	return model.eval()


def load_video_frames(video_file):
	'''
	This function extracts the frames from the video

	Args:
		- video_file (string) : Path of the video file
	Returns:
		- frames (list) : List of frames extracted from the video
		- msg (string) : Message to be returned
	'''

	# Read the video
	try:
		vr = VideoReader(video_file, ctx=cpu(0))
	except:
		msg = "Oops! Could not load the input video file"
		return None, msg


	# Extract the frames
	frames = []
	for k in range(len(vr)):
		frames.append(vr[k].asnumpy())

	frames = np.asarray(frames)

	return frames, "success"



def get_keypoints(frames):

	'''
	This function extracts the keypoints from the frames using MediaPipe Holistic pipeline

	Args:
		- frames (list) : List of frames extracted from the video
	Returns:
		- kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
		- msg (string) : Message to be returned
	'''

	try:
		holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) 

		resolution = frames[0].shape
		all_frame_kps = []

		for frame in frames:

			results = holistic.process(frame)

			pose, left_hand, right_hand, face = None, None, None, None
			if results.pose_landmarks is not None:
				pose = protobuf_to_dict(results.pose_landmarks)['landmark']
			if results.left_hand_landmarks is not None:
				left_hand = protobuf_to_dict(results.left_hand_landmarks)['landmark']
			if results.right_hand_landmarks is not None:
				right_hand = protobuf_to_dict(results.right_hand_landmarks)['landmark']
			if results.face_landmarks is not None:
				face = protobuf_to_dict(results.face_landmarks)['landmark']

			frame_dict = {"pose":pose, "left_hand":left_hand, "right_hand":right_hand, "face":face}

			all_frame_kps.append(frame_dict)

		kp_dict = {"kps":all_frame_kps, "resolution":resolution}
	except Exception as e:
		print("Error: ", e)
		return None, "Error: Could not extract keypoints from the frames"

	return kp_dict, "success"


def check_visible_gestures(kp_dict):

	'''
	This function checks if the gestures in the video are visible

	Args:
		- kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
	Returns:
		- msg (string) : Message to be returned
	'''

	keypoints = kp_dict['kps']
	keypoints = np.array(keypoints)

	if len(keypoints)<25:
		msg = "Not enough keypoints to process! Please give a longer video as input"
		return msg
	
	pose_count, hand_count = 0, 0
	for frame_kp_dict in keypoints:

		pose = frame_kp_dict["pose"]
		left_hand = frame_kp_dict["left_hand"]
		right_hand = frame_kp_dict["right_hand"]

		if pose is None:
			pose_count += 1
		
		if left_hand is None and right_hand is None:
			hand_count += 1


	if hand_count/len(keypoints) > 0.7 or pose_count/len(keypoints) > 0.7:
		msg = "The gestures in the input video are not visible! Please give a video with visible gestures as input."
		return msg

	print("Successfully verified the input video - Gestures are visible!")

	return "success"

def load_rgb_masked_frames(input_frames, kp_dict, stride=1, window_frames=25, width=480, height=270):

	'''
	This function masks the faces using the keypoints extracted from the frames

	Args:
		- input_frames (list) : List of frames extracted from the video
		- kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
		- stride (int) : Stride to extract the frames
		- window_frames (int) : Number of frames in each window that is given as input to the model
		- width (int) : Width of the frames
		- height (int) : Height of the frames
	Returns:
		- input_frames (array) : Frame window to be given as input to the model
		- num_frames (int) : Number of frames to extract
		- orig_masked_frames (array) : Masked frames extracted from the video
		- msg (string) : Message to be returned
	'''

	# Face indices to extract the face-coordinates needed for masking
	face_oval_idx = [10, 21, 54, 58, 67, 93, 103, 109, 127, 132, 136, 148, 149, 150, 152, 162, 172, 
					176, 234, 251, 284, 288, 297, 323, 332, 338, 356, 361, 365, 377, 378, 379, 389, 397, 400, 454]

	
	input_keypoints, resolution = kp_dict['kps'], kp_dict['resolution']
	print("Input keypoints: ", len(input_keypoints))

	print("Creating masked input frames...")
	input_frames_masked = []
	for i, frame_kp_dict in tqdm(enumerate(input_keypoints)):

		img = input_frames[i]
		face = frame_kp_dict["face"]

		if face is None:
			img = cv2.resize(img, (width, height))
			masked_img = cv2.rectangle(img, (0,0), (width,110), (0,0,0), -1)
		else:
			face_kps = []
			for idx in range(len(face)):
				if idx in face_oval_idx:
					x, y = int(face[idx]["x"]*resolution[1]), int(face[idx]["y"]*resolution[0])
					face_kps.append((x,y))

			face_kps = np.array(face_kps)
			x1, y1 = min(face_kps[:,0]), min(face_kps[:,1])
			x2, y2 = max(face_kps[:,0]), max(face_kps[:,1])
			masked_img = cv2.rectangle(img, (0,0), (resolution[1],y2+15), (0,0,0), -1)

		if masked_img.shape[0] != width or masked_img.shape[1] != height:
			masked_img = cv2.resize(masked_img, (width, height))

		input_frames_masked.append(masked_img)

	orig_masked_frames = np.array(input_frames_masked)
	input_frames = np.array(input_frames_masked) / 255.
	print("Input images full: ", input_frames.shape)      	# num_framesx270x480x3 

	input_frames = np.array([input_frames[i:i+window_frames, :, :] for i in range(0,input_frames.shape[0], stride) if (i+window_frames <= input_frames.shape[0])])
	print("Input images window: ", input_frames.shape)      	# Tx25x270x480x3
	
	num_frames = input_frames.shape[0]

	if num_frames<10:
		msg = "Not enough frames to process! Please give a longer video as input."
		return None, None, None, msg
	
	return input_frames, num_frames, orig_masked_frames, "success"

def load_spectrograms(wav_file, num_frames, window_frames=25, stride=4):

	'''
	This function extracts the spectrogram from the audio file

	Args:
		- wav_file (string) : Path of the extracted audio file
		- num_frames (int) : Number of frames to extract
		- window_frames (int) : Number of frames in each window that is given as input to the model
		- stride (int) : Stride to extract the audio frames
	Returns:
		- spec (array) : Spectrogram array window to be used as input to the model
		- orig_spec (array) : Spectrogram array extracted from the audio file
		- msg (string) : Message to be returned
	'''

	# Extract the audio from the input video file using ffmpeg
	try:
		wav = librosa.load(wav_file, sr=16000)[0]
	except:
		msg = "Oops! Could extract the spectrograms from the audio file. Please check the input and try again."
		return None, None, msg
	
	# Convert to tensor
	wav = torch.FloatTensor(wav).unsqueeze(0)
	mel, _, _, _ = wav2filterbanks(wav.to(device))
	spec = mel.squeeze(0).cpu().numpy()
	orig_spec = spec
	spec = np.array([spec[i:i+(window_frames*stride), :] for i in range(0, spec.shape[0], stride) if (i+(window_frames*stride) <= spec.shape[0])])

	if len(spec) != num_frames:
		spec = spec[:num_frames]
		frame_diff = np.abs(len(spec) - num_frames)
		if frame_diff > 60:
			print("The input video and audio length do not match - The results can be unreliable! Please check the input video.")

	return spec, orig_spec, "success"


def calc_optimal_av_offset(vid_emb, aud_emb, num_avg_frames, model):
	'''
	This function calculates the audio-visual offset between the video and audio

	Args:
		- vid_emb (array) : Video embedding array
		- aud_emb (array) : Audio embedding array
		- num_avg_frames (int) : Number of frames to average the scores
		- model (object) : Model object
	Returns:
		- offset (int) : Optimal audio-visual offset
		- msg (string) : Message to be returned
	'''

	pos_vid_emb, all_aud_emb, pos_idx, stride, status = create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames)
	if status != "success":
		return None, status
	scores, _ = calc_av_scores(pos_vid_emb, all_aud_emb, model)
	offset = scores.argmax()*stride - pos_idx

	return offset.item(), "success"

def create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames, stride=5):

	'''
	This function creates all possible positive and negative audio embeddings to compare and obtain the sync offset

	Args:
		- vid_emb (array) : Video embedding array
		- aud_emb (array) : Audio embedding array
		- num_avg_frames (int) : Number of frames to average the scores
		- stride (int) : Stride to extract the negative windows
	Returns:
		- vid_emb_pos (array) : Positive video embedding array
		- aud_emb_posneg (array) : All possible combinations of audio embedding array 
		- pos_idx_frame (int) : Positive video embedding array frame
		- stride (int) : Stride used to extract the negative windows
		- msg (string) : Message to be returned
	'''

	slice_size = num_avg_frames
	aud_emb_posneg = aud_emb.squeeze(1).unfold(-1, slice_size, stride)
	aud_emb_posneg = aud_emb_posneg.permute([0, 2, 1, 3])
	aud_emb_posneg = aud_emb_posneg[:, :int(n_negative_samples/stride)+1]

	pos_idx = (aud_emb_posneg.shape[1]//2)
	pos_idx_frame = pos_idx*stride

	min_offset_frames = -(pos_idx)*stride
	max_offset_frames = (aud_emb_posneg.shape[1] - pos_idx - 1)*stride
	print("With the current video length and the number of average frames, the model can predict the offsets in the range: [{}, {}]".format(min_offset_frames, max_offset_frames))

	vid_emb_pos = vid_emb[:, :, pos_idx_frame:pos_idx_frame+slice_size]
	if vid_emb_pos.shape[2] != slice_size:
		msg = "Video is too short to use {} frames to average the scores. Please use a longer input video or reduce the number of average frames".format(slice_size)
		return None, None, None, None, msg
	
	return vid_emb_pos, aud_emb_posneg, pos_idx_frame, stride, "success"

def calc_av_scores(vid_emb, aud_emb, model):

	'''
	This function calls functions to calculate the audio-visual similarity and attention map between the video and audio embeddings

	Args:
		- vid_emb (array) : Video embedding array
		- aud_emb (array) : Audio embedding array
		- model (object) : Model object
	Returns:
		- scores (array) : Audio-visual similarity scores
		- att_map (array) : Attention map
	'''

	scores = calc_att_map(vid_emb, aud_emb, model)
	att_map = logsoftmax_2d(scores)
	scores = scores.mean(-1)
	
	return scores, att_map

def calc_att_map(vid_emb, aud_emb, model):

	'''
	This function calculates the similarity between the video and audio embeddings

	Args:
		- vid_emb (array) : Video embedding array
		- aud_emb (array) : Audio embedding array
		- model (object) : Model object
	Returns:
		- scores (array) : Audio-visual similarity scores
	'''

	vid_emb = vid_emb[:, :, None]
	aud_emb = aud_emb.transpose(1, 2)

	scores = run_func_in_parts(lambda x, y: (x * y).sum(1),
							   vid_emb,
							   aud_emb,
							   part_len=10,
							   dim=3,
							   device=device)

	scores = model.logits_scale(scores[..., None]).squeeze(-1)

	return scores

def generate_video(frames, audio_file, video_fname):
	
	'''
	This function generates the video from the frames and audio file

	Args:
		- frames (array) : Frames to be used to generate the video
		- audio_file (string) : Path of the audio file
		- video_fname (string) : Path of the video file
	Returns:
		- video_output (string) : Path of the video file
	'''	

	fname = 'inference.avi'
	video = cv2.VideoWriter(fname, cv2.VideoWriter_fourcc(*'DIVX'), 25, (frames[0].shape[1], frames[0].shape[0]))

	for i in range(len(frames)):
		video.write(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB))
	video.release()
	
	no_sound_video = video_fname + '_nosound.mp4'
	status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (fname, no_sound_video), shell=True)
	if status != 0:
		msg = "Oops! Could not generate the video. Please check the input video and try again."
		return None, msg

	video_output = video_fname + '.mp4'
	status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 -shortest %s' % 
					(audio_file, no_sound_video, video_output), shell=True)
	if status != 0:
		msg = "Oops! Could not generate the video. Please check the input video and try again."
		return None, msg

	os.remove(fname)
	os.remove(no_sound_video)
	
	return video_output

def sync_correct_video(video_path, frames, wav_file, offset, result_folder, sample_rate=16000, fps=25):

	'''
	This function corrects the video and audio to sync with each other

	Args:
		- video_path (string) : Path of the video file
		- frames (array) : Frames to be used to generate the video
		- wav_file (string) : Path of the audio file
		- offset (int) : Predicted sync-offset to be used to correct the video
		- result_folder (string) : Path of the result folder to save the output sync-corrected video
		- sample_rate (int) : Sample rate of the audio
		- fps (int) : Frames per second of the video
	Returns:
		- video_output (string) : Path of the video file
	'''

	if offset == 0:
		print("The input audio and video are in-sync! No need to perform sync correction.")
		return video_path
	
	print("Performing Sync Correction...")
	corrected_frames = np.zeros_like(frames)
	if offset > 0:
		audio_offset = int(offset*(sample_rate/fps))
		wav = librosa.core.load(wav_file, sr=sample_rate)[0]
		corrected_wav = wav[audio_offset:]
		corrected_wav_file = os.path.join(result_folder, "audio_sync_corrected.wav")
		write(corrected_wav_file, sample_rate, corrected_wav)
		wav_file = corrected_wav_file
		corrected_frames = frames
	elif offset < 0:
		corrected_frames[0:len(frames)+offset] = frames[np.abs(offset):]
		corrected_frames = corrected_frames[:len(frames)-np.abs(offset)]

	corrected_video_path = os.path.join(result_folder, "result_sync_corrected")
	video_output = generate_video(corrected_frames, wav_file, corrected_video_path)

	return video_output

class Logger:
	def __init__(self, filename):
		self.terminal = sys.stdout
		self.log = open(filename, "w")

	def write(self, message):
		self.terminal.write(message)
		self.log.write(message)
		
	def flush(self):
		self.terminal.flush()
		self.log.flush()
		
	def isatty(self):
		return False    


def process_video(video_path, num_avg_frames, apply_preprocess):
	try:
		# Extract the video filename
		video_fname = os.path.basename(video_path.split(".")[0])
		
		# Create folders to save the inputs and results
		result_folder = os.path.join("results", video_fname)
		result_folder_input = os.path.join(result_folder, "input")
		result_folder_output = os.path.join(result_folder, "output")

		if os.path.exists(result_folder):
			rmtree(result_folder)

		os.makedirs(result_folder)
		os.makedirs(result_folder_input)
		os.makedirs(result_folder_output)

		
		# Preprocess the video
		print("Applying preprocessing: ", apply_preprocess)
		wav_file, fps, vid_path_processed, status = preprocess_video(video_path, result_folder_input, apply_preprocess)
		if status != "success":
			return status, None
		print("Successfully preprocessed the video")

		# Resample the video to 25 fps if it is not already 25 fps
		print("FPS of video: ", fps)
		if fps!=25:
			vid_path = resample_video(vid_path_processed, "preprocessed_video_25fps", result_folder_input)
			orig_vid_path_25fps = resample_video(video_path, "input_video_25fps", result_folder_input)
		else:
			vid_path = vid_path_processed
			orig_vid_path_25fps = video_path

		# Load the original video frames (before pre-processing) - Needed for the final sync-correction 
		orig_frames, status = load_video_frames(orig_vid_path_25fps)
		if status != "success":
			return status, None
			
		# Load the pre-processed video frames
		frames, status = load_video_frames(vid_path)
		if status != "success":
			return status, None
		print("Successfully extracted the video frames")

		if len(frames) < num_avg_frames:
			return "Error: The input video is too short. Please use a longer input video.", None

		# Load keypoints and check if gestures are visible
		kp_dict, status = get_keypoints(frames)
		if status != "success":
			return status, None
		print("Successfully extracted the keypoints: ", len(kp_dict), len(kp_dict["kps"]))

		status = check_visible_gestures(kp_dict)
		if status != "success":
			return status, None

		# Load RGB frames
		rgb_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict, window_frames=25, width=480, height=270)
		if status != "success":
			return status, None
		print("Successfully loaded the RGB frames")

		# Convert frames to tensor
		rgb_frames = np.transpose(rgb_frames, (4, 0, 1, 2, 3))
		rgb_frames = torch.FloatTensor(rgb_frames).unsqueeze(0)
		B = rgb_frames.size(0)
		print("Successfully converted the frames to tensor")

		# Load spectrograms
		spec, orig_spec, status = load_spectrograms(wav_file, num_frames, window_frames=25)
		if status != "success":
			return status, None
		spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0, 1, 2, 4, 3)
		print("Successfully loaded the spectrograms")

		# Create input windows
		video_sequences = torch.cat([rgb_frames[:, :, i] for i in range(rgb_frames.size(2))], dim=0)
		audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0)

		# Load the trained model
		model = Transformer_RGB()
		model = load_checkpoint(CHECKPOINT_PATH, model)
		print("Successfully loaded the model")

		# Process in batches
		batch_size = 12
		video_emb = []
		audio_emb = []

		for i in tqdm(range(0, len(video_sequences), batch_size)):
			video_inp = video_sequences[i:i+batch_size, ]
			audio_inp = audio_sequences[i:i+batch_size, ]
			
			vid_emb = model.forward_vid(video_inp.to(device))
			vid_emb = torch.mean(vid_emb, axis=-1).unsqueeze(-1)
			aud_emb = model.forward_aud(audio_inp.to(device))

			video_emb.append(vid_emb.detach())
			audio_emb.append(aud_emb.detach())
			
			torch.cuda.empty_cache()

		audio_emb = torch.cat(audio_emb, dim=0)
		video_emb = torch.cat(video_emb, dim=0)

		# L2 normalize embeddings
		video_emb = torch.nn.functional.normalize(video_emb, p=2, dim=1)
		audio_emb = torch.nn.functional.normalize(audio_emb, p=2, dim=1)

		audio_emb = torch.split(audio_emb, B, dim=0)
		audio_emb = torch.stack(audio_emb, dim=2)
		audio_emb = audio_emb.squeeze(3)
		audio_emb = audio_emb[:, None]

		video_emb = torch.split(video_emb, B, dim=0)
		video_emb = torch.stack(video_emb, dim=2)
		video_emb = video_emb.squeeze(3)
		print("Successfully extracted GestSync embeddings")

		# Calculate sync offset
		pred_offset, status = calc_optimal_av_offset(video_emb, audio_emb, num_avg_frames, model)
		if status != "success":
			return status, None
		print("Predicted offset: ", pred_offset)

		# Generate sync-corrected video
		video_output = sync_correct_video(video_path, orig_frames, wav_file, pred_offset, result_folder_output, sample_rate=16000, fps=fps)
		print("Successfully generated the video:", video_output)

		return f"Predicted offset: {pred_offset}", video_output

	except Exception as e:
		return f"Error: {str(e)}", None

def read_logs():
	sys.stdout.flush()
	with open("output.log", "r") as f:
		return f.read()


if __name__ == "__main__":

	sys.stdout = Logger("output.log")


	# Define the custom HTML for the header
	custom_css = """
	<style>
		body {
			background-color: #ffffff;
			color: #333333;  /* Default text color */
		}
		.container {
			max-width: 100% !important;
			padding-left: 0 !important;
			padding-right: 0 !important;
		}
		.header {
			background-color: #f0f0f0;
			color: #333333;
			padding: 30px;
			margin-bottom: 30px;
			text-align: center;
			font-family: 'Helvetica Neue', Arial, sans-serif;
			box-shadow: 0 2px 4px rgba(0,0,0,0.1);
		}
		.header h1 {
			font-size: 36px;
			margin-bottom: 15px;
			font-weight: bold;
			color: #333333;  /* Explicitly set heading color */
		}
		.header h2 {
			font-size: 24px;
			margin-bottom: 10px;
			color: #333333;  /* Explicitly set subheading color */
		}
		.header p {
			font-size: 18px;
			margin: 5px 0;
			color: #666666;
		}
		.blue-text {
			color: #4a90e2;
		}
		/* Custom styles for slider container */
		.slider-container {
			background-color: white !important;
			padding-top: 0.9em;
			padding-bottom: 0.9em;
		}
		/* Add gap before examples */
		.examples-holder {
			margin-top: 2em;
		}
		/* Set fixed size for example videos */
		.gradio-container .gradio-examples .gr-sample {
			width: 240px !important;
			height: 135px !important;
			object-fit: cover;
			display: inline-block;
			margin-right: 10px;
		}

		.gradio-container .gradio-examples {
			display: flex;
			flex-wrap: wrap;
			gap: 10px;
		}

		/* Ensure the parent container does not stretch */
		.gradio-container .gradio-examples {
			max-width: 100%;
			overflow: hidden;
		}

		/* Additional styles to ensure proper sizing in Safari */
		.gradio-container .gradio-examples .gr-sample img {
			width: 240px !important;
			height: 135px !important;
			object-fit: cover;
		}
	</style>
	"""

	custom_html = custom_css + """
	<div class="header">
		<h1><span class="blue-text">GestSync:</span> Determining who is speaking without a talking head</h1>
		<h2>Upload any video to predict the synchronization offset and generate a sync-corrected video</h2>
		<p>Sindhu Hegde and Andrew Zisserman</p>
		<p>VGG, University of Oxford</p>
	</div>
	"""

	# Define paths to sample videos
	sample_videos = [
						"samples/sync_sample_1.mp4",
						"samples/sync_sample_2.mp4",
					]
	
	# Define Gradio interface
	with gr.Blocks(css=custom_css, theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.pink)) as demo:
		gr.HTML(custom_html)
		with gr.Row():
			with gr.Column():
				with gr.Group(elem_classes="slider-container"):
					num_avg_frames = gr.Slider(
						minimum=50,
						maximum=150,
						step=5,
						value=75,
						label="Number of Average Frames",
					)
				apply_preprocess = gr.Checkbox(label="Apply Preprocessing", value=False)
				video_input = gr.Video(label="Upload Video", height=400)
			
			with gr.Column():
				result_text = gr.Textbox(label="Result")
				output_video = gr.Video(label="Sync Corrected Video", height=400)
		
		with gr.Row():
			submit_button = gr.Button("Submit", variant="primary")
			clear_button = gr.Button("Clear")

		submit_button.click(
			fn=process_video,
			inputs=[video_input, num_avg_frames, apply_preprocess],
			outputs=[result_text, output_video]
		)
		
		clear_button.click(
			fn=lambda: (None, 75, False, "", None),
			inputs=[],
			outputs=[video_input, num_avg_frames, apply_preprocess, result_text, output_video]
		)

		gr.HTML('<div class="examples-holder"></div>')

		# Add examples 
		gr.Examples(
			examples=sample_videos,
			inputs=video_input,
			outputs=None,
			fn=None,
			cache_examples=False,
		)

		logs = gr.Textbox(label="Logs")
		demo.load(read_logs, None, logs, every=1)

	# Launch the interface
	demo.queue().launch(allowed_paths=["."], show_error=True)