TestingViscosity / models /training.py
biplab2008's picture
single file added
dc7407d
raw
history blame
No virus
4.53 kB
from torch import nn, tensor, float32
import os
import torch
from torch.utils.data.dataloader import DataLoader
from sklearn.metrics import r2_score
import numpy as np
from models.viscosity_models import CNN3D
from pytorchtools import EarlyStopping
from typing import List, Optional, Callable, Tuple
from utils.datastruct import history, metrics
from tqdm import tqdm
def train(model : CNN3D,
data_loader : DataLoader,
optimizer : torch.optim.Optimizer,
criterion : torch.nn.modules.loss._Loss,
device : torch.device) -> float:
train_loss = []
model.train()
for (X,y) in data_loader:
X = X.to(device)
y = y.to(device)
y = y.to(float32)
# zeroing grads
optimizer.zero_grad()
# model out
#out = model(data.x, data.edge_index, data.batch)
out = model(X)
#loss = criterion(out,data.y.reshape(-1,1))
loss = criterion(out,y)
loss.backward()
optimizer.step()
train_loss.append(loss.item())
return np.mean(train_loss)
def test(model : CNN3D, data_loader : DataLoader, criterion : torch.nn.modules.loss._Loss, device : torch.device) -> Tuple[float, float]:
model.eval()
y_h_all = []
y_all =[]
test_loss = []
with torch.no_grad():
for (X,y) in data_loader:
X = X.to(device)
y = y.to(float32)
y_h = model(X)
loss = criterion(y_h.detach().cpu(),y)
test_loss.append(loss)
y_h_all.extend(y_h.detach().cpu().numpy())
y_all.extend(y.numpy())
return (np.mean(test_loss), r2_score(np.array(y_all),np.array(y_h_all)))
def train_epochs(model : CNN3D ,
dataloaders : List[DataLoader],
optimizer : torch.optim.Optimizer, ##Callable[torch.optim.Optimizer],
criterion : torch.nn.modules.loss._Loss, #Callable[],
epochs : int,
early_stop : Optional[int],
device : torch.device,
path : str,
save_weights_frequency : int) -> Tuple[CNN3D, history]:
# parse dataloaders
'''
if len(data_loader)>2 :
(train_loader, val_loader, test_loader) = dataloaders
else :
(train_loader, val_loader) = dataloaders
'''
(train_loader, val_loader, test_loader) = dataloaders
if early_stop : early_stopping = EarlyStopping(patience=early_stop, verbose=True)
train_loss_list=[]
val_loss_list=[]
test_loss_list=[]
train_r2_list=[]
val_r2_list=[]
test_r2_list=[]
for epoch in tqdm(range(epochs)):
loss = train(model,
data_loader=train_loader,
optimizer=optimizer,
criterion=criterion,
device = device)
# performance evaluatons
(_,r2_train) = test(model = model,data_loader = train_loader, criterion = criterion, device = device)
(val_loss, r2_val) = test(model = model, data_loader=val_loader, criterion = criterion, device = device)
(test_loss, r2_test) = test(model = model, data_loader=test_loader, criterion = criterion, device = device)
(test_loss, r2_test) = test(model = model, data_loader=test_loader, criterion = criterion, device = device)
# early stop
if early_stop :
early_stopping(val_loss, model)
if early_stopping.early_stop:
print("Early stopping")
break
# mse
train_loss_list.append(loss)
val_loss_list.append(val_loss)
test_loss_list.append(test_loss)
# r2
train_r2_list.append(r2_train)
val_r2_list.append(r2_val)
test_r2_list.append(r2_test)
#save params
if (epoch+1) % save_weights_frequency == 0:
torch.save(model.state_dict(), os.path.join(path,'cnn3d_epoch_'+str(epoch+1)+'.pt'))
print(f'Epoch: {epoch:03d}, train loss: {loss : .4f}, val loss: {val_loss:.4f}, test loss : {test_loss:.4f}')
print(f'Epoch: {epoch:03d}, train r2: {r2_train : .4f}, val r2: {r2_val:.4f}, test r2: {r2_test:.4f}')
return model, history(metrics(r2_train, train_loss_list), metrics(r2_val, val_loss_list), metrics(r2_test, test_loss_list))