File size: 2,467 Bytes
09481f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#! /usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright 2023 Imperial College London (Pingchuan Ma)
# Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

import os
import torch
import hydra
from pipelines.metrics.measures import get_wer
from pipelines.metrics.measures import get_cer
from pipelines.pipeline import InferencePipeline


class AverageMeter:
    """Computes and stores the average and current value."""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.total = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.total += val * n
        self.count += n
        self.avg = self.total / self.count


def benchmark_inference(inference_pipeline, data_dir, landmarks_dir, lines, data_ext=".mp4", landmarks_ext=".pkl"):
    wer, cer = AverageMeter(), AverageMeter()
    for idx, line in enumerate(lines):
        basename, groundtruth = line.split()[0], " ".join(line.split()[1:])
        data_filename = os.path.join(data_dir, f"{basename}{data_ext}")
        landmarks_filename = os.path.join(landmarks_dir, f"{basename}{landmarks_ext}") if landmarks_dir else None
        output = inference_pipeline(data_filename, landmarks_filename)

        print(f"hyp: {output}\nref: {groundtruth}" if groundtruth is not None else "")
        if groundtruth is not None:
            wer.update(get_wer(output, groundtruth), len(groundtruth.split()))
            cer.update(get_cer(output, groundtruth), len(groundtruth))
            print(f"progress: {idx+1}/{len(lines)}\tcur WER: {wer.val*100:.2f}\tcur CER: {cer.val*100:.2f}\tavg WER: {wer.avg*100:.2f}\tavg CER: {cer.avg*100:.2f}")


@hydra.main(version_base=None, config_path="hydra_configs", config_name="default")
def main(cfg):
    device = torch.device(f"cuda:{cfg.gpu_idx}") if torch.cuda.is_available() and cfg.gpu_idx >= 0 else "cpu"
    inference_pipeline = InferencePipeline(config_filename=cfg.config_filename, detector=cfg.detector, face_track=not cfg.landmarks_filename and not cfg.landmarks_dir, device=device)
    assert os.path.isdir(cfg.data_dir), f"{cfg.data_dir} is not a directory."
    assert os.path.isfile(cfg.labels_filename), f"{cfg.labels_filename} does not exist."
    benchmark_inference(inference_pipeline, cfg.data_dir, cfg.landmarks_dir, open(cfg.labels_filename).read().splitlines(), cfg.data_ext, cfg.landmarks_ext)


if __name__ == '__main__':
    main()