| | import torch
|
| | from torchmetrics import MeanAbsoluteError, RelativeSquaredError, PearsonCorrCoef, KendallRankCorrCoef, F1Score, Accuracy, AveragePrecision, AUROC
|
| |
|
| |
|
| | def move_to_device(batch, device, non_blocking=False):
|
| | if isinstance(batch, (list, tuple)):
|
| | return type(batch)(move_to_device(item, device, non_blocking) for item in batch)
|
| | return batch.to(device, non_blocking=non_blocking)
|
| |
|
| |
|
| | def train(args, epoch, model, train_loader, valid_loader, device, criterion, optimizer):
|
| | train_loss = 0
|
| | num_labels = model.classes
|
| | metric_mae = MeanAbsoluteError().to(device)
|
| | metric_rse = RelativeSquaredError(num_outputs=num_labels).to(device)
|
| | metric_pcc = PearsonCorrCoef(num_outputs=num_labels).to(device)
|
| | metric_kcc = KendallRankCorrCoef(num_outputs=num_labels).to(device)
|
| |
|
| | if args.dir:
|
| | encodings, labels = [], []
|
| |
|
| | if train_loader is not None:
|
| | model.train()
|
| | for data in train_loader:
|
| | x, gt = data
|
| | x = move_to_device(x, device)
|
| | if args.dir:
|
| | out, features = model(x,
|
| | gt.to(device),
|
| | epoch)
|
| | encodings.append(features.detach().cpu())
|
| | labels.append(gt.cpu())
|
| | else:
|
| | out = model(x)
|
| | loss = criterion(out, gt.to(device))
|
| | loss.backward()
|
| | optimizer.step()
|
| | optimizer.zero_grad()
|
| | train_loss += loss.item()
|
| | train_loss /= len(train_loader)
|
| |
|
| | if args.dir:
|
| | encodings, labels = torch.cat(encodings), torch.cat(labels)
|
| | model.FDS.update_last_epoch_stats(epoch)
|
| | model.FDS.update_running_stats(encodings, labels, epoch)
|
| | encodings, labels = [], []
|
| |
|
| |
|
| | model.eval()
|
| | preds = []
|
| | gt_list_valid = []
|
| | with torch.no_grad():
|
| | for data in valid_loader:
|
| | x, gt = data
|
| | x = move_to_device(x, device)
|
| | gt_list_valid.append(gt.to(device))
|
| | out = model(x)
|
| | if args.dir:
|
| | out, _ = out
|
| | preds.append(out)
|
| |
|
| |
|
| | preds = torch.cat(preds, dim=0)
|
| | gt_list_valid = torch.cat(gt_list_valid, dim=0)
|
| |
|
| | mae = metric_mae(preds, gt_list_valid).item()
|
| | rse = metric_rse(preds, gt_list_valid).item()
|
| | pcc = metric_pcc(preds.squeeze(), gt_list_valid.squeeze()).mean().item()
|
| | kcc = metric_kcc(preds.squeeze(), gt_list_valid.squeeze()).mean().item()
|
| | return train_loss, mae, rse, pcc, kcc
|
| |
|
| |
|
| | def update_ce_loss_weight(loss_fn: torch.nn.CrossEntropyLoss, gt: torch.Tensor, num_classes: int, device):
|
| | """
|
| | 根据当前 batch 的 ground truth 标签更新 nn.CrossEntropyLoss 对象中的 weight 缓冲区,
|
| | 使用逆频率方法计算新权重,并通过 register_buffer 进行原地更新。
|
| |
|
| | 参数:
|
| | loss_fn (nn.CrossEntropyLoss): 已初始化的 nn.CrossEntropyLoss 对象,
|
| | 要求在初始化时已经注册了 weight 缓冲区。
|
| | gt (torch.Tensor): 当前 batch 的 ground truth 标签,1D整数张量,标签取值范围 [0, num_classes-1]。
|
| | """
|
| | class_counts = torch.bincount(gt, minlength=num_classes).float()
|
| | epsilon = 1e-6
|
| | new_weights = 1.0 / (class_counts + epsilon)
|
| | new_weights = new_weights / new_weights.sum() * num_classes
|
| |
|
| | loss_fn.register_buffer('weight', new_weights.to(device))
|
| |
|
| | def train_cls(args, epoch, model, train_loader, valid_loader, device, criterion, optimizer):
|
| | train_loss = 0
|
| | num_labels = model.classes
|
| | avg = args.metric_avg
|
| | if num_labels == 1 or num_labels == 2:
|
| | task = 'binary'
|
| | else:
|
| | task = 'multiclass'
|
| | metric_acc = Accuracy(average=avg, task=task, num_classes=num_labels).to(device)
|
| | metric_f1 = F1Score(average=avg, task=task, num_classes=num_labels).to(device)
|
| | metric_ap = AveragePrecision(average=avg, task=task, num_classes=num_labels).to(device)
|
| | metric_auc = AUROC(average=avg, task=task, num_classes=num_labels).to(device)
|
| |
|
| | if train_loader is not None:
|
| | model.train()
|
| | for data in train_loader:
|
| | x, gt = data
|
| | x = move_to_device(x, device)
|
| | out = model(x)
|
| | update_ce_loss_weight(criterion, gt, num_classes=num_labels, device=device)
|
| | loss = criterion(out, gt.to(device))
|
| | loss.backward()
|
| | optimizer.step()
|
| | optimizer.zero_grad()
|
| | train_loss += loss.item()
|
| | train_loss /= len(train_loader)
|
| |
|
| | model.eval()
|
| | preds = []
|
| | gt_list_valid = []
|
| | with torch.no_grad():
|
| | for data in valid_loader:
|
| | x, gt = data
|
| | x = move_to_device(x, device)
|
| | gt_list_valid.append(gt.to(device))
|
| | out = model(x)
|
| | preds.append(out)
|
| |
|
| |
|
| | preds = torch.softmax(torch.cat(preds, dim=0), dim=-1).squeeze()
|
| | gt_list_valid = torch.cat(gt_list_valid, dim=0).int().squeeze()
|
| |
|
| | if num_labels == 2:
|
| | preds = preds[:, 1]
|
| |
|
| | ap = metric_ap(preds, gt_list_valid).item()
|
| | auc = metric_auc(preds, gt_list_valid).item()
|
| | f1 = metric_f1(preds, gt_list_valid).item()
|
| | acc = metric_acc(preds, gt_list_valid).item()
|
| | return train_loss, ap, auc, f1, acc
|
| |
|
| | def train_base(args, epoch, model, train_loader, valid_loader, device, criterion, optimizer):
|
| | train_loss = 0
|
| | num_labels = model.classes
|
| | metric_mae = MeanAbsoluteError().to(device)
|
| | metric_rse = RelativeSquaredError(num_outputs=num_labels).to(device)
|
| | metric_pcc = PearsonCorrCoef(num_outputs=num_labels).to(device)
|
| | metric_kcc = KendallRankCorrCoef(num_outputs=num_labels).to(device)
|
| |
|
| | if args.dir:
|
| | encodings, labels = [], []
|
| |
|
| | if train_loader is not None:
|
| | model.train()
|
| | for data in train_loader:
|
| | seq1, gt = data
|
| | if args.dir:
|
| | out, features = model(seq1.to(device),
|
| | gt.to(device),
|
| | epoch)
|
| | encodings.append(features.detach().cpu())
|
| | labels.append(gt.cpu())
|
| | else:
|
| | out = model(seq1.to(device))
|
| | loss = criterion(out, gt.to(device))
|
| | loss.backward()
|
| | optimizer.step()
|
| | optimizer.zero_grad()
|
| | train_loss += loss.item()
|
| | train_loss /= len(train_loader)
|
| |
|
| | if args.dir:
|
| | encodings, labels = torch.cat(encodings), torch.cat(labels)
|
| | model.FDS.update_last_epoch_stats(epoch)
|
| | model.FDS.update_running_stats(encodings, labels, epoch)
|
| | encodings, labels = [], []
|
| |
|
| |
|
| | model.eval()
|
| | preds = []
|
| | gt_list_valid = []
|
| | with torch.no_grad():
|
| | for data in valid_loader:
|
| | seq1, gt = data
|
| | gt_list_valid.append(gt.to(device))
|
| | out = model(seq1.to(device))
|
| | if args.dir:
|
| | out, _ = out
|
| | preds.append(out)
|
| |
|
| |
|
| | preds = torch.cat(preds, dim=0)
|
| | gt_list_valid = torch.cat(gt_list_valid, dim=0)
|
| |
|
| | mae = metric_mae(preds, gt_list_valid).item()
|
| | rse = metric_rse(preds, gt_list_valid).item()
|
| | pcc = metric_pcc(preds.squeeze(), gt_list_valid.squeeze()).mean().item()
|
| | kcc = metric_kcc(preds.squeeze(), gt_list_valid.squeeze()).mean().item()
|
| | return train_loss, mae, rse, pcc, kcc |