import os from os import listdir from os.path import isfile import torch import numpy as np import torchvision import torch.utils.data import re import random import pandas as pd import matplotlib.pyplot as plt import tqdm import torch.nn.functional as F import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from datetime import datetime import torch.nn as nn from torch import optim import h5py class EMAHelper(object): def __init__(self, mu=0.999): self.mu = mu self.shadow = {} def register(self, module): if isinstance(module, nn.DataParallel): module = module.module for name, param in module.named_parameters(): if param.requires_grad: self.shadow[name] = param.data.clone() def update(self, module): if isinstance(module, nn.DataParallel): module = module.module for name, param in module.named_parameters(): if param.requires_grad: self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data def ema(self, module): if isinstance(module, nn.DataParallel): module = module.module for name, param in module.named_parameters(): if param.requires_grad: param.data.copy_(self.shadow[name].data) def ema_copy(self, module): if isinstance(module, nn.DataParallel): inner_module = module.module module_copy = type(inner_module)(inner_module.config).to(inner_module.config.device) module_copy.load_state_dict(inner_module.state_dict()) module_copy = nn.DataParallel(module_copy) else: module_copy = type(module)(module.config).to(module.config.device) module_copy.load_state_dict(module.state_dict()) self.ema(module_copy) return module_copy def state_dict(self): return self.shadow def load_state_dict(self, state_dict): self.shadow = state_dict class spec: def __init__(self, data_dir,batch_size,num_workers): self.data_dir = data_dir self.batch_size = batch_size self.num_workers = num_workers self.transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) def get_loaders(self): print("=> Loader the spectra dataset...") train_dataset = specDataset(dir=os.path.join(self.data_dir, '../0_40PbPb502train'), transforms = self.transforms) val_dataset = specDataset(dir=os.path.join(self.data_dir, '../0_40PbPb502val'), transforms=self.transforms) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = self.batch_size, shuffle = True, num_workers = self.num_workers, pin_memory = True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True) return train_loader , val_loader class specDataset(torch.utils.data.Dataset): def __init__(self, dir, transforms): super().__init__() self.dir = dir spec_dir = dir input_names = [] #training file list inputs = os.path.join(spec_dir) profiles = [f for f in listdir(inputs) if isfile(os.path.join(inputs, f))] input_names += [os.path.join(inputs, i) for i in profiles] #this is a list of filenames x = list(enumerate(input_names)) random.shuffle(x) indices, input_names = zip(*x) self.input_names = input_names self.transforms = transforms def get_profiles(self, index): input_name = self.input_names[index] #read h5 file dataset = h5py.File(input_name, 'r') IC = dataset['IC'][:] chSpec = dataset['chSpec'][:] shear = int(input_name[-4]) IC = IC[18:82,18:82] return self.transforms(IC).view(1,64,64), self.transforms(chSpec).view(1,64,64), shear def __getitem__(self, index): res = self.get_profiles(index) return res def __len__(self): return len(self.input_names) data_dir = './' batch_size = 32 num_workers = 4 DATASET = spec(data_dir, batch_size, num_workers) train_loader, valloader = DATASET.get_loaders() print(len(train_loader)) print(len(valloader)) from diffusion_cond import * from Diff_unet_attn import * def init_weights(m): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def validate(model, valloader, diffusion, device): model.eval() #ema_model.eval() val_loss = 0 mse = nn.MSELoss() with torch.no_grad(): for i, (condition,Spec,shear) in enumerate(valloader): condition = condition.to(device).to(torch.float32) Spec = Spec.to(device).to(torch.float32) shear = shear.to(device) t = diffusion.sample_timesteps(Spec.shape[0]).to(device) Spec_t, noise = diffusion.noise_images(Spec, t) predicted_noise = model(torch.cat([condition, Spec_t], dim=1), t,shear) loss = mse(predicted_noise,noise) val_loss += loss.item() val_loss /= len(valloader) return val_loss epochs = 10000 def train(epochs): device = torch.device("cuda" if torch.cuda.is_available else "cpu") model = DiffusionUNet(ch = 128, num_res_blocks=2, image_size = 64, drop_out = 0).to(device) if torch.cuda.device_count() > 1: print(torch.cuda.device_count()) model = nn.DataParallel(model) model.apply(init_weights) ema_helper = EMAHelper() ema_helper.register(model) optimizer = optim.AdamW(model.parameters(), lr=0.0001) #optimizer = optim.Adam(model.parameters(), lr=0.0005, weight_decay=0.0,betas=(0.9, 0.999), amsgrad=False, eps=0.00000001) #scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.00001, max_lr=0.001, step_size_up=2000, mode='triangular2') mse = nn.MSELoss() diffusion = Diffusion(img_size = 64, device=device) l = len(train_loader) best_loss = float("inf") for epoch in range(epochs): epoch_loss=0 for i, (condition, Spec, shear) in enumerate(train_loader): condition = condition.to(device).to(torch.float32) Spec = Spec.to(device).to(torch.float32) shear = shear.to(device) t = diffusion.sample_timesteps(Spec.shape[0]).to(device) Spec_t, noise = diffusion.noise_images(Spec, t) predicted_noise = model(torch.cat([condition, Spec_t], dim=1), t,shear) loss = mse(predicted_noise,noise) optimizer.zero_grad() loss.backward() optimizer.step() ema_helper.update(model) epoch_loss = epoch_loss+loss.item()/l val_loss=validate(model, valloader, diffusion, device) current_time = datetime.now() print(current_time,f"Epoch [{epoch+1}/{epochs}], Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}") if best_loss > val_loss: best_loss = val_loss torch.save(model.state_dict(), 'model.pth') torch.save(ema_helper.state_dict(), 'ema_model.pth') current_time = datetime.now() print(current_time,'model is saved. The loss is ',val_loss) train(epochs=10000)