File size: 3,782 Bytes
1964059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision.transforms.functional import center_crop

from pathlib import Path
import numpy as np

from pyphoon2.DigitalTyphoonDataset import DigitalTyphoonDataset


class TyphoonDataModule(pl.LightningDataModule):
    """Typhoon Dataset Module using lightning architecture"""
    def __init__(
        self,
        dataroot,
        batch_size,
        num_workers,
        labels,
        split_by,
        load_data,
        dataset_split,
        standardize_range,
        downsample_size,
        cropped,
        corruption_ceiling_pct=100,
    ):
        super().__init__()

        self.batch_size = batch_size
        self.num_workers = num_workers

        data_path = Path(dataroot)
        self.images_path = str(data_path / "image") + "/"
        self.track_path = str(data_path / "metadata") + "/"
        self.metadata_path = str(data_path / "metadata.json")
        self.load_data = load_data
        self.split_by = split_by
        self.labels = labels

        self.dataset_split = dataset_split
        self.standardize_range = standardize_range
        self.downsample_size = downsample_size
        self.cropped = cropped

        self.corruption_ceiling_pct = corruption_ceiling_pct

    def setup(self, stage):
        # Load Dataset
        dataset = DigitalTyphoonDataset(
            str(self.images_path),
            str(self.track_path),
            str(self.metadata_path),
            self.labels,
            load_data_into_memory=self.load_data,
            filter_func=self.image_filter,
            transform_func=self.transform_func,
            spectrum="Infrared",
            verbose=False,
        )
        # generator1 = torch.Generator().manual_seed(3)

        self.train_set, self.val_set, _ = dataset.random_split(
            self.dataset_split, split_by=self.split_by, #generator=generator1
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_set,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_set,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

    def image_filter(self, image):
        return (
            (image.grade() < 6)
            and (image.grade() > 2)
            and (image.year() != 2023)
            and (100.0 <= image.long() <= 180.0)
        )

    def transform_func(self, image_batch):
        """transform function applied on the images for pre-processing"""
        image_batch = np.clip(
            image_batch, self.standardize_range[0], self.standardize_range[1]
        )
        image_batch = (image_batch - self.standardize_range[0]) / (
            self.standardize_range[1] - self.standardize_range[0]
        )
        if self.downsample_size != (512, 512):
            image_batch = torch.Tensor(image_batch)
            if self.cropped:
                image_batch = center_crop(image_batch, (224, 224))
            else:
                image_batch = torch.reshape(
                    image_batch, [1, 1, image_batch.size()[0], image_batch.size()[1]]
                )
                image_batch = nn.functional.interpolate(
                    image_batch,
                    size=self.downsample_size,
                    mode="bilinear",
                    align_corners=False,
                )
                image_batch = torch.reshape(
                    image_batch, [image_batch.size()[2], image_batch.size()[3]]
                )
            image_batch = image_batch.numpy()
        return image_batch