|
import os |
|
from typing import Dict |
|
|
|
__all__ = ["get_class2id_map", "get_id2class_map", "get_n_classes"] |
|
|
|
|
|
def get_class2id_map(dataset: str, dataset_dir: str = "./dataset") -> Dict[str, int]: |
|
""" |
|
Args: |
|
dataset: |
|
dataset_dir: the path to the datset directory |
|
""" |
|
|
|
|
|
with open(os.path.join(dataset_dir, "{}/mapping.txt".format(dataset)), "r") as f: |
|
actions = f.read().split("\n")[:-1] |
|
|
|
class2id_map = dict() |
|
for a in actions: |
|
class2id_map[a.split()[1]] = int(a.split()[0]) |
|
|
|
return class2id_map |
|
|
|
|
|
def get_id2class_map(dataset: str, dataset_dir: str = "./dataset") -> Dict[int, str]: |
|
class2id_map = get_class2id_map(dataset, dataset_dir) |
|
|
|
return {val: key for key, val in class2id_map.items()} |
|
|
|
|
|
def get_n_classes(dataset: str, dataset_dir: str = "./dataset") -> int: |
|
return len(get_class2id_map(dataset, dataset_dir)) |
|
|