Spaces:
Runtime error
Runtime error
| # flake8: noqa | |
| import argparse | |
| import csv | |
| import os | |
| from collections import OrderedDict | |
| def init_config(config, default_config, name=None): | |
| """Initialise non-given config values with defaults""" | |
| if config is None: | |
| config = default_config | |
| else: | |
| for k in default_config.keys(): | |
| if k not in config.keys(): | |
| config[k] = default_config[k] | |
| if name and config["PRINT_CONFIG"]: | |
| print("\n%s Config:" % name) | |
| for c in config.keys(): | |
| print("%-20s : %-30s" % (c, config[c])) | |
| return config | |
| def update_config(config): | |
| """ | |
| Parse the arguments of a script and updates the config values for a given value if specified in the arguments. | |
| :param config: the config to update | |
| :return: the updated config | |
| """ | |
| parser = argparse.ArgumentParser() | |
| for setting in config.keys(): | |
| if type(config[setting]) == list or type(config[setting]) == type(None): | |
| parser.add_argument("--" + setting, nargs="+") | |
| else: | |
| parser.add_argument("--" + setting) | |
| args = parser.parse_args().__dict__ | |
| for setting in args.keys(): | |
| if args[setting] is not None: | |
| if type(config[setting]) == type(True): | |
| if args[setting] == "True": | |
| x = True | |
| elif args[setting] == "False": | |
| x = False | |
| else: | |
| raise Exception( | |
| "Command line parameter " + setting + "must be True or False" | |
| ) | |
| elif type(config[setting]) == type(1): | |
| x = int(args[setting]) | |
| elif type(args[setting]) == type(None): | |
| x = None | |
| else: | |
| x = args[setting] | |
| config[setting] = x | |
| return config | |
| def get_code_path(): | |
| """Get base path where code is""" | |
| return os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| def validate_metrics_list(metrics_list): | |
| """Get names of metric class and ensures they are unique, further checks that the fields within each metric class | |
| do not have overlapping names. | |
| """ | |
| metric_names = [metric.get_name() for metric in metrics_list] | |
| # check metric names are unique | |
| if len(metric_names) != len(set(metric_names)): | |
| raise TrackEvalException( | |
| "Code being run with multiple metrics of the same name" | |
| ) | |
| fields = [] | |
| for m in metrics_list: | |
| fields += m.fields | |
| # check metric fields are unique | |
| if len(fields) != len(set(fields)): | |
| raise TrackEvalException( | |
| "Code being run with multiple metrics with fields of the same name" | |
| ) | |
| return metric_names | |
| def write_summary_results(summaries, cls, output_folder): | |
| """Write summary results to file""" | |
| fields = sum([list(s.keys()) for s in summaries], []) | |
| values = sum([list(s.values()) for s in summaries], []) | |
| # In order to remain consistent upon new fields being adding, for each of the following fields if they are present | |
| # they will be output in the summary first in the order below. Any further fields will be output in the order each | |
| # metric family is called, and within each family either in the order they were added to the dict (python >= 3.6) or | |
| # randomly (python < 3.6). | |
| default_order = [ | |
| "HOTA", | |
| "DetA", | |
| "AssA", | |
| "DetRe", | |
| "DetPr", | |
| "AssRe", | |
| "AssPr", | |
| "LocA", | |
| "OWTA", | |
| "HOTA(0)", | |
| "LocA(0)", | |
| "HOTALocA(0)", | |
| "MOTA", | |
| "MOTP", | |
| "MODA", | |
| "CLR_Re", | |
| "CLR_Pr", | |
| "MTR", | |
| "PTR", | |
| "MLR", | |
| "CLR_TP", | |
| "CLR_FN", | |
| "CLR_FP", | |
| "IDSW", | |
| "MT", | |
| "PT", | |
| "ML", | |
| "Frag", | |
| "sMOTA", | |
| "IDF1", | |
| "IDR", | |
| "IDP", | |
| "IDTP", | |
| "IDFN", | |
| "IDFP", | |
| "Dets", | |
| "GT_Dets", | |
| "IDs", | |
| "GT_IDs", | |
| ] | |
| default_ordered_dict = OrderedDict( | |
| zip(default_order, [None for _ in default_order]) | |
| ) | |
| for f, v in zip(fields, values): | |
| default_ordered_dict[f] = v | |
| for df in default_order: | |
| if default_ordered_dict[df] is None: | |
| del default_ordered_dict[df] | |
| fields = list(default_ordered_dict.keys()) | |
| values = list(default_ordered_dict.values()) | |
| out_file = os.path.join(output_folder, cls + "_summary.txt") | |
| os.makedirs(os.path.dirname(out_file), exist_ok=True) | |
| with open(out_file, "w", newline="") as f: | |
| writer = csv.writer(f, delimiter=" ") | |
| writer.writerow(fields) | |
| writer.writerow(values) | |
| def write_detailed_results(details, cls, output_folder): | |
| """Write detailed results to file""" | |
| sequences = details[0].keys() | |
| fields = ["seq"] + sum([list(s["COMBINED_SEQ"].keys()) for s in details], []) | |
| out_file = os.path.join(output_folder, cls + "_detailed.csv") | |
| os.makedirs(os.path.dirname(out_file), exist_ok=True) | |
| with open(out_file, "w", newline="") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(fields) | |
| for seq in sorted(sequences): | |
| if seq == "COMBINED_SEQ": | |
| continue | |
| writer.writerow([seq] + sum([list(s[seq].values()) for s in details], [])) | |
| writer.writerow( | |
| ["COMBINED"] + sum([list(s["COMBINED_SEQ"].values()) for s in details], []) | |
| ) | |
| def load_detail(file): | |
| """Loads detailed data for a tracker.""" | |
| data = {} | |
| with open(file) as f: | |
| for i, row_text in enumerate(f): | |
| row = row_text.replace("\r", "").replace("\n", "").split(",") | |
| if i == 0: | |
| keys = row[1:] | |
| continue | |
| current_values = row[1:] | |
| seq = row[0] | |
| if seq == "COMBINED": | |
| seq = "COMBINED_SEQ" | |
| if (len(current_values) == len(keys)) and seq != "": | |
| data[seq] = {} | |
| for key, value in zip(keys, current_values): | |
| data[seq][key] = float(value) | |
| return data | |
| class TrackEvalException(Exception): | |
| """Custom exception for catching expected errors.""" | |
| ... | |