|
|
|
from transformers import AutoModel, AutoTokenizer, BatchEncoding, TrainingArguments, Trainer |
|
from functools import partial |
|
from huggingface_hub import snapshot_download |
|
from huggingface_hub.constants import HF_HUB_CACHE |
|
from accelerate import Accelerator |
|
from accelerate.utils import find_executable_batch_size as auto_find_batch_size |
|
from datasets import load_dataset, Dataset |
|
from torch.utils.data import DataLoader |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import numpy as np |
|
import json |
|
import os |
|
from tqdm import tqdm |
|
import pandas as pd |
|
|
|
import matplotlib.pyplot as plt |
|
from sklearn.metrics import ( |
|
ConfusionMatrixDisplay, |
|
accuracy_score, |
|
classification_report, |
|
confusion_matrix, |
|
f1_score, |
|
recall_score |
|
) |
|
|
|
BASE_PATH = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
class MultiHeadClassification(nn.Module): |
|
""" |
|
MultiHeadClassification |
|
|
|
An easy to use multi-head classification model. It takes a backbone model and a dictionary of head configurations. |
|
It can be used to train multiple classification tasks at once using a single backbone model. |
|
|
|
Apart from joint training, it also supports training individual heads separately, providing a simple way to freeze |
|
and unfreeze heads. |
|
|
|
Example: |
|
>>> from transformers import AutoModel, AutoTokenizer |
|
>>> from torch.optim import AdamW |
|
>>> import torch |
|
>>> import time |
|
>>> import torch.nn as nn |
|
>>> |
|
>>> # Manually load backbone model to create model |
|
>>> backbone = AutoModel.from_pretrained('BAAI/bge-m3') |
|
>>> model = MultiHeadClassification(backbone, {'binary': 2, 'sentiment': 3, 'something': 4}).to('cuda') |
|
>>> print(model) |
|
>>> # Load tokenizer for data preprocessing |
|
>>> tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3') |
|
>>> # some training data |
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt", padding=True, truncation=True) |
|
>>> optimizer = AdamW(model.parameters(), lr=5e-4) |
|
>>> samples = tokenizer(["Hello, my dog is cute", "Hello, my dog is cute", "I like turtles"], return_tensors="pt", padding=True, truncation=True).to('cuda') |
|
>>> labels = {'binary': torch.tensor([0, 0, 1]), 'sentiment': torch.tensor([0, 1, 2]), 'something': torch.tensor([0, 1, 2])} |
|
>>> model.freeze_backbone() |
|
>>> model.train(True) |
|
>>> for i in range(10): |
|
... optimizer.zero_grad() |
|
... outputs = model(samples) |
|
... loss = sum([nn.CrossEntropyLoss()(outputs[name].cpu(), labels[name]) for name in model.heads.keys()]) |
|
... loss.backward() |
|
... optimizer.step() |
|
... print(loss.item()) |
|
... #time.sleep(1) |
|
... print(model(samples)) |
|
>>> # Save full model |
|
>>> model.save('model.pth') |
|
>>> # Save head only |
|
>>> model.save_head('binary', 'binary.pth') |
|
>>> # Load full model |
|
>>> model = MultiHeadClassification(backbone, {}).to('cuda') |
|
>>> model.load('model.pth') |
|
>>> # Load head only |
|
>>> model = MultiHeadClassification(backbone, {}).to('cuda') |
|
>>> model.load_head('binary', 'binary.pth') |
|
>>> # Adding new head |
|
>>> model.add_head('new_head', 3) |
|
>>> print(model) |
|
>>> # extend dataset with data for new head |
|
>>> labels['new_head'] = torch.tensor([0, 1, 2]) |
|
>>> # Freeze all heads and backbone |
|
>>> model.freeze_all() |
|
>>> # Only unfreeze new head |
|
>>> model.unfreeze_head('new_head') |
|
>>> model.train(True) |
|
>>> for i in range(10): |
|
... optimizer.zero_grad() |
|
... outputs = model(samples) |
|
... loss = sum([nn.CrossEntropyLoss()(outputs[name].cpu(), labels[name]) for name in model.heads.keys()]) |
|
... loss.backward() |
|
... optimizer.step() |
|
... print(loss.item()) |
|
>>> print(model(samples)) |
|
|
|
Args: |
|
backbone (transformers.PreTrainedModel): A pretrained transformer model |
|
head_config (dict): A dictionary with head configurations. The key is the head name and the value is the number |
|
of classes for that head. |
|
""" |
|
def __init__(self, backbone, head_config, dropout=0.1, l2_reg=0.01): |
|
super().__init__() |
|
self.backbone = backbone |
|
self.num_heads = len(head_config) |
|
self.heads = nn.ModuleDict({ |
|
name: nn.Linear(backbone.config.hidden_size, num_classes) |
|
for name, num_classes in head_config.items() |
|
}) |
|
self.do = nn.Dropout(dropout) |
|
self.l2_reg = l2_reg |
|
self.device = 'cpu' |
|
self.torch_dtype = torch.float16 |
|
self.head_config = head_config |
|
|
|
def forward(self, x, head_names=None) -> dict: |
|
""" |
|
Forward pass of the model. |
|
|
|
Requires tokenizer output as input. The input should be a dictionary with keys 'input_ids', 'attention_mask'. |
|
|
|
Args: |
|
x (dict): Tokenizer output |
|
head_names (list): (optional) List of head names to return logits for. If None, returns logits for all heads. |
|
|
|
Returns: |
|
dict: A dictionary with head names as keys and logits as values |
|
""" |
|
x = self.backbone(**x, return_dict=True, output_hidden_states=True).last_hidden_state[:, 0, :] |
|
x = self.do(x) |
|
if head_names is None: |
|
return {name: head(x) for name, head in self.heads.items()} |
|
return {name: head(x) for name, head in self.heads.items() if name in head_names} |
|
|
|
def get_l2_loss(self): |
|
""" |
|
Getter for L2 regularization loss |
|
|
|
Returns: |
|
torch.Tensor: L2 regularization loss |
|
""" |
|
l2_loss = torch.tensor(0.).to(self.device) |
|
for param in self.parameters(): |
|
if param.requires_grad: |
|
l2_loss += torch.norm(param, 2) |
|
return (self.l2_reg * l2_loss).to(self.device) |
|
|
|
def to(self, *args, **kwargs): |
|
super().to(*args, **kwargs) |
|
if isinstance(args[0], torch.dtype): |
|
self.torch_dtype = args[0] |
|
elif isinstance(args[0], str): |
|
self.device = args[0] |
|
return self |
|
|
|
def load_head(self, head_name, path): |
|
""" |
|
Load head from a file |
|
|
|
Args: |
|
head_name (str): Name of the head |
|
path (str): Path to the file |
|
|
|
Returns: |
|
None |
|
""" |
|
model = torch.load(path) |
|
if head_name in self.heads: |
|
num_classes = model['weight'].shape[0] |
|
self.heads[head_name].load_state_dict(model) |
|
self.to(self.torch_dtype).to(self.device) |
|
self.head_config[head_name] = num_classes |
|
return |
|
|
|
assert model['weight'].shape[1] == self.backbone.config.hidden_size |
|
num_classes = model['weight'].shape[0] |
|
self.heads[head_name] = nn.Linear(self.backbone.config.hidden_size, num_classes) |
|
self.heads[head_name].load_state_dict(model) |
|
self.head_config[head_name] = num_classes |
|
|
|
self.to(self.torch_dtype).to(self.device) |
|
|
|
def save_head(self, head_name, path): |
|
""" |
|
Save head to a file |
|
|
|
Args: |
|
head_name (str): Name of the head |
|
path (str): Path to the file |
|
""" |
|
torch.save(self.heads[head_name].state_dict(), path) |
|
|
|
def save(self, path): |
|
""" |
|
Save the full model to a file |
|
|
|
Args: |
|
path (str): Path to the file |
|
""" |
|
torch.save(self.state_dict(), path) |
|
|
|
def load(self, path): |
|
""" |
|
Load the full model from a file |
|
|
|
Args: |
|
path (str): Path to the file |
|
""" |
|
self.load_state_dict(torch.load(path)) |
|
self.to(self.torch_dtype).to(self.device) |
|
|
|
def save_backbone(self, path): |
|
""" |
|
Save the backbone to a file |
|
|
|
Args: |
|
path (str): Path to the file |
|
""" |
|
self.backbone.save_pretrained(path) |
|
|
|
def load_backbone(self, path): |
|
""" |
|
Load the backbone from a file |
|
|
|
Args: |
|
path (str): Path to the file |
|
""" |
|
self.backbone = AutoModel.from_pretrained(path) |
|
self.to(self.torch_dtype).to(self.device) |
|
|
|
def freeze_backbone(self): |
|
""" Freeze the backbone """ |
|
for param in self.backbone.parameters(): |
|
param.requires_grad = False |
|
|
|
def unfreeze_backbone(self): |
|
""" Unfreeze the backbone """ |
|
for param in self.backbone.parameters(): |
|
param.requires_grad = True |
|
|
|
def freeze_head(self, head_name): |
|
""" |
|
Freeze a head by name |
|
|
|
Args: |
|
head_name (str): Name of the head |
|
""" |
|
for param in self.heads[head_name].parameters(): |
|
param.requires_grad = False |
|
|
|
def unfreeze_head(self, head_name): |
|
""" |
|
Unfreeze a head by name |
|
|
|
Args: |
|
head_name (str): Name of the head |
|
""" |
|
for param in self.heads[head_name].parameters(): |
|
param.requires_grad = True |
|
|
|
def freeze_all_heads(self): |
|
""" Freeze all heads """ |
|
for head_name in self.heads.keys(): |
|
self.freeze_head(head_name) |
|
|
|
def unfreeze_all_heads(self): |
|
""" Unfreeze all heads """ |
|
for head_name in self.heads.keys(): |
|
self.unfreeze_head(head_name) |
|
|
|
def freeze_all(self): |
|
""" Freeze all """ |
|
self.freeze_backbone() |
|
self.freeze_all_heads() |
|
|
|
def unfreeze_all(self): |
|
""" Unfreeze all """ |
|
self.unfreeze_backbone() |
|
self.unfreeze_all_heads() |
|
|
|
def add_head(self, head_name, num_classes): |
|
""" |
|
Add a new head to the model |
|
|
|
Args: |
|
head_name (str): Name of the head |
|
num_classes (int): Number of classes for the head |
|
""" |
|
self.heads[head_name] = nn.Linear(self.backbone.config.hidden_size, num_classes) |
|
self.heads[head_name].to(self.torch_dtype).to(self.device) |
|
self.head_config[head_name] = num_classes |
|
|
|
def remove_head(self, head_name): |
|
""" |
|
Remove a head from the model |
|
""" |
|
if head_name not in self.heads: |
|
raise ValueError(f'Head {head_name} not found') |
|
del self.heads[head_name] |
|
del self.head_config[head_name] |
|
|
|
@classmethod |
|
def from_pretrained(cls, model_name, head_config=None, dropout=0.1, l2_reg=0.01): |
|
""" |
|
Load a pretrained model from Huggingface model hub |
|
|
|
Args: |
|
model_name (str): Name of the model |
|
head_config (dict): Head configuration |
|
dropout (float): Dropout rate |
|
l2_reg (float): L2 regularization rate |
|
""" |
|
if head_config is None: |
|
head_config = {} |
|
|
|
hf_cache_dir = HF_HUB_CACHE |
|
model_path = os.path.join(hf_cache_dir, model_name) |
|
if os.path.exists(model_path): |
|
return cls._from_directory(model_path, head_config, dropout, l2_reg) |
|
|
|
model_path = snapshot_download(repo_id=model_name, cache_dir=hf_cache_dir) |
|
return cls._from_directory(model_path, head_config, dropout, l2_reg) |
|
|
|
@classmethod |
|
def _from_directory(cls, model_path, head_config, dropout=0.1, l2_reg=0.01): |
|
""" |
|
Load a model from a directory |
|
|
|
Args: |
|
model_path (str): Path to the model directory |
|
head_config (dict): Head configuration |
|
dropout (float): Dropout rate |
|
l2_reg (float): L2 regularization rate |
|
""" |
|
backbone = AutoModel.from_pretrained(os.path.join(model_path, 'pretrained/backbone.pth')) |
|
instance = cls(backbone, head_config, dropout, l2_reg) |
|
instance.load(os.path.join(model_path, 'pretrained/model.pth')) |
|
instance.head_config = {k: v. instance.heads} |
|
return instance |
|
|
|
class MultiHeadClassificationTrainer: |
|
def __init__(self, **kwargs): |
|
self.model_conf = kwargs.get('model_conf', {}) |
|
self.optimizer_conf = kwargs.get('optimizer_conf', {}) |
|
self.scheduler_conf = kwargs.get('scheduler_conf', {}) |
|
self.dropout = kwargs.get('dropout', 0.1) |
|
self.l2_loss_weight = kwargs.get('l2_loss_weight', 0.01) |
|
self.num_epochs = kwargs.get('num_epochs', 100) |
|
self.device = kwargs.get('device', 'cuda') |
|
self.train_run = kwargs.get('train_run', 0) |
|
self.name_prefix = kwargs.get('name_prefix', 'multihead-classification') |
|
self.use_lr_scheduler = kwargs.get('use_lr_scheduler', True) |
|
self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1) |
|
self.batch_size = kwargs.get('batch_size', 4) |
|
self.train_test_split = kwargs.get('train_test_split', 0.2) |
|
self.load_best = kwargs.get('load_best', True) |
|
self.auto_find_batch_size = kwargs.get('auto_find_batch_size', False) |
|
self.test_data = None |
|
self.accelerator = Accelerator() |
|
|
|
self.classifier = MultiHeadClassification( |
|
**self.model_conf |
|
).to(torch.float16) |
|
self.classifier.freeze_backbone() |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_conf.get('tokenizer', self.classifier.backbone.name_or_path), model_max_length=128) |
|
|
|
def _batch_data(self, batch_size, data): |
|
return DataLoader(data, shuffle=True, batch_size=batch_size) |
|
|
|
def train(self, dataset_name: str = None, train_data: DataLoader = None, val_data: DataLoader = None, lr: float = None, num_epochs: int = None, target_heads: list[str] = None, batch_size: int = 4, sample_key=None, label_key=None): |
|
has_dataset = train_data is not None |
|
assert (dataset_name is not None and not has_dataset) or (has_dataset and dataset_name is None), 'Must provide either dataset or dataset_name' |
|
if dataset_name is not None: |
|
assert target_heads is not None, 'target_heads must be provided when using dataset_name' |
|
|
|
if sample_key is None: |
|
sample_key = 'sample' |
|
if label_key is None: |
|
label_key = 'label' |
|
|
|
self.accelerator.free_memory() |
|
self.classifier = self.accelerator.prepare(self.classifier) |
|
|
|
if dataset_name is not None: |
|
dataset = load_dataset(dataset_name)['train'].train_test_split(test_size=self.train_test_split) |
|
train_data = dataset['train'] |
|
val_data = dataset['test'].train_test_split(test_size=0.5) |
|
self.test_data = val_data['test'] |
|
val_data = val_data['train'] |
|
|
|
if batch_size is not None: |
|
self.batch_size = batch_size |
|
|
|
if isinstance(train_data, Dataset): |
|
sample = next(iter(train_data)) |
|
print('Tokenizing dataset...', sample, type(sample)) |
|
is_string_dataset = isinstance(sample[0], str) if not isinstance(sample, dict) else isinstance(sample[sample_key], str) |
|
|
|
if is_string_dataset: |
|
if isinstance(sample, list): |
|
train_data = train_data.map(lambda x: self.tokenizer([x[0]] if isinstance(x[0], str) else x[0], return_tensors="pt", padding=True, truncation=True), batched=True) |
|
val_data = val_data.map(lambda x: self.tokenizer([x[0]] if isinstance(x[0], str) else x[0], return_tensors="pt", padding=True, truncation=True), batched=True) |
|
elif isinstance(sample, dict): |
|
assert sample_key in sample and label_key in sample, 'Invalid dataset format' |
|
train_data = train_data.map(lambda x: self.tokenizer([x[sample_key]] if isinstance(x[sample_key], str) else x[sample_key], return_tensors="pt", padding=True, truncation=True), batched=True) |
|
val_data = val_data.map(lambda x: self.tokenizer([x[sample_key]] if isinstance(x[sample_key], str) else x[sample_key], return_tensors="pt", padding=True, truncation=True), batched=True) |
|
else: |
|
raise ValueError('Invalid dataset format') |
|
|
|
create_train_data = partial(self._batch_data, data=train_data) |
|
create_val_data = partial(self._batch_data, data=val_data) |
|
|
|
if self.auto_find_batch_size: |
|
train_data = auto_find_batch_size(create_train_data)() |
|
val_data = auto_find_batch_size(create_val_data)() |
|
else: |
|
train_data = create_train_data(self.batch_size) |
|
val_data = create_val_data(self.batch_size) |
|
|
|
else: |
|
assert train_data is not None and val_data is not None, 'train_data and val_data must be provided' |
|
assert isinstance(train_data, DataLoader) and isinstance(val_data, DataLoader), 'train_data and val_data must be DataLoader instances' |
|
|
|
optimizer_name = self.optimizer_conf.pop('optimizer', 'sgd') |
|
loss_name = self.optimizer_conf.pop('loss', 'crossentropy') |
|
if lr: |
|
self.optimizer_conf['lr'] = lr |
|
if num_epochs: |
|
self.num_epochs = num_epochs |
|
|
|
self.classifier.unfreeze_all() |
|
|
|
print('Freezing backbone') |
|
self.classifier.freeze_backbone() |
|
|
|
|
|
if target_heads is None: |
|
sample = next(iter(train_data)) |
|
if isinstance(sample, dict): |
|
train_heads = list(sample[label_key].keys()) |
|
elif isinstance(sample, list): |
|
train_heads = list(sample[1].keys()) |
|
else: |
|
raise ValueError('Invalid dataset format') |
|
else: |
|
train_heads = target_heads |
|
|
|
for head_name in self.classifier.heads.keys(): |
|
if head_name not in train_heads: |
|
print(f'Freezing head {head_name}') |
|
self.classifier.freeze_head(head_name) |
|
|
|
self.classifier.to(self.device) |
|
self.classifier.train(True) |
|
loss_func = {'crossentropy': nn.CrossEntropyLoss, 'bce': nn.BCELoss}.get(loss_name, nn.CrossEntropyLoss) |
|
optimizer_class = {'sgd': optim.SGD, 'adam': optim.Adam}.get(optimizer_name, optim.SGD) |
|
optimizer = optimizer_class(self.classifier.parameters(), **self.optimizer_conf) |
|
|
|
scheduler = None |
|
if self.use_lr_scheduler: |
|
scheduler_class = { |
|
'plateau': optim.lr_scheduler.ReduceLROnPlateau, |
|
'step': optim.lr_scheduler.StepLR, |
|
}.get(self.scheduler_conf.get('scheduler'), optim.lr_scheduler.ReduceLROnPlateau) |
|
scheduler = scheduler_class(optimizer, 'min', **self.scheduler_conf) |
|
|
|
history = self._train(loss_func(), optimizer, scheduler, self.accelerator.prepare(train_data), self.accelerator.prepare(val_data), train_heads, sample_key, label_key) |
|
if self.load_best: |
|
self.classifier.load(os.path.join(BASE_PATH, f'../train_runs/{self.name_prefix}-run-{self.train_run-1}-best-model.pth')) |
|
return self.classifier, history |
|
|
|
def _train(self, criterion, optimizer, scheduler, dataloader, val_dataloader, head_names, sample_key, label_key): |
|
average_acc = 0 |
|
losses = [] |
|
precisions = [] |
|
best_prec = 0.0 |
|
|
|
val_losses = [] |
|
val_accs = [] |
|
avg_val_acc = 0.0 |
|
|
|
patience = 50 |
|
reset_patience = 25 |
|
patience_reset_counter = 0 |
|
patience_counter = 0 |
|
current_max = 0 |
|
total_max = 0 |
|
num_samples = len(dataloader) |
|
pbar = tqdm(total=self.num_epochs * num_samples, desc='Training model...') |
|
for epoch in range(self.num_epochs): |
|
self.classifier.train() |
|
running_loss = 0.0 |
|
all_preds = {name: [] for name in head_names} |
|
all_labels = {name: [] for name in head_names} |
|
|
|
for step, sample in enumerate(dataloader): |
|
labels = {name: sample[label_key] for name in head_names} |
|
embeddings = BatchEncoding({k: torch.stack(v, dim=1).to(self.device) for k, v in sample.items() if k not in [label_key, sample_key]}).to(self.device) |
|
outputs = self.classifier(embeddings, head_names=head_names) |
|
loss = sum([criterion(outputs[name].to(self.device), labels[name].to(self.device)) for name in labels.keys()]) |
|
loss += self.l2_loss_weight * self.classifier.get_l2_loss().to(self.device) |
|
running_loss += loss.item() |
|
loss.backward() |
|
if (step + 1) % self.gradient_accumulation_steps == 0: |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
for name in labels.keys(): |
|
preds = outputs[name][0].argmax().item() |
|
all_labels[name].append(labels[name][0].cpu().numpy()) |
|
all_preds[name].append(preds) |
|
|
|
pbar.update(1) |
|
|
|
torch.cuda.empty_cache() |
|
|
|
epoch_loss = running_loss / num_samples |
|
if scheduler: |
|
scheduler.step(epoch_loss) |
|
|
|
average_acc += np.mean([np.mean(np.abs(np.array(all_labels[name]) - np.array(all_preds[name])) == 0) for name in head_names]) |
|
average_acc /= 2.0 |
|
if val_dataloader: |
|
val_loss, val_acc = self.validate(self.classifier, criterion, val_dataloader, head_names, sample_key, label_key) |
|
avg_val_acc += val_acc.item() |
|
avg_val_acc /= 2.0 |
|
val_losses.append(val_loss) |
|
val_accs.append(val_acc) |
|
losses.append(epoch_loss) |
|
precisions.append(average_acc) |
|
if avg_val_acc > current_max: |
|
current_max = avg_val_acc |
|
self.classifier.save(os.path.join(BASE_PATH, f'../train_runs/{self.name_prefix}-run-{self.train_run}-best-model.pth')) |
|
best_prec = max(average_acc, best_prec) |
|
|
|
pbar_data = { |
|
'epoch': epoch + 1, |
|
'loss': epoch_loss, |
|
'avg_acc': average_acc, |
|
'acc_max': best_prec |
|
} |
|
if scheduler: |
|
pbar_data['lr'] = scheduler.get_last_lr()[0] |
|
if val_dataloader: |
|
pbar_data['val_loss'] = val_loss |
|
pbar_data['val_acc'] = val_acc |
|
pbar_data['avg_val_acc'] = avg_val_acc |
|
pbar.set_postfix(pbar_data) |
|
|
|
torch.cuda.empty_cache() |
|
|
|
pbar.close() |
|
param_dict = { |
|
'dropout': self.dropout, |
|
'model_conf': {k:v for k, v in self.model_conf.items() if k not in ['tokenizer', 'backbone']}, |
|
'optimizer_conf': self.optimizer_conf, |
|
'scheduler_conf': self.scheduler_conf, |
|
'l2_loss_weight': self.l2_loss_weight, |
|
'num_epochs': self.num_epochs, |
|
'device': self.device, |
|
'train_run': self.train_run, |
|
'name_prefix': self.name_prefix, |
|
'use_lr_scheduler': self.use_lr_scheduler, |
|
'metrics': { |
|
'loss': losses, |
|
'val_loss': val_losses, |
|
'precision': precisions, |
|
'val_precision': val_accs |
|
} |
|
} |
|
with open(os.path.join(BASE_PATH, f'../train_runs/{self.name_prefix}-train-run-{self.train_run}.json'), 'w') as f: |
|
json.dump(param_dict, f) |
|
print("Training complete!") |
|
self.train_run += 1 |
|
|
|
return param_dict |
|
|
|
def _plot_history(self, loss, val_loss, precision, val_precision): |
|
fig = plt.figure(figsize=(15,7)) |
|
ax = plt.subplot(1,2, 1) |
|
ax.set_title('loss') |
|
plt.plot(range(len(loss)), loss, 'g--', label='train_loss') |
|
plt.plot(range(len(loss)), val_loss, 'r--', label='val_loss') |
|
plt.yscale('log') |
|
plt.legend() |
|
ax = plt.subplot(1,2, 2) |
|
ax.set_title('accuracy') |
|
plt.plot(range(len(precision)), precision, 'g--', label='prec') |
|
plt.plot(range(len(precision)), val_precision, 'r--',label='val_prec') |
|
plt.legend() |
|
return fig |
|
|
|
def validate(self, model, criterion, dataloader, head_names=None, sample_key='sample', label_key='label'): |
|
running_loss = 0 |
|
num_samples = len(dataloader) |
|
if head_names is None: |
|
sample = next(iter(dataloader))[1] |
|
head_names = list(sample.keys()) |
|
|
|
all_labels = {name: [] for name in head_names} |
|
all_preds = {name: [] for name in head_names} |
|
|
|
num_labels = {name: model.heads[name].out_features for name in head_names} |
|
|
|
model.train(False) |
|
for sample in dataloader: |
|
labels = {name: sample[label_key] for name in head_names} |
|
embeddings = BatchEncoding({k: torch.stack(v, dim=1).to(self.device) for k, v in sample.items() if k not in [label_key, sample_key]}) |
|
outputs = model(embeddings) |
|
loss = sum([criterion(outputs[name].to(self.device), labels[name].to(self.device)) for name in head_names]) |
|
loss += self.l2_loss_weight * model.get_l2_loss().to(self.device) |
|
running_loss += loss.item() |
|
|
|
for name in head_names: |
|
preds = outputs[name][0].argmax().item() |
|
all_labels[name].append(labels[name][0].cpu().numpy()) |
|
all_preds[name].append(preds) |
|
torch.cuda.empty_cache() |
|
return running_loss / num_samples, np.mean([np.mean(np.abs(np.array(all_labels[name]) - np.array(all_preds[name])) == 0) for name in head_names]) |
|
|
|
def eval(self, label_map, test_set=None, sample_key='sample', label_key='label'): |
|
if test_set is None: |
|
assert self.test_data is not None, 'No test data provided' |
|
test_set = self.test_data |
|
sample = next(iter(test_set)) |
|
is_string_dataset = isinstance(sample[0], str) if not isinstance(sample, dict) else isinstance(sample[sample_key], str) |
|
|
|
if is_string_dataset: |
|
if isinstance(sample, list): |
|
test_set = test_set.map(lambda x: self.tokenizer([x[0]] if isinstance(x[0], str) else x[0], return_tensors="pt", padding=True, truncation=True), batched=True) |
|
elif isinstance(sample, dict): |
|
assert sample_key in sample and label_key in sample, 'Invalid dataset format' |
|
test_set = test_set.map(lambda x: self.tokenizer([x[sample_key]] if isinstance(x[sample_key], str) else x[sample_key], return_tensors="pt", padding=True, truncation=True), batched=True) |
|
else: |
|
raise ValueError('Invalid dataset format') |
|
|
|
test_set = DataLoader(test_set, shuffle=True, batch_size=self.batch_size) |
|
self.classifier.to(self.device) |
|
return self._eval_model(test_set, label_map, sample_key, label_key) |
|
|
|
def _eval_model(self, dataloader, label_map, sample_key, label_key): |
|
self.classifier.train(False) |
|
eval_heads = list(label_map.keys()) |
|
y_pred = {h: [] for h in eval_heads} |
|
y_test = {h: [] for h in eval_heads} |
|
for sample in tqdm(dataloader, total=len(dataloader), desc='Evaluating model...'): |
|
labels = {name: sample[label_key] for name in eval_heads} |
|
embeddings = BatchEncoding({k: torch.stack(v, dim=1).to(self.device) for k, v in sample.items() if k not in [label_key, sample_key]}) |
|
output = self.classifier(embeddings.to('cuda'), head_names=eval_heads) |
|
for head in eval_heads: |
|
y_pred[head].extend(output[head].argmax(dim=1).cpu()) |
|
y_test[head].extend(labels[head]) |
|
torch.cuda.empty_cache() |
|
|
|
accuracies = {h: accuracy_score(y_test[h], y_pred[h]) for h in eval_heads} |
|
f1_scores = {h: f1_score(y_test[h], y_pred[h], average="macro") for h in eval_heads} |
|
recalls = {h: recall_score(y_test[h], y_pred[h], average='macro') for h in eval_heads} |
|
|
|
report = {} |
|
for head in eval_heads: |
|
cm = confusion_matrix(y_test[head], y_pred[head], labels=list(label_map[head].keys())) |
|
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=list(label_map[head].values())) |
|
clf_report = classification_report( |
|
y_test[head], y_pred[head], output_dict=True, target_names=list(label_map[head].values()) |
|
) |
|
del clf_report["accuracy"] |
|
clf_report = pd.DataFrame(clf_report).T.reset_index() |
|
report[head] = dict( |
|
clf_report=clf_report, confusion_matrix=disp, metrics={'accuracy': accuracies[head], 'f1': f1_scores[head], 'recall': recalls[head]} |
|
) |
|
return report |
|
|