| | import torch |
| | import evaluate |
| | from tqdm import tqdm |
| | import logging |
| |
|
| | class Tester: |
| | def __init__(self, test_dataset_dict, model, train_domain) -> None: |
| | self.test_dataset_dict = test_dataset_dict |
| | self.model = model |
| | self.train_domain = train_domain |
| |
|
| | self.accuracy = evaluate.load("accuracy") |
| | self.f1 = evaluate.load("f1") |
| | self.precision = evaluate.load("precision") |
| | self.recall = evaluate.load("recall") |
| | self.loss_fn = torch.nn.BCELoss() |
| |
|
| | self.device = torch.device( |
| | "cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | def _validate(self, test_dataset): |
| | with torch.no_grad(): |
| | total_loss = 0 |
| |
|
| | for batch in tqdm(test_dataset): |
| | input_ids = batch['input_ids'].to(self.device) |
| | attention_mask = batch['attention_mask'].to(self.device) |
| | labels = batch['label'].to(self.device) |
| |
|
| | logits = self.model(input_ids, attention_mask=attention_mask).squeeze(dim=1) |
| | |
| | loss = self.loss_fn(logits, labels.float()) |
| |
|
| | |
| |
|
| | predictions = (logits > 0.5).long() |
| | |
| | |
| | predictions = predictions.cpu() |
| | labels = labels.cpu() |
| |
|
| | accuracy = self.accuracy.add_batch( |
| | predictions=predictions, references=labels) |
| | |
| | f1 = self.f1.add_batch( |
| | predictions=predictions, references=labels) |
| | |
| | precision = self.precision.add_batch( |
| | predictions=predictions, references=labels) |
| | |
| | recall = self.recall.add_batch( |
| | predictions=predictions, references=labels) |
| | |
| | total_loss += loss.item() |
| | |
| | accuracy = self.accuracy.compute()['accuracy'] |
| | f1 = self.f1.compute()['f1'] |
| | precision = self.precision.compute()['precision'] |
| | recall = self.recall.compute()['recall'] |
| | total_loss = total_loss / len(test_dataset) |
| |
|
| | return accuracy, f1, precision, recall, total_loss |
| | |
| |
|
| | def validate(self): |
| | self.model.eval() |
| | self.model.to(self.device) |
| |
|
| | results = {} |
| | average_results = {} |
| |
|
| | for domain in self.test_dataset_dict.keys(): |
| | logging.info(f"Testing {domain} domain...") |
| | accuracy, f1, precision, recall, total_loss = self._validate(self.test_dataset_dict[domain]) |
| |
|
| | results[domain] = { |
| | 'accuracy': accuracy, |
| | 'f1': f1, |
| | 'precision': precision, |
| | 'recall': recall, |
| | 'loss': total_loss |
| | } |
| |
|
| | |
| | if self.train_domain in results.keys(): |
| | results.pop(self.train_domain) |
| | |
| | if len(results.keys()) == 0: |
| | logging.info("Only one domain to test, returning results") |
| | return results |
| | |
| | |
| | for metric in ['accuracy', 'f1', 'precision', 'recall', 'loss']: |
| | average_results[metric] = sum([results[domain][metric] for domain in results.keys()]) / len(results.keys()) |
| |
|
| | return results, average_results |
| | |
| | |
| | def _bagging(self, logits): |
| | |
| | return torch.mean(logits, dim=0) |
| |
|
| |
|
| | def _test(self, test_dataset): |
| | with torch.no_grad(): |
| | total_loss = 0 |
| |
|
| | for batch in tqdm(test_dataset): |
| | input_ids = batch['input_ids'].to(self.device) |
| | attention_mask = batch['attention_mask'].to(self.device) |
| | labels = batch['label'].to(self.device) |
| |
|
| | logits = self.model(input_ids, attention_mask=attention_mask).squeeze(dim=1) |
| |
|
| | logits = self._bagging(logits) |
| | |
| | loss = self.loss_fn(logits, labels.float()) |
| |
|
| | |
| | predictions = (logits > 0.5).long() |
| | |
| | |
| | predictions = predictions.cpu() |
| | labels = labels.cpu() |
| |
|
| | accuracy = self.accuracy.add_batch( |
| | predictions=predictions, references=labels) |
| | |
| | f1 = self.f1.add_batch( |
| | predictions=predictions, references=labels) |
| | |
| | precision = self.precision.add_batch( |
| | predictions=predictions, references=labels) |
| | |
| | recall = self.recall.add_batch( |
| | predictions=predictions, references=labels) |
| | |
| | total_loss += loss.item() |
| | |
| | accuracy = self.accuracy.compute()['accuracy'] |
| | f1 = self.f1.compute()['f1'] |
| | precision = self.precision.compute()['precision'] |
| | recall = self.recall.compute()['recall'] |
| | total_loss = total_loss / len(test_dataset) |
| |
|
| | return accuracy, f1, precision, recall, total_loss |
| |
|
| | def test(self): |
| | results={} |
| |
|
| | with torch.no_grad(): |
| | for test_set in self.test_dataset_dict.keys(): |
| | logging.info(f"Testing {test_set} dataset") |
| | accuracy, f1, precision, recall, total_loss = self._test(self.test_dataset_dict[test_set]) |
| | |
| | results[test_set] = { |
| | 'accuracy': accuracy, |
| | 'f1': f1, |
| | 'precision': precision, |
| | 'recall': recall, |
| | 'loss': total_loss |
| | } |
| |
|
| | logging.info(f"Results for {test_set} dataset: {results[test_set]}") |
| | |
| | return results |