wissemkarous commited on
Commit
b072050
·
verified ·
1 Parent(s): 1427b80
Files changed (10) hide show
  1. .gitattributes +1 -0
  2. .gitignore +4 -0
  3. README.md +109 -0
  4. cvtransforms.py +13 -0
  5. dataset.py +190 -0
  6. inference.py +350 -0
  7. model.py +113 -0
  8. options.py +21 -0
  9. requirements.txt +0 -0
  10. train.py +237 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.dat filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ venv
2
+ __pycache__
3
+ samples
4
+ output_videos
README.md ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+
5
+ # LipCoordNet: Enhanced Lip Reading with Landmark Coordinates
6
+
7
+ ## Introduction
8
+
9
+ LipCoordNet is an advanced neural network model designed for accurate lip reading by incorporating lip landmark coordinates as a supplementary input to the traditional image sequence input. This enhancement to the original LipNet architecture aims to improve the precision of sentence predictions by providing additional geometric context to the model.
10
+
11
+ ## Features
12
+
13
+ - **Dual Input System**: Utilizes both raw image sequences and corresponding lip landmark coordinates for improved context.
14
+ - **Enhanced Spatial Resolution**: Improved spatial analysis of lip movements through detailed landmark tracking.
15
+ - **State-of-the-Art Performance**: Outperforms the original LipNet, as well as VIPL's [PyTorch implementation of LipNet](https://github.com/VIPL-Audio-Visual-Speech-Understanding/LipNet-PyTorch).
16
+ | Scenario | Image Size (W x H) | CER | WER |
17
+ | :-------------------------------: | :----------------: | :--: | :---: |
18
+ | Unseen speakers (Original) | 100 x 50 | 6.7% | 13.6% |
19
+ | Overlapped speakers (Original) | 100 x 50 | 2.0% | 5.6% |
20
+ | Unseen speakers (VIPL LipNet) | 128 x 64 | 6.7% | 13.3% |
21
+ | Overlapped speakers (VIPL LipNet) | 128 x 64 | 1.9% | 4.6% |
22
+ | Overlapped speakers (LipCoordNet) | 128 x 64 | 0.6% | 1.7% |
23
+
24
+ ## Getting Started
25
+
26
+ ### Prerequisites
27
+
28
+ - Python 3.10 or later
29
+ - Pytorch 2.0 or later
30
+ - OpenCV
31
+ - NumPy
32
+ - dlib (for landmark detection)
33
+ - The detailed list of dependencies can be found in `requirements.txt`.
34
+
35
+ ### Installation
36
+
37
+ 1. Clone the repository:
38
+
39
+ ```bash
40
+ git clone https://huggingface.co/SilentSpeak/LipCoordNet
41
+ ```
42
+
43
+ 2. Navigate to the project directory:
44
+ ```bash
45
+ cd LipCoordNet
46
+ ```
47
+ 3. Install the required dependencies:
48
+ ```bash
49
+ pip install -r requirements.txt
50
+ ```
51
+
52
+ ### Usage
53
+
54
+ To train the LipCoordNet model with your dataset, first update the options.py file with the appropriate paths to your dataset and pretrained weights (comment out the weights if you want to start from scratch). Then, run the following command:
55
+
56
+ ```bash
57
+ python train.py
58
+ ```
59
+
60
+ To perform sentence prediction using the pre-trained model:
61
+
62
+ ```bash
63
+ python inference.py --input_video <path_to_video>
64
+ ```
65
+
66
+ note: ffmpeg is required to convert video to image sequence and run the inference script.
67
+
68
+ ## Model Architecture
69
+
70
+ ![LipCoordNet model architecture](./assets/LipCoordNet_model_architecture.png)
71
+
72
+ ## Training
73
+
74
+ This model is built on top of the [LipNet-Pytorch](https://github.com/VIPL-Audio-Visual-Speech-Understanding/LipNet-PyTorch) project on GitHub. The training process if similar to the original LipNet model, with the addition of landmark coordinates as a supplementary input. We used the pretrained weights from the original LipNet model as a starting point for training our model, froze the weights for the original LipNet layers, and trained the new layers for the landmark coordinates.
75
+
76
+ The dataset used to train this model is the [EGCLLC dataset](https://huggingface.co/datasets/SilentSpeak/EGCLLC). The dataset is not included in this repository, but can be downloaded from the link above.
77
+
78
+ Total training time: 2 days
79
+ Total epochs: 51
80
+ Training hardware: NVIDIA GeForce RTX 3080 12GB
81
+
82
+ ![LipCoordNet training curves](./assets/training_graphs.png)
83
+
84
+ For an interactive view of the training curves, please refer to the tensorboard logs in the `runs` directory.
85
+ Use this command to view the logs:
86
+
87
+ ```bash
88
+ tensorboard --logdir runs
89
+ ```
90
+
91
+ ## Evaluation
92
+
93
+ We achieved a lowest WER of 1.7%, CER of 0.6% and a loss of 0.0256 on the validation dataset.
94
+
95
+ ## License
96
+
97
+ This project is licensed under the MIT License.
98
+
99
+ ## Acknowledgments
100
+
101
+ This model, LipCoordNet, has been developed with reference to the LipNet-PyTorch implementation available at [VIPL-Audio-Visual-Speech-Understanding](https://github.com/VIPL-Audio-Visual-Speech-Understanding/LipNet-PyTorch). We extend our gratitude to the contributors of this repository for providing a solid foundation and insightful examples that greatly facilitated the development of our enhanced lip reading model. Their work has been instrumental in advancing the field of audio-visual speech understanding and has provided the community with valuable resources to build upon.
102
+
103
+ Alvarez Casado, C., Bordallo Lopez, M. Real-time face alignment: evaluation methods, training strategies and implementation optimization. Springer Journal of Real-time image processing, 2021
104
+
105
+ Assael, Y., Shillingford, B., Whiteson, S., & Freitas, N. (2017). LipNet: End-to-End Sentence-level Lipreading. GPU Technology Conference.
106
+
107
+ ## Contact
108
+
109
+ Project Link: https://github.com/ffeew/LipCoordNet
cvtransforms.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+
4
+ def HorizontalFlip(batch_img, p=0.5):
5
+ # (T, H, W, C)
6
+ if random.random() > p:
7
+ batch_img = batch_img[:, :, ::-1, ...]
8
+ return batch_img
9
+
10
+
11
+ def ColorNormalize(batch_img):
12
+ batch_img = batch_img / 255.0
13
+ return batch_img
dataset.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import os
4
+ from torch.utils.data import Dataset
5
+ from cvtransforms import *
6
+ import torch
7
+ import editdistance
8
+ import json
9
+
10
+
11
+ class MyDataset(Dataset):
12
+ letters = [
13
+ " ",
14
+ "A",
15
+ "B",
16
+ "C",
17
+ "D",
18
+ "E",
19
+ "F",
20
+ "G",
21
+ "H",
22
+ "I",
23
+ "J",
24
+ "K",
25
+ "L",
26
+ "M",
27
+ "N",
28
+ "O",
29
+ "P",
30
+ "Q",
31
+ "R",
32
+ "S",
33
+ "T",
34
+ "U",
35
+ "V",
36
+ "W",
37
+ "X",
38
+ "Y",
39
+ "Z",
40
+ ]
41
+
42
+ def __init__(
43
+ self,
44
+ video_path,
45
+ anno_path,
46
+ coords_path,
47
+ file_list,
48
+ vid_pad,
49
+ txt_pad,
50
+ phase,
51
+ ):
52
+ self.anno_path = anno_path
53
+ self.coords_path = coords_path
54
+ self.vid_pad = vid_pad
55
+ self.txt_pad = txt_pad
56
+ self.phase = phase
57
+
58
+ with open(file_list, "r") as f:
59
+ self.videos = [
60
+ os.path.join(video_path, line.strip()) for line in f.readlines()
61
+ ]
62
+
63
+ self.data = []
64
+ for vid in self.videos:
65
+ items = vid.split("/")
66
+ self.data.append((vid, items[-4], items[-1]))
67
+
68
+ def __getitem__(self, idx):
69
+ (vid, spk, name) = self.data[idx]
70
+ vid = self._load_vid(vid)
71
+ anno = self._load_anno(
72
+ os.path.join(self.anno_path, spk, "align", name + ".align")
73
+ )
74
+ coord = self._load_coords(os.path.join(self.coords_path, spk, name + ".json"))
75
+
76
+ if self.phase == "train":
77
+ vid = HorizontalFlip(vid)
78
+
79
+ vid = ColorNormalize(vid)
80
+
81
+ vid_len = vid.shape[0]
82
+ anno_len = anno.shape[0]
83
+ vid = self._padding(vid, self.vid_pad)
84
+ anno = self._padding(anno, self.txt_pad)
85
+ coord = self._padding(coord, self.vid_pad)
86
+
87
+ return {
88
+ "vid": torch.FloatTensor(vid.transpose(3, 0, 1, 2)),
89
+ "txt": torch.LongTensor(anno),
90
+ "coord": torch.FloatTensor(coord),
91
+ "txt_len": anno_len,
92
+ "vid_len": vid_len,
93
+ }
94
+
95
+ def __len__(self):
96
+ return len(self.data)
97
+
98
+ def _load_vid(self, p):
99
+ files = os.listdir(p)
100
+ files = list(filter(lambda file: file.find(".jpg") != -1, files))
101
+ files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
102
+ array = [cv2.imread(os.path.join(p, file)) for file in files]
103
+ array = list(filter(lambda im: not im is None, array))
104
+ array = [
105
+ cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4) for im in array
106
+ ]
107
+ array = np.stack(array, axis=0).astype(np.float32)
108
+
109
+ return array
110
+
111
+ def _load_anno(self, name):
112
+ with open(name, "r") as f:
113
+ lines = [line.strip().split(" ") for line in f.readlines()]
114
+ txt = [line[2] for line in lines]
115
+ txt = list(filter(lambda s: not s.upper() in ["SIL", "SP"], txt))
116
+ return MyDataset.txt2arr(" ".join(txt).upper(), 1)
117
+
118
+ def _load_coords(self, name):
119
+ # obtained from the resized image in the lip coordinate extraction
120
+ img_width = 600
121
+ img_height = 500
122
+ with open(name, "r") as f:
123
+ coords_data = json.load(f)
124
+
125
+ coords = []
126
+ for frame in sorted(coords_data.keys(), key=int):
127
+ frame_coords = coords_data[frame]
128
+
129
+ # Normalize the coordinates
130
+ normalized_coords = []
131
+ for x, y in zip(frame_coords[0], frame_coords[1]):
132
+ normalized_x = x / img_width
133
+ normalized_y = y / img_height
134
+ normalized_coords.append((normalized_x, normalized_y))
135
+
136
+ coords.append(normalized_coords)
137
+ coords_array = np.array(coords, dtype=np.float32)
138
+ return coords_array
139
+
140
+ def _padding(self, array, length):
141
+ array = [array[_] for _ in range(array.shape[0])]
142
+ size = array[0].shape
143
+ for i in range(length - len(array)):
144
+ array.append(np.zeros(size))
145
+ return np.stack(array, axis=0)
146
+
147
+ @staticmethod
148
+ def txt2arr(txt, start):
149
+ arr = []
150
+ for c in list(txt):
151
+ arr.append(MyDataset.letters.index(c) + start)
152
+ return np.array(arr)
153
+
154
+ @staticmethod
155
+ def arr2txt(arr, start):
156
+ txt = []
157
+ for n in arr:
158
+ if n >= start:
159
+ txt.append(MyDataset.letters[n - start])
160
+ return "".join(txt).strip()
161
+
162
+ @staticmethod
163
+ def ctc_arr2txt(arr, start):
164
+ pre = -1
165
+ txt = []
166
+ for n in arr:
167
+ if pre != n and n >= start:
168
+ if (
169
+ len(txt) > 0
170
+ and txt[-1] == " "
171
+ and MyDataset.letters[n - start] == " "
172
+ ):
173
+ pass
174
+ else:
175
+ txt.append(MyDataset.letters[n - start])
176
+ pre = n
177
+ return "".join(txt).strip()
178
+
179
+ @staticmethod
180
+ def wer(predict, truth):
181
+ word_pairs = [(p[0].split(" "), p[1].split(" ")) for p in zip(predict, truth)]
182
+ wer = [1.0 * editdistance.eval(p[0], p[1]) / len(p[1]) for p in word_pairs]
183
+ return wer
184
+
185
+ @staticmethod
186
+ def cer(predict, truth):
187
+ cer = [
188
+ 1.0 * editdistance.eval(p[0], p[1]) / len(p[1]) for p in zip(predict, truth)
189
+ ]
190
+ return cer
inference.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from model import LipCoordNet
4
+ from dataset import MyDataset
5
+ import torch
6
+ import cv2
7
+ import face_alignment
8
+ import numpy as np
9
+ import dlib
10
+ import glob
11
+
12
+
13
+ def get_position(size, padding=0.25):
14
+ x = [
15
+ 0.000213256,
16
+ 0.0752622,
17
+ 0.18113,
18
+ 0.29077,
19
+ 0.393397,
20
+ 0.586856,
21
+ 0.689483,
22
+ 0.799124,
23
+ 0.904991,
24
+ 0.98004,
25
+ 0.490127,
26
+ 0.490127,
27
+ 0.490127,
28
+ 0.490127,
29
+ 0.36688,
30
+ 0.426036,
31
+ 0.490127,
32
+ 0.554217,
33
+ 0.613373,
34
+ 0.121737,
35
+ 0.187122,
36
+ 0.265825,
37
+ 0.334606,
38
+ 0.260918,
39
+ 0.182743,
40
+ 0.645647,
41
+ 0.714428,
42
+ 0.793132,
43
+ 0.858516,
44
+ 0.79751,
45
+ 0.719335,
46
+ 0.254149,
47
+ 0.340985,
48
+ 0.428858,
49
+ 0.490127,
50
+ 0.551395,
51
+ 0.639268,
52
+ 0.726104,
53
+ 0.642159,
54
+ 0.556721,
55
+ 0.490127,
56
+ 0.423532,
57
+ 0.338094,
58
+ 0.290379,
59
+ 0.428096,
60
+ 0.490127,
61
+ 0.552157,
62
+ 0.689874,
63
+ 0.553364,
64
+ 0.490127,
65
+ 0.42689,
66
+ ]
67
+
68
+ y = [
69
+ 0.106454,
70
+ 0.038915,
71
+ 0.0187482,
72
+ 0.0344891,
73
+ 0.0773906,
74
+ 0.0773906,
75
+ 0.0344891,
76
+ 0.0187482,
77
+ 0.038915,
78
+ 0.106454,
79
+ 0.203352,
80
+ 0.307009,
81
+ 0.409805,
82
+ 0.515625,
83
+ 0.587326,
84
+ 0.609345,
85
+ 0.628106,
86
+ 0.609345,
87
+ 0.587326,
88
+ 0.216423,
89
+ 0.178758,
90
+ 0.179852,
91
+ 0.231733,
92
+ 0.245099,
93
+ 0.244077,
94
+ 0.231733,
95
+ 0.179852,
96
+ 0.178758,
97
+ 0.216423,
98
+ 0.244077,
99
+ 0.245099,
100
+ 0.780233,
101
+ 0.745405,
102
+ 0.727388,
103
+ 0.742578,
104
+ 0.727388,
105
+ 0.745405,
106
+ 0.780233,
107
+ 0.864805,
108
+ 0.902192,
109
+ 0.909281,
110
+ 0.902192,
111
+ 0.864805,
112
+ 0.784792,
113
+ 0.778746,
114
+ 0.785343,
115
+ 0.778746,
116
+ 0.784792,
117
+ 0.824182,
118
+ 0.831803,
119
+ 0.824182,
120
+ ]
121
+
122
+ x, y = np.array(x), np.array(y)
123
+
124
+ x = (x + padding) / (2 * padding + 1)
125
+ y = (y + padding) / (2 * padding + 1)
126
+ x = x * size
127
+ y = y * size
128
+ return np.array(list(zip(x, y)))
129
+
130
+
131
+ def transformation_from_points(points1, points2):
132
+ points1 = points1.astype(np.float64)
133
+ points2 = points2.astype(np.float64)
134
+
135
+ c1 = np.mean(points1, axis=0)
136
+ c2 = np.mean(points2, axis=0)
137
+ points1 -= c1
138
+ points2 -= c2
139
+ s1 = np.std(points1)
140
+ s2 = np.std(points2)
141
+ points1 /= s1
142
+ points2 /= s2
143
+
144
+ U, S, Vt = np.linalg.svd(points1.T * points2)
145
+ R = (U * Vt).T
146
+ return np.vstack(
147
+ [
148
+ np.hstack(((s2 / s1) * R, c2.T - (s2 / s1) * R * c1.T)),
149
+ np.matrix([0.0, 0.0, 1.0]),
150
+ ]
151
+ )
152
+
153
+
154
+ def load_video(file, device: str):
155
+ # create the samples directory if it doesn't exist
156
+ if not os.path.exists("samples"):
157
+ os.makedirs("samples")
158
+
159
+ p = os.path.join("samples")
160
+ output = os.path.join("samples", "%04d.jpg")
161
+ cmd = "ffmpeg -hide_banner -loglevel error -i {} -qscale:v 2 -r 25 {}".format(
162
+ file, output
163
+ )
164
+ os.system(cmd)
165
+
166
+ files = os.listdir(p)
167
+ files = sorted(files, key=lambda x: int(os.path.splitext(x)[0]))
168
+
169
+ array = [cv2.imread(os.path.join(p, file)) for file in files]
170
+
171
+ array = list(filter(lambda im: not im is None, array))
172
+
173
+ fa = face_alignment.FaceAlignment(
174
+ face_alignment.LandmarksType._2D, flip_input=False, device=device
175
+ )
176
+ points = [fa.get_landmarks(I) for I in array]
177
+
178
+ front256 = get_position(256)
179
+ video = []
180
+ for point, scene in zip(points, array):
181
+ if point is not None:
182
+ shape = np.array(point[0])
183
+ shape = shape[17:]
184
+ M = transformation_from_points(np.matrix(shape), np.matrix(front256))
185
+
186
+ img = cv2.warpAffine(scene, M[:2], (256, 256))
187
+ (x, y) = front256[-20:].mean(0).astype(np.int32)
188
+ w = 160 // 2
189
+ img = img[y - w // 2 : y + w // 2, x - w : x + w, ...]
190
+ img = cv2.resize(img, (128, 64))
191
+ video.append(img)
192
+
193
+ video = np.stack(video, axis=0).astype(np.float32)
194
+ video = torch.FloatTensor(video.transpose(3, 0, 1, 2)) / 255.0
195
+
196
+ return video
197
+
198
+
199
+ def extract_lip_coordinates(detector, predictor, img_path):
200
+ image = cv2.imread(img_path)
201
+ image = cv2.resize(image, (600, 500))
202
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
203
+
204
+ rects = detector(gray)
205
+ retries = 3
206
+ while retries > 0:
207
+ try:
208
+ assert len(rects) == 1
209
+ break
210
+ except AssertionError as e:
211
+ retries -= 1
212
+
213
+ for rect in rects:
214
+ # apply the shape predictor to the face ROI
215
+ shape = predictor(gray, rect)
216
+ x = []
217
+ y = []
218
+ for n in range(48, 68):
219
+ x.append(shape.part(n).x)
220
+ y.append(shape.part(n).y)
221
+ return [x, y]
222
+
223
+
224
+ def generate_lip_coordinates(frame_images_directory, detector, predictor):
225
+ frames = glob.glob(frame_images_directory + "/*.jpg")
226
+ frames.sort()
227
+
228
+ img = cv2.imread(frames[0])
229
+ height, width, layers = img.shape
230
+
231
+ coords = []
232
+ for frame in frames:
233
+ x_coords, y_coords = extract_lip_coordinates(detector, predictor, frame)
234
+ normalized_coords = []
235
+ for x, y in zip(x_coords, y_coords):
236
+ normalized_x = x / width
237
+ normalized_y = y / height
238
+ normalized_coords.append((normalized_x, normalized_y))
239
+ coords.append(normalized_coords)
240
+ coords_array = np.array(coords, dtype=np.float32)
241
+ coords_array = torch.from_numpy(coords_array)
242
+ return coords_array
243
+
244
+
245
+ def ctc_decode(y):
246
+ y = y.argmax(-1)
247
+ t = y.size(0)
248
+ result = []
249
+ for i in range(t + 1):
250
+ result.append(MyDataset.ctc_arr2txt(y[:i], start=1))
251
+ return result
252
+
253
+
254
+ def output_video(p, txt, output_path):
255
+ files = os.listdir(p)
256
+ files = sorted(files, key=lambda x: int(os.path.splitext(x)[0]))
257
+
258
+ font = cv2.FONT_HERSHEY_SIMPLEX
259
+
260
+ for file, line in zip(files, txt):
261
+ img = cv2.imread(os.path.join(p, file))
262
+ h, w, _ = img.shape
263
+ img = cv2.putText(
264
+ img, line, (w // 8, 11 * h // 12), font, 1.2, (0, 0, 0), 3, cv2.LINE_AA
265
+ )
266
+ img = cv2.putText(
267
+ img,
268
+ line,
269
+ (w // 8, 11 * h // 12),
270
+ font,
271
+ 1.2,
272
+ (255, 255, 255),
273
+ 0,
274
+ cv2.LINE_AA,
275
+ )
276
+ h = h // 2
277
+ w = w // 2
278
+ img = cv2.resize(img, (w, h))
279
+ cv2.imwrite(os.path.join(p, file), img)
280
+
281
+ # create the output_videos directory if it doesn't exist
282
+ if not os.path.exists(output_path):
283
+ os.makedirs(output_path)
284
+
285
+ output = os.path.join(output_path, "output.mp4")
286
+ cmd = "ffmpeg -hide_banner -loglevel error -y -i {}/%04d.jpg -r 25 {}".format(
287
+ p, output
288
+ )
289
+ os.system(cmd)
290
+
291
+
292
+ def main():
293
+ parser = argparse.ArgumentParser()
294
+ parser.add_argument(
295
+ "--weights",
296
+ type=str,
297
+ default="pretrain/LipCoordNet_coords_loss_0.025581153109669685_wer_0.01746208431890914_cer_0.006488426950253695.pt",
298
+ help="path to the weights file",
299
+ )
300
+ parser.add_argument(
301
+ "--input_video",
302
+ type=str,
303
+ help="path to the input video frames",
304
+ )
305
+ parser.add_argument(
306
+ "--device",
307
+ type=str,
308
+ default="cuda",
309
+ help="device to run the model on",
310
+ )
311
+
312
+ parser.add_argument(
313
+ "--output_path",
314
+ type=str,
315
+ default="output_videos",
316
+ help="directory to save the output video",
317
+ )
318
+
319
+ args = parser.parse_args()
320
+
321
+ # validate if device is valid
322
+ if args.device not in ("cuda", "cpu"):
323
+ raise ValueError("Invalid device, must be either cuda or cpu")
324
+
325
+ device = args.device
326
+
327
+ # load model
328
+ model = LipCoordNet()
329
+ model.load_state_dict(torch.load(args.weights))
330
+ model = model.to(device)
331
+ model.eval()
332
+ detector = dlib.get_frontal_face_detector()
333
+ predictor = dlib.shape_predictor(
334
+ "lip_coordinate_extraction/shape_predictor_68_face_landmarks_GTX.dat"
335
+ )
336
+
337
+ # load video
338
+ video = load_video(args.input_video, device)
339
+
340
+ # generate lip coordinates
341
+ coords = generate_lip_coordinates("samples", detector, predictor)
342
+
343
+ pred = model(video[None, ...].to(device), coords[None, ...].to(device))
344
+ output = ctc_decode(pred[0])
345
+ print(output[-1])
346
+ output_video("samples", output, args.output_path)
347
+
348
+
349
+ if __name__ == "__main__":
350
+ main()
model.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.init as init
4
+ import math
5
+
6
+
7
+ class LipCoordNet(torch.nn.Module):
8
+ def __init__(self, dropout_p=0.5, coord_input_dim=40, coord_hidden_dim=128):
9
+ super(LipCoordNet, self).__init__()
10
+ self.conv1 = nn.Conv3d(3, 32, (3, 5, 5), (1, 2, 2), (1, 2, 2))
11
+ self.pool1 = nn.MaxPool3d((1, 2, 2), (1, 2, 2))
12
+
13
+ self.conv2 = nn.Conv3d(32, 64, (3, 5, 5), (1, 1, 1), (1, 2, 2))
14
+ self.pool2 = nn.MaxPool3d((1, 2, 2), (1, 2, 2))
15
+
16
+ self.conv3 = nn.Conv3d(64, 96, (3, 3, 3), (1, 1, 1), (1, 1, 1))
17
+ self.pool3 = nn.MaxPool3d((1, 2, 2), (1, 2, 2))
18
+
19
+ self.gru1 = nn.GRU(96 * 4 * 8, 256, 1, bidirectional=True)
20
+ self.gru2 = nn.GRU(512, 256, 1, bidirectional=True)
21
+
22
+ self.FC = nn.Linear(512 + 2 * coord_hidden_dim, 27 + 1)
23
+ self.dropout_p = dropout_p
24
+
25
+ self.relu = nn.ReLU(inplace=True)
26
+ self.dropout = nn.Dropout(self.dropout_p)
27
+ self.dropout3d = nn.Dropout3d(self.dropout_p)
28
+
29
+ # New GRU layer for lip coordinates
30
+ self.coord_gru = nn.GRU(
31
+ coord_input_dim, coord_hidden_dim, 1, bidirectional=True
32
+ )
33
+
34
+ self._init()
35
+
36
+ def _init(self):
37
+ init.kaiming_normal_(self.conv1.weight, nonlinearity="relu")
38
+ init.constant_(self.conv1.bias, 0)
39
+
40
+ init.kaiming_normal_(self.conv2.weight, nonlinearity="relu")
41
+ init.constant_(self.conv2.bias, 0)
42
+
43
+ init.kaiming_normal_(self.conv3.weight, nonlinearity="relu")
44
+ init.constant_(self.conv3.bias, 0)
45
+
46
+ init.kaiming_normal_(self.FC.weight, nonlinearity="sigmoid")
47
+ init.constant_(self.FC.bias, 0)
48
+
49
+ for m in (self.gru1, self.gru2):
50
+ stdv = math.sqrt(2 / (96 * 3 * 6 + 256))
51
+ for i in range(0, 256 * 3, 256):
52
+ init.uniform_(
53
+ m.weight_ih_l0[i : i + 256],
54
+ -math.sqrt(3) * stdv,
55
+ math.sqrt(3) * stdv,
56
+ )
57
+ init.orthogonal_(m.weight_hh_l0[i : i + 256])
58
+ init.constant_(m.bias_ih_l0[i : i + 256], 0)
59
+ init.uniform_(
60
+ m.weight_ih_l0_reverse[i : i + 256],
61
+ -math.sqrt(3) * stdv,
62
+ math.sqrt(3) * stdv,
63
+ )
64
+ init.orthogonal_(m.weight_hh_l0_reverse[i : i + 256])
65
+ init.constant_(m.bias_ih_l0_reverse[i : i + 256], 0)
66
+
67
+ def forward(self, x, coords):
68
+ # branch 1
69
+ x = self.conv1(x)
70
+ x = self.relu(x)
71
+ x = self.dropout3d(x)
72
+ x = self.pool1(x)
73
+
74
+ x = self.conv2(x)
75
+ x = self.relu(x)
76
+ x = self.dropout3d(x)
77
+ x = self.pool2(x)
78
+
79
+ x = self.conv3(x)
80
+ x = self.relu(x)
81
+ x = self.dropout3d(x)
82
+ x = self.pool3(x)
83
+
84
+ # (B, C, T, H, W)->(T, B, C, H, W)
85
+ x = x.permute(2, 0, 1, 3, 4).contiguous()
86
+ # (B, C, T, H, W)->(T, B, C*H*W)
87
+ x = x.view(x.size(0), x.size(1), -1)
88
+
89
+ self.gru1.flatten_parameters()
90
+ self.gru2.flatten_parameters()
91
+
92
+ x, h = self.gru1(x)
93
+ x = self.dropout(x)
94
+ x, h = self.gru2(x)
95
+ x = self.dropout(x)
96
+
97
+ # branch 2
98
+ # Process lip coordinates through GRU
99
+ self.coord_gru.flatten_parameters()
100
+
101
+ # (B, T, N, C)->(T, B, C, N, C)
102
+ coords = coords.permute(1, 0, 2, 3).contiguous()
103
+ # (T, B, C, N, C)->(T, B, C, N*C)
104
+ coords = coords.view(coords.size(0), coords.size(1), -1)
105
+ coords, _ = self.coord_gru(coords)
106
+ coords = self.dropout(coords)
107
+
108
+ # combine the two branches
109
+ combined = torch.cat((x, coords), dim=2)
110
+
111
+ x = self.FC(combined)
112
+ x = x.permute(1, 0, 2).contiguous()
113
+ return x
options.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu = "0"
2
+ random_seed = 0
3
+ data_type = "coords"
4
+ video_path = "lip_images/"
5
+ train_list = f"data/{data_type}_train.txt"
6
+ val_list = f"data/{data_type}_val.txt"
7
+ anno_path = "GRID_alignments"
8
+ coords_path = "lip_coordinates"
9
+ vid_padding = 75
10
+ txt_padding = 200
11
+ batch_size = 40
12
+ base_lr = 2e-5
13
+ num_workers = 16
14
+ max_epoch = 10000
15
+ display = 50
16
+ test_step = 1000
17
+ save_prefix = f"weights/LipCoordNet_{data_type}"
18
+ is_optimize = True
19
+ pin_memory = True
20
+
21
+ weights = "pretrain/LipCoordNet_coords_loss_0.025581153109669685_wer_0.01746208431890914_cer_0.006488426950253695.pt"
requirements.txt ADDED
Binary file (1.17 kB). View file
 
train.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader
4
+ import os
5
+ from dataset import MyDataset
6
+ import numpy as np
7
+ import time
8
+ from model import LipCoordNet
9
+ import torch.optim as optim
10
+ from tensorboardX import SummaryWriter
11
+ import options as opt
12
+ from tqdm import tqdm
13
+
14
+
15
+ def dataset2dataloader(dataset, num_workers=opt.num_workers, shuffle=True):
16
+ return DataLoader(
17
+ dataset,
18
+ batch_size=opt.batch_size,
19
+ shuffle=shuffle,
20
+ num_workers=num_workers,
21
+ drop_last=False,
22
+ pin_memory=opt.pin_memory,
23
+ )
24
+
25
+
26
+ def show_lr(optimizer):
27
+ lr = []
28
+ for param_group in optimizer.param_groups:
29
+ lr += [param_group["lr"]]
30
+ return np.array(lr).mean()
31
+
32
+
33
+ def ctc_decode(y):
34
+ y = y.argmax(-1)
35
+ return [MyDataset.ctc_arr2txt(y[_], start=1) for _ in range(y.size(0))]
36
+
37
+
38
+ def test(model, net):
39
+ with torch.no_grad():
40
+ dataset = MyDataset(
41
+ opt.video_path,
42
+ opt.anno_path,
43
+ opt.coords_path,
44
+ opt.val_list,
45
+ opt.vid_padding,
46
+ opt.txt_padding,
47
+ "test",
48
+ )
49
+
50
+ print("num_test_data:{}".format(len(dataset.data)))
51
+ model.eval()
52
+ loader = dataset2dataloader(dataset, shuffle=False)
53
+ loss_list = []
54
+ wer = []
55
+ cer = []
56
+ crit = nn.CTCLoss()
57
+ tic = time.time()
58
+ print("RUNNING VALIDATION")
59
+ pbar = tqdm(loader)
60
+ for i_iter, input in enumerate(pbar):
61
+ vid = input.get("vid").cuda(non_blocking=opt.pin_memory)
62
+ txt = input.get("txt").cuda(non_blocking=opt.pin_memory)
63
+ vid_len = input.get("vid_len").cuda(non_blocking=opt.pin_memory)
64
+ txt_len = input.get("txt_len").cuda(non_blocking=opt.pin_memory)
65
+ coord = input.get("coord").cuda(non_blocking=opt.pin_memory)
66
+
67
+ y = net(vid, coord)
68
+
69
+ loss = (
70
+ crit(
71
+ y.transpose(0, 1).log_softmax(-1),
72
+ txt,
73
+ vid_len.view(-1),
74
+ txt_len.view(-1),
75
+ )
76
+ .detach()
77
+ .cpu()
78
+ .numpy()
79
+ )
80
+ loss_list.append(loss)
81
+ pred_txt = ctc_decode(y)
82
+
83
+ truth_txt = [MyDataset.arr2txt(txt[_], start=1) for _ in range(txt.size(0))]
84
+ wer.extend(MyDataset.wer(pred_txt, truth_txt))
85
+ cer.extend(MyDataset.cer(pred_txt, truth_txt))
86
+ if i_iter % opt.display == 0:
87
+ v = 1.0 * (time.time() - tic) / (i_iter + 1)
88
+ eta = v * (len(loader) - i_iter) / 3600.0
89
+
90
+ print("".join(101 * "-"))
91
+ print("{:<50}|{:>50}".format("predict", "truth"))
92
+ print("".join(101 * "-"))
93
+ for predict, truth in list(zip(pred_txt, truth_txt))[:10]:
94
+ print("{:<50}|{:>50}".format(predict, truth))
95
+ print("".join(101 * "-"))
96
+ print(
97
+ "test_iter={},eta={},wer={},cer={}".format(
98
+ i_iter, eta, np.array(wer).mean(), np.array(cer).mean()
99
+ )
100
+ )
101
+ print("".join(101 * "-"))
102
+
103
+ return (np.array(loss_list).mean(), np.array(wer).mean(), np.array(cer).mean())
104
+
105
+
106
+ def train(model, net):
107
+ dataset = MyDataset(
108
+ opt.video_path,
109
+ opt.anno_path,
110
+ opt.coords_path,
111
+ opt.train_list,
112
+ opt.vid_padding,
113
+ opt.txt_padding,
114
+ "train",
115
+ )
116
+
117
+ loader = dataset2dataloader(dataset)
118
+ optimizer = optim.Adam(
119
+ model.parameters(), lr=opt.base_lr, weight_decay=0.0, amsgrad=True
120
+ )
121
+
122
+ print("num_train_data:{}".format(len(dataset.data)))
123
+ crit = nn.CTCLoss()
124
+ tic = time.time()
125
+
126
+ train_wer = []
127
+ for epoch in range(opt.max_epoch):
128
+ print(f"RUNNING EPOCH {epoch}")
129
+ pbar = tqdm(loader)
130
+
131
+ for i_iter, input in enumerate(pbar):
132
+ model.train()
133
+ vid = input.get("vid").cuda(non_blocking=opt.pin_memory)
134
+ txt = input.get("txt").cuda(non_blocking=opt.pin_memory)
135
+ vid_len = input.get("vid_len").cuda(non_blocking=opt.pin_memory)
136
+ txt_len = input.get("txt_len").cuda(non_blocking=opt.pin_memory)
137
+ coord = input.get("coord").cuda(non_blocking=opt.pin_memory)
138
+
139
+ optimizer.zero_grad()
140
+ y = net(vid, coord)
141
+ loss = crit(
142
+ y.transpose(0, 1).log_softmax(-1),
143
+ txt,
144
+ vid_len.view(-1),
145
+ txt_len.view(-1),
146
+ )
147
+ loss.backward()
148
+
149
+ if opt.is_optimize:
150
+ optimizer.step()
151
+
152
+ tot_iter = i_iter + epoch * len(loader)
153
+
154
+ pred_txt = ctc_decode(y)
155
+
156
+ truth_txt = [MyDataset.arr2txt(txt[_], start=1) for _ in range(txt.size(0))]
157
+ train_wer.extend(MyDataset.wer(pred_txt, truth_txt))
158
+
159
+ if tot_iter % opt.display == 0:
160
+ v = 1.0 * (time.time() - tic) / (tot_iter + 1)
161
+ eta = (len(loader) - i_iter) * v / 3600.0
162
+
163
+ writer.add_scalar("train loss", loss, tot_iter)
164
+ writer.add_scalar("train wer", np.array(train_wer).mean(), tot_iter)
165
+ print("".join(101 * "-"))
166
+ print("{:<50}|{:>50}".format("predict", "truth"))
167
+ print("".join(101 * "-"))
168
+
169
+ for predict, truth in list(zip(pred_txt, truth_txt))[:3]:
170
+ print("{:<50}|{:>50}".format(predict, truth))
171
+ print("".join(101 * "-"))
172
+ print(
173
+ "epoch={},tot_iter={},eta={},loss={},train_wer={}".format(
174
+ epoch, tot_iter, eta, loss, np.array(train_wer).mean()
175
+ )
176
+ )
177
+ print("".join(101 * "-"))
178
+
179
+ if tot_iter % opt.test_step == 0:
180
+ (loss, wer, cer) = test(model, net)
181
+ print(
182
+ "i_iter={},lr={},loss={},wer={},cer={}".format(
183
+ tot_iter, show_lr(optimizer), loss, wer, cer
184
+ )
185
+ )
186
+ writer.add_scalar("val loss", loss, tot_iter)
187
+ writer.add_scalar("wer", wer, tot_iter)
188
+ writer.add_scalar("cer", cer, tot_iter)
189
+ savename = "{}_loss_{}_wer_{}_cer_{}.pt".format(
190
+ opt.save_prefix, loss, wer, cer
191
+ )
192
+ (path, name) = os.path.split(savename)
193
+ if not os.path.exists(path):
194
+ os.makedirs(path)
195
+ torch.save(model.state_dict(), savename)
196
+ if not opt.is_optimize:
197
+ exit()
198
+
199
+
200
+ if __name__ == "__main__":
201
+ print("Loading options...")
202
+ os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
203
+ writer = SummaryWriter()
204
+ model = LipCoordNet()
205
+ model = model.cuda()
206
+ net = nn.DataParallel(model).cuda()
207
+
208
+ if hasattr(opt, "weights"):
209
+ pretrained_dict = torch.load(opt.weights)
210
+ model_dict = model.state_dict()
211
+ pretrained_dict = {
212
+ k: v
213
+ for k, v in pretrained_dict.items()
214
+ if k in model_dict.keys() and v.size() == model_dict[k].size()
215
+ }
216
+
217
+ # freeze the pretrained layers
218
+ for k, param in pretrained_dict.items():
219
+ param.requires_grad = False
220
+
221
+ missed_params = [
222
+ k for k, v in model_dict.items() if not k in pretrained_dict.keys()
223
+ ]
224
+ print(
225
+ "loaded params/tot params:{}/{}".format(
226
+ len(pretrained_dict), len(model_dict)
227
+ )
228
+ )
229
+ print("miss matched params:{}".format(missed_params))
230
+ model_dict.update(pretrained_dict)
231
+ model.load_state_dict(model_dict)
232
+
233
+ torch.manual_seed(opt.random_seed)
234
+ torch.cuda.manual_seed_all(opt.random_seed)
235
+ torch.backends.cudnn.benchmark = True
236
+
237
+ train(model, net)