|
import os |
|
from os import listdir |
|
from os.path import isfile |
|
import torch |
|
import numpy as np |
|
import torchvision |
|
import torch.utils.data |
|
import re |
|
import random |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import tqdm |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import torch.distributed as dist |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from datetime import datetime |
|
import torch.nn as nn |
|
from torch import optim |
|
import h5py |
|
from datetime import datetime |
|
|
|
|
|
class spec: |
|
def __init__(self, data_dir,batch_size,num_workers): |
|
self.data_dir = data_dir |
|
self.batch_size = batch_size |
|
self.num_workers = num_workers |
|
self.transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) |
|
|
|
def get_loaders(self): |
|
print("=> Loader the spectra dataset...") |
|
train_dataset = specDataset(dir=os.path.join(self.data_dir, './dataset'), transforms = self.transforms) |
|
|
|
|
|
|
|
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = self.batch_size, shuffle = True, num_workers = self.num_workers, pin_memory = True) |
|
|
|
|
|
return train_loader |
|
|
|
|
|
class specDataset(torch.utils.data.Dataset): |
|
def __init__(self, dir, transforms): |
|
super().__init__() |
|
self.dir = dir |
|
spec_dir = dir |
|
input_names = [] |
|
|
|
|
|
inputs = os.path.join(spec_dir) |
|
profiles = [f for f in listdir(inputs) if isfile(os.path.join(inputs, f))] |
|
|
|
|
|
input_names += [os.path.join(inputs, i) for i in profiles] |
|
|
|
x = list(enumerate(input_names)) |
|
random.shuffle(x) |
|
indices, input_names = zip(*x) |
|
|
|
self.input_names = input_names |
|
self.transforms = transforms |
|
|
|
def get_profiles(self, index): |
|
input_name = self.input_names[index] |
|
|
|
|
|
dataset = h5py.File(input_name, 'r') |
|
IC = dataset['IC'][:] |
|
chSpec = dataset['chSpec'][:] |
|
|
|
shear = int(input_name[-4]) |
|
|
|
|
|
return self.transforms(IC).view(1,100,100), self.transforms(chSpec).view(1,64,64), shear, input_name |
|
|
|
def __getitem__(self, index): |
|
res = self.get_profiles(index) |
|
return res |
|
|
|
def __len__(self): |
|
return len(self.input_names) |
|
|
|
|
|
data_dir = './' |
|
batch_size = 32 |
|
num_workers = 4 |
|
|
|
DATASET = spec(data_dir, batch_size, num_workers) |
|
train_loader = DATASET.get_loaders() |
|
print(train_loader) |
|
|
|
from diffusion_cond import * |
|
from Diff_unet_attn import * |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available else "cpu") |
|
|
|
model = DiffusionUNet(ch = 128, num_res_blocks=2, image_size = 64, drop_out = 0).to(device) |
|
diffusion = Diffusion(img_size = 64, device=device) |
|
|
|
|
|
model.load_state_dict(torch.load('pretrained_weight.pth')) |
|
|
|
print(len(train_loader)) |
|
for i, (condition, Spec, shear, file_name) in enumerate(train_loader): |
|
condition = condition.to(device).to(torch.float32) |
|
Spec = Spec.to(device).to(torch.float32) |
|
shear = shear.to(device) |
|
sample_times = 1 |
|
t = diffusion.sample_timesteps(Spec.shape[0]).to(device) |
|
x = torch.zeros(Spec.size(0),sample_times,64,64).to(device) |
|
for ix in range(sample_times): |
|
x[:,ix,:,:]=diffusion.sample(model, condition[:,:,18:82,18:82], Spec, shear).squeeze(1) |
|
|
|
num = x.size(0) |
|
for j in range(num): |
|
|
|
file = file_name[j] |
|
file = os.path.basename(file) |
|
|
|
file='shear1_ddim1_'+file |
|
print(file) |
|
temp_truth = Spec[j].view(64,64).cpu().numpy() |
|
temp_genera = x[j].view(sample_times,64,64).cpu().numpy() |
|
temp_condition = condition[j].view(100,100).cpu().numpy() |
|
with h5py.File('./test_res/'+file, 'w') as h5f: |
|
h5f.create_dataset('truth', data=temp_truth) |
|
for ix in range(sample_times): |
|
h5f.create_dataset('genera_{}'.format(ix), data=temp_genera[ix,:,:].reshape((64,64))) |
|
h5f.create_dataset('IC', data=temp_condition) |
|
|
|
|