workshop / LaSA /libs /class_id_map.py
qiushuocheng's picture
Upload 173 files
41e3185
import os
from typing import Dict
__all__ = ["get_class2id_map", "get_id2class_map", "get_n_classes"]
def get_class2id_map(dataset: str, dataset_dir: str = "./dataset") -> Dict[str, int]:
"""
Args:
dataset:
dataset_dir: the path to the datset directory
"""
with open(os.path.join(dataset_dir, "{}/mapping.txt".format(dataset)), "r") as f:
actions = f.read().split("\n")[:-1]
class2id_map = dict()
for a in actions:
class2id_map[a.split()[1]] = int(a.split()[0])
return class2id_map
def get_id2class_map(dataset: str, dataset_dir: str = "./dataset") -> Dict[int, str]:
class2id_map = get_class2id_map(dataset, dataset_dir)
return {val: key for key, val in class2id_map.items()}
def get_n_classes(dataset: str, dataset_dir: str = "./dataset") -> int:
return len(get_class2id_map(dataset, dataset_dir))