File size: 15,514 Bytes
d9fdf57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b639d03
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
import tensorflow as tf
from tensorflow.keras import layers, models # type: ignore
import numpy as np


class SpatiotemporalLSTMCell(layers.Layer):
    """
    SpatiotemporalLSTMCell: A custom LSTM cell that captures both spatial and temporal dependencies.
    It extends the traditional LSTM by adding a memory state (m_t) that focuses on spatial correlations.
    """
    def __init__(self, filters, kernel_size, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters  # Number of output filters in the convolution
        self.kernel_size = kernel_size  # Size of the convolutional kernel
        
        # Convolutional components for standard LSTM operations
        self.conv_xg = layers.Conv2D(filters, kernel_size, padding="same", activation="tanh")    # For cell input
        self.conv_xi = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For input gate
        self.conv_xf = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For forget gate
        self.conv_xo = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For output gate
        
        # Convolutional components for spatiotemporal memory operations
        self.conv_xg_st = layers.Conv2D(filters, kernel_size, padding="same", activation="tanh")    # For ST cell input
        self.conv_xi_st = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For ST input gate
        self.conv_xf_st = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For ST forget gate
        
        # Fusion layer to combine the cell state and spatiotemporal memory
        self.conv_fusion = layers.Conv2D(filters, (1, 1), padding="same")  # 1x1 conv for dimensionality reduction
    
    def call(self, inputs, states):
        """
        Forward pass of the spatiotemporal LSTM cell.
        
        Args:
            inputs: Input tensor of shape [batch_size, height, width, channels]
            states: List of previous states [h_t-1, c_t-1, m_t-1]
                   h_t-1: previous hidden state
                   c_t-1: previous cell state
                   m_t-1: previous spatiotemporal memory
        """
        prev_h, prev_c, prev_m = states
        
        # Standard LSTM operations
        g_t = self.conv_xg(inputs) + self.conv_xg(prev_h)  # Cell input activation
        i_t = self.conv_xi(inputs) + self.conv_xi(prev_h)  # Input gate
        f_t = self.conv_xf(inputs) + self.conv_xf(prev_h)  # Forget gate
        o_t = self.conv_xo(inputs) + self.conv_xo(prev_h)  # Output gate
        
        # Cell state update - bug detected: should use prev_c instead of self.conv_xo(prev_h)
        c_t = tf.sigmoid(f_t) * self.conv_xo(prev_h) + tf.sigmoid(i_t) * tf.tanh(g_t)
        
        # Spatiotemporal memory operations
        g_t_st = self.conv_xg_st(inputs) + self.conv_xg_st(prev_m)  # ST cell input
        i_t_st = self.conv_xi_st(inputs) + self.conv_xi_st(prev_m)  # ST input gate
        f_t_st = self.conv_xf_st(inputs) + self.conv_xf_st(prev_m)  # ST forget gate
        
        # Spatiotemporal memory update - bug detected: should use prev_m directly instead of self.conv_xf_st(prev_m)
        m_t = tf.sigmoid(f_t_st) * self.conv_xf_st(prev_m) + tf.sigmoid(i_t_st) * tf.tanh(g_t_st)
        
        # Hidden state update by fusing cell state and spatiotemporal memory
        h_t = tf.sigmoid(o_t) * tf.tanh(self.conv_fusion(tf.concat([c_t, m_t], axis=-1)))
        
        return h_t, [h_t, c_t, m_t]  # Return the hidden state and all updated states

class SpatiotemporalLSTM(layers.Layer):
    """
    SpatiotemporalLSTM: Custom layer that applies the SpatiotemporalLSTMCell to a sequence of inputs.
    This processes 3D data with spatial and temporal dimensions.
    """
    def __init__(self, filters, kernel_size, **kwargs):
        super().__init__(**kwargs)
        self.cell = SpatiotemporalLSTMCell(filters, kernel_size)
    
    def call(self, inputs):
        """
        Forward pass of the SpatiotemporalLSTM layer.
        
        Args:
            inputs: Input tensor of shape [batch_size, time_steps, height, width, channels]
        """
        batch_size = tf.shape(inputs)[0]
        time_steps = inputs.shape[1]
        height = inputs.shape[2]
        width = inputs.shape[3]
        channels = inputs.shape[4]
        
        # Initialize states with zeros
        h_t = tf.zeros((batch_size, height, width, channels))  # Hidden state
        c_t = tf.zeros((batch_size, height, width, channels))  # Cell state
        m_t = tf.zeros((batch_size, height, width, channels))  # Spatiotemporal memory
        
        outputs = []
        # Process sequence step by step
        for t in range(time_steps):
            # Apply the cell to the current time step and previous states
            h_t, [h_t, c_t, m_t] = self.cell(inputs[:, t], [h_t[:,:,:,:inputs.shape[4]], 
                                                          c_t[:,:,:,:inputs.shape[4]], 
                                                          m_t[:,:,:,:inputs.shape[4]]])
            outputs.append(h_t)
        
        # Stack outputs along time dimension
        return tf.stack(outputs, axis=1)

def build_st_lstm_model(input_shape=(8, 95, 95, 2)):
    """
    Build a complete spatiotemporal LSTM model for sequence processing of spatial data.
    
    Args:
        input_shape: Tuple of (time_steps, height, width, channels)
    
    Returns:
        A Keras model with spatiotemporal LSTM layers
    """
    # Create input layer with fixed batch size
    input_tensor = layers.Input(shape=input_shape, batch_size=16)
    
    # First spatiotemporal LSTM block
    st_lstm_layer = SpatiotemporalLSTM(filters=32, kernel_size=(3, 3))
    x = st_lstm_layer(input_tensor)
    x = layers.Conv3D(filters=32, kernel_size=(3, 3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
    
    # Second spatiotemporal LSTM block
    st_lstm_layer = SpatiotemporalLSTM(filters=64, kernel_size=(3, 3))
    x = st_lstm_layer(x)
    x = layers.Conv3D(filters=64, kernel_size=(3, 3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
    
    # Third spatiotemporal LSTM block
    st_lstm_layer = SpatiotemporalLSTM(filters=128, kernel_size=(3, 3))
    x = st_lstm_layer(x)
    x = layers.Conv3D(filters=128, kernel_size=(3, 3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
    
    # Flatten and prepare for output layers (not included in this model)
    x = layers.Flatten()(x)
    
    # Create and return the model
    model = models.Model(inputs=input_tensor, outputs=x)
    return model

def radial_structure_subnet(input_shape):
    """
    Creates the subnet for extracting TC radial structure features using a five-branch CNN design with 2D convolutions.

    Parameters:
    - input_shape: tuple, shape of the input data (e.g., (95, 95, 3))

    Returns:
    - model: tf.keras.Model, the radial structure subnet model
    """
    
    input_tensor = layers.Input(shape=input_shape)

    # Divide input data into four quadrants (NW, NE, SW, SE)
    # Assuming the input shape is (batch_size, height, width, channels)
    
    # Quadrant extraction - using slicing to separate quadrants
    nw_quadrant = input_tensor[:, :input_shape[0]//2, :input_shape[1]//2, :]
    ne_quadrant = input_tensor[:, :input_shape[0]//2, input_shape[1]//2:, :]
    sw_quadrant = input_tensor[:, input_shape[0]//2:, :input_shape[1]//2, :]
    se_quadrant = input_tensor[:, input_shape[0]//2:, input_shape[1]//2:, :]


    target_height = max(input_shape[0]//2, input_shape[0] - input_shape[0]//2)  # 48
    target_width = max(input_shape[1]//2, input_shape[1] - input_shape[1]//2)  # 48
    
    # Padding the quadrants to match the target size (48, 48)
    nw_quadrant = layers.ZeroPadding2D(padding=((0, target_height - nw_quadrant.shape[1]), 
                                                (0, target_width - nw_quadrant.shape[2])))(nw_quadrant)
    ne_quadrant = layers.ZeroPadding2D(padding=((0, target_height - ne_quadrant.shape[1]), 
                                                (0, target_width - ne_quadrant.shape[2])))(ne_quadrant)
    sw_quadrant = layers.ZeroPadding2D(padding=((0, target_height - sw_quadrant.shape[1]), 
                                                (0, target_width - sw_quadrant.shape[2])))(sw_quadrant)
    se_quadrant = layers.ZeroPadding2D(padding=((0, target_height - se_quadrant.shape[1]), 
                                                (0, target_width - se_quadrant.shape[2])))(se_quadrant)

    print(nw_quadrant.shape)
    print(ne_quadrant.shape)
    print(sw_quadrant.shape)
    print(se_quadrant.shape)
    # Main branch (processing the entire structure)
    main_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(input_tensor)
    y=layers.MaxPool2D()(main_branch)

    y = layers.ZeroPadding2D(padding=((0, target_height - y.shape[1]), 
                                   (0, target_width - y.shape[2])))(y)
    # Side branches (processing the individual quadrants)
    nw_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(nw_quadrant)
    ne_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(ne_quadrant)
    sw_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(sw_quadrant)
    se_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(se_quadrant)
    
    # Apply padding to the side branches to match the dimensions of the main branch
    # nw_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(nw_branch)
    # ne_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(ne_branch)
    # sw_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(sw_branch)
    # se_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(se_branch)
    
    # Fusion operations (concatenate the outputs from the main branch and side branches)
    fusion = layers.concatenate([y, nw_branch, ne_branch, sw_branch, se_branch], axis=-1)
    
    # Additional convolution layer to combine the fused features
    x = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(fusion)
    x=layers.MaxPool2D(pool_size=(2, 2))(x)
    # Final dense layer for further processing
    nw_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(nw_branch)
    
    ne_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(ne_branch)
    sw_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(sw_branch)
    se_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(se_branch)
    nw_branch = layers.MaxPool2D(pool_size=(2, 2))(nw_branch)
    ne_branch = layers.MaxPool2D(pool_size=(2, 2))(ne_branch)
    sw_branch = layers.MaxPool2D(pool_size=(2, 2))(sw_branch)
    se_branch = layers.MaxPool2D(pool_size=(2, 2))(se_branch)

    fusion = layers.concatenate([x, nw_branch, ne_branch, sw_branch, se_branch], axis=-1)
    x = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(fusion)
    x=layers.MaxPool2D(pool_size=(2, 2))(x)
    
    nw_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(nw_branch)
    
    ne_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(ne_branch)
    sw_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(sw_branch)
    se_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(se_branch)
    nw_branch = layers.MaxPool2D(pool_size=(2, 2))(nw_branch)
    ne_branch = layers.MaxPool2D(pool_size=(2, 2))(ne_branch)
    sw_branch = layers.MaxPool2D(pool_size=(2, 2))(sw_branch)
    se_branch = layers.MaxPool2D(pool_size=(2, 2))(se_branch)

    fusion = layers.concatenate([x, nw_branch, ne_branch, sw_branch, se_branch], axis=-1)
    x = layers.Conv2D(filters=32, kernel_size=(3, 3),  activation='relu')(fusion)
    x=layers.Conv2D(filters=32, kernel_size=(3, 3), activation=None)(x)
    # Create and return the model
    x=layers.Flatten()(x)
    model = models.Model(inputs=input_tensor, outputs=x)
    return model

# Define input shape (batch_size, height, width, channels)
# input_shape = (95, 95, 8)  # Example input shape (95x95 spatial resolution, 3 channels)

# # Build the model
# model = radial_structure_subnet(input_shape)

# # Model summary
# model.summary()

def build_cnn_model(input_shape=(8, 8, 1)):
    # Define the input layer
    input_tensor = layers.Input(shape=input_shape)
    
    # Convolutional layer
    x = layers.Conv2D(64, (3, 3), padding='same')(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    
    # Flatten layer
    x = layers.Flatten()(x)
    
    # Create the model
    model = models.Model(inputs=input_tensor, outputs=x)
    
    return model

from tensorflow.keras import layers, models, Input # type: ignore

def build_combined_model():
    # Define input shapes
    input_shape_3d = (8, 95, 95, 2)
    input_shape_radial = (95, 95, 8)
    input_shape_cnn = (8, 8, 1)
    
    input_shape_latitude = (8,)
    input_shape_longitude = (8,)
    input_shape_other = (9,)

    # Build individual models
    model_3d = build_st_lstm_model(input_shape=input_shape_3d)
    model_radial = radial_structure_subnet(input_shape=input_shape_radial)
    model_cnn = build_cnn_model(input_shape=input_shape_cnn)

    # Define new inputs
    input_latitude = Input(shape=input_shape_latitude ,name="latitude_input")
    input_longitude = Input(shape=input_shape_longitude, name="longitude_input")
    input_other = Input(shape=input_shape_other, name="other_input")

    # Flatten the additional inputs
    flat_latitude = layers.Dense(32,activation='relu')(input_latitude)
    flat_longitude = layers.Dense(32,activation='relu')(input_longitude)
    flat_other = layers.Dense(64,activation='relu')(input_other)

    # Combine all outputs
    combined = layers.concatenate([
        model_3d.output, 
        model_radial.output, 
        model_cnn.output,
        flat_latitude, 
        flat_longitude, 
        flat_other
    ])

    # Add dense layers for final processing
    x = layers.Dense(128, activation='relu')(combined)  
    x = layers.Dense(1, activation=None)(x)

    # Create the final model
    final_model = models.Model(
        inputs=[model_3d.input, model_radial.input, model_cnn.input,
                input_latitude, input_longitude, input_other ],
        outputs=x
    )

    return final_model

import h5py
with h5py.File(r"spatio_tempral_LSTM.h5", 'r') as f:
    print(f.attrs.get('keras_version'))
    print(f.attrs.get('backend'))
    print("Model layers:", list(f['model_weights'].keys()))

model = build_combined_model()  # Your original model building function
model.load_weights(r"spatio_tempral_LSTM.h5")


def predict_stlstm(reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test):
    y=model.predict([reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test ])
    return y