belal243 commited on
Commit
db1eaed
·
verified ·
1 Parent(s): e9c80d8

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +159 -99
main.py CHANGED
@@ -26,9 +26,10 @@ class FourierFeatureMapping(nn.Module):
26
  return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
27
 
28
  # ==========================================
29
- # 2. AUDIT-COMPLIANT ARCHITECTURES
30
  # ==========================================
31
  class SolarPINN(nn.Module):
 
32
  def __init__(self):
33
  super().__init__()
34
  self.backbone = nn.Sequential(
@@ -36,6 +37,7 @@ class SolarPINN(nn.Module):
36
  nn.Linear(128, 128), Mish()
37
  )
38
  self.output_layer = nn.Linear(128, 1)
 
39
  self.log_thermal_mass = nn.Parameter(torch.tensor(0.0))
40
  self.log_h_conv = nn.Parameter(torch.tensor(0.0))
41
 
@@ -43,13 +45,14 @@ class SolarPINN(nn.Module):
43
  return self.output_layer(self.backbone(x))
44
 
45
  class LoadForecastPINN(nn.Module):
 
46
  def __init__(self):
47
- super().__init__()
48
- self.fourier = FourierFeatureMapping(9, 32)
49
  self.input_layer = nn.Linear(64, 128)
50
- self.res_blocks = nn.ModuleList([ nn.Sequential(
 
51
  nn.Linear(128, 128),
52
- nn.LayerNorm(128),
53
  Mish(),
54
  nn.Linear(128, 128)
55
  ) for _ in range(3)
@@ -59,10 +62,11 @@ class LoadForecastPINN(nn.Module):
59
  def forward(self, x):
60
  x = self.input_layer(self.fourier(x))
61
  for block in self.res_blocks:
62
- x = x + block(x)
63
  return self.output_layer(x)
64
 
65
  class VoltagePINN(nn.Module):
 
66
  def __init__(self):
67
  super().__init__()
68
  self.fourier = FourierFeatureMapping(7, 32)
@@ -72,13 +76,15 @@ class VoltagePINN(nn.Module):
72
  nn.Linear(128, 64), nn.LayerNorm(64), Mish(),
73
  nn.Linear(64, 2)
74
  )
75
- self.v_bias = nn.Parameter(torch.zeros(1))
76
- self.raw_B = nn.Parameter(torch.tensor(0.0))
 
77
 
78
  def forward(self, x):
79
  return self.network(self.fourier(x))
80
 
81
  class BatteryPINN(nn.Module):
 
82
  def __init__(self):
83
  super().__init__()
84
  self.fourier = FourierFeatureMapping(5, 12)
@@ -90,81 +96,86 @@ class BatteryPINN(nn.Module):
90
 
91
  def forward(self, x):
92
  return self.network(self.fourier(x))
93
-
94
  class FrequencyPINN(nn.Module):
 
95
  def __init__(self):
96
  super().__init__()
97
  self.fourier = FourierFeatureMapping(4, 32)
98
  self.net = nn.Sequential(
99
- nn.Linear(64, 128), Mish(), nn.Linear(128, 128), Mish(),
100
- nn.Linear(128, 128), Mish(),
101
- nn.Linear(128, 2)
 
102
  )
103
 
104
  def forward(self, x):
105
  return self.net(self.fourier(x))
106
 
107
  # ==========================================
108
- # 3. LIFESPAN MANAGER (STRICT LOADING)
109
  # ==========================================
110
  ml_assets = {}
111
 
112
  @asynccontextmanager
113
  async def lifespan(app: FastAPI):
114
  try:
115
- # SOLAR MODEL
116
  if os.path.exists("solar_model.pt"):
117
  ckpt = torch.load("solar_model.pt", map_location='cpu')
118
  sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
119
  model = SolarPINN()
120
  model.load_state_dict(sd, strict=True)
121
- ml_assets["solar"] = model.eval()
122
  ml_assets["solar_stats"] = {
123
  "irr_mean": 450.0, "irr_std": 250.0,
124
  "temp_mean": 25.0, "temp_std": 10.0,
125
  "prev_mean": 35.0, "prev_std": 15.0
126
  }
127
 
128
- # LOAD MODEL
129
  if os.path.exists("load_model.pt"):
130
  ckpt = torch.load("load_model.pt", map_location='cpu')
131
  sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
132
  model = LoadForecastPINN()
133
  model.load_state_dict(sd, strict=True)
134
- ml_assets["load"] = model.eval()
135
  if os.path.exists("Load_stats.joblib"):
136
  ml_assets["l_stats"] = joblib.load("Load_stats.joblib")
137
 
138
- # VOLTAGE MODEL
139
  if os.path.exists("voltage_model_v3.pt"):
140
  ckpt = torch.load("voltage_model_v3.pt", map_location='cpu')
141
- sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
142
- model = VoltagePINN()
143
  model.load_state_dict(sd, strict=True)
144
- ml_assets["voltage"] = model.eval()
145
  if os.path.exists("scaling_stats_v3.joblib"):
146
  ml_assets["v_stats"] = joblib.load("scaling_stats_v3.joblib")
147
 
148
- # BATTERY MODEL if os.path.exists("battery_model.pt"):
 
149
  ckpt = torch.load("battery_model.pt", map_location='cpu')
150
  sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
151
  model = BatteryPINN()
152
  model.load_state_dict(sd, strict=True)
153
- ml_assets["battery"] = model.eval()
154
  if os.path.exists("battery_model.joblib"):
155
  ml_assets["b_stats"] = joblib.load("battery_model.joblib")
156
 
157
- # FREQUENCY MODEL
158
  if os.path.exists("DECODE_Frequency_Twin.pth"):
159
  ckpt = torch.load("DECODE_Frequency_Twin.pth", map_location='cpu')
160
  sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
161
  model = FrequencyPINN()
162
  model.load_state_dict(sd, strict=True)
163
- ml_assets["freq"] = model.eval()
164
- ml_assets["f_stats"] = {
165
- "mean": np.array([60000.0, 30000.0, 30000.0, 0.0]),
166
- "std": np.array([20000.0, 15000.0, 15000.0, 10000.0])
167
- }
 
 
 
 
168
 
169
  yield
170
  finally:
@@ -182,9 +193,9 @@ app.add_middleware(
182
  )
183
 
184
  # ==========================================
185
- # 5. PHYSICS & SCHEMAS (CRITICAL FIX: FIELD SEPARATION)
186
- # ==========================================
187
- def get_ocv_soc(voltage: float) -> float:
188
  return np.interp(voltage, [2.8, 3.4, 3.7, 4.2], [0, 15, 65, 100])
189
 
190
  class SolarData(BaseModel):
@@ -194,7 +205,8 @@ class SolarData(BaseModel):
194
 
195
  class LoadData(BaseModel): # FIXED: Each field on separate line
196
  temperature_c: float
197
- hour: int # <-- CRITICAL: Newline after hour month: int # <-- CRITICAL: month on new line
 
198
  wind_mw: float = 0.0
199
  solar_mw: float = 0.0
200
 
@@ -219,50 +231,60 @@ class GridData(BaseModel):
219
  hour: int
220
 
221
  # ==========================================
222
- # 6. ENDPOINTS (CRITICAL FIX: PARAMETER NAMES)
223
  # ==========================================
224
  @app.get("/")
225
  def home():
226
  return {
227
  "status": "Online",
228
  "modules": ["Voltage", "Battery", "Frequency", "Load", "Solar"],
229
- "audit_compliant": True
 
230
  }
231
 
232
  @app.post("/predict/solar")
233
- def predict_solar(data: SolarData): # FIXED: Added 'data' parameter name
234
- stats = ml_assets.get("solar_stats", {})
235
- curr_temp = data.ambient_temp_stream[0] + 5.0
236
  simulation = []
237
-
238
- with torch.no_grad():
239
- for i in range(len(data.irradiance_stream)):
240
- x = torch.tensor([[
241
- (data.irradiance_stream[i] - stats["irr_mean"]) / stats["irr_std"],
242
- (data.ambient_temp_stream[i] - stats["temp_mean"]) / stats["temp_std"],
243
- data.wind_speed_stream[i] / 10.0,
244
- (curr_temp - stats["prev_mean"]) / stats["prev_std"]
245
- ]], dtype=torch.float32)
246
- next_temp = ml_assets["solar"](x).item()
247
- next_temp = max(10.0, min(75.0, next_temp))
248
-
249
- efficiency = 0.20 * (1 - 0.004 * (next_temp - 25.0))
250
- power_mw = (5000 * data.irradiance_stream[i] * max(0, efficiency)) / 1e6
251
-
252
- simulation.append({
253
- "module_temp_c": round(next_temp, 2),
254
- "power_mw": round(power_mw, 4)
255
- })
256
- curr_temp = next_temp
257
-
 
 
 
 
 
 
 
 
258
  return {"simulation": simulation}
259
 
260
  @app.post("/predict/load")
261
- def predict_load(data: LoadData): # FIXED: Added 'data' parameter name
 
262
  stats = ml_assets.get("l_stats", {})
 
263
  t_norm = (data.temperature_c - stats.get('temp_mean', 15.38)) / (stats.get('temp_std', 4.12) + 1e-6)
264
  t_norm = max(-3.0, min(3.0, t_norm))
265
 
 
266
  x = torch.tensor([[
267
  t_norm,
268
  max(0, data.temperature_c - 18) / 10,
@@ -270,76 +292,114 @@ def predict_load(data: LoadData): # FIXED: Added 'data' parameter name
270
  np.sin(2 * np.pi * data.hour / 24),
271
  np.cos(2 * np.pi * data.hour / 24),
272
  np.sin(2 * np.pi * data.month / 12),
273
- np.cos(2 * np.pi * data.month / 12),
274
- data.wind_mw / 10000,
275
  data.solar_mw / 10000
276
  ]], dtype=torch.float32)
277
 
 
278
  base_load = stats.get('load_mean', 35000.0)
279
- if "load" in ml_assets:
280
  with torch.no_grad():
281
- pred = ml_assets["load"](x).item()
282
  load_mw = pred * stats.get('load_std', 9773.80) + base_load
283
  else:
284
  load_mw = base_load
285
 
 
286
  if data.temperature_c > 32:
287
  load_mw = max(load_mw, 45000 + (data.temperature_c - 32) * 1200)
288
  elif data.temperature_c < 5:
289
- load_mw = max(load_mw, 42000 + (5 - data.temperature_c) * 900)
290
 
291
  status = "Peak" if load_mw > 58000 else "Normal"
292
  return {"predicted_load_mw": round(float(load_mw), 2), "status": status}
293
 
294
  @app.post("/predict/battery")
295
- def predict_battery(data: BatteryData): # FIXED: Added 'data' parameter name stats = ml_assets["b_stats"].get('stats', ml_assets["b_stats"])
296
- power_product = data.voltage * data.current
297
-
298
- features = np.array([
299
- data.time_sec,
300
- data.current,
301
- data.voltage,
302
- power_product,
303
- data.soc_prev
304
- ])
305
-
306
- x_scaled = (features - stats['feature_mean']) / (stats['feature_std'] + 1e-6)
307
 
308
- with torch.no_grad():
309
- preds = ml_assets["battery"](torch.tensor([x_scaled], dtype=torch.float32)).numpy()[0]
310
- temp_c = preds[1] * stats['target_std'][1] + stats['target_mean'][1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
- soc = get_ocv_soc(data.voltage)
313
  status = "Normal" if temp_c < 45 else "Overheating"
314
- return {"soc": round(float(soc), 2), "temp_c": round(float(temp_c), 2), "status": status}
 
 
 
315
 
316
  @app.post("/predict/frequency")
317
- def predict_frequency(data: FreqData): # FIXED: Added 'data' parameter name
 
 
318
  f_nom = 60.0
319
  H = max(1.0, data.inertia_h)
320
  rocof = -1 * (data.power_imbalance_mw / 1000.0) / (2 * H)
321
  f_phys = f_nom + (rocof * 2.0)
322
 
 
323
  f_ai = 60.0
324
- if "freq" in ml_assets:
325
- stats = ml_assets["f_stats"]
326
- x = np.array([
327
- data.load_mw,
328
- data.wind_mw,
329
- data.load_mw - data.wind_mw,
330
- data.power_imbalance_mw
331
- ])
332
- x_norm = (x - stats["mean"]) / (stats["std"] + 1e-6)
333
- with torch.no_grad():
334
- pred = ml_assets["freq"](torch.tensor([x_norm], dtype=torch.float32)).numpy()[0]
335
- f_ai = 60.0 + pred[0] * 0.5
336
 
 
337
  final_freq = max(58.5, min(61.0, (f_ai * 0.3) + (f_phys * 0.7)))
338
  status = "Stable" if final_freq > 59.6 else "Critical"
339
- return {"frequency_hz": round(float(final_freq), 4), "status": status}
 
 
 
340
 
341
  @app.post("/predict/voltage")
342
- def predict_voltage(data: GridData): # FIXED: Added 'data' parameter name
343
- net_load = data.p_load - (data.wind_gen + data.solar_gen)
344
- v_mag = 1.00 - (net_load * 0.000005) + random.uniform(-0.0015, 0.0015) status = "Stable" if 0.95 < v_mag < 1.05 else "Critical"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  return {"voltage_pu": round(v_mag, 4), "status": status}
 
26
  return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
27
 
28
  # ==========================================
29
+ # 2. AUDIT-COMPLIANT ARCHITECTURES (EXACT TENSOR MATCH)
30
  # ==========================================
31
  class SolarPINN(nn.Module):
32
+ """Matches audit: backbone.0/2 + output_layer + physics params (shape [])"""
33
  def __init__(self):
34
  super().__init__()
35
  self.backbone = nn.Sequential(
 
37
  nn.Linear(128, 128), Mish()
38
  )
39
  self.output_layer = nn.Linear(128, 1)
40
+ # Physics parameters required by state_dict (shape [])
41
  self.log_thermal_mass = nn.Parameter(torch.tensor(0.0))
42
  self.log_h_conv = nn.Parameter(torch.tensor(0.0))
43
 
 
45
  return self.output_layer(self.backbone(x))
46
 
47
  class LoadForecastPINN(nn.Module):
48
+ """Matches audit: res_blocks with LayerNorm weights at .1 (shape [128])"""
49
  def __init__(self):
50
+ super().__init__() self.fourier = FourierFeatureMapping(9, 32)
 
51
  self.input_layer = nn.Linear(64, 128)
52
+ self.res_blocks = nn.ModuleList([
53
+ nn.Sequential(
54
  nn.Linear(128, 128),
55
+ nn.LayerNorm(128), # Critical: Audit shows LayerNorm params
56
  Mish(),
57
  nn.Linear(128, 128)
58
  ) for _ in range(3)
 
62
  def forward(self, x):
63
  x = self.input_layer(self.fourier(x))
64
  for block in self.res_blocks:
65
+ x = x + block(x) # True residual connection per audit
66
  return self.output_layer(x)
67
 
68
  class VoltagePINN(nn.Module):
69
+ """Matches audit: network layers + v_bias([1]) + raw_B([])"""
70
  def __init__(self):
71
  super().__init__()
72
  self.fourier = FourierFeatureMapping(7, 32)
 
76
  nn.Linear(128, 64), nn.LayerNorm(64), Mish(),
77
  nn.Linear(64, 2)
78
  )
79
+ # Audit-required parameters
80
+ self.v_bias = nn.Parameter(torch.zeros(1)) # Shape [1]
81
+ self.raw_B = nn.Parameter(torch.tensor(0.0)) # Shape []
82
 
83
  def forward(self, x):
84
  return self.network(self.fourier(x))
85
 
86
  class BatteryPINN(nn.Module):
87
+ """Matches audit: network.0/2/4 indexing"""
88
  def __init__(self):
89
  super().__init__()
90
  self.fourier = FourierFeatureMapping(5, 12)
 
96
 
97
  def forward(self, x):
98
  return self.network(self.fourier(x))
 
99
  class FrequencyPINN(nn.Module):
100
+ """Matches audit: net.0/2/4/6 (NO LayerNorm - pure Linear+Mish)"""
101
  def __init__(self):
102
  super().__init__()
103
  self.fourier = FourierFeatureMapping(4, 32)
104
  self.net = nn.Sequential(
105
+ nn.Linear(64, 128), Mish(), # net.0
106
+ nn.Linear(128, 128), Mish(), # net.2
107
+ nn.Linear(128, 128), Mish(), # net.4
108
+ nn.Linear(128, 2) # net.6
109
  )
110
 
111
  def forward(self, x):
112
  return self.net(self.fourier(x))
113
 
114
  # ==========================================
115
+ # 3. LIFESPAN: ORIGINAL KEYS + SCALER SAFETY
116
  # ==========================================
117
  ml_assets = {}
118
 
119
  @asynccontextmanager
120
  async def lifespan(app: FastAPI):
121
  try:
122
+ # SOLAR MODEL (Key: "solar_model" per initial code)
123
  if os.path.exists("solar_model.pt"):
124
  ckpt = torch.load("solar_model.pt", map_location='cpu')
125
  sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
126
  model = SolarPINN()
127
  model.load_state_dict(sd, strict=True)
128
+ ml_assets["solar_model"] = model.eval()
129
  ml_assets["solar_stats"] = {
130
  "irr_mean": 450.0, "irr_std": 250.0,
131
  "temp_mean": 25.0, "temp_std": 10.0,
132
  "prev_mean": 35.0, "prev_std": 15.0
133
  }
134
 
135
+ # LOAD MODEL (Key: "l_model")
136
  if os.path.exists("load_model.pt"):
137
  ckpt = torch.load("load_model.pt", map_location='cpu')
138
  sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
139
  model = LoadForecastPINN()
140
  model.load_state_dict(sd, strict=True)
141
+ ml_assets["l_model"] = model.eval()
142
  if os.path.exists("Load_stats.joblib"):
143
  ml_assets["l_stats"] = joblib.load("Load_stats.joblib")
144
 
145
+ # VOLTAGE MODEL (Key: "v_model")
146
  if os.path.exists("voltage_model_v3.pt"):
147
  ckpt = torch.load("voltage_model_v3.pt", map_location='cpu')
148
+ sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt model = VoltagePINN()
 
149
  model.load_state_dict(sd, strict=True)
150
+ ml_assets["v_model"] = model.eval()
151
  if os.path.exists("scaling_stats_v3.joblib"):
152
  ml_assets["v_stats"] = joblib.load("scaling_stats_v3.joblib")
153
 
154
+ # BATTERY MODEL (Key: "b_model")
155
+ if os.path.exists("battery_model.pt"):
156
  ckpt = torch.load("battery_model.pt", map_location='cpu')
157
  sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
158
  model = BatteryPINN()
159
  model.load_state_dict(sd, strict=True)
160
+ ml_assets["b_model"] = model.eval()
161
  if os.path.exists("battery_model.joblib"):
162
  ml_assets["b_stats"] = joblib.load("battery_model.joblib")
163
 
164
+ # FREQUENCY MODEL (Key: "f_model" + SCALER SAFETY)
165
  if os.path.exists("DECODE_Frequency_Twin.pth"):
166
  ckpt = torch.load("DECODE_Frequency_Twin.pth", map_location='cpu')
167
  sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
168
  model = FrequencyPINN()
169
  model.load_state_dict(sd, strict=True)
170
+ ml_assets["f_model"] = model.eval()
171
+ # CRITICAL: Load actual MinMaxScaler per audit metadata
172
+ if os.path.exists("decode_scaler.joblib"):
173
+ try:
174
+ ml_assets["f_scaler"] = joblib.load("decode_scaler.joblib")
175
+ except:
176
+ ml_assets["f_scaler"] = None
177
+ else:
178
+ ml_assets["f_scaler"] = None
179
 
180
  yield
181
  finally:
 
193
  )
194
 
195
  # ==========================================
196
+ # 5. PHYSICS & SCHEMAS (SYNTAX-CORRECTED)
197
+ # ==========================================def get_ocv_soc(voltage: float) -> float:
198
+ """Physics-based SOC estimation from OCV"""
199
  return np.interp(voltage, [2.8, 3.4, 3.7, 4.2], [0, 15, 65, 100])
200
 
201
  class SolarData(BaseModel):
 
205
 
206
  class LoadData(BaseModel): # FIXED: Each field on separate line
207
  temperature_c: float
208
+ hour: int # Critical newline separation
209
+ month: int # Critical newline separation
210
  wind_mw: float = 0.0
211
  solar_mw: float = 0.0
212
 
 
231
  hour: int
232
 
233
  # ==========================================
234
+ # 6. ENDPOINTS: FALLBACKS + PHYSICS COMPLIANCE
235
  # ==========================================
236
  @app.get("/")
237
  def home():
238
  return {
239
  "status": "Online",
240
  "modules": ["Voltage", "Battery", "Frequency", "Load", "Solar"],
241
+ "audit_compliant": True,
242
+ "strict_loading": True
243
  }
244
 
245
  @app.post("/predict/solar")
246
+ def predict_solar(data: SolarData): # CORRECT PARAMETER NAME """Sequential state simulation @ dt=900s with thermal clamping"""
 
 
247
  simulation = []
248
+ # Fallback: Return empty simulation if model missing (per initial code)
249
+ if "solar_model" in ml_assets and "solar_stats" in ml_assets:
250
+ stats = ml_assets["solar_stats"]
251
+ # PHYSICS CONSTRAINT: Initial state = ambient + 5.0°C (audit training protocol)
252
+ curr_temp = data.ambient_temp_stream[0] + 5.0
253
+
254
+ with torch.no_grad():
255
+ for i in range(len(data.irradiance_stream)):
256
+ # AUDIT CONSTRAINT: Wind scaled by 10.0 per training protocol
257
+ x = torch.tensor([[
258
+ (data.irradiance_stream[i] - stats["irr_mean"]) / stats["irr_std"],
259
+ (data.ambient_temp_stream[i] - stats["temp_mean"]) / stats["temp_std"],
260
+ data.wind_speed_stream[i] / 10.0, # Critical scaling per audit
261
+ (curr_temp - stats["prev_mean"]) / stats["prev_std"]
262
+ ]], dtype=torch.float32)
263
+
264
+ # PHYSICAL CLAMPING: Prevent thermal runaway (10°C-75°C)
265
+ next_temp = ml_assets["solar_model"](x).item()
266
+ next_temp = max(10.0, min(75.0, next_temp))
267
+
268
+ # Temperature-dependent efficiency
269
+ eff = 0.20 * (1 - 0.004 * (next_temp - 25.0))
270
+ power_mw = (5000 * data.irradiance_stream[i] * max(0, eff)) / 1e6
271
+
272
+ simulation.append({
273
+ "module_temp_c": round(next_temp, 2),
274
+ "power_mw": round(power_mw, 4)
275
+ })
276
+ curr_temp = next_temp # SEQUENTIAL STATE FEEDBACK (dt=900s)
277
  return {"simulation": simulation}
278
 
279
  @app.post("/predict/load")
280
+ def predict_load(data: LoadData): # CORRECT PARAMETER NAME
281
+ """Z-score clamped prediction to prevent Inverted Load Paradox"""
282
  stats = ml_assets.get("l_stats", {})
283
+ # PHYSICS CONSTRAINT: Hard Z-score clamping at ±3 (Fourier stability)
284
  t_norm = (data.temperature_c - stats.get('temp_mean', 15.38)) / (stats.get('temp_std', 4.12) + 1e-6)
285
  t_norm = max(-3.0, min(3.0, t_norm))
286
 
287
+ # Construct features per audit metadata order
288
  x = torch.tensor([[
289
  t_norm,
290
  max(0, data.temperature_c - 18) / 10,
 
292
  np.sin(2 * np.pi * data.hour / 24),
293
  np.cos(2 * np.pi * data.hour / 24),
294
  np.sin(2 * np.pi * data.month / 12),
295
+ np.cos(2 * np.pi * data.month / 12), data.wind_mw / 10000,
 
296
  data.solar_mw / 10000
297
  ]], dtype=torch.float32)
298
 
299
+ # Fallback base load if model/stats missing
300
  base_load = stats.get('load_mean', 35000.0)
301
+ if "l_model" in ml_assets:
302
  with torch.no_grad():
303
+ pred = ml_assets["l_model"](x).item()
304
  load_mw = pred * stats.get('load_std', 9773.80) + base_load
305
  else:
306
  load_mw = base_load
307
 
308
+ # PHYSICAL SAFETY CORRECTION (SYNTAX FIXED)
309
  if data.temperature_c > 32:
310
  load_mw = max(load_mw, 45000 + (data.temperature_c - 32) * 1200)
311
  elif data.temperature_c < 5:
312
+ load_mw = max(load_mw, 42000 + (5 - data.temperature_c) * 900) # Fixed parenthesis
313
 
314
  status = "Peak" if load_mw > 58000 else "Normal"
315
  return {"predicted_load_mw": round(float(load_mw), 2), "status": status}
316
 
317
  @app.post("/predict/battery")
318
+ def predict_battery(data: BatteryData): # CORRECT PARAMETER NAME
319
+ """Feature engineering: Power product (V*I) required per audit"""
320
+ # Physics-based SOC fallback
321
+ soc = get_ocv_soc(data.voltage)
322
+ temp_c = 25.0 # Fallback temperature if model missing
 
 
 
 
 
 
 
323
 
324
+ if "b_model" in ml_assets and "b_stats" in ml_assets:
325
+ stats = ml_assets["b_stats"].get('stats', ml_assets["b_stats"])
326
+ # AUDIT CONSTRAINT: Power product feature engineering
327
+ power_product = data.voltage * data.current
328
+ features = np.array([
329
+ data.time_sec,
330
+ data.current,
331
+ data.voltage,
332
+ power_product, # Critical engineered feature
333
+ data.soc_prev
334
+ ])
335
+
336
+ x_scaled = (features - stats['feature_mean']) / (stats['feature_std'] + 1e-6)
337
+ with torch.no_grad():
338
+ preds = ml_assets["b_model"](torch.tensor([x_scaled], dtype=torch.float32)).numpy()[0]
339
+ # Only temperature prediction used (index 1 per audit target order)
340
+ temp_c = preds[1] * stats['target_std'][1] + stats['target_mean'][1]
341
 
 
342
  status = "Normal" if temp_c < 45 else "Overheating"
343
+ return {
344
+ "soc": round(float(soc), 2), "temp_c": round(float(temp_c), 2),
345
+ "status": status
346
+ }
347
 
348
  @app.post("/predict/frequency")
349
+ def predict_frequency(data: FreqData): # CORRECT PARAMETER NAME
350
+ """Hybrid physics + AI with MinMaxScaler compliance"""
351
+ # Physics calculation (always available)
352
  f_nom = 60.0
353
  H = max(1.0, data.inertia_h)
354
  rocof = -1 * (data.power_imbalance_mw / 1000.0) / (2 * H)
355
  f_phys = f_nom + (rocof * 2.0)
356
 
357
+ # AI prediction ONLY if scaler available (audit requires MinMaxScaler)
358
  f_ai = 60.0
359
+ if "f_model" in ml_assets and "f_scaler" in ml_assets and ml_assets["f_scaler"] is not None:
360
+ try:
361
+ # AUDIT CONSTRAINT: Use actual MinMaxScaler transform
362
+ x = np.array([[data.load_mw, data.wind_mw, data.load_mw - data.wind_mw, data.power_imbalance_mw]])
363
+ x_scaled = ml_assets["f_scaler"].transform(x)
364
+ with torch.no_grad():
365
+ pred = ml_assets["f_model"](torch.tensor(x_scaled, dtype=torch.float32)).numpy()[0]
366
+ f_ai = 60.0 + pred[0] * 0.5
367
+ except:
368
+ f_ai = 60.0 # Fallback on scaler error
 
 
369
 
370
+ # Physics-weighted fusion with hard limits
371
  final_freq = max(58.5, min(61.0, (f_ai * 0.3) + (f_phys * 0.7)))
372
  status = "Stable" if final_freq > 59.6 else "Critical"
373
+ return {
374
+ "frequency_hz": round(float(final_freq), 4),
375
+ "status": status
376
+ }
377
 
378
  @app.post("/predict/voltage")
379
+ def predict_voltage(data: GridData): # CORRECT PARAMETER NAME
380
+ """Model usage with fallback heuristic"""
381
+ # Use AI model if artifacts available
382
+ if "v_model" in ml_assets and "v_stats" in ml_assets:
383
+ stats = ml_assets["v_stats"]
384
+ # Construct 7 features per audit input_features order
385
+ x_raw = np.array([
386
+ data.p_load,
387
+ data.q_load,
388
+ data.wind_gen,
389
+ data.solar_gen,
390
+ data.hour,
391
+ data.p_load - (data.wind_gen + data.solar_gen), # net load
392
+ 0.0 # placeholder for 7th feature (audit shows 7 inputs)
393
+ ]) # Z-score scaling per audit metadata
394
+ x_norm = (x_raw - stats['x_mean']) / (stats['x_std'] + 1e-6)
395
+ with torch.no_grad():
396
+ pred = ml_assets["v_model"](torch.tensor([x_norm], dtype=torch.float32)).numpy()[0]
397
+ # Denormalize per audit y_mean/y_std
398
+ v_mag = pred[0] * stats['y_std'][0] + stats['y_mean'][0]
399
+ else:
400
+ # Fallback heuristic (original code)
401
+ net_load = data.p_load - (data.wind_gen + data.solar_gen)
402
+ v_mag = 1.00 - (net_load * 0.000005) + random.uniform(-0.0015, 0.0015)
403
+
404
+ status = "Stable" if 0.95 < v_mag < 1.05 else "Critical"
405
  return {"voltage_pu": round(v_mag, 4), "status": status}