FlowIID / src /train_vae.py
Mithlesh Singla
first commit
b31bba1
print("|| RAM ||")
import yaml
import argparse
import torch
import random
import os
import numpy as np
from tqdm import tqdm
from models.vae import VAE
from torch.utils.data.dataloader import DataLoader
from torch.optim import Adam
from models.discriminator import Discriminator
import imageio.v3 as iio
import lpips
import gc
#import time
import OpenEXR
import Imath
import torchvision.utils as vutils
def strip_prefix_if_present(state_dict, prefix="_orig_mod."):
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith(prefix):
new_state_dict[k[len(prefix):]] = v
else:
new_state_dict[k] = v
return new_state_dict
# Add this at the top of your script, before importing imageio
from dataloader_image_hyperism import HDRGrayscaleEXRDataset_new,ImageDataset,ImageDataset_d
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
#print("imported all the libraries")
import torch
#device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
print(f"Using device: {torch.cuda.get_device_name(device)} (CUDA:{torch.cuda.current_device()})")
else:
print("Using device: CPU")
loss_fn_alex = lpips.LPIPS(net='vgg').to(device)
# Set the device to GPU if available
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
from torch.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
import wandb
wandb.init(project="ldr_sh_vae_training_trial")
from torch.utils.data import Dataset
def check_nan(tensor, name="tensor"):
"""Check if tensor contains NaN values and print debugging info"""
if torch.isnan(tensor).any():
print(f"NaN detected in {name}")
non_nan_mask = ~torch.isnan(tensor)
if non_nan_mask.any():
print(f"Non-NaN values stats - Min: {tensor[non_nan_mask].min().item()}, Max: {tensor[non_nan_mask].max().item()}")
return True
return False
def to_rgb(image):
if image.shape[1] == 1:
return image.repeat(1, 3, 1, 1)
return image
def save_training_samples(output, gt_image, scene_infos, train_config, step_count, img_save_count):
# Try to import OpenEXR if available
try:
has_openexr = True
except ImportError:
has_openexr = False
print("Warning: OpenEXR not available, falling back to imageio")
sample_size = min(8, output.shape[0])
gt_image = gt_image[:sample_size].detach().cpu()
save_output = output[:sample_size].detach().cpu()#.numpy()
# Apply normalization
# epsilon = 1e-3 # Small value to prevent division by zero
# save_output = np.where(save_output == -1, save_output + epsilon, save_output)
# save_output = (1 - save_output) / (1 + save_output) # Normalize to [0, 1]
# save_output = torch.clip((save_output + 1) / 2.0, 0.0, 1.0)
# Base save path
base_save_path = os.path.join('/home/project/dataset/Hyperism', train_config['task_name'], 'vae_autoencoder_samples_1')
def save_exr(filepath, data):
"""Helper function to save EXR files using either OpenEXR or imageio
Args:
filepath (str): Path to save the EXR file
data (ndarray): Image data in shape (H, W), (H, W, C) or (C, H, W) format
Returns:
bool: True if save successful, False otherwise
"""
# Handle (C, H, W) format by converting to (H, W, C)
if len(data.shape) == 3 and (data.shape[0] == 3 or data.shape[0] == 1):
if data.shape[0] == 3 and data.shape[1] > 3 and data.shape[2] > 3: # Likely (C, H, W) format
data = np.transpose(data, (1, 2, 0))
if has_openexr and len(data.shape) == 2:
# For grayscale images
data_flat = data.astype(np.float32).tobytes()
header = OpenEXR.Header(data.shape[1], data.shape[0])
header['channels'] = {'Y': Imath.Channel(Imath.PixelType(Imath.PixelType.FLOAT))}
exr = OpenEXR.OutputFile(filepath, header)
exr.writePixels({'Y': data_flat})
exr.close()
return True
elif has_openexr and len(data.shape) == 3 and data.shape[2] == 3:
# For RGB images in (H, W, C) format
R = data[:,:,0].astype(np.float32).tobytes()
G = data[:,:,1].astype(np.float32).tobytes()
B = data[:,:,2].astype(np.float32).tobytes()
header = OpenEXR.Header(data.shape[1], data.shape[0])
header['channels'] = {
'R': Imath.Channel(Imath.PixelType(Imath.PixelType.FLOAT)),
'G': Imath.Channel(Imath.PixelType(Imath.PixelType.FLOAT)),
'B': Imath.Channel(Imath.PixelType(Imath.PixelType.FLOAT))
}
exr = OpenEXR.OutputFile(filepath, header)
exr.writePixels({'R': R, 'G': G, 'B': B})
exr.close()
return True
else:
# Fall back to imageio with tifffile plugin (which can handle EXR)
try:
# Ensure we're using a format imageio can handle
if len(data.shape) == 3 and data.shape[0] == 3 and data.shape[1] > 3 and data.shape[2] > 3:
# Convert from (C, H, W) to (H, W, C) for imageio
data = np.transpose(data, (1, 2, 0))
iio.imwrite(filepath, data, plugin='tifffile', photometric='rgb')
return True
except Exception as e:
print(f"Error saving with tifffile plugin: {e}")
try:
# Last resort - try PNG instead of EXR
png_path = filepath.replace('.exr', '.png')
# Ensure data is in correct shape for PNG
if len(data.shape) == 3 and data.shape[0] == 3 and data.shape[1] > 3 and data.shape[2] > 3:
data = np.transpose(data, (1, 2, 0))
iio.imwrite(png_path, data)
print(f"Saved as PNG instead: {png_path}")
return True
except Exception as e2:
print(f"Failed to save image: {e2}")
return False
collage = torch.cat([save_output, gt_image], dim=0)
os.makedirs("/home/project/dataset/Hyperism/ldr_to_sh_1/vae_autoencoder_samples/", exist_ok=True)
output_path = f"/home/project/dataset/Hyperism/ldr_to_sh_1/vae_autoencoder_samples/{step_count}.png"
vutils.save_image(collage, output_path, nrow=4, normalize=True)
# Also save a simple numbered output for easy viewing
simple_save_path = os.path.join(base_save_path, 'numbered_samples')
os.makedirs(simple_save_path, exist_ok=True)
return img_save_count + 1
# Create a combined dataset class
class CombinedDataset(Dataset):
def __init__(self, sh_dataset):
"""
A dataset that matches corresponding images across the three datasets based on scene metadata.
Args:
sh_dataset: The HDRGrayscaleEXRDataset for spherical harmonics shading
albedo_dataset: The ImageDataset for albedo (diffuse_reflectance.exr)
ldr_dataset: The ImageDataset for LDR input (dequantize.exr)
"""
self.sh_dataset = sh_dataset
#self.albedo_dataset = albedo_dataset
# self.ldr_dataset = ldr_dataset
# Create a mapping from scene info to indices for each dataset
self.matching_indices = self._find_matching_indices()
def _find_matching_indices(self):
"""Find matching indices across all three datasets based on scene info"""
# Create dictionaries to map scene info to indices for each dataset
sh_indices = {}
#sh_indices = {}
for idx in range(len(self.sh_dataset)):
info = self.sh_dataset.get_scene_info(idx)
key = (info['ai_folder'], info['scene_folder'], info['frame_num'])
sh_indices[key] = idx
sh_keys = set(sh_indices.keys())
# ldr_keys = set(ldr_indices.keys())
common_keys = sh_keys
# Create a list of matching indices
matching_indices = [
(sh_indices[key])
for key in common_keys
]
return matching_indices
def __len__(self):
return len(self.matching_indices)
def __getitem__(self, idx):
# Get the matching indices for all three datasets
sh_idx= self.matching_indices[idx]
# Get the items from each dataset
sh_image = self.sh_dataset[sh_idx]
#albedo_image = self.sh_dataset[sh_idx]
# ldr_image = self.ldr_dataset[ldr_idx]
# Also store the scene info for saving output images
info = self.sh_dataset.get_scene_info(sh_idx)
return sh_image, info
def kl_divergence(mu, logvar):
"""
Compute the KL divergence between the encoded distribution and a standard normal distribution.
This version includes proper batch averaging and clipping to prevent numerical issues.
"""
# Clamp logvar to prevent extreme values
logvar = torch.clamp(logvar, min=-10, max=10)
# Calculate KL divergence term by term
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
kl_loss = kl_loss / (logvar.size(0) * logvar.size(1)* logvar.size(2) * logvar.size(3))
# Average over batch dimension
return kl_loss.mean()
# Function to evaluate the model on validation set
def validate(model, val_loader, discriminator, recon_criterion, disc_criterion, train_config, kl_weight,step_count, disc_step_start):
model.eval()
discriminator.eval()
val_recon_losses_sh = []
val_kl_losses = []
val_perceptual_losses = []
val_disc_losses = []
val_gen_losses = []
val_total_losses = []
# For discriminator predictions
val_real_preds = []
val_fake_preds = []
with torch.no_grad():
for batch in val_loader:
sh_im,_ = batch
# Convert to float and move to device
sh_im = sh_im.float().to(device)
#sh_im = sh_im.float().to(device)
#ldr_im = ldr_im.float().to(device)
# Get model output
model_output = model(sh_im)
output, z, _ = model_output
mean, logvar = torch.chunk(z, 2, dim=1)
# Calculate reconstruction loss for shading
recon_loss = recon_criterion(output, sh_im)
recon_loss = recon_loss / train_config['autoencoder_acc_steps']
val_recon_losses_sh.append(train_config['sh_weight']*recon_loss.item())
# Calculate KL loss
kl_loss = kl_divergence(mean, logvar)
kl_loss = kl_loss / train_config['autoencoder_acc_steps']
val_kl_losses.append(kl_weight * kl_loss.item())
output_rgb = to_rgb(output)
sh_im_rgb = to_rgb(sh_im)
# Calculate perceptual loss
lpips_loss = (loss_fn_alex(output_rgb.to(device), sh_im_rgb.to(device)).mean()) # Ensure lpips_loss is a scalar
lpips_loss = lpips_loss / train_config['autoencoder_acc_steps']
val_perceptual_losses.append(train_config['perceptual_weight'] * lpips_loss.item())
gen_loss = 0
if step_count > disc_step_start :
# Discriminator predictions
disc_fake_pred = discriminator(output)
disc_real_pred = discriminator(sh_im)
# Store predictions
real_probs = torch.sigmoid(disc_real_pred).mean().item()
fake_probs = torch.sigmoid(disc_fake_pred).mean().item()
val_real_preds.append(real_probs)
val_fake_preds.append(fake_probs)
# Generator adversarial loss
gen_loss = disc_criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
gen_loss = gen_loss / train_config['autoencoder_acc_steps']
val_gen_losses.append(train_config['disc_weight'] * gen_loss.item())
# Discriminator adversarial loss
disc_fake_loss = disc_criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
disc_real_loss = disc_criterion(disc_real_pred, torch.ones_like(disc_real_pred))
disc_loss = (disc_fake_loss + disc_real_loss) / 2
disc_loss = disc_loss / train_config['autoencoder_acc_steps']
val_disc_losses.append(train_config['disc_weight'] * disc_loss.item())
# Calculate total loss
total_loss = (train_config['sh_weight'] * recon_loss + #(train_config['gradient_weight'] * grad_loss) +
(kl_weight * kl_loss) +
(train_config['perceptual_weight'] * lpips_loss) +
(train_config['disc_weight'] * gen_loss))
val_total_losses.append(total_loss.item())
model.train()
discriminator.train()
# Return average losses
return {
'recon_loss_sh': np.mean(val_recon_losses_sh),
'kl_loss': np.mean(val_kl_losses),
#'gradient_loss_albedo': np.mean(val_gradient_losses_albedo),
'perceptual_loss': np.mean(val_perceptual_losses),
'gen_loss': np.mean(val_gen_losses) if val_gen_losses else 0,
'disc_loss': np.mean(val_disc_losses) if val_disc_losses else 0,
'total_loss': np.mean(val_total_losses),
'real_prediction': np.mean(val_real_preds) if val_real_preds else 0,
'fake_prediction': np.mean(val_fake_preds) if val_fake_preds else 0,
}
def train(args):
# Read the config file #
with open(args.config_path, 'r') as file:
try:
config = yaml.safe_load(file)
except yaml.YAMLError as exc:
print(exc)
#print(config)
dataset_config = config['dataset_params_shading']
autoencoder_config = config['autoencoder_params']
train_config = config['train_params']
#albedo_config = config['albedo_params']
#ldr_input_config = config['dataset_params_input']
# Set the desired seed value #
seed = train_config['seed']
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if device == 'cuda':
torch.cuda.manual_seed_all(seed)
#############################
# Create the model and dataset #
model = VAE(latent_dim=8).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters in VAE: {total_params:,}")
#model.apply(weights_init)
discriminator = Discriminator(im_channels=dataset_config['im_channels']).to(device)
#discriminator.apply(weights_init)
sh_dataset = HDRGrayscaleEXRDataset_new(im_path=dataset_config['im_path'],
im_size=dataset_config['im_size'])
# Create the combined dataset
combined_dataset = CombinedDataset(sh_dataset)
# Split dataset into train and validation (90:10 ratio)
dataset_size = len(combined_dataset)
train_size = int(0.95 * dataset_size)
val_size = dataset_size - train_size
indices = np.arange(len(combined_dataset))
train_indices = indices[:train_size]
val_indices = indices[train_size:]
train_dataset = torch.utils.data.Subset(combined_dataset, train_indices)
val_dataset = torch.utils.data.Subset(combined_dataset, val_indices)
print(f"Total dataset size: {dataset_size}")
print(f"Training set size: {train_size}")
print(f"Validation set size: {val_size}")
wandb.config.update({
"learning_rate_autoencoder": train_config['autoencoder_lr'],
"learning_rate_discriminator": train_config['discriminator_lr'],
"batch_size": train_config['autoencoder_batch_size'],
"gradient_weight": train_config['gradient_weight'],
"sh_weight": train_config['sh_weight'],
"kl_weight": train_config['kl_weight'],
"perceptual_weight": train_config['perceptual_weight'],
"disc_weight": train_config['disc_weight'],
"disc_start": train_config['disc_start'],
"autoencoder_acc_steps": train_config['autoencoder_acc_steps']
})
train_loader = DataLoader(train_dataset,
batch_size=wandb.config['batch_size'],
shuffle=True , num_workers=16 , pin_memory=False)
val_loader = DataLoader(val_dataset,
batch_size=wandb.config['batch_size'],
shuffle=False , num_workers=16 , pin_memory=False)
# Create output directories
if not os.path.exists(train_config['task_name']):
os.makedirs(train_config['task_name'], exist_ok=True)
num_epochs = train_config['autoencoder_epochs']
# L1/L2 loss for Reconstruction
recon_criterion = torch.nn.MSELoss()
# Disc Loss can even be BCEWithLogits
disc_criterion = torch.nn.BCEWithLogitsLoss()
optimizer_d = Adam(discriminator.parameters(), lr=wandb.config['learning_rate_discriminator'], betas=(0.5, 0.999))
optimizer_g = Adam(model.parameters(), lr=wandb.config['learning_rate_autoencoder'], betas=(0.5, 0.999) )
scaler = GradScaler('cuda') if device == 'cuda' else None
# Setup schedulers
scheduler_g = ReduceLROnPlateau(optimizer_g, mode='min', factor=0.9, patience=5, min_lr=0.00001)
scheduler_d = ReduceLROnPlateau(optimizer_d, mode='min', factor=0.9, patience=5, min_lr=0.000001)
disc_step_start = wandb.config['disc_start']
step_count = 0
start_epoch = 0
best_val_loss = float('inf')
# This is for accumulating gradients incase the images are huge
# And one cant afford higher batch sizes
acc_steps = wandb.config['autoencoder_acc_steps']
image_save_steps = train_config['autoencoder_img_save_steps']
img_save_count = 0
# Lists to store epoch metrics
train_losses_history = []
val_losses_history = []
# Check if checkpoint exists and load it for resuming training
# checkpoint_path = os.path.join(train_config['task_name'], 'epoch_10_best_autoencoder_model_checkpoint.pth')
# if os.path.exists(checkpoint_path):
# logging.info(f"Loading checkpoint from {checkpoint_path}")
# checkpoint = torch.load(checkpoint_path, weights_only=False)
# model.load_state_dict(checkpoint['model_state_dict'])
# discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
# optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
# optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])
# start_epoch = checkpoint['epoch']
# step_count = checkpoint.get('step_count', 0)
# img_save_count = checkpoint.get('img_save_count', 0)
# best_val_loss = checkpoint['best_val_loss']
# train_losses_history = checkpoint.get('train_losses_history', [])
# val_losses_history = checkpoint.get('val_losses_history', [])
# logging.info(f"Resuming from epoch {start_epoch} with best validation loss: {best_val_loss}")
#checkpoint_path="/home/project/ldr_image_to_ldr_shading/LDR_image_to_LDR_shading_hyperism/train_vae_mithlesh/ldr_to_sh_st_15/epoch_5_best_autoencoder_model_checkpoint.pth"
# Load checkpoint
checkpoint_path = os.path.join(train_config['task_name'], 'epoch_95_best_autoencoder_model_checkpoint.pth')
if os.path.exists(checkpoint_path):
logging.info(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, weights_only=False)
# Remove _orig_mod. prefix from model and discriminator state_dicts
model_state_dict = strip_prefix_if_present(checkpoint['model_state_dict'], '_orig_mod.')
discriminator_state_dict = strip_prefix_if_present(checkpoint['discriminator_state_dict'], '_orig_mod.')
model.load_state_dict(model_state_dict)
discriminator.load_state_dict(discriminator_state_dict)
optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])
start_epoch = checkpoint['epoch']
step_count = checkpoint.get('step_count', 0)
img_save_count = checkpoint.get('img_save_count', 0)
best_val_loss = checkpoint['best_val_loss']
train_losses_history = checkpoint.get('train_losses_history', [])
val_losses_history = checkpoint.get('val_losses_history', [])
logging.info(f"Resuming from epoch {start_epoch} with best validation loss: {best_val_loss}")
# Check if the model is already compiled
if hasattr(torch, 'compile'):
model = torch.compile(model)
discriminator = torch.compile(discriminator)
#logging.info(f"Learning rates updated: Generator -> {new_lr_g}, Discriminator -> {new_lr_d}")
for epoch_idx in range(start_epoch, num_epochs):
# Training metrics
#recon_losses_shading = []
recon_losses_sh = []
kl_losses = []
perceptual_losses = []
disc_losses = []
gen_losses = []
train_real_preds = []
train_fake_preds = []
#grad_losses_shading = []
# grad_losses_albedo = []
losses = []
optimizer_g.zero_grad()
optimizer_d.zero_grad()
# Training loop
for batch in tqdm(train_loader):
#start_time = time.perf_counter()
step_count += 1
# Unpack the batch - each element is a batch of images from each dataset
sh_im, scene_infos = batch
#sh_im = sh_im.float().to(device)
sh_im = sh_im.float().to(device)
with autocast(device_type='cuda'):
# Fetch autoencoders output(reconstructions)
model_output = model(sh_im)
output, h,_ = model_output
mean, logvar = torch.chunk(h, 2, dim=1)
if check_nan(output, "raw_model_output"):
print("NaN values detected in model output! Skipping this batch.")
continue
# Image Saving Logic
if step_count % image_save_steps == 0 or step_count == 1:
img_save_count = save_training_samples(
output, sh_im, scene_infos, train_config, step_count, img_save_count
)
######### Optimize Generator ##########
# L2 Loss for shading
recon_loss = recon_criterion(output, sh_im)
recon_loss = recon_loss / acc_steps
recon_losses_sh.append(wandb.config['sh_weight'] * recon_loss.item()) # add average loss for 1 image
kl_weight = wandb.config['kl_weight']
kl_loss = kl_divergence(mean, logvar)
kl_loss = kl_loss / acc_steps
kl_losses.append(kl_weight * kl_loss.item())
# total_loss_generator
g_loss = (wandb.config['sh_weight'] * recon_loss +
# (wandb.config['gradient_weight'] * grad_loss) +
(kl_weight * kl_loss ))
# Adversarial loss only if disc_step_start steps passed
if step_count > disc_step_start:
disc_fake_pred = discriminator(model_output[0])
disc_fake_loss = disc_criterion(disc_fake_pred,
torch.ones(disc_fake_pred.shape,
device=disc_fake_pred.device))
disc_fake_loss = disc_fake_loss / acc_steps
gen_losses.append(wandb.config['disc_weight'] * disc_fake_loss.item())
g_loss += wandb.config['disc_weight'] * disc_fake_loss
# LPIPS Loss
output_rgb = to_rgb(output)
sh_im_rgb = to_rgb(sh_im)
# Calculate perceptual loss
lpips_loss = (loss_fn_alex(output_rgb, sh_im_rgb).mean()) # Ensure lpips_loss is a scalar
lpips_loss = lpips_loss / acc_steps
perceptual_losses.append(wandb.config['perceptual_weight'] * lpips_loss.item())
g_loss += wandb.config['perceptual_weight'] * lpips_loss
losses.append(g_loss.item())
#g_loss.backward()
#####################################
if scaler is not None:
scaler.scale(g_loss).backward()
else:
g_loss.backward()
######### Optimize Discriminator #######
if step_count > disc_step_start and step_count % 2 == 0:
with autocast(device_type='cuda'):
fake = output
disc_fake_pred = discriminator(fake.detach())
disc_real_pred = discriminator(sh_im)
disc_fake_loss = disc_criterion(disc_fake_pred,
torch.zeros(disc_fake_pred.shape,
device=disc_fake_pred.device))
disc_real_loss = disc_criterion(disc_real_pred,
torch.ones(disc_real_pred.shape,
device=disc_real_pred.device))
disc_loss = wandb.config['disc_weight'] * (disc_fake_loss + disc_real_loss) / 2
disc_loss = disc_loss / acc_steps
disc_losses.append(disc_loss.item())
with torch.no_grad():
# Convert logits to probabilities using sigmoid
real_probs = torch.sigmoid(disc_real_pred).mean().item()
fake_probs = torch.sigmoid(disc_fake_pred).mean().item()
train_real_preds.append(real_probs)
train_fake_preds.append(fake_probs)
# Scale the discriminator loss and backward
if scaler is not None:
scaler.scale(disc_loss).backward()
else:
disc_loss.backward()
if step_count % acc_steps == 0:
# Apply gradient clipping (see below)
if scaler is not None:
# Unscale before clipping
scaler.unscale_(optimizer_d)
# Here we'll add gradient clipping (code below)
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
# Step with scaler
scaler.step(optimizer_d)
scaler.update()
else:
# Here we'll add gradient clipping (code below)
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
optimizer_d.step()
optimizer_d.zero_grad()
#####################################
if step_count % acc_steps == 0:
# Apply gradient clipping (see below)
if scaler is not None:
# Unscale before clipping
scaler.unscale_(optimizer_g)
# Here we'll add gradient clipping (code below)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Step with scaler
scaler.step(optimizer_g)
scaler.update()
else:
# Here we'll add gradient clipping (code below)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer_g.step()
optimizer_g.zero_grad()
if step_count > disc_step_start and step_count % 2 == 0:
# Final optimizer steps at end of epoch
optimizer_d.step()
optimizer_d.zero_grad()
optimizer_g.step()
# Calculate validation metrics
val_metrics = validate(model, val_loader, discriminator, recon_criterion, disc_criterion, train_config,kl_weight,step_count,disc_step_start)
# Store epoch metrics for plotting
train_loss = np.mean(losses)
val_loss = val_metrics['total_loss']
train_losses_history.append(train_loss)
val_losses_history.append(val_loss)
epochs_since_disc_start = max(0, epoch_idx - (disc_step_start // len(train_loader)))
if step_count <= disc_step_start:
# Before discriminator starts - continue normal generator training
scheduler_g.step(val_loss)
elif epochs_since_disc_start <= 20:
# Stabilization period - don't adjust any learning rates
scheduler_d.step(val_metrics["disc_loss"])
else:
# After stabilization - resume normal scheduling
scheduler_g.step(val_loss)
scheduler_d.step(val_metrics["disc_loss"])
# After validation and calculating epoch metrics
wandb.log({
"epoch": epoch_idx + 1,
"train/recon_loss_sh": np.mean(recon_losses_sh),
#"train/gradient_loss_albedo": np.mean(grad_losses_albedo),
"train/kl_loss": np.mean(kl_losses),
"train/perceptual_loss": np.mean(perceptual_losses),
"train/gen_loss": np.mean(gen_losses) if len(gen_losses) > 0 else 0,
"train/disc_loss": np.mean(disc_losses) if len(disc_losses) > 0 else 0,
"train/total_loss": train_loss,
"train/real_prediction": np.mean(train_real_preds) if len(train_real_preds) > 0 else 0,
"train/fake_prediction": np.mean(train_fake_preds) if len(train_fake_preds) > 0 else 0,
"val/recon_loss_sh": val_metrics["recon_loss_sh"],
#"val/gradient_loss_albedo": val_metrics["gradient_loss_albedo"],
"val/kl_loss": val_metrics["kl_loss"],
"val/perceptual_loss": val_metrics["perceptual_loss"],
"val/gen_loss": val_metrics["gen_loss"],
"val/disc_loss": val_metrics["disc_loss"],
"val/total_loss": val_metrics["total_loss"],
"val/real_prediction": val_metrics["real_prediction"],
"val/fake_prediction": val_metrics["fake_prediction"],
"learning_rate/generator": optimizer_g.param_groups[0]['lr'],
"learning_rate/discriminator": optimizer_d.param_groups[0]['lr']
})
# Print epoch results
print('\n' + '=' * 80)
print(f'Epoch {epoch_idx + 1}/{num_epochs}')
print('-' * 80)
print('TRAINING:')
print(f'Recon Loss_sh: {np.mean(recon_losses_sh):.4f} | '
#f'Gradient Loss_albedo: {np.mean(grad_losses_albedo):.4f} | '
f'KL Loss: {np.mean(kl_losses):.4f} | '
f'Perceptual Loss: {np.mean(perceptual_losses):.4f}')
if len(disc_losses) > 0 and len(gen_losses) > 0:
print(f'Generator Loss: {np.mean(gen_losses):.4f} | '
f'Discriminator Loss: {np.mean(disc_losses):.4f}')
print(f'Total Training Loss: {train_loss:.4f}')
print('\nVALIDATION:')
print(f'Recon Loss_shading: {val_metrics["recon_loss_sh"]:.4f} | '
#f'Gradient Loss_shading: {val_metrics["gradient_loss_albedo"]:.4f} | '
f'KL Loss: {val_metrics["kl_loss"]:.4f} | '
f'Perceptual Loss: {val_metrics["perceptual_loss"]:.4f}')
print(f'Generator Loss: {val_metrics["gen_loss"]:.4f} | '
f'Discriminator Loss: {val_metrics["disc_loss"]:.4f}')
print(f'Total Validation Loss: {val_metrics["total_loss"]:.4f}')
# Check if validation loss improved
if val_loss < best_val_loss:
print(f"\nValidation loss improved from {best_val_loss:.4f} to {val_loss:.4f}. Saving model...")
best_val_loss = val_loss
# # Save only the best model checkpoint
# checkpoint = {
# 'epoch': epoch_idx + 1,
# 'model_state_dict': model.state_dict(),
# 'discriminator_state_dict': discriminator.state_dict(),
# 'optimizer_g_state_dict': optimizer_g.state_dict(),
# 'optimizer_d_state_dict': optimizer_d.state_dict(),
# 'best_val_loss': best_val_loss,
# 'step_count': step_count,
# 'img_save_count': img_save_count,
# 'train_losses_history': train_losses_history,
# 'val_losses_history': val_losses_history
# }
# torch.save(checkpoint, os.path.join(train_config['task_name'], 'best_autoencoder_model_checkpoint.pth'))
# # Save individual model files for compatibility with original code
# torch.save(model.state_dict(), os.path.join(train_config['task_name'],
# train_config['vae_autoencoder_ckpt_name']))
# torch.save(discriminator.state_dict(), os.path.join(train_config['task_name'],
#
else:
print(f"Validation loss did not improve from {best_val_loss:.4f}")
if (epoch_idx + 1) % 5 == 0:
checkpoint = {
'epoch': epoch_idx + 1,
'model_state_dict': model.state_dict(),
'discriminator_state_dict': discriminator.state_dict(),
'optimizer_g_state_dict': optimizer_g.state_dict(),
'optimizer_d_state_dict': optimizer_d.state_dict(),
'best_val_loss': best_val_loss,
'step_count': step_count,
'img_save_count': img_save_count,
'train_losses_history': train_losses_history,
'val_losses_history': val_losses_history
}
torch.save(checkpoint, os.path.join(train_config['task_name'], f'epoch_{epoch_idx + 1}_best_autoencoder_model_checkpoint.pth'))
# Save individual model files for compatibility with original code
torch.save(model.state_dict(), os.path.join(train_config['task_name'],
train_config['vae_autoencoder_ckpt_name']))
torch.save(discriminator.state_dict(), os.path.join(train_config['task_name'],
train_config['vae_discriminator_ckpt_name'])) # Every 10 epochs
if epoch_idx % 1 == 0: # Every 5 epochs
gc.collect()
torch.cuda.empty_cache()
print('=' * 80 + '\n')
print('Done Training...')
# Save final training history
np.savez(os.path.join(train_config['task_name'], 'training_history.npz'),
train_losses=np.array(train_losses_history),
val_losses=np.array(val_losses_history))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Arguments for vae training')
parser.add_argument('--config', dest='config_path',
default='config/autoen_alb_1.yaml', type=str)
# Handle Jupyter/IPython arguments
import sys
args, unknown = parser.parse_known_args(sys.argv[1:])
train(args)