Wonder-Griffin commited on
Commit
5376ec3
·
verified ·
1 Parent(s): c546054

Create modeling_storm_oracle.py

Browse files
Files changed (1) hide show
  1. modeling_storm_oracle.py +24 -0
modeling_storm_oracle.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ from .configuration_storm_oracle import StormOracleConfig
5
+
6
+ # ---- import your actual model code ----
7
+ # If your code lives in tornado_predictor.py (as pasted), import from there:
8
+ from .tornado_predictor import TornadoSuperPredictor # adjust if filename differs
9
+
10
+ class StormOracleModel(PreTrainedModel):
11
+ config_class = StormOracleConfig
12
+
13
+ def __init__(self, config: StormOracleConfig):
14
+ super().__init__(config)
15
+ self.model = TornadoSuperPredictor(in_channels=config.in_channels)
16
+ self.post_init() # HF bookkeeping
17
+
18
+ def forward(self, radar_x: torch.Tensor, atmo: dict):
19
+ """
20
+ radar_x: (B, C, H, W)
21
+ atmo: dict of tensors (cape, wind_shear, helicity, temperature, dewpoint, pressure)
22
+ returns TornadoPredictionBatch (your dataclass)
23
+ """
24
+ return self.model(radar_x, atmo)