| import dataclasses |
| from collections import defaultdict |
| import pandas as pd |
| import torch |
| from typing import Dict, List, Optional, Union |
|
|
| from barista.data.dataframe_wrapper import DataframeWrapper |
| from barista.data.metadata_spatial_groups import ( |
| MetadataSpatialGroupRow, |
| MetadataSpatialGroups, |
| ) |
|
|
|
|
| @dataclasses.dataclass |
| class MetadataRow: |
| dataset: str |
| subject: str |
| session: str |
| subject_session: str |
| experiment: str |
| d_input: int |
| d_data: torch.Size |
| split: str |
| path: str |
| filename: str |
| processing_str: str |
| seq_len: int |
| label: Optional[float] |
|
|
|
|
| class Metadata(DataframeWrapper): |
| """ |
| Metadata class to keep track of all segment meta information. |
| """ |
|
|
| def __init__(self, df=None, load_path=None, spatial_group_df=None): |
| if df is None: |
| assert spatial_group_df is None |
|
|
| super().__init__(df, load_path) |
|
|
| self._spatial_groups = None |
| if load_path is not None: |
| try: |
| self._spatial_groups = MetadataSpatialGroups( |
| load_path=self._get_spatial_group_path(load_path) |
| ) |
| except FileNotFoundError: |
| pass |
| elif spatial_group_df is not None: |
| self._spatial_groups = MetadataSpatialGroups(df=spatial_group_df) |
|
|
| def _get_spatial_group_path(self, path: str) -> str: |
| suffix = ".csv" |
| new_path = path[: -len(suffix)] |
| spatial_path = f"{new_path}_spatial_groups{suffix}" |
| return spatial_path |
|
|
| def save(self, path: str) -> None: |
| super().save(path) |
| self._spatial_groups.save(self._get_spatial_group_path(path)) |
|
|
| @classmethod |
| def merge( |
| cls, |
| metadatas: List["Metadata"], |
| drop_duplicate: bool = False, |
| merge_columns: Union[str, List[str], None] = None, |
| keep="first", |
| ) -> "Metadata": |
| new_metadata = super().merge(metadatas, drop_duplicate, merge_columns, keep) |
|
|
| |
| spatial_groups = [m._spatial_groups for m in metadatas] |
| merged_spatial_groups = MetadataSpatialGroups.merge( |
| spatial_groups, |
| drop_duplicate=True, |
| merge_columns=[ |
| "dataset", |
| "subject_session", |
| "name", |
| ], |
| ) |
| new_metadata._spatial_groups = merged_spatial_groups |
| return new_metadata |
|
|
| def get_subject_session_d_input(self) -> dict: |
| return self._get_column_mapping_dict_from_dataframe( |
| key_col="subject_session", |
| value_col="d_input", |
| ) |
|
|
| def get_subjects(self) -> dict: |
| return self.get_unique_values_in_col("subject") |
|
|
| def _shape_str_to_list(self, value) -> tuple: |
| if not isinstance(value, str): |
| return value |
| return [int(a) for a in value.split(",")] |
|
|
| def get_subject_session_full_d_data(self) -> Dict[str, List[int]]: |
| """ |
| Returns a dict containing subject_session to data shape |
| """ |
| my_dict = self._get_column_mapping_dict_from_dataframe( |
| key_col="subject_session", |
| value_col="d_data", |
| ) |
| return {k: self._shape_str_to_list(v) for k, v in my_dict.items()} |
|
|
|
|
| def get_labels_count_summary(self) -> dict: |
| splits = self.get_unique_values_in_col("split") |
| labels = self.get_unique_values_in_col("label") |
| |
| labels_count = defaultdict(dict) |
| for split in splits: |
| for label in labels: |
| count = len( |
| self.get_indices_matching_cols_values( |
| ["split", "label"], |
| [split, label], |
| ) |
| ) |
| labels_count[split][label] = count |
| return labels_count |
|
|
| def get_summary_str(self) -> str: |
| subjects = self.get_unique_values_in_col("subject") |
| labels_count = self.get_labels_count_summary() |
| |
| summary_str = f"Metadata for {len(subjects)} subjects ({subjects})" |
|
|
| for split, labels in labels_count.items(): |
| for label, count in labels.items(): |
| summary_str += f", {count} {split} segments with label {label}" |
|
|
| return summary_str |
|
|
| |
|
|
| def add_spatial_group(self, spatial_group_row: MetadataSpatialGroupRow): |
| """ |
| Add (or overwrite) the spatial group |
| """ |
| self._spatial_groups.remove_spatial_group( |
| spatial_group_row.subject_session, spatial_group_row.name |
| ) |
| self._spatial_groups.concat(pd.DataFrame([spatial_group_row])) |
|
|
| def get_spatial_grouping( |
| self, subject_session: str, name: str |
| ) -> Optional[MetadataSpatialGroupRow]: |
| """ |
| Return spatial grouping information for spatial grouping `name` and subject_session `subject_session`'s. |
| |
| Spatial grouping is MetadataSpatialGroupRow which the most important property is group_components |
| which is a list of tuples that contains group info for each channel of the data, |
| and group_ids which is a list of integer that specify which group each channel belongs to. |
| """ |
|
|
| return self._spatial_groups.get_spatial_grouping(subject_session, name) |
|
|
| def get_spatial_grouping_id_hashmap(self, name: str) -> Dict[str, List[int]]: |
| """ |
| Return spatial grouping dictionary which maps each subject_session to list of group ids which is a list of |
| length channels specifying which group each channel belongs to. |
| |
| # NOTE Don't use during forward because of the copy |
| """ |
| temp_copy = self._spatial_groups.copy() |
| temp_copy.reduce_based_on_col_value(col_name="name", value=name, keep=True) |
| return temp_copy._get_column_mapping_dict_from_dataframe( |
| "subject_session", "group_ids" |
| ) |
|
|