Spaces:
Runtime error
Runtime error
| """ | |
| @author: Caglar Aytekin | |
| contact: caglar@deepcause.ai | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader, TensorDataset | |
| from sklearn.metrics import accuracy_score as accuracy | |
| from sklearn.metrics import roc_auc_score | |
| from torch.optim.lr_scheduler import StepLR | |
| import numpy as np | |
| import copy | |
| class Trainer: | |
| def __init__(self, model, X_train, X_val, y_train, y_val,lr,batch_size,epochs,problem_type,verbose=True): | |
| self.model = model | |
| self.optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
| self.problem_type=problem_type | |
| self.verbose=verbose | |
| if self.problem_type==0: | |
| self.criterion = nn.MSELoss() | |
| elif self.problem_type==1: | |
| self.criterion = nn.BCEWithLogitsLoss() | |
| elif self.problem_type==2: | |
| self.criterion = nn.CrossEntropyLoss() | |
| y_train=y_train.squeeze().long() | |
| y_val=y_val.squeeze().long() | |
| train_dataset = TensorDataset(X_train, y_train) | |
| val_dataset = TensorDataset(X_val, y_val) | |
| self.train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) | |
| self.val_loader = DataLoader(dataset=val_dataset, batch_size=len(val_dataset), shuffle=False) | |
| self.batch_size=batch_size | |
| self.epochs=epochs | |
| self.best_metric = float('inf') if problem_type == 0 else float('-inf') | |
| self.scheduler = StepLR(self.optimizer, step_size=epochs//3, gamma=0.2) | |
| def train_epoch(self): | |
| self.model.train() | |
| total_loss = 0 | |
| total=0 | |
| correct=0 | |
| for inputs, labels in self.train_loader: | |
| self.optimizer.zero_grad() | |
| outputs = self.model(inputs) | |
| loss = self.criterion(outputs, labels)# + torch.sum(torch.abs(self.model.causal_discovery()))*1 | |
| loss.backward() | |
| self.optimizer.step() | |
| total_loss += loss.item() | |
| total += len(labels.squeeze()) | |
| if self.problem_type==1: | |
| correct += (torch.round(torch.sigmoid(outputs.data)).squeeze() == labels.squeeze()).sum().item() | |
| elif self.problem_type==2: | |
| correct += (torch.max(outputs.data, 1)[1] == labels.squeeze()).sum().item() | |
| return total_loss/len(self.train_loader) , correct/total | |
| def validate(self): | |
| self.model.eval() | |
| val_loss = 0 | |
| total=0 | |
| val_predictions = [] | |
| val_targets = [] | |
| with torch.no_grad(): | |
| for inputs, labels in self.val_loader: | |
| outputs = self.model(inputs) | |
| val_loss += self.criterion(outputs, labels).item() | |
| total += len(labels.squeeze()) | |
| if self.problem_type==1: | |
| val_predictions.extend(torch.sigmoid(outputs).view(-1).cpu().numpy()) | |
| elif self.problem_type==2: | |
| val_predictions.extend(torch.max(outputs.data, 1)[1].view(-1).cpu().numpy()) | |
| val_targets.extend(labels.view(-1).cpu().numpy()) | |
| if self.problem_type==1: | |
| val_roc_auc =roc_auc_score(val_targets, val_predictions) | |
| val_acc = accuracy(val_targets, np.round(val_predictions)) | |
| elif self.problem_type==2: | |
| val_acc = accuracy(val_targets,val_predictions) | |
| val_roc_auc=0 | |
| else: | |
| val_roc_auc=0 | |
| val_acc=0 | |
| return val_loss /len(self.val_loader), val_acc,val_roc_auc | |
| def train(self): | |
| for epoch in range(self.epochs): | |
| #Increase alpha up to 1-tenth of entire epochs | |
| alpha_now=np.minimum(1.0,float(epoch)/float(self.epochs/10)) | |
| # print(alpha_now) | |
| self.model.set_alpha(alpha_now) | |
| if epoch>self.epochs//10: | |
| save_permit=True | |
| else: | |
| save_permit=False | |
| tr_loss, tr_acc = self.train_epoch() | |
| val_loss, val_acc , val_roc_auc= self.validate() | |
| if self.problem_type == 0: | |
| if self.verbose: | |
| print(f'Epoch {epoch}: Train Loss {tr_loss:.4f}, Val Loss {val_loss:.4f}') | |
| if (val_loss < self.best_metric)and(save_permit): | |
| self.best_metric = val_loss | |
| # Save model checkpoint | |
| self.model.nninput=None #Delete data remaining from training | |
| self.encodings=None | |
| self.taus=None | |
| # torch.save(self.model, 'best_model.pth') | |
| # torch.save(self.model.state_dict(), 'best_model_weights.pth') | |
| self.best_model=copy.deepcopy(self.model.state_dict()) | |
| # print("Saving model with best validation loss.") | |
| # Problem type 1: Focus on loss, accuracy, and AUC | |
| elif self.problem_type == 1: | |
| if self.verbose: | |
| print(f'Epoch {epoch}: Train Loss {tr_loss:.4f}, Train Acc {tr_acc:.4f}, Val Loss {val_loss:.4f}, Val Acc {val_acc:.4f}, Val ROC AUC {val_roc_auc:.4f}') | |
| if (val_roc_auc > self.best_metric)and(save_permit): | |
| self.best_metric = val_roc_auc | |
| # Save model checkpoint | |
| self.model.nninput=None #Delete data remaining from training | |
| self.encodings=None | |
| self.taus=None | |
| # torch.save(self.model, 'best_model.pth') | |
| # torch.save(self.model.state_dict(), 'best_model_weights.pth') | |
| self.best_model=copy.deepcopy(self.model.state_dict()) | |
| # print("Saving model with best validation ROC AUC.") | |
| # Problem type 2: Focus on loss and accuracy | |
| elif self.problem_type == 2: | |
| if self.verbose: | |
| print(f'Epoch {epoch}: Train Loss {tr_loss:.4f}, Train Acc {tr_acc:.4f}, Val Loss {val_loss:.4f}, Val Acc {val_acc:.4f}') | |
| if (val_acc > self.best_metric)and(save_permit): | |
| self.best_metric = val_acc | |
| # Save model checkpoint | |
| self.model.nninput=None #Delete data remaining from training | |
| self.encodings=None | |
| self.taus=None | |
| # torch.save(self.model, 'best_model.pth') | |
| # torch.save(self.model.state_dict(), 'best_model_weights.pth') | |
| self.best_model=copy.deepcopy(self.model.state_dict()) | |
| # print("Saving model with best validation accuracy.") | |
| self.scheduler.step() | |
| # Load best validation model | |
| self.model.load_state_dict(self.best_model) | |
| # self.model = torch.load('best_model.pth') | |
| def evaluate(self,X_test, y_test,verbose=True): | |
| test_loader=DataLoader(dataset=TensorDataset(X_test, y_test), batch_size=len(y_test), shuffle=True) | |
| self.model.eval() | |
| test_loss = 0 | |
| total=0 | |
| test_predictions = [] | |
| test_targets = [] | |
| with torch.no_grad(): | |
| for inputs, labels in test_loader: | |
| outputs = self.model(inputs) | |
| test_loss += self.criterion(outputs, labels).item() | |
| total += len(labels.squeeze()) | |
| if self.problem_type==1: | |
| test_predictions.extend(torch.sigmoid(outputs).view(-1).cpu().numpy()) | |
| elif self.problem_type==2: | |
| test_predictions.extend(torch.max(outputs.data, 1)[1].view(-1).cpu().numpy()) | |
| test_targets.extend(labels.view(-1).cpu().numpy()) | |
| if self.problem_type==1: | |
| test_roc_auc =roc_auc_score(test_targets, test_predictions) | |
| test_acc = accuracy(test_targets, np.round(test_predictions)) | |
| if verbose: | |
| print('ROC-AUC: ', test_roc_auc) | |
| return test_roc_auc | |
| elif self.problem_type==2: | |
| test_acc = accuracy(test_targets,test_predictions) | |
| test_roc_auc=0 | |
| if verbose: | |
| print('ACC: ', test_acc) | |
| return test_acc | |
| else: | |
| test_roc_auc=0 | |
| test_acc=0 | |
| if verbose: | |
| print('MSE: ', test_loss /len(test_loader)) | |
| return test_loss /len(test_loader) | |