dms3_demo / utils /labels.py
qilongyu
Add application file
446f9ef
import copy
import numbers
import os
import torch
import numpy as np
import random
import pandas as pd
from utils.common import cprint, Color, log_warn, log_error
import matplotlib.pyplot as plt
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('max_colwidth', 100)
pd.set_option('display.width', 5000)
pd.options.mode.chained_assignment = None # default='warn'
plt.rcParams['axes.unicode_minus'] = False # 正常显示正负号
def labels_to_class_weights(labels, nc=80):
# Get class weights (inverse frequency) from training labels
# 从训练标签计算每一类的权重,次数出现越少的类别越重要,对应weights越大。没有出现时,值为1
if labels[0] is None: # no labels loaded
return torch.Tensor()
labels = np.array([l[1] for l in labels if l[1] != -1], dtype=np.int)
# labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
# classes = labels[:, 0].astype(np.int) # labels = [class xywh]
classes = labels
weights = np.bincount(classes, minlength=nc) # occurrences per class,计算每一类出现的次数
# Prepend gridpoint count (for uCE training)
# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
# weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
weights[weights == 0] = 1 # replace empty bins with 1
weights = 1 / weights # number of targets per class
weights /= weights.sum() # normalize
return torch.from_numpy(weights)
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
# Produces image weights based on class_weights and image contents
class_counts = np.array([np.bincount(np.array([x[1]], dtype=np.int), minlength=nc) for x in labels])
image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
# index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
return image_weights
def labels_to_class_weights_mtl(labels, num_classes=[3, 3]):
# Get class weights (inverse frequency) from training labels
# 从训练标签计算每一类的权重,次数出现越少的类别越重要,对应weights越大。没有出现时,值为1
if labels[0] is None: # no labels loaded
return torch.Tensor()
labels = np.array([l[1] for l in labels], dtype=np.int)
# labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
# classes = labels[:, 0].astype(np.int) # labels = [class xywh]
all_weights = []
for i, nc in enumerate(num_classes):
cur_labels = np.array([l[1][i] for l in labels if l[1][i] != -1], dtype=np.int)
weights = np.bincount(cur_labels, minlength=nc) # occurrences per class,计算每一类出现的次数
weights[weights == 0] = 1 # replace empty bins with 1
weights = 1 / weights # number of targets per class
weights /= weights.sum() # normalize
all_weights.append(weights)
return all_weights
def labels_to_image_weights_mtl(labels, num_classes=[3, 3], class_weights=[]):
# Produces image weights based on class_weights and image contents
labels = np.array([l[1] for l in labels], dtype=np.int)
class_weights = [weights / weights.sum() for weights in class_weights]
all_image_weights = []
for i, nc in enumerate(num_classes):
class_counts = np.array(
[np.bincount(np.array([x], dtype=np.int), minlength=nc) for x in np.squeeze(labels[:, i])])
image_weights = (class_weights[i].reshape(1, nc) * class_counts).sum(1)
all_image_weights.append(np.squeeze(image_weights))
# index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
# return np.array(all_image_weights).mean(0)
return all_image_weights[0]
def sort_mtl_labels(inputs):
if inputs[0] is None: # no labels loaded
return []
random.shuffle(inputs)
ori_labels = np.array([l[1] for l in inputs], dtype=np.float)
labels = ori_labels.copy()
labels[labels >= 0] = 1
labels[labels < 0] = 0
counts = (np.sum(labels, axis=1) - 1) * 100
n, c = labels.shape[:2]
for i in range(c):
labels[:, i][labels[:, i] > 0] = c - i
res = (np.sum(labels, axis=1) + counts).squeeze()
idx = np.argsort(-res).tolist()
ori_labels = ori_labels[idx]
return idx
def load_labels(label_file, classes=None, prints=True):
label_dict = {}
if not os.path.exists(label_file):
return label_dict
with open(label_file) as f:
for line in f:
info = line.strip().split(' ')
name = info[0]
labels = list(map(eval, info[1:]))
if len(labels) == 1:
label_dict[name] = labels[0]
else:
label_dict[name] = labels
if prints:
print(f"load {len(label_dict)} labels from '{label_file}'")
if classes is not None:
summary_labels(label_dict, classes)
return label_dict
def load_labels_batch(file_lst, classes=None):
if isinstance(file_lst, str):
file_lst = [file_lst]
label_dict = {}
for i, file in enumerate(file_lst):
dct = load_labels(file, classes)
label_dict.update(dct)
print(f'total get {len(label_dict.keys())} items')
if classes is not None:
summary_labels(label_dict, classes)
return label_dict
def save_labels(label_file, label_dict, mod='w'):
with open(label_file, mod) as f:
for img_name, labels in label_dict.items():
if isinstance(labels, (int, float, str)):
f.write(f"{img_name} {labels}\n")
elif isinstance(labels, (tuple, list)):
labels = list(map(str, labels))
f.write(' '.join([img_name] + labels) + '\n')
else:
raise NotImplementedError
print(f"Save {len(label_dict)} annotation in '{label_file}'")
def merge_labels(file_list, save_file):
label_dict = load_labels_batch(file_list)
save_labels(save_file, label_dict)
def concat_labels(file_list, save_file, dim=0):
"""
:param file_list:
:param save_file:
:param dim: 0 纵向合并,1 横向合并
:return:
"""
assert isinstance(file_list, list) and len(file_list) >= 2
if dim == 0:
merge_labels(file_list, save_file)
else:
first_dict = load_labels(file_list[0])
common_keys = set(first_dict.keys())
dicts_list = [first_dict]
for label_file in file_list[1:]:
tmp_dict = load_labels(label_file)
common_keys = common_keys.intersection(set(tmp_dict.keys()))
dicts_list.append(tmp_dict)
print(f"{len(common_keys)} common keys")
label_dict = {}
for name in common_keys:
tmp_label = []
for tmp_dict in dicts_list:
label = tmp_dict[name]
if isinstance(label, int):
tmp_label.append(label)
elif isinstance(label, (list, tuple)):
tmp_label += list(label)
else:
raise NotImplementedError
label_dict[name] = tmp_label
save_labels(save_file, label_dict)
def summary_labels(label_dict, classes=None):
labels = list(label_dict.values())
if not labels:
return
if isinstance(labels[0], int):
labels = [l for l in labels if l != -1]
print(f"all: {len(labels)}", end=' ')
if classes:
assert max(labels) < len(classes)
for i, category in enumerate(classes):
end = '\n' if i == len(classes) - 1 else ' '
print(f"{category}: {labels.count(i)}", end=end)
else:
all_cls = sorted(list(set(labels)))
for i, category in enumerate(all_cls):
end = '\n' if i == len(classes) - 1 else ' '
print(f"{category}: {labels.count(category)}", end=end)
elif isinstance(labels[0], list):
label_array = np.array(labels, dtype=int)
label_count = list(np.sum(label_array, axis=0))
if classes:
assert len(classes) == len(labels[0])
for i, (category, count) in enumerate(zip(classes, label_count)):
end = '\n' if i == len(classes) - 1 else ' '
print(f"{category}: {count}", end=end)
else:
for i, count in enumerate(label_count):
end = '\n' if i == len(label_count) - 1 else ' '
print(f"{i}: {count}", end=end)
def summary_label_file(label_file, classes=None):
if isinstance(label_file, str):
label_dict = load_labels(label_file)
elif isinstance(label_file, (list, tuple)):
label_dict = load_labels_batch(label_file) if len(label_file) >= 2 else label_file[0]
else:
raise NotImplementedError
summary_labels(label_dict, classes)
class MTLabel:
_invalid = -1
def __init__(self, input_label=None, tasks=None, classes=None):
label_data = self._load(input_label)
tasks = self._check_tasks(tasks)
classes = self._check_classes(classes)
if not label_data.empty:
tasks = [f"task{i}" for i in range(label_data.shape[1])] if tasks is None else tasks
assert len(tasks) == label_data.shape[1], \
f"Tasks length {len(tasks)} not match to label length {label_data.shape[1]}"
label_data.columns = tasks
if tasks is not None:
classes = [None] * len(tasks) if classes is None else classes
assert len(tasks) == len(classes), f"Tasks {tasks} not match to {classes}"
self.label_data = label_data
self.tasks = copy.deepcopy(tasks)
self.classes = copy.deepcopy(classes)
@staticmethod
def _load(input_label):
if input_label is None:
return pd.DataFrame()
if isinstance(input_label, str):
if not os.path.isfile(input_label):
raise FileNotFoundError(f'Check label file path: {input_label}')
label_data = pd.read_csv(input_label, sep=' ', header=None, index_col=0)
cprint(f"{label_data.shape if label_data.shape[1] > 1 else label_data.shape[0]} labels from '{input_label}'", prefix='Load')
elif isinstance(input_label, pd.core.frame.DataFrame):
label_data = copy.deepcopy(input_label)
elif isinstance(input_label, dict):
label_data = pd.DataFrame(input_label).transpose()
elif isinstance(input_label, list):
label_data = pd.DataFrame(dict(input_label)).transpose()
else:
raise TypeError(f'{type(input_label)} is not support')
label_data.index.rename("name", inplace=True)
return label_data
@staticmethod
def _save(label_data, save_file, filter_invalid=True):
save_dir = os.path.dirname(os.path.abspath(save_file))
os.makedirs(save_dir, exist_ok=True)
label_data.fillna(MTLabel._invalid, inplace=True)
if filter_invalid:
label_data = label_data[(label_data != MTLabel._invalid).any(axis=1)]
label_data = label_data.astype(object)
label_data.to_csv(save_file, sep=' ', index=True, header=False)
cprint(f"{label_data.shape if label_data.shape[1] > 1 else label_data.shape[0]} labels in '{save_file}'", prefix='Save')
@classmethod
def _new(cls, input_label=None, tasks=None, classes=None):
return cls(input_label, tasks, classes)
@staticmethod
def _check_tasks(tasks, int_ok=False):
if tasks is None or not tasks:
return
if isinstance(tasks, (str, numbers.Integral)):
tasks = [tasks]
elif isinstance(tasks, tuple):
tasks = list(tasks)
assert tasks and isinstance(tasks, list), f"arg 'task' should be type (int, str, tuple, list)"
for i, t in enumerate(tasks):
if isinstance(t, numbers.Integral):
if not int_ok:
raise TypeError(f"'int' type task {t} at {i} not support, set 'int_ok=True' ?")
elif not isinstance(t, str):
raise TypeError(f"{type(t)} task {t} at {i} not support")
return tasks
@staticmethod
def _check_classes(classes):
if classes is None:
return
assert classes and isinstance(classes, list), f"classes {classes} should be a list"
if isinstance(classes[0], str):
classes = [classes]
return classes
@staticmethod
def _check_names(names, int_ok=False):
if names is None or not names:
return
if isinstance(names, (str, numbers.Integral)):
names = [names]
elif isinstance(names, (tuple, set)):
names = list(names)
assert names and isinstance(names, list), f"arg 'names' should be type (int, str, tuple, set, list)"
for i, t in enumerate(names):
if isinstance(t, numbers.Integral):
if not int_ok:
raise TypeError(f"'int' type name {t} at {i} not support, set 'int_ok=True' ?")
elif not isinstance(t, str):
raise TypeError(f"{type(t)} name {t} at {i} not support")
return names
@staticmethod
def _type_data(values, types):
return [t(v) for t, v in zip(types, values)]
def _check_value(self, value, max_len):
if isinstance(value, list):
if len(value) < max_len:
value = value + [self._invalid] * (max_len - len(value))
elif len(value) > max_len:
value = value[:max_len]
return value
def _convert_tasks(self, tasks):
tasks = self._check_tasks(tasks, int_ok=True)
if tasks is None:
return
for i, t in enumerate(tasks):
if isinstance(t, numbers.Integral):
assert t < len(self.tasks), f"index {t} out of range({len(self.tasks)})"
tasks[i] = self.tasks[t]
elif isinstance(t, str):
assert t in self.tasks, f"'{t}' not in {self.tasks}"
return tasks
def _convert_names(self, names):
names = self._check_names(names, int_ok=True)
if names is None:
return
for i, t in enumerate(names):
if isinstance(t, numbers.Integral):
assert t < len(self), f"index {t} out of range({len(self)})"
names[i] = self.names[t]
# elif isinstance(t, str):
# assert t in set(self.names), f"'{t}' not in names"
return names
@staticmethod
def _map_task_classes(tasks, classes):
assert isinstance(tasks, list) and isinstance(classes, list) and len(tasks) == len(classes)
map_dict = {t: c for t, c in zip(tasks, classes)}
return map_dict
@property
def shape(self):
return self.label_data.shape
@property
def empty(self):
return self.label_data.empty
@property
def index(self):
return self.label_data.index
@property
def values(self):
return self.label_data.values
@property
def names(self):
return list(self.index)
@property
def columns(self):
return self.tasks
@property
def dtypes(self):
return self.label_data.dtypes
def head(self, n=5):
return self.label_data.astype(object).head(n)
def tail(self, n=5):
return self.label_data.astype(object).tail(n)
def astype(self, dtype):
self.label_data = self.label_data.astype(dtype)
return self
def summary(self, tasks=None, extra_info=''):
print("-" * 120)
tasks = self._convert_tasks(tasks)
if tasks is None and self.tasks is None:
log_warn(f"Label data is empty or 'tasks' is None")
return
tasks = self.tasks if tasks is None else tasks
tc_dict = self._map_task_classes(self.tasks, self.classes)
cprint(extra_info, prefix="Summary")
for t in tasks:
classes = tc_dict[t]
cur_label = self.label_data[t]
cur_label = cur_label[cur_label != self._invalid]
cprint(f"Task {t}\t==> all: {len(cur_label)}", end=' ')
if cur_label.empty:
print()
continue
if classes is None:
classes = ['mean', 'std', 'min', 'max', '1%', '50%', '99%']
res = cur_label.describe(percentiles=[0.01, 0.99], include=[np.number])
for i, category in enumerate(classes):
end = '\n' if i == len(classes) - 1 else ' '
cprint(f"{category}: {res[category]:.3f}", end=end)
else:
res = cur_label.value_counts()
for i, category in enumerate(classes):
end = '\n' if i == len(classes) - 1 else '\t'
if i in res:
cprint(f"{category}: {res[i]}", end=end)
else:
cprint(f"{category}: 0", end=end)
print("-"*120)
def plot(self, tasks=None):
tasks = self._convert_tasks(tasks)
if tasks is None:
tasks = self.tasks
tc_dict = self._map_task_classes(self.tasks, self.classes)
cls_part = []
reg_part = []
for t in tasks:
classes = tc_dict[t]
if classes is None:
reg_part.append(t)
else:
cls_part.append(t)
if cls_part:
self._plt_bar(cls_part, in_one=False)
if reg_part:
self._hist(reg_part)
def _hist(self, tasks=None, cols=3, in_one=False):
tasks = self._convert_tasks(tasks)
if tasks is None:
tasks = self.tasks
cols = cols if len(tasks) >= cols else len(tasks)
rows, mod = divmod(len(tasks), cols)
rows += mod != 0
for i, t in enumerate(tasks):
plt.subplot(rows, cols, i + 1) if in_one else plt.figure()
cur_label = self.label_data[t]
cur_label = cur_label[cur_label != self._invalid]
bin_size = (cur_label.max() - cur_label.min()) / 11
cur_label.hist(bins=np.arange(cur_label.min(), 1.01 * cur_label.max(), bin_size))
plt.title(t)
plt.tight_layout()
plt.grid(False)
plt.show()
def _plt_bar(self, tasks=None, cols=3, in_one=False):
tasks = self._convert_tasks(tasks)
if tasks is None:
tasks = self.tasks
tc_dict = self._map_task_classes(self.tasks, self.classes)
cols = cols if len(tasks) >= cols else len(tasks)
rows, mod = divmod(len(tasks), cols)
rows += mod != 0
if in_one:
fig, axes = plt.subplots(rows, cols)
for i, t in enumerate(tasks):
# plt.figure()
# plt.subplot(rows, cols, i + 1)
classes = tc_dict[t]
cur_label = self.label_data[t]
cur_label = cur_label[cur_label != self._invalid]
res = cur_label.value_counts()
values = [res[i] for i in range(len(classes))]
df = pd.DataFrame({"category": classes, "count": values})
if in_one:
r, c = divmod(i, cols)
ax = axes[r, c] if rows > 1 else axes[i]
else:
ax = None
df.plot(kind='bar', x="category", y="count", title=t, grid=False, rot=30 if in_one else 0,
ax=ax, legend=False)
plt.tight_layout()
plt.show()
def set_tasks(self, tasks):
tasks = self._check_tasks(tasks)
if self.tasks is not None and len(tasks) != len(self.tasks):
log_error(f"new tasks length {len(tasks)} not equal to ori {len(self.tasks)}")
return
self.tasks = copy.deepcopy(tasks)
self.label_data.columns = self.tasks
def set_classes(self, classes):
classes = self._check_classes(classes)
if self.classes is not None and len(classes) != len(self.classes):
log_error(f"new tasks length {len(classes)} not equal to ori {len(self.classes)}")
return
self.classes = copy.deepcopy(classes)
def insert(self, task, value=None, loc=None, category=None, dtype=None):
if task in self.tasks:
raise KeyError(f"{task} already exits")
if value is None:
value = self._invalid
if isinstance(category, str):
category = [category]
assert category is None or isinstance(category, list)
self.label_data.insert(len(self.tasks) if loc is None else loc, task, value)
if dtype is not None:
self.label_data[task] = self.label_data[task].astype(dtype)
if loc is None:
self.tasks.append(task)
self.classes.append(category)
else:
self.tasks.insert(loc, task)
self.classes.insert(loc, category)
def remove(self, task):
self.__delitem__(task)
def add(self, key, value, keep_dtypes=False):
ori_dtypes = self.label_data.dtypes
code = 1 if key not in self.label_data.index else 0
self.label_data.loc[key] = value
if (ori_dtypes != self.label_data.dtypes).any() and keep_dtypes:
self.label_data = self.label_data.astype(ori_dtypes)
return code
def update(self, other_label, tasks=None, classes=None, inplace=False):
if not isinstance(other_label, type(self)):
other_label = self._new(other_label, tasks, classes)
if other_label.empty:
log_warn(f"Empty label data, check path or data")
return
assert (other_label.columns == self.label_data.columns).all(), \
f"current label({self.label_data.shape}) not match to input({other_label.shape}) at columns"
ori_dtypes = self.label_data.dtypes
other_label = other_label.label_data
other_label.columns = self.label_data.columns
common_index = self.label_data.index.intersection(other_label.index)
if common_index.empty:
label_data = pd.concat([self.label_data, other_label])
cprint(f"{len(other_label)} add.", prefix='Update')
else:
common1 = self.label_data.loc[common_index]
common2 = other_label.loc[common_index]
no_equal = common1[common1.ne(common2).any(axis=1)]
add_label = other_label[~other_label.index.isin(common_index)]
label_data = self.label_data[:]
label_data.update(common2)
label_data = pd.concat([label_data, add_label])
label_data = label_data.astype(ori_dtypes)
cprint(f"{len(common1)} common ({len(no_equal)} update), {len(add_label)} add.", prefix='Update')
if inplace:
self.label_data = label_data
else:
return self._new(label_data, self.tasks, self.classes)
def join(self, other_label, tasks=None, classes=None, inplace=False):
if not isinstance(other_label, type(self)):
other_label = self._new(other_label, tasks, classes)
if other_label.empty:
log_warn(f"Empty label data, check path or data")
return
assert other_label.shape[0] == self.label_data.shape[0], \
f"current label({self.label_data.shape}) not match to input({other_label.shape}) at index"
ori_tasks = self.tasks
ori_classes = self.classes
other_tasks = other_label.tasks
other_classes = other_label.classes
other_dict = self._map_task_classes(other_tasks, other_classes)
use_tasks = [t for t in other_tasks if t not in ori_tasks]
use_classes = [other_dict[t] for t in use_tasks]
new_tasks = ori_tasks + use_tasks
new_classes = ori_classes + use_classes
label_data = pd.concat([self.label_data, other_label.label_data], axis=1)
label_data = label_data.T
label_data = label_data[~label_data.index.duplicated(keep='last')].T
label_data = label_data[new_tasks]
if inplace:
self.label_data = label_data
self.tasks = new_tasks
self.classes = new_classes
else:
return self._new(label_data, new_tasks, new_classes)
def concat(self, other_label, tasks=None, classes=None, inplace=False, fill=True):
if not isinstance(other_label, type(self)):
other_label = self._new(other_label, tasks, classes)
if other_label.empty:
log_warn(f"Empty label data, check path or data")
return
ori_tasks = self.tasks
ori_classes = self.classes
other_tasks = other_label.tasks
other_classes = other_label.classes
other_dict = self._map_task_classes(other_tasks, other_classes)
use_tasks = [t for t in other_tasks if t not in ori_tasks]
use_classes = [other_dict[t] for t in use_tasks]
new_tasks = ori_tasks + use_tasks
new_classes = ori_classes + use_classes
label_data = self.label_data[:]
other_label = other_label.label_data
common_index = label_data.index.intersection(other_label.index)
common_columns = label_data.columns.intersection(other_label.columns)
other_index = other_label.index[~other_label.index.isin(common_index)]
other_columns = other_label.columns[~other_label.columns.isin(common_columns)]
if not common_index.empty and not common_columns.empty:
label_data.update(other_label.loc[common_index, common_columns]) # 交差部分
if not common_columns.empty:
label_data = pd.concat([label_data, other_label.loc[other_index, common_columns]]) # 相同列
if not common_index.empty:
label_data = pd.concat([label_data, other_label.loc[common_index, other_columns]], axis=1) # 相同行
# 不交叉部分
if common_index.empty and common_columns.empty:
label_data = pd.concat([label_data, other_label])
else:
label_data.update(other_label)
if fill:
label_data.fillna(MTLabel._invalid, inplace=True)
cprint(f"update common ({len(common_index)}, {len(common_columns)}), "
f"add ({len(other_index)}, {len(other_columns)})", prefix="Concat")
if inplace:
self.label_data = label_data
self.tasks = new_tasks
self.classes = new_classes
else:
return self._new(label_data, new_tasks, new_classes)
def sample(self, num):
obj = self._new(self.label_data.sample(num), self.tasks, self.classes)
return obj
def pick_tasks(self, tasks, inplace=False):
tasks = self._convert_tasks(tasks)
if tasks is None:
log_warn("'task' is None, use ori label")
return self
classes = [self.classes[self.tasks.index(t)] for t in tasks]
if inplace:
self.label_data = self.label_data[tasks]
self.tasks = copy.deepcopy(tasks)
self.classes = copy.deepcopy(classes)
return self
else:
obj = self._new(self.label_data[tasks], tasks=tasks, classes=classes)
return obj
def pick_names(self, name_list, inplace=False):
names_list = self._check_names(name_list)
if names_list is None:
log_warn("'name_list' is None, use ori label")
return self
if inplace:
self.label_data = self.label_data.loc[names_list]
return self
else:
obj = self._new(self.label_data.loc[names_list], tasks=self.tasks, classes=self.classes)
return obj
def tolist(self):
names = self.index.tolist()
values = self.values.tolist()
dtypes = [t.type for t in self.dtypes.tolist()]
values = [self._type_data(v, dtypes) for v in values]
return list(zip(names, values))
def todict(self):
names = self.index.tolist()
values = self.values.tolist()
dtypes = [t.type for t in self.dtypes.tolist()]
values = [self._type_data(v, dtypes) for v in values]
return dict(zip(names, values))
def export(self, save_file, filter_invalid=True):
self._save(self.label_data, save_file, filter_invalid)
def export_tasks(self, task, save_file, filter_invalid=True):
obj = self.pick_tasks(task)
self._save(obj.label_data, save_file, filter_invalid)
def __getitem__(self, item):
# 普通切片
item = slice(item, item + 1) if isinstance(item, numbers.Integral) else item
if isinstance(item, slice):
obj = self._new(self.label_data[item], tasks=self.tasks, classes=self.classes)
return obj
# 列名/图像名
elif isinstance(item, str):
if item in self.tasks:
obj = self._new(self.label_data[[item]], tasks=[item], classes=[self.classes[self.tasks.index(item)]])
elif item in self.label_data.index:
obj = self._new(self.label_data.loc[[item]], tasks=self.tasks, classes=self.classes)
else:
raise KeyError(f"'{item}' not in tasks or index")
return obj
# 目前只支持task列表
elif isinstance(item, list):
return self.pick_tasks(item)
# 二维切片
elif isinstance(item, tuple):
assert len(item) == 2, f"tuple length {len(item)} != 2"
row_slice, col_slice = item
if isinstance(row_slice, (slice, numbers.Integral)) or \
(isinstance(row_slice, str) and row_slice not in self.tasks):
obj = self.__getitem__(row_slice)
if isinstance(col_slice, (slice, numbers.Integral)):
obj = obj.__getitem__(obj.tasks[col_slice])
return obj
else:
return obj.__getitem__(col_slice)
else:
raise TypeError(f"{item} first item type {type(row_slice)} is an invalid key")
else:
raise TypeError(f"{item} is an invalid key")
def __setitem__(self, key, value):
"""
只接受label_data中已有的数据
loc: only work on index, can assign a new index or column value
iloc: work on position, can not assign a new index or column value
at: get scalar values. It's a very fast loc
iat: Get scalar values. It's a very fast iloc
"""
if isinstance(key, (numbers.Integral, slice)):
self.label_data.iloc[key] = value
# 列名/index
elif isinstance(key, str):
if key in self.tasks:
self.label_data.loc[:, key] = value
elif key in self.label_data.index:
self.label_data.loc[key] = value
else:
raise KeyError(f"'{key}' not in tasks or index")
# task 列表
elif isinstance(key, list):
key = self._convert_tasks(key)
value = self._check_value(value, len(key))
self.label_data.loc[:, key] = value
# 二维切片
elif isinstance(key, tuple):
assert len(key) == 2, f"tuple length {len(key)} != 2"
row_slice, col_slice = key
if isinstance(row_slice, numbers.Integral):
row_slice = self.label_data.index[row_slice]
elif isinstance(row_slice, str):
assert row_slice in self.label_data.index, f"{row_slice} not in index"
elif not isinstance(row_slice, slice):
raise TypeError(f"{key} first item type {type(row_slice)} is an invalid key")
if isinstance(col_slice, numbers.Integral):
col_slice = self.tasks[col_slice]
elif isinstance(col_slice, str):
assert col_slice in self.tasks, f"{col_slice} not in {self.tasks}"
elif isinstance(col_slice, list):
col_slice = self._convert_tasks(col_slice)
elif not isinstance(col_slice, slice):
raise TypeError(f"{key} second item type {type(col_slice)} is an invalid key")
if isinstance(row_slice, str) and isinstance(col_slice, str):
self.label_data.at[row_slice, col_slice] = value
else:
self.label_data.loc[row_slice, col_slice] = value
def __delitem__(self, key):
if not isinstance(key, (str, list)):
raise KeyError(f"'{key}' is an invalid key")
key = self._convert_tasks(key)
self.label_data.drop(columns=key, axis=1, inplace=True)
tc_dict = self._map_task_classes(self.tasks, self.classes)
self.tasks = [t for t in self.tasks if t not in set(key)]
self.classes = [tc_dict[t] for t in self.tasks]
def __getattr__(self, item):
if item in self.tasks:
return self.__getitem__(item)
# def __setattr__(self, key, value):
# if key in self.tasks:
# self.label_data[key] = value
def __len__(self):
return len(self.label_data)
def __repr__(self):
return self.label_data.astype(object).__repr__()
def __copy__(self):
return MTLabel(self.label_data, self.tasks, self.classes)
def __deepcopy__(self, memodict={}):
return MTLabel(copy.deepcopy(self.label_data, memodict),
copy.deepcopy(self.tasks, memodict),
copy.deepcopy(self.classes, memodict))
if __name__ == '__main__':
dms_tasks = ["ems", "eye", 'mouth', 'glass', 'mask', 'smoke', 'phone', "eyelid_r", "eyelid_l"]
dms_classes = [['normal', 'look_left', 'look_down', 'look_right', 'invalid'],
['normal', 'close_eye', 'invalid'],
['normal', 'yawn', 'invalid'],
['normal', 'glass', 'invalid'],
['normal', 'mask', 'invalid'],
['normal', 'smoke'],
['normal', 'phone'], None, None]
test_label = MTLabel('../test_dms3_labels_v4_rec.txt', dms_tasks, dms_classes)
test_label.summary()
# test_label.plot(["eyelid_r", "eyelid_l"])
# print(test_label.head())
# a = test_label[:3, :5]
# b = test_label[2:5, 4:]
# b[0] = 8
# print(a)
# print(b)
#
# c = a.concat(b, fill=True)
# print(c)
test_label['ems'].summary()
test_label['ems'].export("ems_rec.txt")