SFM_Inference_Demo / util /datasets.py
Anirudh Bhalekar
added models and util folder
a3f0d6c
raw
history blame
20.8 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
import os
import PIL
import os, random, glob
import numpy as np
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from os.path import isfile, join
import segyio
from itertools import permutations
random.seed(42)
from torchvision import datasets, transforms
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def build_dataset(is_train, args):
transform = build_transform(is_train, args)
root = os.path.join(args.data_path, 'train' if is_train else 'val')
dataset = datasets.ImageFolder(root, transform=transform)
print(dataset)
return dataset
def build_transform(is_train, args):
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
# train transform
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation='bicubic',
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
mean=mean,
std=std,
)
return transform
# eval transform
t = []
if args.input_size <= 224:
crop_pct = 224 / 256
else:
crop_pct = 1.0
size = int(args.input_size / crop_pct)
t.append(
transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(args.input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean, std))
return transforms.Compose(t)
## pretrain
class SeismicSet(data.Dataset):
def __init__(self, path, input_size) -> None:
super().__init__()
# self.file_list = os.listdir(path)
# self.file_list = [os.path.join(path, f) for f in self.file_list]
self.get_file_list(path)
self.input_size = input_size
print(len(self.file_list))
def __len__(self) -> int:
return len(self.file_list)
# return 100000
def __getitem__(self, index):
d = np.fromfile(self.file_list[index], dtype=np.float32)
d = d.reshape(1, self.input_size, self.input_size)
d = (d - d.mean()) / (d.std()+1e-6)
# return to_transforms(d, self.input_size)
return d,torch.tensor([1])
def get_file_list(self, path):
dirs = [os.path.join(path, f) for f in os.listdir(path)]
self.file_list = dirs
# for ds in dirs:
# if os.path.isdir(ds):
# self.file_list += [os.path.join(ds, f) for f in os.listdir(ds)]
return random.shuffle(self.file_list)
def to_transforms(d, input_size):
t = transforms.Compose([
transforms.RandomResizedCrop(input_size,
scale=(0.2, 1.0),
interpolation=3), # 3 is bicubic
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
return t(d)
### fintune
class FacesSet(data.Dataset):
# folder/train/data/**.dat, folder/train/label/**.dat
# folder/test/data/**.dat, folder/test/label/**.dat
def __init__(self,
folder,
shape=[768, 768],
is_train=True) -> None:
super().__init__()
self.shape = shape
# self.data_list = sorted(glob.glob(folder + 'seismic/*.dat'))
self.data_list = [folder +'seismic/'+ str(f)+'.dat' for f in range(117)]
n = len(self.data_list)
if is_train:
self.data_list = self.data_list[:100]
elif not is_train:
self.data_list = self.data_list[100:]
self.label_list = [
f.replace('/seismic/', '/label/') for f in self.data_list
]
def __getitem__(self, index):
d = np.fromfile(self.data_list[index], np.float32)
d = d.reshape([1] + self.shape)
l = np.fromfile(self.label_list[index], np.float32).reshape(self.shape)-1
l = l.astype(int)
return torch.tensor(d), torch.tensor(l)
def __len__(self):
return len(self.data_list)
class SaltSet(data.Dataset):
def __init__(self,
folder,
shape=[224, 224],
is_train=True) -> None:
super().__init__()
self.shape = shape
self.data_list = [folder +'seismic/'+ str(f)+'.dat' for f in range(4000)]
n = len(self.data_list)
if is_train:
self.data_list = self.data_list[:3500]
elif not is_train:
self.data_list = self.data_list[3500:]
self.label_list = [
f.replace('/seismic/', '/label/') for f in self.data_list
]
def __getitem__(self, index):
d = np.fromfile(self.data_list[index], np.float32)
d = d.reshape([1] + self.shape)
l = np.fromfile(self.label_list[index], np.float32).reshape(self.shape)
l = l.astype(int)
return torch.tensor(d), torch.tensor(l)
def __len__(self):
return len(self.data_list)
class InterpolationSet(data.Dataset):
# folder/train/data/**.dat, folder/train/label/**.dat
# folder/test/data/**.dat, folder/test/label/**.dat
def __init__(self,
folder,
shape=[224, 224],
is_train=True) -> None:
super().__init__()
self.shape = shape
self.data_list = [folder + str(f)+'.dat' for f in range(6000)]
n = len(self.data_list)
if is_train:
self.data_list = self.data_list
elif not is_train:
self.data_list = [folder+'U'+ + str(f)+'.dat' for f in range(2000,4000)]
self.label_list = self.data_list
def __getitem__(self, index):
d = np.fromfile(self.data_list[index], np.float32)
d = d.reshape([1] + self.shape)
return torch.tensor(d), torch.tensor(d)
def __len__(self):
return len(self.data_list)
# return 10000
class DenoiseSet(data.Dataset):
def __init__(self,
folder,
shape=[224, 224],
is_train=True) -> None:
super().__init__()
self.shape = shape
self.data_list = [folder+'seismic/'+ str(f)+'.dat' for f in range(2000)]
n = len(self.data_list)
if is_train:
self.data_list = self.data_list
self.label_list = [f.replace('/seismic/', '/label/') for f in self.data_list]
elif not is_train:
self.data_list = [folder+'field/'+ str(f)+'.dat' for f in range(4000)]
self.label_list = self.data_list
def __getitem__(self, index):
d = np.fromfile(self.data_list[index], np.float32)
d = d.reshape([1] + self.shape)
# d = (d - d.mean())/d.std()
l = np.fromfile(self.label_list[index], np.float32)
l = l.reshape([1] + self.shape)
# l = (l - d.mean())/l.std()
return torch.tensor(d), torch.tensor(l)
def __len__(self):
return len(self.data_list)
class ReflectSet(data.Dataset):
# folder/train/data/**.dat, folder/train/label/**.dat
# folder/test/data/**.dat, folder/test/label/**.dat
def __init__(self,
folder,
shape=[224, 224],
is_train=True) -> None:
super().__init__()
self.shape = shape
self.data_list = [folder+'seismic/'+ str(f)+'.dat' for f in range(2200)]
n = len(self.data_list)
if is_train:
self.data_list = self.data_list
self.label_list = [
f.replace('/seismic/', '/label/') for f in self.data_list
]
elif not is_train:
self.data_list = [folder+'SEAMseismic/'+ str(f)+'.dat' for f in range(4000)]
self.label_list = [
f.replace('/SEAMseismic/', '/SEAMreflect/') for f in self.data_list
]
def __getitem__(self, index):
d = np.fromfile(self.data_list[index], np.float32)
d = d- d.mean()
d = d/(d.std()+1e-6)
d = d.reshape([1] + self.shape)
l = np.fromfile(self.label_list[index], np.float32)
l = l-l.mean()
l = l/(l.std()+1e-6)
l = l.reshape([1] + self.shape)
return torch.tensor(d), torch.tensor(l)
def __len__(self):
return len(self.data_list)
class ThebeSet(data.Dataset):
def __init__(self, folder, shape=[224, 224], mode ='train') -> None:
super().__init__()
self.folder = folder
if not os.path.exists(folder):
raise FileNotFoundError(f"The folder {folder} does not exist.")
self.num_files = len(os.listdir(join(folder, 'fault')))
self.shape = shape
self.fault_list = [folder + '/fault/{i}.npy'.format(i=i) for i in range(1, self.num_files + 1)]
self.seis_list = [folder + '/seis/{i}.npy'.format(i=i) for i in range(1, self.num_files + 1)]
self.train_size = int(0.75 * self.num_files)
self.val_size = int(0.15 * self.num_files)
self.test_size = self.num_files - self.train_size - self.val_size
self.train_index = self.train_size
self.val_index = self.train_index + self.val_size
if mode == 'train':
self.fault_list = self.fault_list[:self.train_index]
self.seis_list = self.seis_list[:self.train_index]
elif mode == 'val':
self.fault_list = self.fault_list[self.train_index:self.val_index]
self.seis_list = self.seis_list[self.train_index:self.val_index]
elif mode == 'test':
self.fault_list = self.fault_list[self.val_index:]
self.seis_list = self.seis_list[self.val_index:]
else:
raise ValueError("Mode must be 'train', 'val', or 'test'.")
def __len__(self):
return len(self.fault_list)
def retrieve_patch(self, fault, seis):
# image will (probably) be of size [3174, 1537]
# return a patch of size [224, 224]
patch_height = self.shape[0]
patch_width = self.shape[1]
h, w = fault.shape
if h < patch_height or w < patch_width:
raise ValueError(f"Image dimensions must be at least {patch_height}x{patch_width}.")
top = random.randint(0, h - patch_height)
left = random.randint(0, w - patch_width)
return fault[top:top + patch_height, left:left + patch_width], seis[top:top + patch_height, left:left + patch_width]
def random_transform(self, fault, seis):
# Apply the same random transformations to the fault and seismic data
# Mirror the patch horizontally
if random.random() > 0.5:
fault = np.fliplr(fault)
seis = np.fliplr(seis)
# Mirror the patch vertically
if random.random() > 0.5:
fault = np.flipud(fault)
seis = np.flipud(seis)
return fault, seis
def __getitem__(self, index):
# need to see if we do normalization here (i.e. what data pre-treatement we do)
fault = np.load(self.fault_list[index])
seis = np.load(self.seis_list[index])
fault, seis = self.retrieve_patch(fault, seis)
fault, seis = self.random_transform(fault, seis)
seis = (seis - seis.mean()) / (seis.std() + 1e-6)
fault = torch.tensor(fault.copy(), dtype=torch.float32).unsqueeze(0)
seis = torch.tensor(seis.copy(), dtype=torch.float32).unsqueeze(0)
return seis, fault
class FSegSet(data.Dataset):
def __init__(self, folder, shape=[128, 128], mode ='train') -> None:
super().__init__()
self.folder = folder
if not os.path.exists(folder):
raise FileNotFoundError(f"The folder {folder} does not exist.")
self.shape = shape
self.mode = mode
if mode == 'train':
self.fault_path = join(self.folder, 'train/fault')
self.seis_path = join(self.folder, 'train/seis')
elif mode == 'val':
self.fault_path = join(self.folder, 'val/fault')
self.seis_path = join(self.folder, 'val/seis')
else:
raise ValueError("Mode must be 'train' or 'val'.")
self.fault_list = [join(self.fault_path, f) for f in os.listdir(self.fault_path) if f.endswith('.npy')]
self.seis_list = [join(self.seis_path, f) for f in os.listdir(self.seis_path) if f.endswith('.npy')]
def __len__(self):
return len(self.fault_list)
def __getitem__(self, index):
fault_img, seis_img = np.load(self.fault_list[index]), np.load(self.seis_list[index])
# These will be 128x128
seis_img = (seis_img - seis_img.mean()) / (seis_img.std() + 1e-6)
fault = torch.tensor(fault_img.copy(), dtype=torch.float32).unsqueeze(0)
seis = torch.tensor(seis_img.copy(), dtype=torch.float32).unsqueeze(0)
return seis, fault
class F3DFaciesSet(data.Dataset):
def __init__(self, folder, shape=[128, 128], mode='train', random_resize = False):
super().__init__()
self.folder = folder
if not os.path.exists(folder):
raise FileNotFoundError(f"The folder {folder} does not exist.")
self.seises = np.load(join(folder, "{}/seismic.npy".format(mode)))
self.labels = np.load(join(folder, "{}/labels.npy".format(mode)))
self.image_shape = shape
if mode == 'train':
self.size_categories = [
(401, 701),
(701, 255),
(401, 255)
]
elif mode == 'val':
self.size_categories = [
(601, 200),
(200, 255),
(601, 255)
]
elif mode == 'test':
self.size_categories = [
(701, 255),
(200, 701),
(200, 255)
]
else:
raise ValueError("Mode must be 'train', 'val', or 'test'.")
def __len__(self):
# We will take cross sections along each dimension, so the length is the sum of all dimensions
return sum(self.seises.shape)
def random_transform(self, label, seis):
# Apply the same random transformations to the fault and seismic data
# Mirror the patch horizontally
if random.random() > 0.5:
label = np.fliplr(label)
seis = np.fliplr(seis)
# Mirror the patch vertically
if random.random() > 0.5:
label = np.flipud(label)
seis = np.flipud(seis)
return label, seis
def __getitem__(self, index):
m1, m2, m3 = self.seises.shape
if index < m1:
seis, label = self.seises[index, :, :], self.labels[index, :, :]
elif index < m1 + m2:
seis, label = self.seises[:, index - m1, :], self.labels[:, index - m1, :]
elif index < m1 + m2 + m3:
seis, label = self.seises[:, :, index - m1 - m2], self.labels[:, :, index - m1 - m2]
else:
raise IndexError("Index out of bounds")
seis, label = self.random_transform(seis, label)
seis = (seis - seis.mean()) / (seis.std() + 1e-6)
seis, label = torch.tensor(seis.copy(), dtype=torch.float32).unsqueeze(0), torch.tensor(label.copy(), dtype=torch.float32).unsqueeze(0)
# label is now shape [1, H, W]
# we want shape [6, H, W] with each slice being a binary mask depending on the int value of label
label = label.squeeze(0)
label = (label == torch.arange(6, device=label.device).view(6, 1, 1)).float()
return seis, label
class P3DFaciesSet(data.Dataset):
def __init__(self, folder, shape=[128, 128], mode='train', random_resize = False):
super().__init__()
self.folder = folder
if not os.path.exists(folder):
raise FileNotFoundError(f"The folder {folder} does not exist.")
self.random_resize = random_resize
# Validation set will be validation set from F3DSet
if mode == 'val': mode = 'train' # TEMPORARY SINCE P3D does not have labelled val set
self.mode = mode
self.image_shape = shape
self.s_path = join(folder, "{}/seismic.segy".format(mode))
self.l_path = join(folder, "{}/labels.segy".format(mode))
if mode != 'val':
with segyio.open(self.s_path, ignore_geometry=True) as seis_file:
self.seises = seis_file.trace.raw[:]
if self.mode in ['val', 'train']:
with segyio.open(self.l_path, ignore_geometry=True) as label_file:
self.labels = label_file.trace.raw[:]
else:
# Since the test files are unlabeled
self.labels = np.zeros_like(self.seises)
else:
f3d_file_path = "C:\\Users\\abhalekar\\Desktop\\DATASETS\\F3D_facies_DATASET"
self.seises = np.load(join(f3d_file_path, "val/seismic.npy"))
self.labels = np.load(join(f3d_file_path, "val/labels.npy"))
if mode == 'train':
m1, m2, m3 = 590, 782, 1006
elif mode == 'val':
m1, m2, m3 = 601, 200, 255
elif mode == 'test_1':
m1, m2, m3 = 841, 334, 1006
elif mode == 'test_2':
m1, m2, m3 = 251, 782, 1006
else:
raise ValueError("Mode must be 'train', 'test_2', 'val', or 'test_1'.")
self.size_categories = list(permutations([m1, m2, m3], 2))
self.seises = self.seises.reshape(m1, m2, m3)
self.labels = self.labels.reshape(m1, m2, m3)
def __len__(self):
# We will take cross sections along the first 2 dimensions ONLY
return self.seises.shape[0] + self.seises.shape[1]
def _random_transform(self, label, seis):
# Apply the same random transformations to the fault and seismic data
# Mirror the patch horizontally
if random.random() > 0.5:
label = np.fliplr(label)
seis = np.fliplr(seis)
# Mirror the patch vertically
if random.random() > 0.5:
label = np.flipud(label)
seis = np.flipud(seis)
# random rotation to 2D image label,seis
#r_int = random.randint(0, 3)
#label = np.rot90(label, r_int)
#seis = np.rot90(seis, r_int)
return label, seis
def _random_resize(self, label, seis, min_size = (256, 256)):
# Randomly resize the label and seismic data
r_height = random.randint(min_size[0], seis.shape[0])
r_width = random.randint(min_size[1], seis.shape[1])
r_pos_x = random.randint(0, seis.shape[0] - r_height)
r_pos_y = random.randint(0, seis.shape[1] - r_width)
label = label[r_pos_x:r_pos_x + r_height, r_pos_y:r_pos_y + r_width]
seis = seis[r_pos_x:r_pos_x + r_height, r_pos_y:r_pos_y + r_width]
return label, seis
def __getitem__(self, index):
m1, m2, m3 = self.seises.shape
if index < m1:
seis, label = self.seises[index, :, :], self.labels[index, :, :]
elif index < m1 + m2:
seis, label = self.seises[:, index - m1, :], self.labels[:, index - m1, :]
elif index < m1 + m2 + m3:
seis, label = self.seises[:, :, index - m1 - m2], self.labels[:, :, index - m1 - m2]
else:
raise IndexError("Index out of bounds")
seis, label = self._random_transform(seis, label)
if self.random_resize: seis, label = self._random_resize(seis, label)
seis = (seis - seis.mean()) / (seis.std() + 1e-6)
seis, label = torch.tensor(seis.copy(), dtype=torch.float32).unsqueeze(0), torch.tensor(label.copy(), dtype=torch.float32).unsqueeze(0)
# label is now shape [1, H, W]
# we want shape [6, H, W] with each slice being a binary mask depending on the int value of label
label = label.squeeze(0)
label = (label == torch.arange(1, 7, device=label.device).view(6, 1, 1)).float()
return seis, label