|
import os |
|
from os import listdir |
|
from os.path import isfile |
|
import torch |
|
import numpy as np |
|
import torchvision |
|
import torch.utils.data |
|
import re |
|
import random |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import tqdm |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import torch.distributed as dist |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from datetime import datetime |
|
import torch.nn as nn |
|
from torch import optim |
|
import h5py |
|
class EMAHelper(object): |
|
def __init__(self, mu=0.999): |
|
self.mu = mu |
|
self.shadow = {} |
|
|
|
def register(self, module): |
|
if isinstance(module, nn.DataParallel): |
|
module = module.module |
|
for name, param in module.named_parameters(): |
|
if param.requires_grad: |
|
self.shadow[name] = param.data.clone() |
|
|
|
def update(self, module): |
|
if isinstance(module, nn.DataParallel): |
|
module = module.module |
|
for name, param in module.named_parameters(): |
|
if param.requires_grad: |
|
self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data |
|
|
|
def ema(self, module): |
|
if isinstance(module, nn.DataParallel): |
|
module = module.module |
|
for name, param in module.named_parameters(): |
|
if param.requires_grad: |
|
param.data.copy_(self.shadow[name].data) |
|
|
|
def ema_copy(self, module): |
|
if isinstance(module, nn.DataParallel): |
|
inner_module = module.module |
|
module_copy = type(inner_module)(inner_module.config).to(inner_module.config.device) |
|
module_copy.load_state_dict(inner_module.state_dict()) |
|
module_copy = nn.DataParallel(module_copy) |
|
else: |
|
module_copy = type(module)(module.config).to(module.config.device) |
|
module_copy.load_state_dict(module.state_dict()) |
|
self.ema(module_copy) |
|
return module_copy |
|
|
|
def state_dict(self): |
|
return self.shadow |
|
|
|
def load_state_dict(self, state_dict): |
|
self.shadow = state_dict |
|
|
|
|
|
class spec: |
|
def __init__(self, data_dir,batch_size,num_workers): |
|
self.data_dir = data_dir |
|
self.batch_size = batch_size |
|
self.num_workers = num_workers |
|
self.transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) |
|
|
|
def get_loaders(self): |
|
print("=> Loader the spectra dataset...") |
|
train_dataset = specDataset(dir=os.path.join(self.data_dir, '../0_40PbPb502train'), transforms = self.transforms) |
|
val_dataset = specDataset(dir=os.path.join(self.data_dir, '../0_40PbPb502val'), transforms=self.transforms) |
|
|
|
|
|
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = self.batch_size, shuffle = True, num_workers = self.num_workers, pin_memory = True) |
|
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True) |
|
|
|
return train_loader , val_loader |
|
|
|
|
|
class specDataset(torch.utils.data.Dataset): |
|
def __init__(self, dir, transforms): |
|
super().__init__() |
|
self.dir = dir |
|
spec_dir = dir |
|
input_names = [] |
|
|
|
|
|
inputs = os.path.join(spec_dir) |
|
profiles = [f for f in listdir(inputs) if isfile(os.path.join(inputs, f))] |
|
|
|
|
|
input_names += [os.path.join(inputs, i) for i in profiles] |
|
|
|
x = list(enumerate(input_names)) |
|
random.shuffle(x) |
|
indices, input_names = zip(*x) |
|
|
|
self.input_names = input_names |
|
self.transforms = transforms |
|
|
|
def get_profiles(self, index): |
|
input_name = self.input_names[index] |
|
|
|
|
|
dataset = h5py.File(input_name, 'r') |
|
IC = dataset['IC'][:] |
|
chSpec = dataset['chSpec'][:] |
|
|
|
shear = int(input_name[-4]) |
|
IC = IC[18:82,18:82] |
|
|
|
return self.transforms(IC).view(1,64,64), self.transforms(chSpec).view(1,64,64), shear |
|
|
|
def __getitem__(self, index): |
|
res = self.get_profiles(index) |
|
return res |
|
|
|
def __len__(self): |
|
return len(self.input_names) |
|
|
|
data_dir = './' |
|
batch_size = 32 |
|
num_workers = 4 |
|
|
|
DATASET = spec(data_dir, batch_size, num_workers) |
|
train_loader, valloader = DATASET.get_loaders() |
|
print(len(train_loader)) |
|
print(len(valloader)) |
|
|
|
from diffusion_cond import * |
|
from Diff_unet_attn import * |
|
|
|
def init_weights(m): |
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): |
|
nn.init.xavier_uniform_(m.weight) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.BatchNorm2d): |
|
nn.init.constant_(m.weight, 1) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def validate(model, valloader, diffusion, device): |
|
model.eval() |
|
|
|
val_loss = 0 |
|
mse = nn.MSELoss() |
|
with torch.no_grad(): |
|
for i, (condition,Spec,shear) in enumerate(valloader): |
|
condition = condition.to(device).to(torch.float32) |
|
Spec = Spec.to(device).to(torch.float32) |
|
shear = shear.to(device) |
|
|
|
t = diffusion.sample_timesteps(Spec.shape[0]).to(device) |
|
Spec_t, noise = diffusion.noise_images(Spec, t) |
|
|
|
predicted_noise = model(torch.cat([condition, Spec_t], dim=1), t,shear) |
|
loss = mse(predicted_noise,noise) |
|
val_loss += loss.item() |
|
val_loss /= len(valloader) |
|
return val_loss |
|
|
|
epochs = 10000 |
|
|
|
|
|
def train(epochs): |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available else "cpu") |
|
|
|
model = DiffusionUNet(ch = 128, num_res_blocks=2, image_size = 64, drop_out = 0).to(device) |
|
|
|
if torch.cuda.device_count() > 1: |
|
print(torch.cuda.device_count()) |
|
model = nn.DataParallel(model) |
|
|
|
model.apply(init_weights) |
|
ema_helper = EMAHelper() |
|
ema_helper.register(model) |
|
optimizer = optim.AdamW(model.parameters(), lr=0.0001) |
|
|
|
|
|
|
|
mse = nn.MSELoss() |
|
|
|
diffusion = Diffusion(img_size = 64, device=device) |
|
|
|
l = len(train_loader) |
|
|
|
best_loss = float("inf") |
|
|
|
for epoch in range(epochs): |
|
|
|
epoch_loss=0 |
|
for i, (condition, Spec, shear) in enumerate(train_loader): |
|
|
|
condition = condition.to(device).to(torch.float32) |
|
Spec = Spec.to(device).to(torch.float32) |
|
shear = shear.to(device) |
|
|
|
t = diffusion.sample_timesteps(Spec.shape[0]).to(device) |
|
Spec_t, noise = diffusion.noise_images(Spec, t) |
|
predicted_noise = model(torch.cat([condition, Spec_t], dim=1), t,shear) |
|
loss = mse(predicted_noise,noise) |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
ema_helper.update(model) |
|
epoch_loss = epoch_loss+loss.item()/l |
|
|
|
val_loss=validate(model, valloader, diffusion, device) |
|
current_time = datetime.now() |
|
print(current_time,f"Epoch [{epoch+1}/{epochs}], Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}") |
|
if best_loss > val_loss: |
|
best_loss = val_loss |
|
torch.save(model.state_dict(), 'model.pth') |
|
torch.save(ema_helper.state_dict(), 'ema_model.pth') |
|
current_time = datetime.now() |
|
print(current_time,'model is saved. The loss is ',val_loss) |
|
|
|
train(epochs=10000) |
|
|