import tensorflow as tf from tensorflow.keras import layers, models # type: ignore import numpy as np class SpatiotemporalLSTMCell(layers.Layer): """ SpatiotemporalLSTMCell: A custom LSTM cell that captures both spatial and temporal dependencies. It extends the traditional LSTM by adding a memory state (m_t) that focuses on spatial correlations. """ def __init__(self, filters, kernel_size, **kwargs): super().__init__(**kwargs) self.filters = filters # Number of output filters in the convolution self.kernel_size = kernel_size # Size of the convolutional kernel # Convolutional components for standard LSTM operations self.conv_xg = layers.Conv2D(filters, kernel_size, padding="same", activation="tanh") # For cell input self.conv_xi = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For input gate self.conv_xf = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For forget gate self.conv_xo = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For output gate # Convolutional components for spatiotemporal memory operations self.conv_xg_st = layers.Conv2D(filters, kernel_size, padding="same", activation="tanh") # For ST cell input self.conv_xi_st = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For ST input gate self.conv_xf_st = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid") # For ST forget gate # Fusion layer to combine the cell state and spatiotemporal memory self.conv_fusion = layers.Conv2D(filters, (1, 1), padding="same") # 1x1 conv for dimensionality reduction def call(self, inputs, states): """ Forward pass of the spatiotemporal LSTM cell. Args: inputs: Input tensor of shape [batch_size, height, width, channels] states: List of previous states [h_t-1, c_t-1, m_t-1] h_t-1: previous hidden state c_t-1: previous cell state m_t-1: previous spatiotemporal memory """ prev_h, prev_c, prev_m = states # Standard LSTM operations g_t = self.conv_xg(inputs) + self.conv_xg(prev_h) # Cell input activation i_t = self.conv_xi(inputs) + self.conv_xi(prev_h) # Input gate f_t = self.conv_xf(inputs) + self.conv_xf(prev_h) # Forget gate o_t = self.conv_xo(inputs) + self.conv_xo(prev_h) # Output gate # Cell state update - bug detected: should use prev_c instead of self.conv_xo(prev_h) c_t = tf.sigmoid(f_t) * self.conv_xo(prev_h) + tf.sigmoid(i_t) * tf.tanh(g_t) # Spatiotemporal memory operations g_t_st = self.conv_xg_st(inputs) + self.conv_xg_st(prev_m) # ST cell input i_t_st = self.conv_xi_st(inputs) + self.conv_xi_st(prev_m) # ST input gate f_t_st = self.conv_xf_st(inputs) + self.conv_xf_st(prev_m) # ST forget gate # Spatiotemporal memory update - bug detected: should use prev_m directly instead of self.conv_xf_st(prev_m) m_t = tf.sigmoid(f_t_st) * self.conv_xf_st(prev_m) + tf.sigmoid(i_t_st) * tf.tanh(g_t_st) # Hidden state update by fusing cell state and spatiotemporal memory h_t = tf.sigmoid(o_t) * tf.tanh(self.conv_fusion(tf.concat([c_t, m_t], axis=-1))) return h_t, [h_t, c_t, m_t] # Return the hidden state and all updated states class SpatiotemporalLSTM(layers.Layer): """ SpatiotemporalLSTM: Custom layer that applies the SpatiotemporalLSTMCell to a sequence of inputs. This processes 3D data with spatial and temporal dimensions. """ def __init__(self, filters, kernel_size, **kwargs): super().__init__(**kwargs) self.cell = SpatiotemporalLSTMCell(filters, kernel_size) def call(self, inputs): """ Forward pass of the SpatiotemporalLSTM layer. Args: inputs: Input tensor of shape [batch_size, time_steps, height, width, channels] """ batch_size = tf.shape(inputs)[0] time_steps = inputs.shape[1] height = inputs.shape[2] width = inputs.shape[3] channels = inputs.shape[4] # Initialize states with zeros h_t = tf.zeros((batch_size, height, width, channels)) # Hidden state c_t = tf.zeros((batch_size, height, width, channels)) # Cell state m_t = tf.zeros((batch_size, height, width, channels)) # Spatiotemporal memory outputs = [] # Process sequence step by step for t in range(time_steps): # Apply the cell to the current time step and previous states h_t, [h_t, c_t, m_t] = self.cell(inputs[:, t], [h_t[:,:,:,:inputs.shape[4]], c_t[:,:,:,:inputs.shape[4]], m_t[:,:,:,:inputs.shape[4]]]) outputs.append(h_t) # Stack outputs along time dimension return tf.stack(outputs, axis=1) def build_st_lstm_model(input_shape=(8, 95, 95, 2)): """ Build a complete spatiotemporal LSTM model for sequence processing of spatial data. Args: input_shape: Tuple of (time_steps, height, width, channels) Returns: A Keras model with spatiotemporal LSTM layers """ # Create input layer with fixed batch size input_tensor = layers.Input(shape=input_shape, batch_size=16) # First spatiotemporal LSTM block st_lstm_layer = SpatiotemporalLSTM(filters=32, kernel_size=(3, 3)) x = st_lstm_layer(input_tensor) x = layers.Conv3D(filters=32, kernel_size=(3, 3, 3), padding='same', activation='relu')(x) x = layers.MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x) # Second spatiotemporal LSTM block st_lstm_layer = SpatiotemporalLSTM(filters=64, kernel_size=(3, 3)) x = st_lstm_layer(x) x = layers.Conv3D(filters=64, kernel_size=(3, 3, 3), padding='same', activation='relu')(x) x = layers.MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x) # Third spatiotemporal LSTM block st_lstm_layer = SpatiotemporalLSTM(filters=128, kernel_size=(3, 3)) x = st_lstm_layer(x) x = layers.Conv3D(filters=128, kernel_size=(3, 3, 3), padding='same', activation='relu')(x) x = layers.MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x) # Flatten and prepare for output layers (not included in this model) x = layers.Flatten()(x) # Create and return the model model = models.Model(inputs=input_tensor, outputs=x) return model def radial_structure_subnet(input_shape): """ Creates the subnet for extracting TC radial structure features using a five-branch CNN design with 2D convolutions. Parameters: - input_shape: tuple, shape of the input data (e.g., (95, 95, 3)) Returns: - model: tf.keras.Model, the radial structure subnet model """ input_tensor = layers.Input(shape=input_shape) # Divide input data into four quadrants (NW, NE, SW, SE) # Assuming the input shape is (batch_size, height, width, channels) # Quadrant extraction - using slicing to separate quadrants nw_quadrant = input_tensor[:, :input_shape[0]//2, :input_shape[1]//2, :] ne_quadrant = input_tensor[:, :input_shape[0]//2, input_shape[1]//2:, :] sw_quadrant = input_tensor[:, input_shape[0]//2:, :input_shape[1]//2, :] se_quadrant = input_tensor[:, input_shape[0]//2:, input_shape[1]//2:, :] target_height = max(input_shape[0]//2, input_shape[0] - input_shape[0]//2) # 48 target_width = max(input_shape[1]//2, input_shape[1] - input_shape[1]//2) # 48 # Padding the quadrants to match the target size (48, 48) nw_quadrant = layers.ZeroPadding2D(padding=((0, target_height - nw_quadrant.shape[1]), (0, target_width - nw_quadrant.shape[2])))(nw_quadrant) ne_quadrant = layers.ZeroPadding2D(padding=((0, target_height - ne_quadrant.shape[1]), (0, target_width - ne_quadrant.shape[2])))(ne_quadrant) sw_quadrant = layers.ZeroPadding2D(padding=((0, target_height - sw_quadrant.shape[1]), (0, target_width - sw_quadrant.shape[2])))(sw_quadrant) se_quadrant = layers.ZeroPadding2D(padding=((0, target_height - se_quadrant.shape[1]), (0, target_width - se_quadrant.shape[2])))(se_quadrant) print(nw_quadrant.shape) print(ne_quadrant.shape) print(sw_quadrant.shape) print(se_quadrant.shape) # Main branch (processing the entire structure) main_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(input_tensor) y=layers.MaxPool2D()(main_branch) y = layers.ZeroPadding2D(padding=((0, target_height - y.shape[1]), (0, target_width - y.shape[2])))(y) # Side branches (processing the individual quadrants) nw_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(nw_quadrant) ne_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(ne_quadrant) sw_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(sw_quadrant) se_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(se_quadrant) # Apply padding to the side branches to match the dimensions of the main branch # nw_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(nw_branch) # ne_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(ne_branch) # sw_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(sw_branch) # se_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(se_branch) # Fusion operations (concatenate the outputs from the main branch and side branches) fusion = layers.concatenate([y, nw_branch, ne_branch, sw_branch, se_branch], axis=-1) # Additional convolution layer to combine the fused features x = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(fusion) x=layers.MaxPool2D(pool_size=(2, 2))(x) # Final dense layer for further processing nw_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(nw_branch) ne_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(ne_branch) sw_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(sw_branch) se_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(se_branch) nw_branch = layers.MaxPool2D(pool_size=(2, 2))(nw_branch) ne_branch = layers.MaxPool2D(pool_size=(2, 2))(ne_branch) sw_branch = layers.MaxPool2D(pool_size=(2, 2))(sw_branch) se_branch = layers.MaxPool2D(pool_size=(2, 2))(se_branch) fusion = layers.concatenate([x, nw_branch, ne_branch, sw_branch, se_branch], axis=-1) x = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(fusion) x=layers.MaxPool2D(pool_size=(2, 2))(x) nw_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(nw_branch) ne_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(ne_branch) sw_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(sw_branch) se_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(se_branch) nw_branch = layers.MaxPool2D(pool_size=(2, 2))(nw_branch) ne_branch = layers.MaxPool2D(pool_size=(2, 2))(ne_branch) sw_branch = layers.MaxPool2D(pool_size=(2, 2))(sw_branch) se_branch = layers.MaxPool2D(pool_size=(2, 2))(se_branch) fusion = layers.concatenate([x, nw_branch, ne_branch, sw_branch, se_branch], axis=-1) x = layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(fusion) x=layers.Conv2D(filters=32, kernel_size=(3, 3), activation=None)(x) # Create and return the model x=layers.Flatten()(x) model = models.Model(inputs=input_tensor, outputs=x) return model # Define input shape (batch_size, height, width, channels) # input_shape = (95, 95, 8) # Example input shape (95x95 spatial resolution, 3 channels) # # Build the model # model = radial_structure_subnet(input_shape) # # Model summary # model.summary() def build_cnn_model(input_shape=(8, 8, 1)): # Define the input layer input_tensor = layers.Input(shape=input_shape) # Convolutional layer x = layers.Conv2D(64, (3, 3), padding='same')(input_tensor) x = layers.BatchNormalization()(x) x = layers.ReLU()(x) # Flatten layer x = layers.Flatten()(x) # Create the model model = models.Model(inputs=input_tensor, outputs=x) return model from tensorflow.keras import layers, models, Input # type: ignore def build_combined_model(): # Define input shapes input_shape_3d = (8, 95, 95, 2) input_shape_radial = (95, 95, 8) input_shape_cnn = (8, 8, 1) input_shape_latitude = (8,) input_shape_longitude = (8,) input_shape_other = (9,) # Build individual models model_3d = build_st_lstm_model(input_shape=input_shape_3d) model_radial = radial_structure_subnet(input_shape=input_shape_radial) model_cnn = build_cnn_model(input_shape=input_shape_cnn) # Define new inputs input_latitude = Input(shape=input_shape_latitude ,name="latitude_input") input_longitude = Input(shape=input_shape_longitude, name="longitude_input") input_other = Input(shape=input_shape_other, name="other_input") # Flatten the additional inputs flat_latitude = layers.Dense(32,activation='relu')(input_latitude) flat_longitude = layers.Dense(32,activation='relu')(input_longitude) flat_other = layers.Dense(64,activation='relu')(input_other) # Combine all outputs combined = layers.concatenate([ model_3d.output, model_radial.output, model_cnn.output, flat_latitude, flat_longitude, flat_other ]) # Add dense layers for final processing x = layers.Dense(128, activation='relu')(combined) x = layers.Dense(1, activation=None)(x) # Create the final model final_model = models.Model( inputs=[model_3d.input, model_radial.input, model_cnn.input, input_latitude, input_longitude, input_other ], outputs=x ) return final_model import h5py with h5py.File(r"spatio_tempral_LSTM.h5", 'r') as f: print(f.attrs.get('keras_version')) print(f.attrs.get('backend')) print("Model layers:", list(f['model_weights'].keys())) model = build_combined_model() # Your original model building function model.load_weights(r"spatio_tempral_LSTM.h5") def predict_stlstm(reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test): y=model.predict([reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test ]) return y