File size: 3,782 Bytes
1964059 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision.transforms.functional import center_crop
from pathlib import Path
import numpy as np
from pyphoon2.DigitalTyphoonDataset import DigitalTyphoonDataset
class TyphoonDataModule(pl.LightningDataModule):
"""Typhoon Dataset Module using lightning architecture"""
def __init__(
self,
dataroot,
batch_size,
num_workers,
labels,
split_by,
load_data,
dataset_split,
standardize_range,
downsample_size,
cropped,
corruption_ceiling_pct=100,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
data_path = Path(dataroot)
self.images_path = str(data_path / "image") + "/"
self.track_path = str(data_path / "metadata") + "/"
self.metadata_path = str(data_path / "metadata.json")
self.load_data = load_data
self.split_by = split_by
self.labels = labels
self.dataset_split = dataset_split
self.standardize_range = standardize_range
self.downsample_size = downsample_size
self.cropped = cropped
self.corruption_ceiling_pct = corruption_ceiling_pct
def setup(self, stage):
# Load Dataset
dataset = DigitalTyphoonDataset(
str(self.images_path),
str(self.track_path),
str(self.metadata_path),
self.labels,
load_data_into_memory=self.load_data,
filter_func=self.image_filter,
transform_func=self.transform_func,
spectrum="Infrared",
verbose=False,
)
# generator1 = torch.Generator().manual_seed(3)
self.train_set, self.val_set, _ = dataset.random_split(
self.dataset_split, split_by=self.split_by, #generator=generator1
)
def train_dataloader(self):
return DataLoader(
self.train_set,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self):
return DataLoader(
self.val_set,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def image_filter(self, image):
return (
(image.grade() < 6)
and (image.grade() > 2)
and (image.year() != 2023)
and (100.0 <= image.long() <= 180.0)
)
def transform_func(self, image_batch):
"""transform function applied on the images for pre-processing"""
image_batch = np.clip(
image_batch, self.standardize_range[0], self.standardize_range[1]
)
image_batch = (image_batch - self.standardize_range[0]) / (
self.standardize_range[1] - self.standardize_range[0]
)
if self.downsample_size != (512, 512):
image_batch = torch.Tensor(image_batch)
if self.cropped:
image_batch = center_crop(image_batch, (224, 224))
else:
image_batch = torch.reshape(
image_batch, [1, 1, image_batch.size()[0], image_batch.size()[1]]
)
image_batch = nn.functional.interpolate(
image_batch,
size=self.downsample_size,
mode="bilinear",
align_corners=False,
)
image_batch = torch.reshape(
image_batch, [image_batch.size()[2], image_batch.size()[3]]
)
image_batch = image_batch.numpy()
return image_batch
|