|
r"""Compute active speaker detection performance for the AVA dataset.
|
|
Please send any questions about this code to the Google Group ava-dataset-users:
|
|
https://groups.google.com/forum/#!forum/ava-dataset-users
|
|
Example usage:
|
|
python -O get_ava_active_speaker_performance.py \
|
|
-g testdata/eval.csv \
|
|
-p testdata/predictions.csv \
|
|
-v
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
import logging
|
|
import time, warnings
|
|
import numpy as np
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
def parse_arguments():
|
|
"""Parses command-line flags.
|
|
Returns:
|
|
args: a named tuple containing three file objects args.labelmap,
|
|
args.groundtruth, and args.detections.
|
|
"""
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("-g",
|
|
"--groundtruth",
|
|
help="CSV file containing ground truth.",
|
|
type=argparse.FileType("r"),
|
|
required=True)
|
|
parser.add_argument("-p",
|
|
"--predictions",
|
|
help="CSV file containing active speaker predictions.",
|
|
type=argparse.FileType("r"),
|
|
required=True)
|
|
parser.add_argument("-v", "--verbose", help="Increase output verbosity.", action="store_true")
|
|
return parser.parse_args()
|
|
|
|
|
|
def run_evaluation(groundtruth, predictions):
|
|
prediction = pd.read_csv(predictions)
|
|
groundtruth = pd.read_csv(groundtruth)
|
|
wrong_list = []
|
|
num = 0
|
|
audible_num = 0
|
|
total = 0
|
|
for i, row in prediction.iterrows():
|
|
entity_id = row['entity_id']
|
|
ts = row['frame_timestamp']
|
|
if row['score'] < 0.5:
|
|
label = "NOT_SPEAKING"
|
|
else:
|
|
label = "SPEAKING_AUDIBLE"
|
|
|
|
true_label = groundtruth.loc[(groundtruth['entity_id'] == entity_id) &
|
|
(groundtruth['frame_timestamp'] == ts)].iloc[0]["label"]
|
|
if true_label != label:
|
|
wrong_list.append([entity_id, ts, true_label, label])
|
|
|
|
if label == "SPEAKING_AUDIBLE":
|
|
num += 1
|
|
if true_label == "SPEAKING_AUDIBLE":
|
|
audible_num += 1
|
|
total += 1
|
|
print(num, audible_num, total)
|
|
|
|
df = pd.DataFrame(wrong_list, columns=['entity_id', 'frame_timestamp', "gt", "prediction"])
|
|
df = df.sort_values(by=["frame_timestamp"])
|
|
df.to_csv("wrong_list.csv")
|
|
|
|
|
|
def main():
|
|
start = time.time()
|
|
args = parse_arguments()
|
|
if args.verbose:
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
del args.verbose
|
|
run_evaluation(**vars(args))
|
|
logging.info("Computed in %s seconds", time.time() - start)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|