vumichien commited on
Commit
68f9039
1 Parent(s): 904c6e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -1
app.py CHANGED
@@ -34,10 +34,109 @@ import fairseq
34
  from fairseq import checkpoint_utils, options, tasks, utils
35
  from fairseq.dataclass.configs import GenerationConfig
36
  from huggingface_hub import hf_hub_download
 
37
 
38
  ckpt_path = hf_hub_download('vumichien/AV-HuBERT', 'model.pt')
39
  user_dir = "/home/user/app/av_hubert/avhubert"
40
  face_detector_path = "/home/user/app/mmod_human_face_detector.dat"
41
  face_predictor_path = "/home/user/app/shape_predictor_68_face_landmarks.dat"
42
  mean_face_path = "/home/user/app/20words_mean_face.npy"
43
- mouth_roi_path = "/home/user/app/roi.mp4"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  from fairseq import checkpoint_utils, options, tasks, utils
35
  from fairseq.dataclass.configs import GenerationConfig
36
  from huggingface_hub import hf_hub_download
37
+ import gradio as gr
38
 
39
  ckpt_path = hf_hub_download('vumichien/AV-HuBERT', 'model.pt')
40
  user_dir = "/home/user/app/av_hubert/avhubert"
41
  face_detector_path = "/home/user/app/mmod_human_face_detector.dat"
42
  face_predictor_path = "/home/user/app/shape_predictor_68_face_landmarks.dat"
43
  mean_face_path = "/home/user/app/20words_mean_face.npy"
44
+ mouth_roi_path = "/home/user/app/roi.mp4"
45
+
46
+ def detect_landmark(image, detector, predictor):
47
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
48
+ face_locations = detector(gray, 1)
49
+ coords = None
50
+ for (_, face_location) in enumerate(face_locations):
51
+ if torch.cuda.is_available():
52
+ rect = face_location.rect
53
+ else:
54
+ rect = face_location
55
+ shape = predictor(gray, rect)
56
+ coords = np.zeros((68, 2), dtype=np.int32)
57
+ for i in range(0, 68):
58
+ coords[i] = (shape.part(i).x, shape.part(i).y)
59
+ return coords
60
+
61
+ def preprocess_video(input_video_path):
62
+ if torch.cuda.is_available():
63
+ detector = dlib.cnn_face_detection_model_v1(face_detector_path)
64
+ else:
65
+ detector = dlib.get_frontal_face_detector()
66
+
67
+ predictor = dlib.shape_predictor(face_predictor_path)
68
+ STD_SIZE = (256, 256)
69
+ mean_face_landmarks = np.load(mean_face_path)
70
+ stablePntsIDs = [33, 36, 39, 42, 45]
71
+ videogen = skvideo.io.vread(input_video_path)
72
+ frames = np.array([frame for frame in videogen])
73
+ landmarks = []
74
+ for frame in tqdm(frames):
75
+ landmark = detect_landmark(frame, detector, predictor)
76
+ landmarks.append(landmark)
77
+ preprocessed_landmarks = landmarks_interpolate(landmarks)
78
+ rois = crop_patch(input_video_path, preprocessed_landmarks, mean_face_landmarks, stablePntsIDs, STD_SIZE,
79
+ window_margin=12, start_idx=48, stop_idx=68, crop_height=96, crop_width=96)
80
+ write_video_ffmpeg(rois, mouth_roi_path, "/usr/bin/ffmpeg")
81
+ return mouth_roi_path
82
+
83
+ def predict(process_video):
84
+ num_frames = int(cv2.VideoCapture(process_video).get(cv2.CAP_PROP_FRAME_COUNT))
85
+ data_dir = tempfile.mkdtemp()
86
+ tsv_cont = ["/\n", f"test-0\t{process_video}\t{None}\t{num_frames}\t{int(16_000*num_frames/25)}\n"]
87
+ label_cont = ["DUMMY\n"]
88
+ with open(f"{data_dir}/test.tsv", "w") as fo:
89
+ fo.write("".join(tsv_cont))
90
+ with open(f"{data_dir}/test.wrd", "w") as fo:
91
+ fo.write("".join(label_cont))
92
+ utils.import_user_module(Namespace(user_dir=user_dir))
93
+ modalities = ["video"]
94
+ gen_subset = "test"
95
+ gen_cfg = GenerationConfig(beam=20)
96
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
97
+ models = [model.eval().cuda() if torch.cuda.is_available() else model.eval() for model in models]
98
+ saved_cfg.task.modalities = modalities
99
+ saved_cfg.task.data = data_dir
100
+ saved_cfg.task.label_dir = data_dir
101
+ task = tasks.setup_task(saved_cfg.task)
102
+ task.load_dataset(gen_subset, task_cfg=saved_cfg.task)
103
+ generator = task.build_generator(models, gen_cfg)
104
+
105
+ def decode_fn(x):
106
+ dictionary = task.target_dictionary
107
+ symbols_ignore = generator.symbols_to_strip_from_output
108
+ symbols_ignore.add(dictionary.pad())
109
+ return task.datasets[gen_subset].label_processors[0].decode(x, symbols_ignore)
110
+
111
+ itr = task.get_batch_iterator(dataset=task.dataset(gen_subset)).next_epoch_itr(shuffle=False)
112
+ sample = next(itr)
113
+ if torch.cuda.is_available():
114
+ sample = utils.move_to_cuda(sample)
115
+ hypos = task.inference_step(generator, models, sample)
116
+ ref = decode_fn(sample['target'][0].int().cpu())
117
+ hypo = hypos[0][0]['tokens'].int().cpu()
118
+ hypo = decode_fn(hypo)
119
+ return hypo
120
+
121
+
122
+ # ---- Gradio Layout -----
123
+ demo = gr.Blocks()
124
+ demo.encrypt = False
125
+ text_output = gr.Textbox()
126
+ with demo:
127
+ with gr.Row():
128
+ video_in = gr.Video(label="Input Video", mirror_webcam=False, interactive=True)
129
+ video_out = gr.Video(label="Audio Visual Video", mirror_webcam=False, interactive=True)
130
+ with gr.Row():
131
+ detect_landmark_btn = gr.Button("Detect landmark")
132
+ detect_landmark_btn.click(preprocess_video, [video_in], [
133
+ video_out])
134
+ predict_btn = gr.Button("Predict")
135
+ predict_btn.click(predict, [video_out], [
136
+ text_output])
137
+ with gr.Row():
138
+ # video_lip = gr.Video(label="Audio Visual Video", mirror_webcam=False)
139
+ text_output.render()
140
+
141
+
142
+ demo.launch(debug=True)