nanduriprudhvi's picture
Update trjgru.py
8928aba verified
import tensorflow as tf
from tensorflow.keras import layers, models # type: ignore
import numpy as np
class TrajectoryGRU2D(layers.Layer):
def __init__(self, filters, kernel_size, return_sequences=True, **kwargs):
super().__init__(**kwargs)
self.filters = filters
self.kernel_size = kernel_size
self.return_sequences = return_sequences
# Projection layer to match GRU feature space
self.input_projection = layers.Conv2D(filters, (1, 1), padding="same")
# GRU Gates
self.conv_z = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid")
self.conv_r = layers.Conv2D(filters, kernel_size, padding="same", activation="sigmoid")
self.conv_h = layers.Conv2D(filters, kernel_size, padding="same", activation="tanh")
# Motion-based trajectory update
self.motion_conv = layers.Conv2D(filters, kernel_size, padding="same", activation="tanh")
def build(self, input_shape):
# Ensures input_projection is built with the correct input shape
self.input_projection.build(input_shape[1:]) # Ignore batch dimension
super().build(input_shape)
def call(self, inputs):
# inputs shape: (batch_size, time_steps, height, width, channels)
batch_size, time_steps, height, width, channels = tf.unstack(tf.shape(inputs))
time_steps = inputs.shape[1]
# Initialize hidden state
h_t = tf.zeros((batch_size, height, width, self.filters))
# List to store outputs at each time step
outputs = []
# Iterate over time steps
for t in range(time_steps):
# Get the input at time step t
x_t = inputs[:, t, :, :, :]
# Project input to match GRU feature dimension
x_projected = self.input_projection(x_t)
# Compute motion-based trajectory update
motion_update = self.motion_conv(x_projected)
# Concatenate projected input, previous hidden state, and motion update
combined = tf.concat([x_projected, h_t, motion_update], axis=-1)
# Compute GRU gates
z = self.conv_z(combined) # Update gate
r = self.conv_r(combined) # Reset gate
# Compute candidate hidden state
h_tilde = self.conv_h(tf.concat([x_projected, r * h_t], axis=-1))
# Update hidden state with motion-based trajectory
h_t = (1 - z) * h_t + z * h_tilde + motion_update # Add motion update
# Store the output if return_sequences is True
if self.return_sequences:
outputs.append(h_t)
# Stack outputs along the time dimension if return_sequences is True
if self.return_sequences:
outputs = tf.stack(outputs, axis=1)
else:
outputs = h_t
return outputs
def compute_output_shape(self, input_shape):
if self.return_sequences:
return (input_shape[0], input_shape[1], input_shape[2], input_shape[3], self.filters)
else:
return (input_shape[0], input_shape[2], input_shape[3], self.filters)
def build_tgru_model(input_shape=(8, 95, 95, 2)): # (time_steps, height, width, channels)
input_tensor = layers.Input(shape=input_shape)
# Apply TGRU Layers
x = TrajectoryGRU2D(filters=32, kernel_size=(3, 3), return_sequences=True)(input_tensor)
x = layers.Conv3D(filters=32, kernel_size=(3, 3, 3), padding='same', activation='relu')(x)
x = layers.MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
x = TrajectoryGRU2D(filters=64, kernel_size=(3, 3), return_sequences=True)(x)
x = layers.Conv3D(filters=64, kernel_size=(3, 3, 3), padding='same', activation='relu')(x)
x = layers.MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
x = TrajectoryGRU2D(filters=128, kernel_size=(3, 3), return_sequences=True)(x)
x = layers.Conv3D(filters=128, kernel_size=(3, 3, 3), padding='same', activation='relu')(x)
x = layers.MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
# Flatten before Fully Connected Layer
x = layers.Flatten()(x)
# x = layers.Dense(1, activation='sigmoid')(x)
model = models.Model(inputs=input_tensor, outputs=x)
return model
def radial_structure_subnet(input_shape):
"""
Creates the subnet for extracting TC radial structure features using a five-branch CNN design with 2D convolutions.
Parameters:
- input_shape: tuple, shape of the input data (e.g., (95, 95, 3))
Returns:
- model: tf.keras.Model, the radial structure subnet model
"""
input_tensor = layers.Input(shape=input_shape)
# Divide input data into four quadrants (NW, NE, SW, SE)
# Assuming the input shape is (batch_size, height, width, channels)
# Quadrant extraction - using slicing to separate quadrants
nw_quadrant = input_tensor[:, :input_shape[0]//2, :input_shape[1]//2, :]
ne_quadrant = input_tensor[:, :input_shape[0]//2, input_shape[1]//2:, :]
sw_quadrant = input_tensor[:, input_shape[0]//2:, :input_shape[1]//2, :]
se_quadrant = input_tensor[:, input_shape[0]//2:, input_shape[1]//2:, :]
target_height = max(input_shape[0]//2, input_shape[0] - input_shape[0]//2) # 48
target_width = max(input_shape[1]//2, input_shape[1] - input_shape[1]//2) # 48
# Padding the quadrants to match the target size (48, 48)
nw_quadrant = layers.ZeroPadding2D(padding=((0, target_height - nw_quadrant.shape[1]),
(0, target_width - nw_quadrant.shape[2])))(nw_quadrant)
ne_quadrant = layers.ZeroPadding2D(padding=((0, target_height - ne_quadrant.shape[1]),
(0, target_width - ne_quadrant.shape[2])))(ne_quadrant)
sw_quadrant = layers.ZeroPadding2D(padding=((0, target_height - sw_quadrant.shape[1]),
(0, target_width - sw_quadrant.shape[2])))(sw_quadrant)
se_quadrant = layers.ZeroPadding2D(padding=((0, target_height - se_quadrant.shape[1]),
(0, target_width - se_quadrant.shape[2])))(se_quadrant)
print(nw_quadrant.shape)
print(ne_quadrant.shape)
print(sw_quadrant.shape)
print(se_quadrant.shape)
# Main branch (processing the entire structure)
main_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(input_tensor)
y=layers.MaxPool2D()(main_branch)
y = layers.ZeroPadding2D(padding=((0, target_height - y.shape[1]),
(0, target_width - y.shape[2])))(y)
# Side branches (processing the individual quadrants)
nw_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(nw_quadrant)
ne_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(ne_quadrant)
sw_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(sw_quadrant)
se_branch = layers.Conv2D(filters=8, kernel_size=(3, 3), padding='same', activation='relu')(se_quadrant)
# Apply padding to the side branches to match the dimensions of the main branch
# nw_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(nw_branch)
# ne_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(ne_branch)
# sw_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(sw_branch)
# se_branch = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(se_branch)
# Fusion operations (concatenate the outputs from the main branch and side branches)
fusion = layers.concatenate([y, nw_branch, ne_branch, sw_branch, se_branch], axis=-1)
# Additional convolution layer to combine the fused features
x = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(fusion)
x=layers.MaxPool2D(pool_size=(2, 2))(x)
# Final dense layer for further processing
nw_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(nw_branch)
ne_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(ne_branch)
sw_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(sw_branch)
se_branch = layers.Conv2D(filters=16, kernel_size=(3, 3), padding='same', activation='relu')(se_branch)
nw_branch = layers.MaxPool2D(pool_size=(2, 2))(nw_branch)
ne_branch = layers.MaxPool2D(pool_size=(2, 2))(ne_branch)
sw_branch = layers.MaxPool2D(pool_size=(2, 2))(sw_branch)
se_branch = layers.MaxPool2D(pool_size=(2, 2))(se_branch)
fusion = layers.concatenate([x, nw_branch, ne_branch, sw_branch, se_branch], axis=-1)
x = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(fusion)
x=layers.MaxPool2D(pool_size=(2, 2))(x)
nw_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(nw_branch)
ne_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(ne_branch)
sw_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(sw_branch)
se_branch = layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', activation='relu')(se_branch)
nw_branch = layers.MaxPool2D(pool_size=(2, 2))(nw_branch)
ne_branch = layers.MaxPool2D(pool_size=(2, 2))(ne_branch)
sw_branch = layers.MaxPool2D(pool_size=(2, 2))(sw_branch)
se_branch = layers.MaxPool2D(pool_size=(2, 2))(se_branch)
fusion = layers.concatenate([x, nw_branch, ne_branch, sw_branch, se_branch], axis=-1)
x = layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(fusion)
x=layers.Conv2D(filters=32, kernel_size=(3, 3), activation=None)(x)
# Create and return the model
x=layers.Flatten()(x)
model = models.Model(inputs=input_tensor, outputs=x)
return model
# Define input shape (batch_size, height, width, channels)
# input_shape = (95, 95, 8) # Example input shape (95x95 spatial resolution, 3 channels)
# # Build the model
# model = radial_structure_subnet(input_shape)
# # Model summary
# model.summary()
def build_cnn_model(input_shape=(8, 8, 1)):
# Define the input layer
input_tensor = layers.Input(shape=input_shape)
# Convolutional layer
x = layers.Conv2D(64, (3, 3), padding='same')(input_tensor)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
# Flatten layer
x = layers.Flatten()(x)
# Create the model
model = models.Model(inputs=input_tensor, outputs=x)
return model
from tensorflow.keras import layers, models, Input # type: ignore
def build_combined_model():
# Define input shapes
input_shape_3d = (8, 95, 95, 2)
input_shape_radial = (95, 95, 8)
input_shape_cnn = (8, 8, 1)
input_shape_latitude = (8,)
input_shape_longitude = (8,)
input_shape_other = (9,)
# Build individual models
model_3d = build_tgru_model(input_shape=input_shape_3d)
model_radial = radial_structure_subnet(input_shape=input_shape_radial)
model_cnn = build_cnn_model(input_shape=input_shape_cnn)
# Define new inputs
input_latitude = Input(shape=input_shape_latitude ,name="latitude_input")
input_longitude = Input(shape=input_shape_longitude, name="longitude_input")
input_other = Input(shape=input_shape_other, name="other_input")
# Flatten the additional inputs
flat_latitude = layers.Dense(32,activation='relu')(input_latitude)
flat_longitude = layers.Dense(32,activation='relu')(input_longitude)
flat_other = layers.Dense(64,activation='relu')(input_other)
# Combine all outputs
combined = layers.concatenate([
model_3d.output,
model_radial.output,
model_cnn.output,
flat_latitude,
flat_longitude,
flat_other
])
# Add dense layers for final processing
x = layers.Dense(128, activation='relu')(combined)
x = layers.Dense(1, activation=None)(x)
# Create the final model
final_model = models.Model(
inputs=[model_3d.input, model_radial.input, model_cnn.input,
input_latitude, input_longitude, input_other ],
outputs=x
)
return final_model
import h5py
# with h5py.File(r"Trj_GRU.h5", 'r') as f:
# print(f.attrs.get('keras_version'))
# print(f.attrs.get('backend'))
# print("Model layers:", list(f['model_weights'].keys()))
model = build_combined_model() # Your original model building function
# Rebuild the model architecture
# Step 1: Build the full combined model (with 6 inputs)
# model = build_combined_model()
# Step 2: Call the model once with dummy data to build the weights
# import tensorflow as tf
dummy_input = [
tf.random.normal((1, 8, 95, 95, 2)), # reduced_images_test
tf.random.normal((1, 95, 95, 8)), # hov_m_test
tf.random.normal((1, 8, 8, 1)), # test_vmax_3d
tf.random.normal((1, 8)), # lat_test
tf.random.normal((1, 8)), # lon_test
tf.random.normal((1, 9)), # other_scalar_inputs
]
_ = model(dummy_input) # Build model by doing one forward pass
# Step 3: Load weights
model.load_weights("Trj_GRU.weights.h5") # Make sure this matches the architecture
def predict_trajgru(reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test):
y=model.predict([reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test ])
return y