print("|| RAM ||") import yaml import argparse import torch import random import os import numpy as np from tqdm import tqdm from models.vae import VAE from torch.utils.data.dataloader import DataLoader from torch.optim import Adam from models.discriminator import Discriminator import imageio.v3 as iio import lpips import gc #import time import OpenEXR import Imath import torchvision.utils as vutils def strip_prefix_if_present(state_dict, prefix="_orig_mod."): new_state_dict = {} for k, v in state_dict.items(): if k.startswith(prefix): new_state_dict[k[len(prefix):]] = v else: new_state_dict[k] = v return new_state_dict # Add this at the top of your script, before importing imageio from dataloader_image_hyperism import HDRGrayscaleEXRDataset_new,ImageDataset,ImageDataset_d os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" #print("imported all the libraries") import torch #device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if torch.cuda.is_available(): print(f"Using device: {torch.cuda.get_device_name(device)} (CUDA:{torch.cuda.current_device()})") else: print("Using device: CPU") loss_fn_alex = lpips.LPIPS(net='vgg').to(device) # Set the device to GPU if available import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') from torch.amp import autocast, GradScaler from torch.optim.lr_scheduler import ReduceLROnPlateau import wandb wandb.init(project="ldr_sh_vae_training_trial") from torch.utils.data import Dataset def check_nan(tensor, name="tensor"): """Check if tensor contains NaN values and print debugging info""" if torch.isnan(tensor).any(): print(f"NaN detected in {name}") non_nan_mask = ~torch.isnan(tensor) if non_nan_mask.any(): print(f"Non-NaN values stats - Min: {tensor[non_nan_mask].min().item()}, Max: {tensor[non_nan_mask].max().item()}") return True return False def to_rgb(image): if image.shape[1] == 1: return image.repeat(1, 3, 1, 1) return image def save_training_samples(output, gt_image, scene_infos, train_config, step_count, img_save_count): # Try to import OpenEXR if available try: has_openexr = True except ImportError: has_openexr = False print("Warning: OpenEXR not available, falling back to imageio") sample_size = min(8, output.shape[0]) gt_image = gt_image[:sample_size].detach().cpu() save_output = output[:sample_size].detach().cpu()#.numpy() # Apply normalization # epsilon = 1e-3 # Small value to prevent division by zero # save_output = np.where(save_output == -1, save_output + epsilon, save_output) # save_output = (1 - save_output) / (1 + save_output) # Normalize to [0, 1] # save_output = torch.clip((save_output + 1) / 2.0, 0.0, 1.0) # Base save path base_save_path = os.path.join('/home/project/dataset/Hyperism', train_config['task_name'], 'vae_autoencoder_samples_1') def save_exr(filepath, data): """Helper function to save EXR files using either OpenEXR or imageio Args: filepath (str): Path to save the EXR file data (ndarray): Image data in shape (H, W), (H, W, C) or (C, H, W) format Returns: bool: True if save successful, False otherwise """ # Handle (C, H, W) format by converting to (H, W, C) if len(data.shape) == 3 and (data.shape[0] == 3 or data.shape[0] == 1): if data.shape[0] == 3 and data.shape[1] > 3 and data.shape[2] > 3: # Likely (C, H, W) format data = np.transpose(data, (1, 2, 0)) if has_openexr and len(data.shape) == 2: # For grayscale images data_flat = data.astype(np.float32).tobytes() header = OpenEXR.Header(data.shape[1], data.shape[0]) header['channels'] = {'Y': Imath.Channel(Imath.PixelType(Imath.PixelType.FLOAT))} exr = OpenEXR.OutputFile(filepath, header) exr.writePixels({'Y': data_flat}) exr.close() return True elif has_openexr and len(data.shape) == 3 and data.shape[2] == 3: # For RGB images in (H, W, C) format R = data[:,:,0].astype(np.float32).tobytes() G = data[:,:,1].astype(np.float32).tobytes() B = data[:,:,2].astype(np.float32).tobytes() header = OpenEXR.Header(data.shape[1], data.shape[0]) header['channels'] = { 'R': Imath.Channel(Imath.PixelType(Imath.PixelType.FLOAT)), 'G': Imath.Channel(Imath.PixelType(Imath.PixelType.FLOAT)), 'B': Imath.Channel(Imath.PixelType(Imath.PixelType.FLOAT)) } exr = OpenEXR.OutputFile(filepath, header) exr.writePixels({'R': R, 'G': G, 'B': B}) exr.close() return True else: # Fall back to imageio with tifffile plugin (which can handle EXR) try: # Ensure we're using a format imageio can handle if len(data.shape) == 3 and data.shape[0] == 3 and data.shape[1] > 3 and data.shape[2] > 3: # Convert from (C, H, W) to (H, W, C) for imageio data = np.transpose(data, (1, 2, 0)) iio.imwrite(filepath, data, plugin='tifffile', photometric='rgb') return True except Exception as e: print(f"Error saving with tifffile plugin: {e}") try: # Last resort - try PNG instead of EXR png_path = filepath.replace('.exr', '.png') # Ensure data is in correct shape for PNG if len(data.shape) == 3 and data.shape[0] == 3 and data.shape[1] > 3 and data.shape[2] > 3: data = np.transpose(data, (1, 2, 0)) iio.imwrite(png_path, data) print(f"Saved as PNG instead: {png_path}") return True except Exception as e2: print(f"Failed to save image: {e2}") return False collage = torch.cat([save_output, gt_image], dim=0) os.makedirs("/home/project/dataset/Hyperism/ldr_to_sh_1/vae_autoencoder_samples/", exist_ok=True) output_path = f"/home/project/dataset/Hyperism/ldr_to_sh_1/vae_autoencoder_samples/{step_count}.png" vutils.save_image(collage, output_path, nrow=4, normalize=True) # Also save a simple numbered output for easy viewing simple_save_path = os.path.join(base_save_path, 'numbered_samples') os.makedirs(simple_save_path, exist_ok=True) return img_save_count + 1 # Create a combined dataset class class CombinedDataset(Dataset): def __init__(self, sh_dataset): """ A dataset that matches corresponding images across the three datasets based on scene metadata. Args: sh_dataset: The HDRGrayscaleEXRDataset for spherical harmonics shading albedo_dataset: The ImageDataset for albedo (diffuse_reflectance.exr) ldr_dataset: The ImageDataset for LDR input (dequantize.exr) """ self.sh_dataset = sh_dataset #self.albedo_dataset = albedo_dataset # self.ldr_dataset = ldr_dataset # Create a mapping from scene info to indices for each dataset self.matching_indices = self._find_matching_indices() def _find_matching_indices(self): """Find matching indices across all three datasets based on scene info""" # Create dictionaries to map scene info to indices for each dataset sh_indices = {} #sh_indices = {} for idx in range(len(self.sh_dataset)): info = self.sh_dataset.get_scene_info(idx) key = (info['ai_folder'], info['scene_folder'], info['frame_num']) sh_indices[key] = idx sh_keys = set(sh_indices.keys()) # ldr_keys = set(ldr_indices.keys()) common_keys = sh_keys # Create a list of matching indices matching_indices = [ (sh_indices[key]) for key in common_keys ] return matching_indices def __len__(self): return len(self.matching_indices) def __getitem__(self, idx): # Get the matching indices for all three datasets sh_idx= self.matching_indices[idx] # Get the items from each dataset sh_image = self.sh_dataset[sh_idx] #albedo_image = self.sh_dataset[sh_idx] # ldr_image = self.ldr_dataset[ldr_idx] # Also store the scene info for saving output images info = self.sh_dataset.get_scene_info(sh_idx) return sh_image, info def kl_divergence(mu, logvar): """ Compute the KL divergence between the encoded distribution and a standard normal distribution. This version includes proper batch averaging and clipping to prevent numerical issues. """ # Clamp logvar to prevent extreme values logvar = torch.clamp(logvar, min=-10, max=10) # Calculate KL divergence term by term kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) kl_loss = kl_loss / (logvar.size(0) * logvar.size(1)* logvar.size(2) * logvar.size(3)) # Average over batch dimension return kl_loss.mean() # Function to evaluate the model on validation set def validate(model, val_loader, discriminator, recon_criterion, disc_criterion, train_config, kl_weight,step_count, disc_step_start): model.eval() discriminator.eval() val_recon_losses_sh = [] val_kl_losses = [] val_perceptual_losses = [] val_disc_losses = [] val_gen_losses = [] val_total_losses = [] # For discriminator predictions val_real_preds = [] val_fake_preds = [] with torch.no_grad(): for batch in val_loader: sh_im,_ = batch # Convert to float and move to device sh_im = sh_im.float().to(device) #sh_im = sh_im.float().to(device) #ldr_im = ldr_im.float().to(device) # Get model output model_output = model(sh_im) output, z, _ = model_output mean, logvar = torch.chunk(z, 2, dim=1) # Calculate reconstruction loss for shading recon_loss = recon_criterion(output, sh_im) recon_loss = recon_loss / train_config['autoencoder_acc_steps'] val_recon_losses_sh.append(train_config['sh_weight']*recon_loss.item()) # Calculate KL loss kl_loss = kl_divergence(mean, logvar) kl_loss = kl_loss / train_config['autoencoder_acc_steps'] val_kl_losses.append(kl_weight * kl_loss.item()) output_rgb = to_rgb(output) sh_im_rgb = to_rgb(sh_im) # Calculate perceptual loss lpips_loss = (loss_fn_alex(output_rgb.to(device), sh_im_rgb.to(device)).mean()) # Ensure lpips_loss is a scalar lpips_loss = lpips_loss / train_config['autoencoder_acc_steps'] val_perceptual_losses.append(train_config['perceptual_weight'] * lpips_loss.item()) gen_loss = 0 if step_count > disc_step_start : # Discriminator predictions disc_fake_pred = discriminator(output) disc_real_pred = discriminator(sh_im) # Store predictions real_probs = torch.sigmoid(disc_real_pred).mean().item() fake_probs = torch.sigmoid(disc_fake_pred).mean().item() val_real_preds.append(real_probs) val_fake_preds.append(fake_probs) # Generator adversarial loss gen_loss = disc_criterion(disc_fake_pred, torch.ones_like(disc_fake_pred)) gen_loss = gen_loss / train_config['autoencoder_acc_steps'] val_gen_losses.append(train_config['disc_weight'] * gen_loss.item()) # Discriminator adversarial loss disc_fake_loss = disc_criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred)) disc_real_loss = disc_criterion(disc_real_pred, torch.ones_like(disc_real_pred)) disc_loss = (disc_fake_loss + disc_real_loss) / 2 disc_loss = disc_loss / train_config['autoencoder_acc_steps'] val_disc_losses.append(train_config['disc_weight'] * disc_loss.item()) # Calculate total loss total_loss = (train_config['sh_weight'] * recon_loss + #(train_config['gradient_weight'] * grad_loss) + (kl_weight * kl_loss) + (train_config['perceptual_weight'] * lpips_loss) + (train_config['disc_weight'] * gen_loss)) val_total_losses.append(total_loss.item()) model.train() discriminator.train() # Return average losses return { 'recon_loss_sh': np.mean(val_recon_losses_sh), 'kl_loss': np.mean(val_kl_losses), #'gradient_loss_albedo': np.mean(val_gradient_losses_albedo), 'perceptual_loss': np.mean(val_perceptual_losses), 'gen_loss': np.mean(val_gen_losses) if val_gen_losses else 0, 'disc_loss': np.mean(val_disc_losses) if val_disc_losses else 0, 'total_loss': np.mean(val_total_losses), 'real_prediction': np.mean(val_real_preds) if val_real_preds else 0, 'fake_prediction': np.mean(val_fake_preds) if val_fake_preds else 0, } def train(args): # Read the config file # with open(args.config_path, 'r') as file: try: config = yaml.safe_load(file) except yaml.YAMLError as exc: print(exc) #print(config) dataset_config = config['dataset_params_shading'] autoencoder_config = config['autoencoder_params'] train_config = config['train_params'] #albedo_config = config['albedo_params'] #ldr_input_config = config['dataset_params_input'] # Set the desired seed value # seed = train_config['seed'] torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if device == 'cuda': torch.cuda.manual_seed_all(seed) ############################# # Create the model and dataset # model = VAE(latent_dim=8).to(device) total_params = sum(p.numel() for p in model.parameters()) print(f"Total parameters in VAE: {total_params:,}") #model.apply(weights_init) discriminator = Discriminator(im_channels=dataset_config['im_channels']).to(device) #discriminator.apply(weights_init) sh_dataset = HDRGrayscaleEXRDataset_new(im_path=dataset_config['im_path'], im_size=dataset_config['im_size']) # Create the combined dataset combined_dataset = CombinedDataset(sh_dataset) # Split dataset into train and validation (90:10 ratio) dataset_size = len(combined_dataset) train_size = int(0.95 * dataset_size) val_size = dataset_size - train_size indices = np.arange(len(combined_dataset)) train_indices = indices[:train_size] val_indices = indices[train_size:] train_dataset = torch.utils.data.Subset(combined_dataset, train_indices) val_dataset = torch.utils.data.Subset(combined_dataset, val_indices) print(f"Total dataset size: {dataset_size}") print(f"Training set size: {train_size}") print(f"Validation set size: {val_size}") wandb.config.update({ "learning_rate_autoencoder": train_config['autoencoder_lr'], "learning_rate_discriminator": train_config['discriminator_lr'], "batch_size": train_config['autoencoder_batch_size'], "gradient_weight": train_config['gradient_weight'], "sh_weight": train_config['sh_weight'], "kl_weight": train_config['kl_weight'], "perceptual_weight": train_config['perceptual_weight'], "disc_weight": train_config['disc_weight'], "disc_start": train_config['disc_start'], "autoencoder_acc_steps": train_config['autoencoder_acc_steps'] }) train_loader = DataLoader(train_dataset, batch_size=wandb.config['batch_size'], shuffle=True , num_workers=16 , pin_memory=False) val_loader = DataLoader(val_dataset, batch_size=wandb.config['batch_size'], shuffle=False , num_workers=16 , pin_memory=False) # Create output directories if not os.path.exists(train_config['task_name']): os.makedirs(train_config['task_name'], exist_ok=True) num_epochs = train_config['autoencoder_epochs'] # L1/L2 loss for Reconstruction recon_criterion = torch.nn.MSELoss() # Disc Loss can even be BCEWithLogits disc_criterion = torch.nn.BCEWithLogitsLoss() optimizer_d = Adam(discriminator.parameters(), lr=wandb.config['learning_rate_discriminator'], betas=(0.5, 0.999)) optimizer_g = Adam(model.parameters(), lr=wandb.config['learning_rate_autoencoder'], betas=(0.5, 0.999) ) scaler = GradScaler('cuda') if device == 'cuda' else None # Setup schedulers scheduler_g = ReduceLROnPlateau(optimizer_g, mode='min', factor=0.9, patience=5, min_lr=0.00001) scheduler_d = ReduceLROnPlateau(optimizer_d, mode='min', factor=0.9, patience=5, min_lr=0.000001) disc_step_start = wandb.config['disc_start'] step_count = 0 start_epoch = 0 best_val_loss = float('inf') # This is for accumulating gradients incase the images are huge # And one cant afford higher batch sizes acc_steps = wandb.config['autoencoder_acc_steps'] image_save_steps = train_config['autoencoder_img_save_steps'] img_save_count = 0 # Lists to store epoch metrics train_losses_history = [] val_losses_history = [] # Check if checkpoint exists and load it for resuming training # checkpoint_path = os.path.join(train_config['task_name'], 'epoch_10_best_autoencoder_model_checkpoint.pth') # if os.path.exists(checkpoint_path): # logging.info(f"Loading checkpoint from {checkpoint_path}") # checkpoint = torch.load(checkpoint_path, weights_only=False) # model.load_state_dict(checkpoint['model_state_dict']) # discriminator.load_state_dict(checkpoint['discriminator_state_dict']) # optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict']) # optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict']) # start_epoch = checkpoint['epoch'] # step_count = checkpoint.get('step_count', 0) # img_save_count = checkpoint.get('img_save_count', 0) # best_val_loss = checkpoint['best_val_loss'] # train_losses_history = checkpoint.get('train_losses_history', []) # val_losses_history = checkpoint.get('val_losses_history', []) # logging.info(f"Resuming from epoch {start_epoch} with best validation loss: {best_val_loss}") #checkpoint_path="/home/project/ldr_image_to_ldr_shading/LDR_image_to_LDR_shading_hyperism/train_vae_mithlesh/ldr_to_sh_st_15/epoch_5_best_autoencoder_model_checkpoint.pth" # Load checkpoint checkpoint_path = os.path.join(train_config['task_name'], 'epoch_95_best_autoencoder_model_checkpoint.pth') if os.path.exists(checkpoint_path): logging.info(f"Loading checkpoint from {checkpoint_path}") checkpoint = torch.load(checkpoint_path, weights_only=False) # Remove _orig_mod. prefix from model and discriminator state_dicts model_state_dict = strip_prefix_if_present(checkpoint['model_state_dict'], '_orig_mod.') discriminator_state_dict = strip_prefix_if_present(checkpoint['discriminator_state_dict'], '_orig_mod.') model.load_state_dict(model_state_dict) discriminator.load_state_dict(discriminator_state_dict) optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict']) optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict']) start_epoch = checkpoint['epoch'] step_count = checkpoint.get('step_count', 0) img_save_count = checkpoint.get('img_save_count', 0) best_val_loss = checkpoint['best_val_loss'] train_losses_history = checkpoint.get('train_losses_history', []) val_losses_history = checkpoint.get('val_losses_history', []) logging.info(f"Resuming from epoch {start_epoch} with best validation loss: {best_val_loss}") # Check if the model is already compiled if hasattr(torch, 'compile'): model = torch.compile(model) discriminator = torch.compile(discriminator) #logging.info(f"Learning rates updated: Generator -> {new_lr_g}, Discriminator -> {new_lr_d}") for epoch_idx in range(start_epoch, num_epochs): # Training metrics #recon_losses_shading = [] recon_losses_sh = [] kl_losses = [] perceptual_losses = [] disc_losses = [] gen_losses = [] train_real_preds = [] train_fake_preds = [] #grad_losses_shading = [] # grad_losses_albedo = [] losses = [] optimizer_g.zero_grad() optimizer_d.zero_grad() # Training loop for batch in tqdm(train_loader): #start_time = time.perf_counter() step_count += 1 # Unpack the batch - each element is a batch of images from each dataset sh_im, scene_infos = batch #sh_im = sh_im.float().to(device) sh_im = sh_im.float().to(device) with autocast(device_type='cuda'): # Fetch autoencoders output(reconstructions) model_output = model(sh_im) output, h,_ = model_output mean, logvar = torch.chunk(h, 2, dim=1) if check_nan(output, "raw_model_output"): print("NaN values detected in model output! Skipping this batch.") continue # Image Saving Logic if step_count % image_save_steps == 0 or step_count == 1: img_save_count = save_training_samples( output, sh_im, scene_infos, train_config, step_count, img_save_count ) ######### Optimize Generator ########## # L2 Loss for shading recon_loss = recon_criterion(output, sh_im) recon_loss = recon_loss / acc_steps recon_losses_sh.append(wandb.config['sh_weight'] * recon_loss.item()) # add average loss for 1 image kl_weight = wandb.config['kl_weight'] kl_loss = kl_divergence(mean, logvar) kl_loss = kl_loss / acc_steps kl_losses.append(kl_weight * kl_loss.item()) # total_loss_generator g_loss = (wandb.config['sh_weight'] * recon_loss + # (wandb.config['gradient_weight'] * grad_loss) + (kl_weight * kl_loss )) # Adversarial loss only if disc_step_start steps passed if step_count > disc_step_start: disc_fake_pred = discriminator(model_output[0]) disc_fake_loss = disc_criterion(disc_fake_pred, torch.ones(disc_fake_pred.shape, device=disc_fake_pred.device)) disc_fake_loss = disc_fake_loss / acc_steps gen_losses.append(wandb.config['disc_weight'] * disc_fake_loss.item()) g_loss += wandb.config['disc_weight'] * disc_fake_loss # LPIPS Loss output_rgb = to_rgb(output) sh_im_rgb = to_rgb(sh_im) # Calculate perceptual loss lpips_loss = (loss_fn_alex(output_rgb, sh_im_rgb).mean()) # Ensure lpips_loss is a scalar lpips_loss = lpips_loss / acc_steps perceptual_losses.append(wandb.config['perceptual_weight'] * lpips_loss.item()) g_loss += wandb.config['perceptual_weight'] * lpips_loss losses.append(g_loss.item()) #g_loss.backward() ##################################### if scaler is not None: scaler.scale(g_loss).backward() else: g_loss.backward() ######### Optimize Discriminator ####### if step_count > disc_step_start and step_count % 2 == 0: with autocast(device_type='cuda'): fake = output disc_fake_pred = discriminator(fake.detach()) disc_real_pred = discriminator(sh_im) disc_fake_loss = disc_criterion(disc_fake_pred, torch.zeros(disc_fake_pred.shape, device=disc_fake_pred.device)) disc_real_loss = disc_criterion(disc_real_pred, torch.ones(disc_real_pred.shape, device=disc_real_pred.device)) disc_loss = wandb.config['disc_weight'] * (disc_fake_loss + disc_real_loss) / 2 disc_loss = disc_loss / acc_steps disc_losses.append(disc_loss.item()) with torch.no_grad(): # Convert logits to probabilities using sigmoid real_probs = torch.sigmoid(disc_real_pred).mean().item() fake_probs = torch.sigmoid(disc_fake_pred).mean().item() train_real_preds.append(real_probs) train_fake_preds.append(fake_probs) # Scale the discriminator loss and backward if scaler is not None: scaler.scale(disc_loss).backward() else: disc_loss.backward() if step_count % acc_steps == 0: # Apply gradient clipping (see below) if scaler is not None: # Unscale before clipping scaler.unscale_(optimizer_d) # Here we'll add gradient clipping (code below) torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Step with scaler scaler.step(optimizer_d) scaler.update() else: # Here we'll add gradient clipping (code below) torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) optimizer_d.step() optimizer_d.zero_grad() ##################################### if step_count % acc_steps == 0: # Apply gradient clipping (see below) if scaler is not None: # Unscale before clipping scaler.unscale_(optimizer_g) # Here we'll add gradient clipping (code below) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Step with scaler scaler.step(optimizer_g) scaler.update() else: # Here we'll add gradient clipping (code below) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer_g.step() optimizer_g.zero_grad() if step_count > disc_step_start and step_count % 2 == 0: # Final optimizer steps at end of epoch optimizer_d.step() optimizer_d.zero_grad() optimizer_g.step() # Calculate validation metrics val_metrics = validate(model, val_loader, discriminator, recon_criterion, disc_criterion, train_config,kl_weight,step_count,disc_step_start) # Store epoch metrics for plotting train_loss = np.mean(losses) val_loss = val_metrics['total_loss'] train_losses_history.append(train_loss) val_losses_history.append(val_loss) epochs_since_disc_start = max(0, epoch_idx - (disc_step_start // len(train_loader))) if step_count <= disc_step_start: # Before discriminator starts - continue normal generator training scheduler_g.step(val_loss) elif epochs_since_disc_start <= 20: # Stabilization period - don't adjust any learning rates scheduler_d.step(val_metrics["disc_loss"]) else: # After stabilization - resume normal scheduling scheduler_g.step(val_loss) scheduler_d.step(val_metrics["disc_loss"]) # After validation and calculating epoch metrics wandb.log({ "epoch": epoch_idx + 1, "train/recon_loss_sh": np.mean(recon_losses_sh), #"train/gradient_loss_albedo": np.mean(grad_losses_albedo), "train/kl_loss": np.mean(kl_losses), "train/perceptual_loss": np.mean(perceptual_losses), "train/gen_loss": np.mean(gen_losses) if len(gen_losses) > 0 else 0, "train/disc_loss": np.mean(disc_losses) if len(disc_losses) > 0 else 0, "train/total_loss": train_loss, "train/real_prediction": np.mean(train_real_preds) if len(train_real_preds) > 0 else 0, "train/fake_prediction": np.mean(train_fake_preds) if len(train_fake_preds) > 0 else 0, "val/recon_loss_sh": val_metrics["recon_loss_sh"], #"val/gradient_loss_albedo": val_metrics["gradient_loss_albedo"], "val/kl_loss": val_metrics["kl_loss"], "val/perceptual_loss": val_metrics["perceptual_loss"], "val/gen_loss": val_metrics["gen_loss"], "val/disc_loss": val_metrics["disc_loss"], "val/total_loss": val_metrics["total_loss"], "val/real_prediction": val_metrics["real_prediction"], "val/fake_prediction": val_metrics["fake_prediction"], "learning_rate/generator": optimizer_g.param_groups[0]['lr'], "learning_rate/discriminator": optimizer_d.param_groups[0]['lr'] }) # Print epoch results print('\n' + '=' * 80) print(f'Epoch {epoch_idx + 1}/{num_epochs}') print('-' * 80) print('TRAINING:') print(f'Recon Loss_sh: {np.mean(recon_losses_sh):.4f} | ' #f'Gradient Loss_albedo: {np.mean(grad_losses_albedo):.4f} | ' f'KL Loss: {np.mean(kl_losses):.4f} | ' f'Perceptual Loss: {np.mean(perceptual_losses):.4f}') if len(disc_losses) > 0 and len(gen_losses) > 0: print(f'Generator Loss: {np.mean(gen_losses):.4f} | ' f'Discriminator Loss: {np.mean(disc_losses):.4f}') print(f'Total Training Loss: {train_loss:.4f}') print('\nVALIDATION:') print(f'Recon Loss_shading: {val_metrics["recon_loss_sh"]:.4f} | ' #f'Gradient Loss_shading: {val_metrics["gradient_loss_albedo"]:.4f} | ' f'KL Loss: {val_metrics["kl_loss"]:.4f} | ' f'Perceptual Loss: {val_metrics["perceptual_loss"]:.4f}') print(f'Generator Loss: {val_metrics["gen_loss"]:.4f} | ' f'Discriminator Loss: {val_metrics["disc_loss"]:.4f}') print(f'Total Validation Loss: {val_metrics["total_loss"]:.4f}') # Check if validation loss improved if val_loss < best_val_loss: print(f"\nValidation loss improved from {best_val_loss:.4f} to {val_loss:.4f}. Saving model...") best_val_loss = val_loss # # Save only the best model checkpoint # checkpoint = { # 'epoch': epoch_idx + 1, # 'model_state_dict': model.state_dict(), # 'discriminator_state_dict': discriminator.state_dict(), # 'optimizer_g_state_dict': optimizer_g.state_dict(), # 'optimizer_d_state_dict': optimizer_d.state_dict(), # 'best_val_loss': best_val_loss, # 'step_count': step_count, # 'img_save_count': img_save_count, # 'train_losses_history': train_losses_history, # 'val_losses_history': val_losses_history # } # torch.save(checkpoint, os.path.join(train_config['task_name'], 'best_autoencoder_model_checkpoint.pth')) # # Save individual model files for compatibility with original code # torch.save(model.state_dict(), os.path.join(train_config['task_name'], # train_config['vae_autoencoder_ckpt_name'])) # torch.save(discriminator.state_dict(), os.path.join(train_config['task_name'], # else: print(f"Validation loss did not improve from {best_val_loss:.4f}") if (epoch_idx + 1) % 5 == 0: checkpoint = { 'epoch': epoch_idx + 1, 'model_state_dict': model.state_dict(), 'discriminator_state_dict': discriminator.state_dict(), 'optimizer_g_state_dict': optimizer_g.state_dict(), 'optimizer_d_state_dict': optimizer_d.state_dict(), 'best_val_loss': best_val_loss, 'step_count': step_count, 'img_save_count': img_save_count, 'train_losses_history': train_losses_history, 'val_losses_history': val_losses_history } torch.save(checkpoint, os.path.join(train_config['task_name'], f'epoch_{epoch_idx + 1}_best_autoencoder_model_checkpoint.pth')) # Save individual model files for compatibility with original code torch.save(model.state_dict(), os.path.join(train_config['task_name'], train_config['vae_autoencoder_ckpt_name'])) torch.save(discriminator.state_dict(), os.path.join(train_config['task_name'], train_config['vae_discriminator_ckpt_name'])) # Every 10 epochs if epoch_idx % 1 == 0: # Every 5 epochs gc.collect() torch.cuda.empty_cache() print('=' * 80 + '\n') print('Done Training...') # Save final training history np.savez(os.path.join(train_config['task_name'], 'training_history.npz'), train_losses=np.array(train_losses_history), val_losses=np.array(val_losses_history)) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Arguments for vae training') parser.add_argument('--config', dest='config_path', default='config/autoen_alb_1.yaml', type=str) # Handle Jupyter/IPython arguments import sys args, unknown = parser.parse_known_args(sys.argv[1:]) train(args)