Spaces:
Running
Running
File size: 4,278 Bytes
8c79f36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# encoding: utf-8
import numpy as np
import cv2
import os
from torch.utils.data import Dataset
from cvtransforms import *
import torch
import editdistance
class MyDataset(Dataset):
letters = [
" ",
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
]
def __init__(self, video_path, anno_path, file_list, vid_pad, txt_pad, phase):
self.anno_path = anno_path
self.vid_pad = vid_pad
self.txt_pad = txt_pad
self.phase = phase
with open(file_list, "r") as f:
self.videos = [
os.path.join(video_path, line.strip()) for line in f.readlines()
]
self.data = []
for vid in self.videos:
# items = vid.split(os.path.sep)
items = vid.split("/")
self.data.append((vid, items[-4], items[-1]))
def __getitem__(self, idx):
(vid, spk, name) = self.data[idx]
vid = self._load_vid(vid)
anno = self._load_anno(
os.path.join(self.anno_path, spk, "align", name + ".align")
)
if self.phase == "train":
vid = HorizontalFlip(vid)
vid = ColorNormalize(vid)
vid_len = vid.shape[0]
anno_len = anno.shape[0]
vid = self._padding(vid, self.vid_pad)
anno = self._padding(anno, self.txt_pad)
return {
"vid": torch.FloatTensor(vid.transpose(3, 0, 1, 2)),
"txt": torch.LongTensor(anno),
"txt_len": anno_len,
"vid_len": vid_len,
}
def __len__(self):
return len(self.data)
def _load_vid(self, p):
files = os.listdir(p)
files = list(filter(lambda file: file.find(".jpg") != -1, files))
files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
array = [cv2.imread(os.path.join(p, file)) for file in files]
array = list(filter(lambda im: not im is None, array))
array = [
cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4) for im in array
]
array = np.stack(array, axis=0).astype(np.float32)
return array
def _load_anno(self, name):
with open(name, "r") as f:
lines = [line.strip().split(" ") for line in f.readlines()]
txt = [line[2] for line in lines]
txt = list(filter(lambda s: not s.upper() in ["SIL", "SP"], txt))
return MyDataset.txt2arr(" ".join(txt).upper(), 1)
def _padding(self, array, length):
array = [array[_] for _ in range(array.shape[0])]
size = array[0].shape
for i in range(length - len(array)):
array.append(np.zeros(size))
return np.stack(array, axis=0)
@staticmethod
def txt2arr(txt, start):
arr = []
for c in list(txt):
arr.append(MyDataset.letters.index(c) + start)
return np.array(arr)
@staticmethod
def arr2txt(arr, start):
txt = []
for n in arr:
if n >= start:
txt.append(MyDataset.letters[n - start])
return "".join(txt).strip()
@staticmethod
def ctc_arr2txt(arr, start):
pre = -1
txt = []
for n in arr:
if pre != n and n >= start:
if (
len(txt) > 0
and txt[-1] == " "
and MyDataset.letters[n - start] == " "
):
pass
else:
txt.append(MyDataset.letters[n - start])
pre = n
return "".join(txt).strip()
@staticmethod
def wer(predict, truth):
word_pairs = [(p[0].split(" "), p[1].split(" ")) for p in zip(predict, truth)]
wer = [1.0 * editdistance.eval(p[0], p[1]) / len(p[1]) for p in word_pairs]
return wer
@staticmethod
def cer(predict, truth):
cer = [
1.0 * editdistance.eval(p[0], p[1]) / len(p[1]) for p in zip(predict, truth)
]
return cer
|