import os from os import listdir from os.path import isfile import torch import numpy as np import torchvision import torch.utils.data import re import random import pandas as pd import matplotlib.pyplot as plt import tqdm import torch.nn.functional as F import numpy as np import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from datetime import datetime import torch.nn as nn from torch import optim import h5py from datetime import datetime class spec: def __init__(self, data_dir,batch_size,num_workers): self.data_dir = data_dir self.batch_size = batch_size self.num_workers = num_workers self.transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) def get_loaders(self): print("=> Loader the spectra dataset...") train_dataset = specDataset(dir=os.path.join(self.data_dir, './dataset'), transforms = self.transforms) #val_dataset = specDataset(dir=os.path.join(self.data_dir, 'val_data'), transforms=self.transforms) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = self.batch_size, shuffle = True, num_workers = self.num_workers, pin_memory = True) #val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True) return train_loader #, val_loader class specDataset(torch.utils.data.Dataset): def __init__(self, dir, transforms): super().__init__() self.dir = dir spec_dir = dir input_names = [] # #training file list inputs = os.path.join(spec_dir) profiles = [f for f in listdir(inputs) if isfile(os.path.join(inputs, f))] input_names += [os.path.join(inputs, i) for i in profiles] #this is a list of filenames x = list(enumerate(input_names)) random.shuffle(x) indices, input_names = zip(*x) self.input_names = input_names self.transforms = transforms def get_profiles(self, index): input_name = self.input_names[index] #read h5 file dataset = h5py.File(input_name, 'r') IC = dataset['IC'][:] chSpec = dataset['chSpec'][:] shear = int(input_name[-4]) #IC = IC[18:82,18:82] return self.transforms(IC).view(1,100,100), self.transforms(chSpec).view(1,64,64), shear, input_name def __getitem__(self, index): res = self.get_profiles(index) return res def __len__(self): return len(self.input_names) data_dir = './' batch_size = 32 num_workers = 4 DATASET = spec(data_dir, batch_size, num_workers) train_loader = DATASET.get_loaders() print(train_loader) from diffusion_cond import * from Diff_unet_attn import * device = torch.device("cuda:0" if torch.cuda.is_available else "cpu") model = DiffusionUNet(ch = 128, num_res_blocks=2, image_size = 64, drop_out = 0).to(device) diffusion = Diffusion(img_size = 64, device=device) model.load_state_dict(torch.load('pretrained_weight.pth')) print(len(train_loader)) for i, (condition, Spec, shear, file_name) in enumerate(train_loader): condition = condition.to(device).to(torch.float32) Spec = Spec.to(device).to(torch.float32) shear = shear.to(device) sample_times = 1 t = diffusion.sample_timesteps(Spec.shape[0]).to(device) x = torch.zeros(Spec.size(0),sample_times,64,64).to(device) for ix in range(sample_times): x[:,ix,:,:]=diffusion.sample(model, condition[:,:,18:82,18:82], Spec, shear).squeeze(1) num = x.size(0) for j in range(num): file = file_name[j] file = os.path.basename(file) #print(file) file='shear1_ddim1_'+file print(file) temp_truth = Spec[j].view(64,64).cpu().numpy() temp_genera = x[j].view(sample_times,64,64).cpu().numpy() temp_condition = condition[j].view(100,100).cpu().numpy() with h5py.File('./test_res/'+file, 'w') as h5f: h5f.create_dataset('truth', data=temp_truth) for ix in range(sample_times): h5f.create_dataset('genera_{}'.format(ix), data=temp_genera[ix,:,:].reshape((64,64))) h5f.create_dataset('IC', data=temp_condition)