|
import torch |
|
import random |
|
import os |
|
import sys |
|
import logging |
|
import numpy as np |
|
import pandas as pd |
|
from shutil import copy |
|
from datetime import datetime |
|
import matplotlib.pyplot as plt |
|
import collections |
|
import umap |
|
import umap.plot |
|
from matplotlib.colors import ListedColormap |
|
|
|
|
|
from sklearn.metrics import classification_report, accuracy_score |
|
|
|
|
|
def count_parameters(model): |
|
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
|
|
def __init__(self): |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
|
|
def fix_randomness(SEED): |
|
random.seed(SEED) |
|
np.random.seed(SEED) |
|
torch.manual_seed(SEED) |
|
torch.cuda.manual_seed(SEED) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
def _logger(logger_name, level=logging.DEBUG): |
|
""" |
|
Method to return a custom logger with the given name and level |
|
:param logger_name: |
|
:param level: |
|
:return: |
|
""" |
|
logger = logging.getLogger(logger_name) |
|
logger.setLevel(level) |
|
format_string = "%(message)s" |
|
log_format = logging.Formatter(format_string) |
|
|
|
console_handler = logging.StreamHandler(sys.stdout) |
|
console_handler.setFormatter(log_format) |
|
logger.addHandler(console_handler) |
|
|
|
file_handler = logging.FileHandler(logger_name, mode='a') |
|
file_handler.setFormatter(log_format) |
|
logger.addHandler(file_handler) |
|
return logger |
|
|
|
|
|
def starting_logs(data_type, exp_log_dir, seed_id): |
|
log_dir = os.path.join(exp_log_dir, "_seed_"+ str(seed_id)) |
|
os.makedirs(log_dir,exist_ok=True) |
|
log_file_name = os.path.join(log_dir, f"logs_{datetime.now().strftime('%d_%m_%Y_%H_%M_%S')}.log") |
|
logger = _logger(log_file_name) |
|
logger.debug('=' * 45) |
|
logger.debug(f'Logging data type {data_type}') |
|
logger.debug("=" * 45) |
|
logger.debug(f'Logging seed id {seed_id}') |
|
logger.debug("=" * 45) |
|
return logger, log_dir |
|
|
|
|
|
def save_checkpoint(exp_log_dir, model, dataset, dataset_configs, hparams, status): |
|
save_dict = { |
|
"dataset": dataset, |
|
"configs": dataset_configs.__dict__, |
|
"hparams": dict(hparams), |
|
"model": model.state_dict(), |
|
} |
|
|
|
save_path = os.path.join(exp_log_dir, f"checkpoint_{status}.pt") |
|
torch.save(save_dict, save_path) |
|
|
|
|
|
def _calc_metrics(pred_labels, true_labels, classes_names): |
|
pred_labels = np.array(pred_labels).astype(int) |
|
true_labels = np.array(true_labels).astype(int) |
|
|
|
r = classification_report(true_labels, pred_labels, target_names=classes_names, digits=6, output_dict=True) |
|
accuracy = accuracy_score(true_labels, pred_labels) |
|
|
|
return accuracy * 100, r["macro avg"]["f1-score"] * 100 |
|
|
|
|
|
def _save_metrics(pred_labels, true_labels, log_dir, status): |
|
pred_labels = np.array(pred_labels).astype(int) |
|
true_labels = np.array(true_labels).astype(int) |
|
|
|
r = classification_report(true_labels, pred_labels, digits=6, output_dict=True) |
|
|
|
df = pd.DataFrame(r) |
|
accuracy = accuracy_score(true_labels, pred_labels) |
|
df["accuracy"] = accuracy |
|
df = df * 100 |
|
|
|
|
|
file_name = f"classification_report_{status}.xlsx" |
|
report_save_path = os.path.join(log_dir, file_name) |
|
df.to_excel(report_save_path) |
|
|
|
|
|
def to_device(input, device): |
|
if torch.is_tensor(input): |
|
return input.to(device=device) |
|
elif isinstance(input, str): |
|
return input |
|
elif isinstance(input, collections.abc.Mapping): |
|
return {k: to_device(sample, device=device) for k, sample in input.items()} |
|
elif isinstance(input, collections.abc.Sequence): |
|
return [to_device(sample, device=device) for sample in input] |
|
else: |
|
raise TypeError("Input must contain tensor, dict or list, found {type(input)}") |
|
|
|
|
|
def copy_files(destination): |
|
destination_dir = os.path.join(destination, "MODEL_BACKUP_FILES") |
|
os.makedirs(destination_dir, exist_ok=True) |
|
copy("main.py", os.path.join(destination_dir, "main.py")) |
|
copy("data/dataloader.py", os.path.join(destination_dir, "dataloader.py")) |
|
copy("data/dataset.py", os.path.join(destination_dir, "dataset.py")) |
|
copy(f"models/MoE_ECGFormer.py", os.path.join(destination_dir, f"models.py")) |
|
copy(f"configs/data_configs.py", os.path.join(destination_dir, f"data_configs.py")) |
|
copy(f"configs/hparams.py", os.path.join(destination_dir, f"hparams.py")) |
|
copy(f"train.py", os.path.join(destination_dir, f"train.py")) |
|
copy("utils.py", os.path.join(destination_dir, "utils.py")) |
|
|
|
|
|
def _plot_umap(model, data_loader, device, save_dir): |
|
|
|
classes_names = ['N', 'S', 'V', 'F', 'Q'] |
|
|
|
font = {'family': 'Times New Roman', |
|
'weight': 'bold', |
|
'size': 17} |
|
plt.rc('font', **font) |
|
|
|
with torch.no_grad(): |
|
|
|
data = data_loader.dataset.x_data.float().to(device) |
|
labels = data_loader.dataset.y_data.view((-1)).long() |
|
out = model[0](data) |
|
features = model[1](out) |
|
|
|
if not os.path.exists(os.path.join(save_dir, "umap_plots")): |
|
os.mkdir(os.path.join(save_dir, "umap_plots")) |
|
|
|
|
|
model_reducer = umap.UMAP() |
|
embedding = model_reducer.fit_transform(features.detach().cpu().numpy()) |
|
|
|
|
|
norm_labels = labels / 4.0 |
|
|
|
|
|
paired = plt.cm.get_cmap('Paired', 12) |
|
new_colors = [paired(0), paired(1), paired(2), paired(4), |
|
paired(6)] |
|
new_cmap = ListedColormap(new_colors) |
|
|
|
print("Plotting UMAP ...") |
|
plt.figure(figsize=(16, 10)) |
|
|
|
scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=norm_labels, cmap=new_cmap, s=15) |
|
|
|
handles, _ = scatter.legend_elements(prop='colors') |
|
plt.legend(handles, classes_names, title="Classes") |
|
file_name = "umap_.png" |
|
fig_save_name = os.path.join(save_dir, "umap_plots", file_name) |
|
plt.xticks([]) |
|
plt.yticks([]) |
|
plt.savefig(fig_save_name, bbox_inches='tight') |
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|