Spaces:
Sleeping
Sleeping
Upload 8 files
Browse files- .gitattributes +1 -0
- 20words_mean_face.npy +3 -0
- TestVisual.sh +7 -0
- app.py +204 -0
- main.py +369 -0
- mmod_human_face_detector.dat +0 -0
- requirements.txt +21 -0
- shape_predictor_68_face_landmarks.dat +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
preprocessing/shape_predictor_68_face_landmarks.dat filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
preprocessing/shape_predictor_68_face_landmarks.dat filter=lfs diff=lfs merge=lfs -text
|
36 |
+
shape_predictor_68_face_landmarks.dat filter=lfs diff=lfs merge=lfs -text
|
20words_mean_face.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dbf68b2044171e1160716df7c53e8bbfaa0ee8c61fb41171d04cb6092bb81422
|
3 |
+
size 1168
|
TestVisual.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python main.py \
|
2 |
+
--config-path ./configs/lrw_resnet18_mstcn.json \
|
3 |
+
--model-path ./train_logs/tcn/2022-06-06T19:09:00/ckpt.best.pth.tar \
|
4 |
+
--data-dir ./video \
|
5 |
+
--label-path ./labels/30VietnameseSort.txt \
|
6 |
+
--save-dir ./result \
|
7 |
+
--test
|
app.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
os.system('git clone https://github.com/facebookresearch/av_hubert.git')
|
5 |
+
os.chdir('/home/user/app/av_hubert')
|
6 |
+
os.system('git submodule init')
|
7 |
+
os.system('git submodule update')
|
8 |
+
os.chdir('/home/user/app/av_hubert/fairseq')
|
9 |
+
os.system('pip install ./')
|
10 |
+
os.system('pip install scipy')
|
11 |
+
os.system('pip install sentencepiece')
|
12 |
+
os.system('pip install python_speech_features')
|
13 |
+
os.system('pip install scikit-video')
|
14 |
+
os.system('pip install transformers')
|
15 |
+
os.system('pip install gradio==3.12')
|
16 |
+
os.system('pip install numpy==1.23.3')
|
17 |
+
|
18 |
+
|
19 |
+
# sys.path.append('/home/user/app/av_hubert')
|
20 |
+
sys.path.append('/home/user/app/av_hubert/avhubert')
|
21 |
+
|
22 |
+
print(sys.path)
|
23 |
+
print(os.listdir())
|
24 |
+
print(sys.argv, type(sys.argv))
|
25 |
+
sys.argv.append('dummy')
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
import dlib, cv2, os
|
30 |
+
import numpy as np
|
31 |
+
import skvideo
|
32 |
+
import skvideo.io
|
33 |
+
from tqdm import tqdm
|
34 |
+
from preparation.align_mouth import landmarks_interpolate, crop_patch, write_video_ffmpeg
|
35 |
+
from base64 import b64encode
|
36 |
+
import torch
|
37 |
+
import cv2
|
38 |
+
import tempfile
|
39 |
+
from argparse import Namespace
|
40 |
+
import fairseq
|
41 |
+
from fairseq import checkpoint_utils, options, tasks, utils
|
42 |
+
from fairseq.dataclass.configs import GenerationConfig
|
43 |
+
from huggingface_hub import hf_hub_download
|
44 |
+
import gradio as gr
|
45 |
+
from pytube import YouTube
|
46 |
+
|
47 |
+
# os.chdir('/home/user/app/av_hubert/avhubert')
|
48 |
+
|
49 |
+
user_dir = "/home/user/app/av_hubert/avhubert"
|
50 |
+
utils.import_user_module(Namespace(user_dir=user_dir))
|
51 |
+
data_dir = "/home/user/app/video"
|
52 |
+
|
53 |
+
# ckpt_path = hf_hub_download('vumichien/AV-HuBERT', 'model.pt')
|
54 |
+
face_detector_path = "/home/user/app/mmod_human_face_detector.dat"
|
55 |
+
face_predictor_path = "/home/user/app/shape_predictor_68_face_landmarks.dat"
|
56 |
+
mean_face_path = "/home/user/app/20words_mean_face.npy"
|
57 |
+
mouth_roi_path = "/home/user/app/roi.mp4"
|
58 |
+
output_video_path = "/home/user/app/video/và/test"
|
59 |
+
modalities = ["video"]
|
60 |
+
gen_subset = "test"
|
61 |
+
gen_cfg = GenerationConfig(beam=20)
|
62 |
+
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
|
63 |
+
models = [model.eval().cuda() if torch.cuda.is_available() else model.eval() for model in models]
|
64 |
+
saved_cfg.task.modalities = modalities
|
65 |
+
saved_cfg.task.data = data_dir
|
66 |
+
saved_cfg.task.label_dir = data_dir
|
67 |
+
task = tasks.setup_task(saved_cfg.task)
|
68 |
+
generator = task.build_generator(models, gen_cfg)
|
69 |
+
|
70 |
+
def get_youtube(video_url):
|
71 |
+
yt = YouTube(video_url)
|
72 |
+
abs_video_path = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first().download()
|
73 |
+
print("Success download video")
|
74 |
+
print(abs_video_path)
|
75 |
+
return abs_video_path
|
76 |
+
|
77 |
+
import dlib, cv2, os
|
78 |
+
import numpy as np
|
79 |
+
import skvideo
|
80 |
+
import skvideo.io
|
81 |
+
from tqdm import tqdm
|
82 |
+
from preparation.align_mouth import landmarks_interpolate, crop_patch, write_video_ffmpeg
|
83 |
+
from IPython.display import HTML
|
84 |
+
from base64 import b64encode
|
85 |
+
import numpy as np
|
86 |
+
|
87 |
+
def convert_bgr2gray(data):
|
88 |
+
# np.stack(배열_1, 배열_2, axis=0): 지정한 axis를 완전히 새로운 axis로 생각
|
89 |
+
return np.stack([cv2.cvtColor(_, cv2.COLOR_BGR2GRAY) for _ in data], axis=0)
|
90 |
+
def save2npz(filename, data=None):
|
91 |
+
"""save2npz.
|
92 |
+
:param filename: str, the fileanme where the data will be saved.
|
93 |
+
:param data: ndarray, arrays to save to the file.
|
94 |
+
"""
|
95 |
+
assert data is not None, "data is {}".format(data)
|
96 |
+
if not os.path.exists(os.path.dirname(filename)):
|
97 |
+
os.makedirs(os.path.dirname(filename))
|
98 |
+
np.savez_compressed(filename, data=data)
|
99 |
+
|
100 |
+
def detect_landmark(image, detector, predictor):
|
101 |
+
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
102 |
+
face_locations = detector(gray, 1)
|
103 |
+
coords = None
|
104 |
+
for (_, face_location) in enumerate(face_locations):
|
105 |
+
if torch.cuda.is_available():
|
106 |
+
rect = face_location.rect
|
107 |
+
else:
|
108 |
+
rect = face_location
|
109 |
+
shape = predictor(gray, rect)
|
110 |
+
coords = np.zeros((68, 2), dtype=np.int32)
|
111 |
+
for i in range(0, 68):
|
112 |
+
coords[i] = (shape.part(i).x, shape.part(i).y)
|
113 |
+
return coords
|
114 |
+
|
115 |
+
def preprocess_video(input_video_path):
|
116 |
+
if torch.cuda.is_available():
|
117 |
+
detector = dlib.cnn_face_detection_model_v1(face_detector_path)
|
118 |
+
else:
|
119 |
+
detector = dlib.get_frontal_face_detector()
|
120 |
+
|
121 |
+
predictor = dlib.shape_predictor(face_predictor_path)
|
122 |
+
STD_SIZE = (256, 256)
|
123 |
+
mean_face_landmarks = np.load(mean_face_path)
|
124 |
+
stablePntsIDs = [33, 36, 39, 42, 45]
|
125 |
+
videogen = skvideo.io.vread(input_video_path)
|
126 |
+
frames = np.array([frame for frame in videogen])
|
127 |
+
landmarks = []
|
128 |
+
for frame in tqdm(frames):
|
129 |
+
landmark = detect_landmark(frame, detector, predictor)
|
130 |
+
landmarks.append(landmark)
|
131 |
+
preprocessed_landmarks = landmarks_interpolate(landmarks)
|
132 |
+
rois = crop_patch(input_video_path, preprocessed_landmarks, mean_face_landmarks, stablePntsIDs, STD_SIZE,
|
133 |
+
window_margin=12, start_idx=48, stop_idx=68, crop_height=96, crop_width=96)
|
134 |
+
rois_gray=convert_bgr2gray(rois)
|
135 |
+
save2npz(output_video_path, data=rois_gray)
|
136 |
+
write_video_ffmpeg(rois, mouth_roi_path, "/usr/bin/ffmpeg")
|
137 |
+
return mouth_roi_path
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
def predict(process_video):
|
142 |
+
os.chdir('/home/user/app')
|
143 |
+
return os.system('bash TestVisual.sh')
|
144 |
+
|
145 |
+
|
146 |
+
# ---- Gradio Layout -----
|
147 |
+
youtube_url_in = gr.Textbox(label="Youtube url", lines=1, interactive=True)
|
148 |
+
video_in = gr.Video(label="Input Video", mirror_webcam=False, interactive=True)
|
149 |
+
video_out = gr.Video(label="Audio Visual Video", mirror_webcam=False, interactive=True)
|
150 |
+
demo = gr.Blocks()
|
151 |
+
demo.encrypt = False
|
152 |
+
text_output = gr.Textbox()
|
153 |
+
|
154 |
+
with demo:
|
155 |
+
# gr.Markdown('''
|
156 |
+
# <div>
|
157 |
+
# <h1 style='text-align: center'>Speech Recognition from Visual Lip Movement by Audio-Visual Hidden Unit BERT Model (AV-HuBERT)</h1>
|
158 |
+
# This space uses AV-HuBERT models from <a href='https://github.com/facebookresearch' target='_blank'><b>Meta Research</b></a> to recoginze the speech from Lip Movement 🤗
|
159 |
+
# <figure>
|
160 |
+
# <img src="https://huggingface.co/vumichien/AV-HuBERT/resolve/main/lipreading.gif" alt="Audio-Visual Speech Recognition">
|
161 |
+
# <figcaption> Speech Recognition from visual lip movement
|
162 |
+
# </figcaption>
|
163 |
+
# </figure>
|
164 |
+
# </div>
|
165 |
+
# ''')
|
166 |
+
# with gr.Row():
|
167 |
+
# gr.Markdown('''
|
168 |
+
# ### Reading Lip movement with youtube link using Avhubert
|
169 |
+
# ##### Step 1a. Download video from youtube (Note: the length of video should be less than 10 seconds if not it will be cut and the face should be stable for better result)
|
170 |
+
# ##### Step 1b. You also can upload video directly
|
171 |
+
# ##### Step 2. Generating landmarks surrounding mouth area
|
172 |
+
# ##### Step 3. Reading lip movement.
|
173 |
+
# ''')
|
174 |
+
with gr.Row():
|
175 |
+
gr.Markdown('''
|
176 |
+
### You can test by following examples:
|
177 |
+
''')
|
178 |
+
examples = gr.Examples(examples=
|
179 |
+
[ "https://www.youtube.com/watch?v=ZXVDnuepW2s",
|
180 |
+
"https://www.youtube.com/watch?v=X8_glJn1B8o",
|
181 |
+
"https://www.youtube.com/watch?v=80yqL2KzBVw"],
|
182 |
+
label="Examples", inputs=[youtube_url_in])
|
183 |
+
with gr.Column():
|
184 |
+
youtube_url_in.render()
|
185 |
+
download_youtube_btn = gr.Button("Download Youtube video")
|
186 |
+
download_youtube_btn.click(get_youtube, [youtube_url_in], [
|
187 |
+
video_in])
|
188 |
+
print(video_in)
|
189 |
+
with gr.Row():
|
190 |
+
video_in.render()
|
191 |
+
video_out.render()
|
192 |
+
with gr.Row():
|
193 |
+
detect_landmark_btn = gr.Button("Phát hiện mốc/cắt môi")
|
194 |
+
detect_landmark_btn.click(preprocess_video, [video_in], [
|
195 |
+
video_out])
|
196 |
+
predict_btn = gr.Button("Dự đoán")
|
197 |
+
predict_btn.click(predict, [video_out], [
|
198 |
+
text_output])
|
199 |
+
with gr.Row():
|
200 |
+
# video_lip = gr.Video(label="Audio Visual Video", mirror_webcam=False)
|
201 |
+
text_output.render()
|
202 |
+
|
203 |
+
|
204 |
+
demo.launch(debug=True)
|
main.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2020 Imperial College London (Pingchuan Ma)
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
""" TCN for lipreading"""
|
8 |
+
|
9 |
+
import os
|
10 |
+
import time
|
11 |
+
import random
|
12 |
+
import argparse # 명령행 인자를 파싱해주는 모듈
|
13 |
+
import numpy as np
|
14 |
+
from tqdm import tqdm # 작업진행률 표시하는 라이브러리
|
15 |
+
|
16 |
+
import torch # 파이토치
|
17 |
+
import torch.nn as nn # 클래스 # attribute 를 활용해 state 를 저장하고 활용
|
18 |
+
import torch.nn.functional as F # 함수 # 인스턴스화시킬 필요없이 사용 가능
|
19 |
+
|
20 |
+
from lipreading.utils import get_save_folder
|
21 |
+
from lipreading.utils import load_json, save2npz
|
22 |
+
from lipreading.utils import load_model, CheckpointSaver
|
23 |
+
from lipreading.utils import get_logger, update_logger_batch
|
24 |
+
from lipreading.utils import showLR, calculateNorm2, AverageMeter
|
25 |
+
from lipreading.model import Lipreading
|
26 |
+
from lipreading.mixup import mixup_data, mixup_criterion
|
27 |
+
from lipreading.optim_utils import get_optimizer, CosineScheduler
|
28 |
+
from lipreading.dataloaders import get_data_loaders, get_preprocessing_pipelines
|
29 |
+
|
30 |
+
from pathlib import Path
|
31 |
+
import wandb # 학습 관리 툴 (Loss, Acc 자동 저장)
|
32 |
+
|
33 |
+
|
34 |
+
# 인자값을 받아서 처리하는 함수
|
35 |
+
def load_args(default_config=None):
|
36 |
+
# 인자값을 받을 수 있는 인스턴스 생성
|
37 |
+
parser = argparse.ArgumentParser(description='Pytorch Lipreading ')
|
38 |
+
|
39 |
+
# 입력받을 인자값 목록
|
40 |
+
# -- dataset config
|
41 |
+
parser.add_argument('--dataset', default='lrw', help='dataset selection')
|
42 |
+
parser.add_argument('--num-classes', type=int, default=30, help='Number of classes')
|
43 |
+
parser.add_argument('--modality', default='video', choices=['video', 'raw_audio'], help='choose the modality')
|
44 |
+
# -- directory
|
45 |
+
parser.add_argument('--data-dir', default='./datasets/visual', help='Loaded data directory')
|
46 |
+
parser.add_argument('--label-path', type=str, default='./labels/30VietnameseSort.txt', help='Path to txt file with labels')
|
47 |
+
parser.add_argument('--annonation-direc', default=None, help='Loaded data directory')
|
48 |
+
# -- model config
|
49 |
+
parser.add_argument('--backbone-type', type=str, default='resnet', choices=['resnet', 'shufflenet'], help='Architecture used for backbone')
|
50 |
+
parser.add_argument('--relu-type', type=str, default='relu', choices=['relu','prelu'], help='what relu to use' )
|
51 |
+
parser.add_argument('--width-mult', type=float, default=1.0, help='Width multiplier for mobilenets and shufflenets')
|
52 |
+
# -- TCN config
|
53 |
+
parser.add_argument('--tcn-kernel-size', type=int, nargs="+", help='Kernel to be used for the TCN module')
|
54 |
+
parser.add_argument('--tcn-num-layers', type=int, default=4, help='Number of layers on the TCN module')
|
55 |
+
parser.add_argument('--tcn-dropout', type=float, default=0.2, help='Dropout value for the TCN module')
|
56 |
+
parser.add_argument('--tcn-dwpw', default=False, action='store_true', help='If True, use the depthwise seperable convolution in TCN architecture')
|
57 |
+
parser.add_argument('--tcn-width-mult', type=int, default=1, help='TCN width multiplier')
|
58 |
+
# -- train
|
59 |
+
parser.add_argument('--training-mode', default='tcn', help='tcn')
|
60 |
+
parser.add_argument('--batch-size', type=int, default=8, help='Mini-batch size') # dafault=32 에서 default=8 (OOM 방지) 로 변경
|
61 |
+
parser.add_argument('--optimizer',type=str, default='adamw', choices = ['adam','sgd','adamw'])
|
62 |
+
parser.add_argument('--lr', default=3e-4, type=float, help='initial learning rate')
|
63 |
+
parser.add_argument('--init-epoch', default=0, type=int, help='epoch to start at')
|
64 |
+
parser.add_argument('--epochs', default=100, type=int, help='number of epochs') # dafault=80 에서 default=10 (테스트 용도) 로 변경
|
65 |
+
parser.add_argument('--test', default=False, action='store_true', help='training mode')
|
66 |
+
parser.add_argument('--save-dir', type=Path, default=Path('/kaggle/working/result/'))
|
67 |
+
# -- mixup
|
68 |
+
parser.add_argument('--alpha', default=0.4, type=float, help='interpolation strength (uniform=1., ERM=0.)')
|
69 |
+
# -- test
|
70 |
+
parser.add_argument('--model-path', type=str, default=None, help='Pretrained model pathname')
|
71 |
+
parser.add_argument('--allow-size-mismatch', default=False, action='store_true',
|
72 |
+
help='If True, allows to init from model with mismatching weight tensors. Useful to init from model with diff. number of classes')
|
73 |
+
# -- feature extractor
|
74 |
+
parser.add_argument('--extract-feats', default=False, action='store_true', help='Feature extractor')
|
75 |
+
parser.add_argument('--mouth-patch-path', type=str, default=None, help='Path to the mouth ROIs, assuming the file is saved as numpy.array')
|
76 |
+
parser.add_argument('--mouth-embedding-out-path', type=str, default=None, help='Save mouth embeddings to a specificed path')
|
77 |
+
# -- json pathname
|
78 |
+
parser.add_argument('--config-path', type=str, default=None, help='Model configuration with json format')
|
79 |
+
# -- other vars
|
80 |
+
parser.add_argument('--interval', default=50, type=int, help='display interval')
|
81 |
+
parser.add_argument('--workers', default=2, type=int, help='number of data loading workers') # dafault=8 에서 default=2 (GCP core 4개의 절반) 로 변경
|
82 |
+
# paths
|
83 |
+
parser.add_argument('--logging-dir', type=str, default='/kaggle/working/train_logs', help = 'path to the directory in which to save the log file')
|
84 |
+
|
85 |
+
# 입력받은 인자값을 args에 저장 (type: namespace)
|
86 |
+
args = parser.parse_args()
|
87 |
+
return args
|
88 |
+
|
89 |
+
|
90 |
+
args = load_args() # args 파싱 및 로드
|
91 |
+
|
92 |
+
# 실험 재현을 위해서 난수 고정
|
93 |
+
torch.manual_seed(1) # 메인 프레임워크인 pytorch 에서 random seed 고정
|
94 |
+
np.random.seed(1) # numpy 에서 random seed 고정
|
95 |
+
random.seed(1) # python random 라이브러리에서 random seed 고정
|
96 |
+
|
97 |
+
# 참고: 실험 재현하려면 torch.backends.cudnn.deterministic = True, torch.backends.cudnn.benchmark = False 이어야 함
|
98 |
+
torch.backends.cudnn.benchmark = True # 내장된 cudnn 자동 튜너를 활성화하여, 하드웨어에 맞게 사용할 최상의 알고리즘(텐서 크기나 conv 연산에 맞게)을 찾음
|
99 |
+
|
100 |
+
|
101 |
+
# feature 추출
|
102 |
+
def extract_feats(model):
|
103 |
+
"""
|
104 |
+
:rtype: FloatTensor
|
105 |
+
"""
|
106 |
+
model.eval() # evaluation 과정에서 사용하지 않아야 하는 layer들을 알아서 off 시키도록 하는 함수
|
107 |
+
preprocessing_func = get_preprocessing_pipelines()['test'] # test 전처리
|
108 |
+
|
109 |
+
mouth_patch_path = args.mouth_patch_path.replace('.','')
|
110 |
+
dir_name = os.path.dirname(os.path.abspath(__file__))
|
111 |
+
dir_name = dir_name + mouth_patch_path
|
112 |
+
|
113 |
+
data_paths = [os.path.join(pth, f) for pth, dirs, files in os.walk(dir_name) for f in files]
|
114 |
+
|
115 |
+
npz_files = np.load(data_paths[0])['data']
|
116 |
+
|
117 |
+
data = preprocessing_func(npz_files) # data: TxHxW
|
118 |
+
# data = preprocessing_func(np.load(args.mouth_patch_path)['data']) # data: TxHxW
|
119 |
+
return data_paths[0], model(torch.FloatTensor(data)[None, None, :, :, :].cuda(), lengths=[data.shape[0]])
|
120 |
+
# return model(torch.FloatTensor(data)[None, None, :, :, :].cuda(), lengths=[data.shape[0]])
|
121 |
+
|
122 |
+
|
123 |
+
# 평가
|
124 |
+
def evaluate(model, dset_loader, criterion, is_print=False):
|
125 |
+
model.eval() # evaluation 과정에서 사용하지 않아야 하는 layer들을 알아서 off 시키도록 하는 함수
|
126 |
+
# running_loss = 0.
|
127 |
+
# running_corrects = 0.
|
128 |
+
prediction=''
|
129 |
+
# evaluation/validation 과정에선 보통 model.eval()과 torch.no_grad()를 함께 사용함
|
130 |
+
with torch.no_grad():
|
131 |
+
inferences = []
|
132 |
+
for batch_idx, (input, lengths, labels) in enumerate(tqdm(dset_loader)):
|
133 |
+
# 모델 생성
|
134 |
+
# input 텐서의 차원을 하나 더 늘리고 gpu 에 할당
|
135 |
+
logits = model(input.unsqueeze(1).cuda(), lengths=lengths)
|
136 |
+
# _, preds = torch.max(F.softmax(logits, dim=1).data, dim=1) # softmax 적용 후 각 원소 중 최대값 가져오기
|
137 |
+
# running_corrects += preds.eq(labels.cuda().view_as(preds)).sum().item() # 정확도 계산
|
138 |
+
|
139 |
+
# loss = criterion(logits, labels.cuda()) # loss 계산
|
140 |
+
# running_loss += loss.item() * input.size(0) # loss.item(): loss 가 갖고 있는 scalar 값
|
141 |
+
# # ------------ Prediction, Confidence 출력 ------------
|
142 |
+
|
143 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
144 |
+
probs = probs[0].detach().cpu().numpy()
|
145 |
+
|
146 |
+
label_path = args.label_path
|
147 |
+
with Path(label_path).open() as fp:
|
148 |
+
vocab = fp.readlines()
|
149 |
+
|
150 |
+
top = np.argmax(probs)
|
151 |
+
prediction = vocab[top].strip()
|
152 |
+
# confidence = np.round(probs[top], 3)
|
153 |
+
# inferences.append({
|
154 |
+
# 'prediction': prediction,
|
155 |
+
# 'confidence': confidence
|
156 |
+
# })
|
157 |
+
|
158 |
+
if is_print:
|
159 |
+
print()
|
160 |
+
print(f'Prediction: {prediction}')
|
161 |
+
# print(f'Confidence: {confidence}')
|
162 |
+
print()
|
163 |
+
return prediction
|
164 |
+
# ------------ Prediction, Confidence 텍스트 파일 저장 ------------
|
165 |
+
# txt_save_path = str(args.save_dir) + f'/predict.txt'
|
166 |
+
# # 파일 없을 경우
|
167 |
+
# if not os.path.exists(os.path.dirname(txt_save_path)):
|
168 |
+
# os.makedirs(os.path.dirname(txt_save_path)) # 디렉토리 생성
|
169 |
+
# with open(txt_save_path, 'w') as f:
|
170 |
+
# for inference in inferences:
|
171 |
+
# prediction = inference['prediction']
|
172 |
+
# confidence = inference['confidence']
|
173 |
+
# f.writelines(f'Prediction: {prediction}, Confidence: {confidence}\n')
|
174 |
+
|
175 |
+
# print('Test Dataset {} In Total \t CR: {}'.format( len(dset_loader.dataset), running_corrects/len(dset_loader.dataset))) # 데이터개수, 정확도 출력
|
176 |
+
# return running_corrects/len(dset_loader.dataset), running_loss/len(dset_loader.dataset), inferences # 정확도, loss, inferences 반환
|
177 |
+
|
178 |
+
|
179 |
+
# 모델 학습
|
180 |
+
# def train(wandb, model, dset_loader, criterion, epoch, optimizer, logger):
|
181 |
+
# data_time = AverageMeter() # 평균, 현재값 ��장
|
182 |
+
# batch_time = AverageMeter() # 평균, 현재값 저장
|
183 |
+
|
184 |
+
# lr = showLR(optimizer) # LR 변화값
|
185 |
+
|
186 |
+
# # 로거 INFO 작성
|
187 |
+
# logger.info('-' * 10)
|
188 |
+
# logger.info('Epoch {}/{}'.format(epoch, args.epochs - 1)) # epoch 작성
|
189 |
+
# logger.info('Current learning rate: {}'.format(lr)) # learning rate 작성
|
190 |
+
|
191 |
+
# model.train() # train mode
|
192 |
+
# running_loss = 0.
|
193 |
+
# running_corrects = 0.
|
194 |
+
# running_all = 0.
|
195 |
+
|
196 |
+
# end = time.time() # 현재 시각
|
197 |
+
# for batch_idx, (input, lengths, labels) in enumerate(dset_loader):
|
198 |
+
# # measure data loading time
|
199 |
+
# data_time.update(time.time() - end) # 평균, 현재값 업데이트
|
200 |
+
|
201 |
+
# # --
|
202 |
+
# # mixup augmentation 계산
|
203 |
+
# input, labels_a, labels_b, lam = mixup_data(input, labels, args.alpha)
|
204 |
+
# labels_a, labels_b = labels_a.cuda(), labels_b.cuda() # tensor 를 gpu 에 할당
|
205 |
+
|
206 |
+
# # Pytorch에서는 gradients값들을 추후에 backward를 해줄때 계속 더해주기 때문
|
207 |
+
# optimizer.zero_grad() # 항상 backpropagation을 하기전에 gradients를 zero로 만들어주고 시작을 해야 함
|
208 |
+
|
209 |
+
# # 모델 생성
|
210 |
+
# # input 텐서의 차원을 하나 더 늘리고 gpu 에 할당
|
211 |
+
# logits = model(input.unsqueeze(1).cuda(), lengths=lengths)
|
212 |
+
|
213 |
+
# loss_func = mixup_criterion(labels_a, labels_b, lam) # mixup 적용
|
214 |
+
# loss = loss_func(criterion, logits) # loss 계산
|
215 |
+
|
216 |
+
# loss.backward() # gradient 계산
|
217 |
+
# optimizer.step() # 저장된 gradient 값을 이용하여 파라미터를 업데이트
|
218 |
+
|
219 |
+
# # measure elapsed time # 경과 시간 측정
|
220 |
+
# batch_time.update(time.time() - end) # 평균, 현재값 업데이트
|
221 |
+
# end = time.time() # 현재 시각
|
222 |
+
# # -- compute running performance # 컴퓨팅 실행 성능
|
223 |
+
# _, predicted = torch.max(F.softmax(logits, dim=1).data, dim=1) # softmax 적용 후 각 원소 중 최대값 가져오기
|
224 |
+
# running_loss += loss.item()*input.size(0) # loss.item(): loss 가 갖고 있는 scalar 값
|
225 |
+
# running_corrects += lam * predicted.eq(labels_a.view_as(predicted)).sum().item() + (1 - lam) * predicted.eq(labels_b.view_as(predicted)).sum().item() # 정확도 계산
|
226 |
+
# running_all += input.size(0)
|
227 |
+
|
228 |
+
|
229 |
+
# # ------------------ wandb 로그 입력 ------------------
|
230 |
+
# wandb.log({'loss': running_loss, 'acc': running_corrects}, step=epoch)
|
231 |
+
|
232 |
+
|
233 |
+
# # -- log intermediate results # 중간 결과 기록
|
234 |
+
# if batch_idx % args.interval == 0 or (batch_idx == len(dset_loader)-1):
|
235 |
+
# # 로거 INFO 작성
|
236 |
+
# update_logger_batch( args, logger, dset_loader, batch_idx, running_loss, running_corrects, running_all, batch_time, data_time )
|
237 |
+
|
238 |
+
# return model # 모델 반환
|
239 |
+
|
240 |
+
|
241 |
+
# model 설정에 대한 json 작성
|
242 |
+
def get_model_from_json():
|
243 |
+
# json 파일이 있는지 확인, 없으면 AssertionError 메시지를 띄움
|
244 |
+
assert args.config_path.endswith('.json') and os.path.isfile(args.config_path), \
|
245 |
+
"'.json' config path does not exist. Path input: {}".format(args.config_path) # 원하는 조건의 변수값을 보증하기 위해 사용
|
246 |
+
|
247 |
+
args_loaded = load_json( args.config_path) # json 읽어오기
|
248 |
+
args.backbone_type = args_loaded['backbone_type'] # json 에서 backbone_type 가져오기
|
249 |
+
args.width_mult = args_loaded['width_mult'] # json 에서 width_mult 가져오기
|
250 |
+
args.relu_type = args_loaded['relu_type'] # json 에서 relu_type 가져오기
|
251 |
+
|
252 |
+
# TCN 옵션 설정
|
253 |
+
tcn_options = { 'num_layers': args_loaded['tcn_num_layers'],
|
254 |
+
'kernel_size': args_loaded['tcn_kernel_size'],
|
255 |
+
'dropout': args_loaded['tcn_dropout'],
|
256 |
+
'dwpw': args_loaded['tcn_dwpw'],
|
257 |
+
'width_mult': args_loaded['tcn_width_mult'],
|
258 |
+
}
|
259 |
+
|
260 |
+
# 립리딩 모델 생성
|
261 |
+
model = Lipreading( modality=args.modality,
|
262 |
+
num_classes=args.num_classes,
|
263 |
+
tcn_options=tcn_options,
|
264 |
+
backbone_type=args.backbone_type,
|
265 |
+
relu_type=args.relu_type,
|
266 |
+
width_mult=args.width_mult,
|
267 |
+
extract_feats=args.extract_feats).cuda()
|
268 |
+
calculateNorm2(model) # 모델 학습이 잘 진행되는지 확인 - 일반적으로 parameter norm(L2)은 학습이 진행될수록 커져야 함
|
269 |
+
return model # 모델 반환
|
270 |
+
|
271 |
+
|
272 |
+
# main() 함수
|
273 |
+
def main():
|
274 |
+
|
275 |
+
# wandb 연결
|
276 |
+
# wandb.init(project="Lipreading_using_TCN_running")
|
277 |
+
# wandb.config = {
|
278 |
+
# "learning_rate": args.lr,
|
279 |
+
# "epochs": args.epochs,
|
280 |
+
# "batch_size": args.batch_size
|
281 |
+
# }
|
282 |
+
|
283 |
+
|
284 |
+
# os.environ['CUDA_LAUNCH_BLOCKING']="1"
|
285 |
+
# os.environ["CUDA_VISIBLE_DEVICES"]="0" # GPU 선택 코드 추가
|
286 |
+
|
287 |
+
# -- logging
|
288 |
+
save_path = get_save_folder( args) # 저장 디렉토리
|
289 |
+
print("Model and log being saved in: {}".format(save_path)) # 저장 디렉토리 경로 출력
|
290 |
+
logger = get_logger(args, save_path) # 로거 생성 및 설정
|
291 |
+
ckpt_saver = CheckpointSaver(save_path) # 체크포인트 저장 설정
|
292 |
+
|
293 |
+
# -- get model
|
294 |
+
model = get_model_from_json()
|
295 |
+
# -- get dataset iterators
|
296 |
+
dset_loaders = get_data_loaders(args)
|
297 |
+
# -- get loss function
|
298 |
+
criterion = nn.CrossEntropyLoss()
|
299 |
+
# -- get optimizer
|
300 |
+
optimizer = get_optimizer(args, optim_policies=model.parameters())
|
301 |
+
# -- get learning rate scheduler
|
302 |
+
scheduler = CosineScheduler(args.lr, args.epochs) # 코사인 스케줄러 설정
|
303 |
+
|
304 |
+
if args.model_path:
|
305 |
+
# tar 파일이 있는지 확인, 없으면 AssertionError 메시지를 띄움
|
306 |
+
assert args.model_path.endswith('.tar') and os.path.isfile(args.model_path), \
|
307 |
+
"'.tar' model path does not exist. Path input: {}".format(args.model_path) # 원하는 조건의 변수값을 보증하기 위해 사용
|
308 |
+
# resume from checkpoint
|
309 |
+
if args.init_epoch > 0:
|
310 |
+
model, optimizer, epoch_idx, ckpt_dict = load_model(args.model_path, model, optimizer) # 모델 불러오기
|
311 |
+
args.init_epoch = epoch_idx # epoch 설정
|
312 |
+
ckpt_saver.set_best_from_ckpt(ckpt_dict) # best 체크포인트 저장
|
313 |
+
logger.info('Model and states have been successfully loaded from {}'.format( args.model_path )) # 로거 INFO 작성
|
314 |
+
# init from trained model
|
315 |
+
else:
|
316 |
+
model = load_model(args.model_path, model, allow_size_mismatch=args.allow_size_mismatch) # 모델 불러오기
|
317 |
+
logger.info('Model has been successfully loaded from {}'.format( args.model_path )) # 로거 INFO 작성
|
318 |
+
# feature extraction
|
319 |
+
if args.mouth_patch_path:
|
320 |
+
|
321 |
+
filename, embeddings = extract_feats(model)
|
322 |
+
filename = filename.split('/')[-1]
|
323 |
+
save_npz_path = os.path.join(args.mouth_embedding_out_path, filename)
|
324 |
+
|
325 |
+
# ExtractEmbedding 은 코드 수정이 필요함!
|
326 |
+
save2npz(save_npz_path, data = embeddings.cpu().detach().numpy()) # npz 파일 저장
|
327 |
+
# save2npz( args.mouth_embedding_out_path, data = extract_feats(model).cpu().detach().numpy()) # npz 파일 저장
|
328 |
+
return
|
329 |
+
# if test-time, performance on test partition and exit. Otherwise, performance on validation and continue (sanity check for reload)
|
330 |
+
if args.test:
|
331 |
+
predicthi = evaluate(model, dset_loaders['test'], criterion, is_print=False) # 모델 평가
|
332 |
+
|
333 |
+
# logging_sentence = 'Test-time performance on partition {}: Loss: {:.4f}\tAcc:{:.4f}'.format( 'test', loss_avg_test, acc_avg_test)
|
334 |
+
# logger.info(logging_sentence) # 로거 INFO 작성
|
335 |
+
|
336 |
+
return predicthi
|
337 |
+
|
338 |
+
# -- fix learning rate after loading the ckeckpoint (latency)
|
339 |
+
if args.model_path and args.init_epoch > 0:
|
340 |
+
scheduler.adjust_lr(optimizer, args.init_epoch-1) # learning rate 업데이트
|
341 |
+
|
342 |
+
|
343 |
+
epoch = args.init_epoch # epoch 초기화
|
344 |
+
while epoch < args.epochs:
|
345 |
+
model = train(wandb, model, dset_loaders['train'], criterion, epoch, optimizer, logger) # 모델 학습
|
346 |
+
acc_avg_val, loss_avg_val, inferences = evaluate(model, dset_loaders['val'], criterion) # 모델 평가
|
347 |
+
logger.info('{} Epoch:\t{:2}\tLoss val: {:.4f}\tAcc val:{:.4f}, LR: {}'.format('val', epoch, loss_avg_val, acc_avg_val, showLR(optimizer))) # 로거 INFO 작성
|
348 |
+
# -- save checkpoint # 체크포인트 상태 기록
|
349 |
+
save_dict = {
|
350 |
+
'epoch_idx': epoch + 1,
|
351 |
+
'model_state_dict': model.state_dict(),
|
352 |
+
'optimizer_state_dict': optimizer.state_dict()
|
353 |
+
}
|
354 |
+
ckpt_saver.save(save_dict, acc_avg_val) # 체크포인트 저장
|
355 |
+
scheduler.adjust_lr(optimizer, epoch) # learning rate 업데이트
|
356 |
+
epoch += 1
|
357 |
+
|
358 |
+
# -- evaluate best-performing epoch on test partition # test 데이터로 best 성능의 epoch 평가
|
359 |
+
best_fp = os.path.join(ckpt_saver.save_dir, ckpt_saver.best_fn) # best 체크포인트 경로
|
360 |
+
_ = load_model(best_fp, model) # 모델 불러오기
|
361 |
+
acc_avg_test, loss_avg_test, inferences = evaluate(model, dset_loaders['test'], criterion) # 모델 평가
|
362 |
+
logger.info('Test time performance of best epoch: {} (loss: {})'.format(acc_avg_test, loss_avg_test)) # 로거 INFO 작성
|
363 |
+
torch.cuda.empty_cache() # GPU 캐시 데이터 삭제
|
364 |
+
|
365 |
+
|
366 |
+
# 해당 모듈이 임포트된 경우가 아니라 인터프리터에서 직접 실행된 경우에만, if문 이하의 코드를 돌리라는 명령
|
367 |
+
# => main.py 실행할 경우 제일 먼저 호출되는 부분
|
368 |
+
if __name__ == '__main__': # 현재 스크립트 파일이 실행되는 상태 파악
|
369 |
+
main() # main() 함수 호출
|
mmod_human_face_detector.dat
ADDED
Binary file (730 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch >= 1.3.0
|
2 |
+
numpy >= 1.16.4
|
3 |
+
scipy >= 1.3.0
|
4 |
+
opencv-python >= 4.1.0
|
5 |
+
matplotlib >= 3.0.3
|
6 |
+
tqdm >= 4.35.0
|
7 |
+
scikit-image >= 0.13.0
|
8 |
+
librosa >= 0.7.0
|
9 |
+
git+https://github.com/facebookresearch/fairseq.git
|
10 |
+
scipy
|
11 |
+
sentencepiece
|
12 |
+
python_speech_features
|
13 |
+
scikit-video
|
14 |
+
scikit-image
|
15 |
+
opencv-python
|
16 |
+
pytube==12.1.0
|
17 |
+
ffmpeg-python
|
18 |
+
cmake
|
19 |
+
dlib
|
20 |
+
face-alignment
|
21 |
+
torchvision==0.2.0
|
shape_predictor_68_face_landmarks.dat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f
|
3 |
+
size 99693937
|