|
import copy |
|
import numbers |
|
import os |
|
import torch |
|
import numpy as np |
|
import random |
|
import pandas as pd |
|
from utils.common import cprint, Color, log_warn, log_error |
|
import matplotlib.pyplot as plt |
|
|
|
pd.set_option('display.max_columns', None) |
|
pd.set_option('display.max_rows', None) |
|
pd.set_option('max_colwidth', 100) |
|
pd.set_option('display.width', 5000) |
|
pd.options.mode.chained_assignment = None |
|
plt.rcParams['axes.unicode_minus'] = False |
|
|
|
|
|
def labels_to_class_weights(labels, nc=80): |
|
|
|
|
|
if labels[0] is None: |
|
return torch.Tensor() |
|
|
|
labels = np.array([l[1] for l in labels if l[1] != -1], dtype=np.int) |
|
|
|
|
|
classes = labels |
|
weights = np.bincount(classes, minlength=nc) |
|
|
|
|
|
|
|
|
|
|
|
weights[weights == 0] = 1 |
|
weights = 1 / weights |
|
weights /= weights.sum() |
|
return torch.from_numpy(weights) |
|
|
|
|
|
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)): |
|
|
|
class_counts = np.array([np.bincount(np.array([x[1]], dtype=np.int), minlength=nc) for x in labels]) |
|
image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1) |
|
|
|
return image_weights |
|
|
|
|
|
def labels_to_class_weights_mtl(labels, num_classes=[3, 3]): |
|
|
|
|
|
if labels[0] is None: |
|
return torch.Tensor() |
|
|
|
labels = np.array([l[1] for l in labels], dtype=np.int) |
|
|
|
|
|
all_weights = [] |
|
for i, nc in enumerate(num_classes): |
|
cur_labels = np.array([l[1][i] for l in labels if l[1][i] != -1], dtype=np.int) |
|
weights = np.bincount(cur_labels, minlength=nc) |
|
weights[weights == 0] = 1 |
|
weights = 1 / weights |
|
weights /= weights.sum() |
|
all_weights.append(weights) |
|
return all_weights |
|
|
|
|
|
def labels_to_image_weights_mtl(labels, num_classes=[3, 3], class_weights=[]): |
|
|
|
labels = np.array([l[1] for l in labels], dtype=np.int) |
|
class_weights = [weights / weights.sum() for weights in class_weights] |
|
all_image_weights = [] |
|
for i, nc in enumerate(num_classes): |
|
class_counts = np.array( |
|
[np.bincount(np.array([x], dtype=np.int), minlength=nc) for x in np.squeeze(labels[:, i])]) |
|
image_weights = (class_weights[i].reshape(1, nc) * class_counts).sum(1) |
|
all_image_weights.append(np.squeeze(image_weights)) |
|
|
|
|
|
return all_image_weights[0] |
|
|
|
|
|
def sort_mtl_labels(inputs): |
|
if inputs[0] is None: |
|
return [] |
|
random.shuffle(inputs) |
|
ori_labels = np.array([l[1] for l in inputs], dtype=np.float) |
|
labels = ori_labels.copy() |
|
labels[labels >= 0] = 1 |
|
labels[labels < 0] = 0 |
|
counts = (np.sum(labels, axis=1) - 1) * 100 |
|
n, c = labels.shape[:2] |
|
for i in range(c): |
|
labels[:, i][labels[:, i] > 0] = c - i |
|
res = (np.sum(labels, axis=1) + counts).squeeze() |
|
idx = np.argsort(-res).tolist() |
|
ori_labels = ori_labels[idx] |
|
return idx |
|
|
|
|
|
def load_labels(label_file, classes=None, prints=True): |
|
label_dict = {} |
|
if not os.path.exists(label_file): |
|
return label_dict |
|
with open(label_file) as f: |
|
for line in f: |
|
info = line.strip().split(' ') |
|
name = info[0] |
|
labels = list(map(eval, info[1:])) |
|
if len(labels) == 1: |
|
label_dict[name] = labels[0] |
|
else: |
|
label_dict[name] = labels |
|
|
|
if prints: |
|
print(f"load {len(label_dict)} labels from '{label_file}'") |
|
|
|
if classes is not None: |
|
summary_labels(label_dict, classes) |
|
|
|
return label_dict |
|
|
|
|
|
def load_labels_batch(file_lst, classes=None): |
|
if isinstance(file_lst, str): |
|
file_lst = [file_lst] |
|
label_dict = {} |
|
for i, file in enumerate(file_lst): |
|
dct = load_labels(file, classes) |
|
label_dict.update(dct) |
|
print(f'total get {len(label_dict.keys())} items') |
|
|
|
if classes is not None: |
|
summary_labels(label_dict, classes) |
|
|
|
return label_dict |
|
|
|
|
|
def save_labels(label_file, label_dict, mod='w'): |
|
with open(label_file, mod) as f: |
|
for img_name, labels in label_dict.items(): |
|
if isinstance(labels, (int, float, str)): |
|
f.write(f"{img_name} {labels}\n") |
|
elif isinstance(labels, (tuple, list)): |
|
labels = list(map(str, labels)) |
|
f.write(' '.join([img_name] + labels) + '\n') |
|
else: |
|
raise NotImplementedError |
|
print(f"Save {len(label_dict)} annotation in '{label_file}'") |
|
|
|
|
|
def merge_labels(file_list, save_file): |
|
label_dict = load_labels_batch(file_list) |
|
save_labels(save_file, label_dict) |
|
|
|
|
|
def concat_labels(file_list, save_file, dim=0): |
|
""" |
|
:param file_list: |
|
:param save_file: |
|
:param dim: 0 纵向合并,1 横向合并 |
|
:return: |
|
""" |
|
assert isinstance(file_list, list) and len(file_list) >= 2 |
|
if dim == 0: |
|
merge_labels(file_list, save_file) |
|
else: |
|
first_dict = load_labels(file_list[0]) |
|
common_keys = set(first_dict.keys()) |
|
dicts_list = [first_dict] |
|
for label_file in file_list[1:]: |
|
tmp_dict = load_labels(label_file) |
|
common_keys = common_keys.intersection(set(tmp_dict.keys())) |
|
dicts_list.append(tmp_dict) |
|
print(f"{len(common_keys)} common keys") |
|
|
|
label_dict = {} |
|
for name in common_keys: |
|
tmp_label = [] |
|
for tmp_dict in dicts_list: |
|
label = tmp_dict[name] |
|
if isinstance(label, int): |
|
tmp_label.append(label) |
|
elif isinstance(label, (list, tuple)): |
|
tmp_label += list(label) |
|
else: |
|
raise NotImplementedError |
|
label_dict[name] = tmp_label |
|
|
|
save_labels(save_file, label_dict) |
|
|
|
|
|
def summary_labels(label_dict, classes=None): |
|
labels = list(label_dict.values()) |
|
if not labels: |
|
return |
|
if isinstance(labels[0], int): |
|
labels = [l for l in labels if l != -1] |
|
print(f"all: {len(labels)}", end=' ') |
|
if classes: |
|
assert max(labels) < len(classes) |
|
for i, category in enumerate(classes): |
|
end = '\n' if i == len(classes) - 1 else ' ' |
|
print(f"{category}: {labels.count(i)}", end=end) |
|
else: |
|
all_cls = sorted(list(set(labels))) |
|
for i, category in enumerate(all_cls): |
|
end = '\n' if i == len(classes) - 1 else ' ' |
|
print(f"{category}: {labels.count(category)}", end=end) |
|
elif isinstance(labels[0], list): |
|
label_array = np.array(labels, dtype=int) |
|
label_count = list(np.sum(label_array, axis=0)) |
|
if classes: |
|
assert len(classes) == len(labels[0]) |
|
for i, (category, count) in enumerate(zip(classes, label_count)): |
|
end = '\n' if i == len(classes) - 1 else ' ' |
|
print(f"{category}: {count}", end=end) |
|
else: |
|
for i, count in enumerate(label_count): |
|
end = '\n' if i == len(label_count) - 1 else ' ' |
|
print(f"{i}: {count}", end=end) |
|
|
|
|
|
def summary_label_file(label_file, classes=None): |
|
if isinstance(label_file, str): |
|
label_dict = load_labels(label_file) |
|
elif isinstance(label_file, (list, tuple)): |
|
label_dict = load_labels_batch(label_file) if len(label_file) >= 2 else label_file[0] |
|
else: |
|
raise NotImplementedError |
|
summary_labels(label_dict, classes) |
|
|
|
|
|
class MTLabel: |
|
_invalid = -1 |
|
|
|
def __init__(self, input_label=None, tasks=None, classes=None): |
|
|
|
label_data = self._load(input_label) |
|
tasks = self._check_tasks(tasks) |
|
classes = self._check_classes(classes) |
|
|
|
if not label_data.empty: |
|
tasks = [f"task{i}" for i in range(label_data.shape[1])] if tasks is None else tasks |
|
assert len(tasks) == label_data.shape[1], \ |
|
f"Tasks length {len(tasks)} not match to label length {label_data.shape[1]}" |
|
label_data.columns = tasks |
|
|
|
if tasks is not None: |
|
classes = [None] * len(tasks) if classes is None else classes |
|
assert len(tasks) == len(classes), f"Tasks {tasks} not match to {classes}" |
|
|
|
self.label_data = label_data |
|
self.tasks = copy.deepcopy(tasks) |
|
self.classes = copy.deepcopy(classes) |
|
|
|
@staticmethod |
|
def _load(input_label): |
|
if input_label is None: |
|
return pd.DataFrame() |
|
if isinstance(input_label, str): |
|
if not os.path.isfile(input_label): |
|
raise FileNotFoundError(f'Check label file path: {input_label}') |
|
label_data = pd.read_csv(input_label, sep=' ', header=None, index_col=0) |
|
cprint(f"{label_data.shape if label_data.shape[1] > 1 else label_data.shape[0]} labels from '{input_label}'", prefix='Load') |
|
elif isinstance(input_label, pd.core.frame.DataFrame): |
|
label_data = copy.deepcopy(input_label) |
|
elif isinstance(input_label, dict): |
|
label_data = pd.DataFrame(input_label).transpose() |
|
elif isinstance(input_label, list): |
|
label_data = pd.DataFrame(dict(input_label)).transpose() |
|
else: |
|
raise TypeError(f'{type(input_label)} is not support') |
|
label_data.index.rename("name", inplace=True) |
|
return label_data |
|
|
|
@staticmethod |
|
def _save(label_data, save_file, filter_invalid=True): |
|
save_dir = os.path.dirname(os.path.abspath(save_file)) |
|
os.makedirs(save_dir, exist_ok=True) |
|
label_data.fillna(MTLabel._invalid, inplace=True) |
|
if filter_invalid: |
|
label_data = label_data[(label_data != MTLabel._invalid).any(axis=1)] |
|
label_data = label_data.astype(object) |
|
label_data.to_csv(save_file, sep=' ', index=True, header=False) |
|
cprint(f"{label_data.shape if label_data.shape[1] > 1 else label_data.shape[0]} labels in '{save_file}'", prefix='Save') |
|
|
|
@classmethod |
|
def _new(cls, input_label=None, tasks=None, classes=None): |
|
return cls(input_label, tasks, classes) |
|
|
|
@staticmethod |
|
def _check_tasks(tasks, int_ok=False): |
|
if tasks is None or not tasks: |
|
return |
|
if isinstance(tasks, (str, numbers.Integral)): |
|
tasks = [tasks] |
|
elif isinstance(tasks, tuple): |
|
tasks = list(tasks) |
|
assert tasks and isinstance(tasks, list), f"arg 'task' should be type (int, str, tuple, list)" |
|
for i, t in enumerate(tasks): |
|
if isinstance(t, numbers.Integral): |
|
if not int_ok: |
|
raise TypeError(f"'int' type task {t} at {i} not support, set 'int_ok=True' ?") |
|
elif not isinstance(t, str): |
|
raise TypeError(f"{type(t)} task {t} at {i} not support") |
|
return tasks |
|
|
|
@staticmethod |
|
def _check_classes(classes): |
|
if classes is None: |
|
return |
|
assert classes and isinstance(classes, list), f"classes {classes} should be a list" |
|
if isinstance(classes[0], str): |
|
classes = [classes] |
|
return classes |
|
|
|
@staticmethod |
|
def _check_names(names, int_ok=False): |
|
if names is None or not names: |
|
return |
|
if isinstance(names, (str, numbers.Integral)): |
|
names = [names] |
|
elif isinstance(names, (tuple, set)): |
|
names = list(names) |
|
assert names and isinstance(names, list), f"arg 'names' should be type (int, str, tuple, set, list)" |
|
for i, t in enumerate(names): |
|
if isinstance(t, numbers.Integral): |
|
if not int_ok: |
|
raise TypeError(f"'int' type name {t} at {i} not support, set 'int_ok=True' ?") |
|
elif not isinstance(t, str): |
|
raise TypeError(f"{type(t)} name {t} at {i} not support") |
|
return names |
|
|
|
@staticmethod |
|
def _type_data(values, types): |
|
return [t(v) for t, v in zip(types, values)] |
|
|
|
def _check_value(self, value, max_len): |
|
if isinstance(value, list): |
|
if len(value) < max_len: |
|
value = value + [self._invalid] * (max_len - len(value)) |
|
elif len(value) > max_len: |
|
value = value[:max_len] |
|
return value |
|
|
|
def _convert_tasks(self, tasks): |
|
tasks = self._check_tasks(tasks, int_ok=True) |
|
if tasks is None: |
|
return |
|
for i, t in enumerate(tasks): |
|
if isinstance(t, numbers.Integral): |
|
assert t < len(self.tasks), f"index {t} out of range({len(self.tasks)})" |
|
tasks[i] = self.tasks[t] |
|
elif isinstance(t, str): |
|
assert t in self.tasks, f"'{t}' not in {self.tasks}" |
|
return tasks |
|
|
|
def _convert_names(self, names): |
|
names = self._check_names(names, int_ok=True) |
|
if names is None: |
|
return |
|
for i, t in enumerate(names): |
|
if isinstance(t, numbers.Integral): |
|
assert t < len(self), f"index {t} out of range({len(self)})" |
|
names[i] = self.names[t] |
|
|
|
|
|
return names |
|
|
|
@staticmethod |
|
def _map_task_classes(tasks, classes): |
|
assert isinstance(tasks, list) and isinstance(classes, list) and len(tasks) == len(classes) |
|
map_dict = {t: c for t, c in zip(tasks, classes)} |
|
return map_dict |
|
|
|
@property |
|
def shape(self): |
|
return self.label_data.shape |
|
|
|
@property |
|
def empty(self): |
|
return self.label_data.empty |
|
|
|
@property |
|
def index(self): |
|
return self.label_data.index |
|
|
|
@property |
|
def values(self): |
|
return self.label_data.values |
|
|
|
@property |
|
def names(self): |
|
return list(self.index) |
|
|
|
@property |
|
def columns(self): |
|
return self.tasks |
|
|
|
@property |
|
def dtypes(self): |
|
return self.label_data.dtypes |
|
|
|
def head(self, n=5): |
|
return self.label_data.astype(object).head(n) |
|
|
|
def tail(self, n=5): |
|
return self.label_data.astype(object).tail(n) |
|
|
|
def astype(self, dtype): |
|
self.label_data = self.label_data.astype(dtype) |
|
return self |
|
|
|
def summary(self, tasks=None, extra_info=''): |
|
print("-" * 120) |
|
tasks = self._convert_tasks(tasks) |
|
if tasks is None and self.tasks is None: |
|
log_warn(f"Label data is empty or 'tasks' is None") |
|
return |
|
tasks = self.tasks if tasks is None else tasks |
|
tc_dict = self._map_task_classes(self.tasks, self.classes) |
|
cprint(extra_info, prefix="Summary") |
|
for t in tasks: |
|
classes = tc_dict[t] |
|
cur_label = self.label_data[t] |
|
cur_label = cur_label[cur_label != self._invalid] |
|
cprint(f"Task {t}\t==> all: {len(cur_label)}", end=' ') |
|
|
|
if cur_label.empty: |
|
print() |
|
continue |
|
|
|
if classes is None: |
|
classes = ['mean', 'std', 'min', 'max', '1%', '50%', '99%'] |
|
res = cur_label.describe(percentiles=[0.01, 0.99], include=[np.number]) |
|
for i, category in enumerate(classes): |
|
end = '\n' if i == len(classes) - 1 else ' ' |
|
cprint(f"{category}: {res[category]:.3f}", end=end) |
|
else: |
|
res = cur_label.value_counts() |
|
for i, category in enumerate(classes): |
|
end = '\n' if i == len(classes) - 1 else '\t' |
|
if i in res: |
|
cprint(f"{category}: {res[i]}", end=end) |
|
else: |
|
cprint(f"{category}: 0", end=end) |
|
print("-"*120) |
|
|
|
def plot(self, tasks=None): |
|
tasks = self._convert_tasks(tasks) |
|
if tasks is None: |
|
tasks = self.tasks |
|
tc_dict = self._map_task_classes(self.tasks, self.classes) |
|
cls_part = [] |
|
reg_part = [] |
|
for t in tasks: |
|
classes = tc_dict[t] |
|
if classes is None: |
|
reg_part.append(t) |
|
else: |
|
cls_part.append(t) |
|
if cls_part: |
|
self._plt_bar(cls_part, in_one=False) |
|
if reg_part: |
|
self._hist(reg_part) |
|
|
|
def _hist(self, tasks=None, cols=3, in_one=False): |
|
tasks = self._convert_tasks(tasks) |
|
if tasks is None: |
|
tasks = self.tasks |
|
cols = cols if len(tasks) >= cols else len(tasks) |
|
rows, mod = divmod(len(tasks), cols) |
|
rows += mod != 0 |
|
for i, t in enumerate(tasks): |
|
plt.subplot(rows, cols, i + 1) if in_one else plt.figure() |
|
cur_label = self.label_data[t] |
|
cur_label = cur_label[cur_label != self._invalid] |
|
bin_size = (cur_label.max() - cur_label.min()) / 11 |
|
cur_label.hist(bins=np.arange(cur_label.min(), 1.01 * cur_label.max(), bin_size)) |
|
plt.title(t) |
|
plt.tight_layout() |
|
plt.grid(False) |
|
plt.show() |
|
|
|
def _plt_bar(self, tasks=None, cols=3, in_one=False): |
|
tasks = self._convert_tasks(tasks) |
|
if tasks is None: |
|
tasks = self.tasks |
|
tc_dict = self._map_task_classes(self.tasks, self.classes) |
|
cols = cols if len(tasks) >= cols else len(tasks) |
|
rows, mod = divmod(len(tasks), cols) |
|
rows += mod != 0 |
|
|
|
if in_one: |
|
fig, axes = plt.subplots(rows, cols) |
|
|
|
for i, t in enumerate(tasks): |
|
|
|
|
|
classes = tc_dict[t] |
|
cur_label = self.label_data[t] |
|
cur_label = cur_label[cur_label != self._invalid] |
|
res = cur_label.value_counts() |
|
values = [res[i] for i in range(len(classes))] |
|
df = pd.DataFrame({"category": classes, "count": values}) |
|
if in_one: |
|
r, c = divmod(i, cols) |
|
ax = axes[r, c] if rows > 1 else axes[i] |
|
else: |
|
ax = None |
|
df.plot(kind='bar', x="category", y="count", title=t, grid=False, rot=30 if in_one else 0, |
|
ax=ax, legend=False) |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
def set_tasks(self, tasks): |
|
tasks = self._check_tasks(tasks) |
|
if self.tasks is not None and len(tasks) != len(self.tasks): |
|
log_error(f"new tasks length {len(tasks)} not equal to ori {len(self.tasks)}") |
|
return |
|
self.tasks = copy.deepcopy(tasks) |
|
self.label_data.columns = self.tasks |
|
|
|
def set_classes(self, classes): |
|
classes = self._check_classes(classes) |
|
if self.classes is not None and len(classes) != len(self.classes): |
|
log_error(f"new tasks length {len(classes)} not equal to ori {len(self.classes)}") |
|
return |
|
self.classes = copy.deepcopy(classes) |
|
|
|
def insert(self, task, value=None, loc=None, category=None, dtype=None): |
|
if task in self.tasks: |
|
raise KeyError(f"{task} already exits") |
|
if value is None: |
|
value = self._invalid |
|
if isinstance(category, str): |
|
category = [category] |
|
assert category is None or isinstance(category, list) |
|
|
|
self.label_data.insert(len(self.tasks) if loc is None else loc, task, value) |
|
if dtype is not None: |
|
self.label_data[task] = self.label_data[task].astype(dtype) |
|
|
|
if loc is None: |
|
self.tasks.append(task) |
|
self.classes.append(category) |
|
else: |
|
self.tasks.insert(loc, task) |
|
self.classes.insert(loc, category) |
|
|
|
def remove(self, task): |
|
self.__delitem__(task) |
|
|
|
def add(self, key, value, keep_dtypes=False): |
|
ori_dtypes = self.label_data.dtypes |
|
code = 1 if key not in self.label_data.index else 0 |
|
self.label_data.loc[key] = value |
|
if (ori_dtypes != self.label_data.dtypes).any() and keep_dtypes: |
|
self.label_data = self.label_data.astype(ori_dtypes) |
|
return code |
|
|
|
def update(self, other_label, tasks=None, classes=None, inplace=False): |
|
if not isinstance(other_label, type(self)): |
|
other_label = self._new(other_label, tasks, classes) |
|
if other_label.empty: |
|
log_warn(f"Empty label data, check path or data") |
|
return |
|
assert (other_label.columns == self.label_data.columns).all(), \ |
|
f"current label({self.label_data.shape}) not match to input({other_label.shape}) at columns" |
|
|
|
ori_dtypes = self.label_data.dtypes |
|
other_label = other_label.label_data |
|
other_label.columns = self.label_data.columns |
|
|
|
common_index = self.label_data.index.intersection(other_label.index) |
|
if common_index.empty: |
|
label_data = pd.concat([self.label_data, other_label]) |
|
cprint(f"{len(other_label)} add.", prefix='Update') |
|
else: |
|
common1 = self.label_data.loc[common_index] |
|
common2 = other_label.loc[common_index] |
|
no_equal = common1[common1.ne(common2).any(axis=1)] |
|
add_label = other_label[~other_label.index.isin(common_index)] |
|
label_data = self.label_data[:] |
|
label_data.update(common2) |
|
label_data = pd.concat([label_data, add_label]) |
|
label_data = label_data.astype(ori_dtypes) |
|
cprint(f"{len(common1)} common ({len(no_equal)} update), {len(add_label)} add.", prefix='Update') |
|
|
|
if inplace: |
|
self.label_data = label_data |
|
else: |
|
return self._new(label_data, self.tasks, self.classes) |
|
|
|
def join(self, other_label, tasks=None, classes=None, inplace=False): |
|
if not isinstance(other_label, type(self)): |
|
other_label = self._new(other_label, tasks, classes) |
|
if other_label.empty: |
|
log_warn(f"Empty label data, check path or data") |
|
return |
|
assert other_label.shape[0] == self.label_data.shape[0], \ |
|
f"current label({self.label_data.shape}) not match to input({other_label.shape}) at index" |
|
|
|
ori_tasks = self.tasks |
|
ori_classes = self.classes |
|
other_tasks = other_label.tasks |
|
other_classes = other_label.classes |
|
other_dict = self._map_task_classes(other_tasks, other_classes) |
|
use_tasks = [t for t in other_tasks if t not in ori_tasks] |
|
use_classes = [other_dict[t] for t in use_tasks] |
|
new_tasks = ori_tasks + use_tasks |
|
new_classes = ori_classes + use_classes |
|
|
|
label_data = pd.concat([self.label_data, other_label.label_data], axis=1) |
|
label_data = label_data.T |
|
label_data = label_data[~label_data.index.duplicated(keep='last')].T |
|
label_data = label_data[new_tasks] |
|
|
|
if inplace: |
|
self.label_data = label_data |
|
self.tasks = new_tasks |
|
self.classes = new_classes |
|
else: |
|
return self._new(label_data, new_tasks, new_classes) |
|
|
|
def concat(self, other_label, tasks=None, classes=None, inplace=False, fill=True): |
|
if not isinstance(other_label, type(self)): |
|
other_label = self._new(other_label, tasks, classes) |
|
if other_label.empty: |
|
log_warn(f"Empty label data, check path or data") |
|
return |
|
|
|
ori_tasks = self.tasks |
|
ori_classes = self.classes |
|
other_tasks = other_label.tasks |
|
other_classes = other_label.classes |
|
other_dict = self._map_task_classes(other_tasks, other_classes) |
|
use_tasks = [t for t in other_tasks if t not in ori_tasks] |
|
use_classes = [other_dict[t] for t in use_tasks] |
|
new_tasks = ori_tasks + use_tasks |
|
new_classes = ori_classes + use_classes |
|
|
|
label_data = self.label_data[:] |
|
other_label = other_label.label_data |
|
common_index = label_data.index.intersection(other_label.index) |
|
common_columns = label_data.columns.intersection(other_label.columns) |
|
other_index = other_label.index[~other_label.index.isin(common_index)] |
|
other_columns = other_label.columns[~other_label.columns.isin(common_columns)] |
|
|
|
if not common_index.empty and not common_columns.empty: |
|
label_data.update(other_label.loc[common_index, common_columns]) |
|
if not common_columns.empty: |
|
label_data = pd.concat([label_data, other_label.loc[other_index, common_columns]]) |
|
if not common_index.empty: |
|
label_data = pd.concat([label_data, other_label.loc[common_index, other_columns]], axis=1) |
|
|
|
if common_index.empty and common_columns.empty: |
|
label_data = pd.concat([label_data, other_label]) |
|
else: |
|
label_data.update(other_label) |
|
|
|
if fill: |
|
label_data.fillna(MTLabel._invalid, inplace=True) |
|
|
|
cprint(f"update common ({len(common_index)}, {len(common_columns)}), " |
|
f"add ({len(other_index)}, {len(other_columns)})", prefix="Concat") |
|
|
|
if inplace: |
|
self.label_data = label_data |
|
self.tasks = new_tasks |
|
self.classes = new_classes |
|
else: |
|
return self._new(label_data, new_tasks, new_classes) |
|
|
|
def sample(self, num): |
|
obj = self._new(self.label_data.sample(num), self.tasks, self.classes) |
|
return obj |
|
|
|
def pick_tasks(self, tasks, inplace=False): |
|
tasks = self._convert_tasks(tasks) |
|
if tasks is None: |
|
log_warn("'task' is None, use ori label") |
|
return self |
|
classes = [self.classes[self.tasks.index(t)] for t in tasks] |
|
if inplace: |
|
self.label_data = self.label_data[tasks] |
|
self.tasks = copy.deepcopy(tasks) |
|
self.classes = copy.deepcopy(classes) |
|
return self |
|
else: |
|
obj = self._new(self.label_data[tasks], tasks=tasks, classes=classes) |
|
return obj |
|
|
|
def pick_names(self, name_list, inplace=False): |
|
names_list = self._check_names(name_list) |
|
if names_list is None: |
|
log_warn("'name_list' is None, use ori label") |
|
return self |
|
if inplace: |
|
self.label_data = self.label_data.loc[names_list] |
|
return self |
|
else: |
|
obj = self._new(self.label_data.loc[names_list], tasks=self.tasks, classes=self.classes) |
|
return obj |
|
|
|
def tolist(self): |
|
names = self.index.tolist() |
|
values = self.values.tolist() |
|
dtypes = [t.type for t in self.dtypes.tolist()] |
|
values = [self._type_data(v, dtypes) for v in values] |
|
return list(zip(names, values)) |
|
|
|
def todict(self): |
|
names = self.index.tolist() |
|
values = self.values.tolist() |
|
dtypes = [t.type for t in self.dtypes.tolist()] |
|
values = [self._type_data(v, dtypes) for v in values] |
|
return dict(zip(names, values)) |
|
|
|
def export(self, save_file, filter_invalid=True): |
|
self._save(self.label_data, save_file, filter_invalid) |
|
|
|
def export_tasks(self, task, save_file, filter_invalid=True): |
|
obj = self.pick_tasks(task) |
|
self._save(obj.label_data, save_file, filter_invalid) |
|
|
|
def __getitem__(self, item): |
|
|
|
item = slice(item, item + 1) if isinstance(item, numbers.Integral) else item |
|
if isinstance(item, slice): |
|
obj = self._new(self.label_data[item], tasks=self.tasks, classes=self.classes) |
|
return obj |
|
|
|
|
|
elif isinstance(item, str): |
|
if item in self.tasks: |
|
obj = self._new(self.label_data[[item]], tasks=[item], classes=[self.classes[self.tasks.index(item)]]) |
|
elif item in self.label_data.index: |
|
obj = self._new(self.label_data.loc[[item]], tasks=self.tasks, classes=self.classes) |
|
else: |
|
raise KeyError(f"'{item}' not in tasks or index") |
|
return obj |
|
|
|
|
|
elif isinstance(item, list): |
|
return self.pick_tasks(item) |
|
|
|
|
|
elif isinstance(item, tuple): |
|
assert len(item) == 2, f"tuple length {len(item)} != 2" |
|
row_slice, col_slice = item |
|
if isinstance(row_slice, (slice, numbers.Integral)) or \ |
|
(isinstance(row_slice, str) and row_slice not in self.tasks): |
|
obj = self.__getitem__(row_slice) |
|
if isinstance(col_slice, (slice, numbers.Integral)): |
|
obj = obj.__getitem__(obj.tasks[col_slice]) |
|
return obj |
|
else: |
|
return obj.__getitem__(col_slice) |
|
else: |
|
raise TypeError(f"{item} first item type {type(row_slice)} is an invalid key") |
|
else: |
|
raise TypeError(f"{item} is an invalid key") |
|
|
|
def __setitem__(self, key, value): |
|
""" |
|
只接受label_data中已有的数据 |
|
loc: only work on index, can assign a new index or column value |
|
iloc: work on position, can not assign a new index or column value |
|
at: get scalar values. It's a very fast loc |
|
iat: Get scalar values. It's a very fast iloc |
|
""" |
|
if isinstance(key, (numbers.Integral, slice)): |
|
self.label_data.iloc[key] = value |
|
|
|
|
|
elif isinstance(key, str): |
|
if key in self.tasks: |
|
self.label_data.loc[:, key] = value |
|
elif key in self.label_data.index: |
|
self.label_data.loc[key] = value |
|
else: |
|
raise KeyError(f"'{key}' not in tasks or index") |
|
|
|
|
|
elif isinstance(key, list): |
|
key = self._convert_tasks(key) |
|
value = self._check_value(value, len(key)) |
|
self.label_data.loc[:, key] = value |
|
|
|
|
|
elif isinstance(key, tuple): |
|
assert len(key) == 2, f"tuple length {len(key)} != 2" |
|
row_slice, col_slice = key |
|
|
|
if isinstance(row_slice, numbers.Integral): |
|
row_slice = self.label_data.index[row_slice] |
|
elif isinstance(row_slice, str): |
|
assert row_slice in self.label_data.index, f"{row_slice} not in index" |
|
elif not isinstance(row_slice, slice): |
|
raise TypeError(f"{key} first item type {type(row_slice)} is an invalid key") |
|
|
|
if isinstance(col_slice, numbers.Integral): |
|
col_slice = self.tasks[col_slice] |
|
elif isinstance(col_slice, str): |
|
assert col_slice in self.tasks, f"{col_slice} not in {self.tasks}" |
|
elif isinstance(col_slice, list): |
|
col_slice = self._convert_tasks(col_slice) |
|
elif not isinstance(col_slice, slice): |
|
raise TypeError(f"{key} second item type {type(col_slice)} is an invalid key") |
|
|
|
if isinstance(row_slice, str) and isinstance(col_slice, str): |
|
self.label_data.at[row_slice, col_slice] = value |
|
else: |
|
self.label_data.loc[row_slice, col_slice] = value |
|
|
|
def __delitem__(self, key): |
|
if not isinstance(key, (str, list)): |
|
raise KeyError(f"'{key}' is an invalid key") |
|
key = self._convert_tasks(key) |
|
self.label_data.drop(columns=key, axis=1, inplace=True) |
|
tc_dict = self._map_task_classes(self.tasks, self.classes) |
|
self.tasks = [t for t in self.tasks if t not in set(key)] |
|
self.classes = [tc_dict[t] for t in self.tasks] |
|
|
|
def __getattr__(self, item): |
|
if item in self.tasks: |
|
return self.__getitem__(item) |
|
|
|
|
|
|
|
|
|
|
|
def __len__(self): |
|
return len(self.label_data) |
|
|
|
def __repr__(self): |
|
return self.label_data.astype(object).__repr__() |
|
|
|
def __copy__(self): |
|
return MTLabel(self.label_data, self.tasks, self.classes) |
|
|
|
def __deepcopy__(self, memodict={}): |
|
return MTLabel(copy.deepcopy(self.label_data, memodict), |
|
copy.deepcopy(self.tasks, memodict), |
|
copy.deepcopy(self.classes, memodict)) |
|
|
|
|
|
if __name__ == '__main__': |
|
dms_tasks = ["ems", "eye", 'mouth', 'glass', 'mask', 'smoke', 'phone', "eyelid_r", "eyelid_l"] |
|
dms_classes = [['normal', 'look_left', 'look_down', 'look_right', 'invalid'], |
|
['normal', 'close_eye', 'invalid'], |
|
['normal', 'yawn', 'invalid'], |
|
['normal', 'glass', 'invalid'], |
|
['normal', 'mask', 'invalid'], |
|
['normal', 'smoke'], |
|
['normal', 'phone'], None, None] |
|
test_label = MTLabel('../test_dms3_labels_v4_rec.txt', dms_tasks, dms_classes) |
|
test_label.summary() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_label['ems'].summary() |
|
test_label['ems'].export("ems_rec.txt") |
|
|