Update trjgru.py
Browse files
    	
        trjgru.py
    CHANGED
    
    | 
         @@ -285,16 +285,25 @@ import h5py 
     | 
|
| 285 | 
         | 
| 286 | 
         
             
            model = build_combined_model()  # Your original model building function
         
     | 
| 287 | 
         
             
            # Rebuild the model architecture
         
     | 
| 288 | 
         
            -
            model  
     | 
| 
         | 
|
| 289 | 
         | 
| 290 | 
         
            -
            #  
     | 
| 291 | 
         
            -
             
     | 
| 292 | 
         
            -
            _ = model(dummy_input)  # Forward pass to build all layers
         
     | 
| 293 | 
         | 
| 294 | 
         
            -
             
     | 
| 295 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 296 | 
         | 
| 297 | 
         
            -
            model.load_weights(r"Trj_GRU.weights.h5")
         
     | 
| 298 | 
         | 
| 299 | 
         | 
| 300 | 
         
             
            def predict_trajgru(reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test):
         
     | 
| 
         | 
|
| 285 | 
         | 
| 286 | 
         
             
            model = build_combined_model()  # Your original model building function
         
     | 
| 287 | 
         
             
            # Rebuild the model architecture
         
     | 
| 288 | 
         
            +
            # Step 1: Build the full combined model (with 6 inputs)
         
     | 
| 289 | 
         
            +
            # model = build_combined_model()
         
     | 
| 290 | 
         | 
| 291 | 
         
            +
            # Step 2: Call the model once with dummy data to build the weights
         
     | 
| 292 | 
         
            +
            # import tensorflow as tf
         
     | 
| 
         | 
|
| 293 | 
         | 
| 294 | 
         
            +
            dummy_input = [
         
     | 
| 295 | 
         
            +
                tf.random.normal((1, 8, 95, 95, 2)),   # reduced_images_test
         
     | 
| 296 | 
         
            +
                tf.random.normal((1, 95, 95, 8)),      # hov_m_test
         
     | 
| 297 | 
         
            +
                tf.random.normal((1, 8, 8, 1)),        # test_vmax_3d
         
     | 
| 298 | 
         
            +
                tf.random.normal((1, 8)),              # lat_test
         
     | 
| 299 | 
         
            +
                tf.random.normal((1, 8)),              # lon_test
         
     | 
| 300 | 
         
            +
                tf.random.normal((1, 9)),              # other_scalar_inputs
         
     | 
| 301 | 
         
            +
            ]
         
     | 
| 302 | 
         
            +
            _ = model(dummy_input)  # Build model by doing one forward pass
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
            # Step 3: Load weights
         
     | 
| 305 | 
         
            +
            model.load_weights("Trj_GRU.weights.h5")  # Make sure this matches the architecture
         
     | 
| 306 | 
         | 
| 
         | 
|
| 307 | 
         | 
| 308 | 
         | 
| 309 | 
         
             
            def predict_trajgru(reduced_images_test,hov_m_test,test_vmax_3d,lat_test,lon_test,int_diff_test):
         
     |