|
import json |
|
from pathlib import Path |
|
from typing import Dict, List, Set, Union |
|
|
|
|
|
class Labels: |
|
""" |
|
Class that contains the labels for a model. |
|
|
|
Args: |
|
_labels_to_index (:obj:`Dict[str, Dict[str, int]]`): |
|
A dictionary from :obj:`str` to :obj:`int`. |
|
_index_to_labels (:obj:`Dict[str, Dict[int, str]]`): |
|
A dictionary from :obj:`int` to :obj:`str`. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
_labels_to_index: Dict[str, Dict[str, int]] = None, |
|
_index_to_labels: Dict[str, Dict[int, str]] = None, |
|
**kwargs, |
|
): |
|
self._labels_to_index = _labels_to_index or {"labels": {}} |
|
self._index_to_labels = _index_to_labels or {"labels": {}} |
|
|
|
|
|
if not _index_to_labels and _labels_to_index: |
|
for namespace in self._labels_to_index: |
|
self._index_to_labels[namespace] = { |
|
v: k for k, v in self._labels_to_index[namespace].items() |
|
} |
|
|
|
def get_index_from_label(self, label: str, namespace: str = "labels") -> int: |
|
""" |
|
Returns the index of a literal label. |
|
|
|
Args: |
|
label (:obj:`str`): |
|
The string representation of the label. |
|
namespace (:obj:`str`, optional, defaults to ``labels``): |
|
The namespace where the label belongs, e.g. ``roles`` for a SRL task. |
|
|
|
Returns: |
|
:obj:`int`: The index of the label. |
|
""" |
|
if namespace not in self._labels_to_index: |
|
raise ValueError( |
|
f"Provided namespace `{namespace}` is not in the label dictionary." |
|
) |
|
|
|
if label not in self._labels_to_index[namespace]: |
|
raise ValueError(f"Provided label {label} is not in the label dictionary.") |
|
|
|
return self._labels_to_index[namespace][label] |
|
|
|
def get_label_from_index(self, index: int, namespace: str = "labels") -> str: |
|
""" |
|
Returns the string representation of the label index. |
|
|
|
Args: |
|
index (:obj:`int`): |
|
The index of the label. |
|
namespace (:obj:`str`, optional, defaults to ``labels``): |
|
The namespace where the label belongs, e.g. ``roles`` for a SRL task. |
|
|
|
Returns: |
|
:obj:`str`: The string representation of the label. |
|
""" |
|
if namespace not in self._index_to_labels: |
|
raise ValueError( |
|
f"Provided namespace `{namespace}` is not in the label dictionary." |
|
) |
|
|
|
if index not in self._index_to_labels[namespace]: |
|
raise ValueError( |
|
f"Provided label `{index}` is not in the label dictionary." |
|
) |
|
|
|
return self._index_to_labels[namespace][index] |
|
|
|
def add_labels( |
|
self, |
|
labels: Union[str, List[str], Set[str], Dict[str, int]], |
|
namespace: str = "labels", |
|
) -> List[int]: |
|
""" |
|
Adds the labels in input in the label dictionary. |
|
|
|
Args: |
|
labels (:obj:`str`, :obj:`List[str]`, :obj:`Set[str]`): |
|
The labels (single label, list of labels or set of labels) to add to the dictionary. |
|
namespace (:obj:`str`, optional, defaults to ``labels``): |
|
Namespace where the labels belongs. |
|
|
|
Returns: |
|
:obj:`List[int]`: The index of the labels just inserted. |
|
""" |
|
if isinstance(labels, dict): |
|
self._labels_to_index[namespace] = labels |
|
self._index_to_labels[namespace] = { |
|
v: k for k, v in self._labels_to_index[namespace].items() |
|
} |
|
|
|
if isinstance(labels, (str, list)): |
|
labels = set(labels) |
|
|
|
if namespace not in self._labels_to_index: |
|
self._labels_to_index[namespace] = {} |
|
self._index_to_labels[namespace] = {} |
|
|
|
return [self._add_label(label, namespace) for label in labels] |
|
|
|
def _add_label(self, label: str, namespace: str = "labels") -> int: |
|
""" |
|
Adds the label in input in the label dictionary. |
|
|
|
Args: |
|
label (:obj:`str`): |
|
The label to add to the dictionary. |
|
namespace (:obj:`str`, optional, defaults to ``labels``): |
|
Namespace where the label belongs. |
|
|
|
Returns: |
|
:obj:`List[int]`: The index of the label just inserted. |
|
""" |
|
if label not in self._labels_to_index[namespace]: |
|
index = len(self._labels_to_index[namespace]) |
|
self._labels_to_index[namespace][label] = index |
|
self._index_to_labels[namespace][index] = label |
|
return index |
|
else: |
|
return self._labels_to_index[namespace][label] |
|
|
|
def get_labels(self, namespace: str = "labels") -> Dict[str, int]: |
|
""" |
|
Returns all the labels that belongs to the input namespace. |
|
|
|
Args: |
|
namespace (:obj:`str`, optional, defaults to ``labels``): |
|
Labels namespace to retrieve. |
|
|
|
Returns: |
|
:obj:`Dict[str, int]`: The label dictionary, from ``str`` to ``int``. |
|
""" |
|
if namespace not in self._labels_to_index: |
|
raise ValueError( |
|
f"Provided namespace `{namespace}` is not in the label dictionary." |
|
) |
|
return self._labels_to_index[namespace] |
|
|
|
def get_label_size(self, namespace: str = "labels") -> int: |
|
""" |
|
Returns the number of the labels in the namespace dictionary. |
|
|
|
Args: |
|
namespace (:obj:`str`, optional, defaults to ``labels``): |
|
Labels namespace to retrieve. |
|
|
|
Returns: |
|
:obj:`int`: Number of labels. |
|
""" |
|
if namespace not in self._labels_to_index: |
|
raise ValueError( |
|
f"Provided namespace `{namespace}` is not in the label dictionary." |
|
) |
|
return len(self._labels_to_index[namespace]) |
|
|
|
def get_namespaces(self) -> List[str]: |
|
""" |
|
Returns all the namespaces in the label dictionary. |
|
|
|
Returns: |
|
:obj:`List[str]`: The namespaces in the label dictionary. |
|
""" |
|
return list(self._labels_to_index.keys()) |
|
|
|
@classmethod |
|
def from_file(cls, file_path: Union[str, Path, dict], **kwargs): |
|
with open(file_path, "r") as f: |
|
labels_to_index = json.load(f) |
|
return cls(labels_to_index, **kwargs) |
|
|
|
def save(self, file_path: Union[str, Path, dict], indent: int = 2, **kwargs): |
|
with open(file_path, "w") as f: |
|
json.dump(self._labels_to_index, f, indent=indent) |
|
|