TechLearnr4S commited on
Commit
ffdc641
Β·
verified Β·
1 Parent(s): a4edc1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -24
app.py CHANGED
@@ -21,7 +21,8 @@ VECNORM_PATH = "models/vecnormalize_lstm_final.pkl"
21
  MAX_STEPS = 50
22
 
23
  ZONE_LABELS = ["Zone 1 (Residential)", "Zone 2 (Commercial)", "Zone 3 (Hospital)"]
24
- ZONE_TYPES = ["Commercial (medium)", "Residential (low)", "Commercial (medium)"]
 
25
 
26
  _DARK_BG = "#1e1e1e"
27
  _DARK_AX = "#181818"
@@ -46,12 +47,38 @@ def _extract_array_obs(obs, length=3):
46
  return [0.33] * length
47
 
48
 
49
- def _heuristic_action(demand):
50
- weights = [demand[0], demand[1], demand[2] * 1.2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  total = sum(weights)
52
  if total <= 0:
53
- return np.array([[0.25, 0.35, 0.40]])
54
- return np.array([[w / total for w in weights]])
 
 
 
 
 
 
55
 
56
 
57
  def _fig_to_pil(fig) -> PILImage.Image:
@@ -84,7 +111,7 @@ def _make_grid_chart(demand_vals, supply_vals, fault_status=None) -> PILImage.Im
84
  # FAULT annotations
85
  if fault_status is None:
86
  fault_status = [
87
- demand_vals[i] > 0.01 and supply_vals[i] < demand_vals[i] * 0.6
88
  for i in range(3)
89
  ]
90
  for i, fault in enumerate(fault_status):
@@ -99,7 +126,8 @@ def _make_grid_chart(demand_vals, supply_vals, fault_status=None) -> PILImage.Im
99
  ax.set_title("Grid Status per Zone", color="white", fontsize=12,
100
  pad=12, fontweight="bold")
101
  ax.set_xticks(x)
102
- ax.set_xticklabels(["Zone 1", "Zone 2", "Zone 3"], color="#cccccc", fontsize=10)
 
103
  ax.set_ylabel("Power Units", color="#888888", fontsize=9)
104
  ax.tick_params(axis="y", colors="#888888", labelsize=8)
105
  ax.tick_params(axis="x", colors="#cccccc")
@@ -131,8 +159,10 @@ class GridSimulator:
131
 
132
  if os.path.exists(MODEL_PATH + ".zip"):
133
  self.model = RecurrentPPO.load(MODEL_PATH, env=self.env)
 
134
  else:
135
- self.model = None # heuristic-only demo mode
 
136
 
137
  self.ready = True
138
  self.reset()
@@ -158,6 +188,7 @@ class GridSimulator:
158
  self.current_supply = [0.33, 0.33, 0.33]
159
  self.action_taken = [0.33, 0.33, 0.33]
160
  self.fault_status = [False, False, False]
 
161
 
162
  # ------------------------------------------------------------------
163
  def step(self, action=None, manual=False):
@@ -176,7 +207,8 @@ class GridSimulator:
176
  )
177
  else:
178
  raw = self.obs[0] if hasattr(self.obs, '__len__') else self.obs
179
- action = _heuristic_action(_extract_array_obs(raw))
 
180
  else:
181
  total = sum(action)
182
  action = np.array([[a / total if total > 0 else 0.33 for a in action]])
@@ -200,13 +232,20 @@ class GridSimulator:
200
  self.action_taken = (action[0].tolist()
201
  if hasattr(action[0], 'tolist') else list(action[0]))
202
 
 
 
 
203
  for key in ("blackout_count", "blackouts", "fault_count"):
204
- if ep_info.get(key):
205
- self.blackouts += int(ep_info[key])
206
  break
 
 
207
 
208
- self.stability = float(ep_info.get("stability_score",
209
- ep_info.get("stability", 1.0)))
 
 
210
  self.stability_history.append(self.stability)
211
  self.unmet_demand += float(ep_info.get("unmet_demand", 0.0))
212
 
@@ -219,21 +258,65 @@ class GridSimulator:
219
  pass
220
 
221
  self.current_supply = self.action_taken[:3]
222
- self.fault_status = [
 
223
  self.current_demand[i] > 0.01 and
224
- self.current_supply[i] < self.current_demand[i] * 0.6
225
  for i in range(3)
226
  ]
227
- return self.get_ui_state("Step completed.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  # ------------------------------------------------------------------
230
  def get_ui_state(self, status_msg="Step completed."):
231
  reward_str = f"{self.last_reward:.4f}"
232
- done_str = "βœ… Yes β€” Episode Complete" if self.done else ""
233
- env_desc = self._build_env_description()
234
- chart = _make_grid_chart(self.current_demand,
235
- self.current_supply,
236
- self.fault_status)
 
 
 
 
 
 
 
 
 
237
  return reward_str, done_str, status_msg, env_desc, chart
238
 
239
  # ------------------------------------------------------------------
@@ -263,13 +346,24 @@ class GridSimulator:
263
  f" Total unmet demand: {self.unmet_demand:.3f}",
264
  f" Total reward: {self.total_reward:.2f}",
265
  "",
 
 
266
  "Task: Allocate power to 3 zones as fractions summing to 1.0.",
267
  "Priority: Serve Zone 3 (Hospital) first. "
268
  "Avoid overloads – they cascade into blackouts.",
269
  "Reply with exactly 3 space-separated floats. Example: 0.20 0.30 0.50",
270
  ]
271
  if self.done:
272
- lines.append("\nβœ… Episode complete. Click 'Reset Env' to start a new episode.")
 
 
 
 
 
 
 
 
 
273
  return "\n".join(lines)
274
 
275
  # ------------------------------------------------------------------
@@ -293,7 +387,7 @@ def ui_reset():
293
  if not sim.ready:
294
  return _ERR
295
  sim.reset()
296
- return sim.get_ui_state("Environment reset. Ready.")
297
 
298
 
299
  def ui_ai_step():
@@ -313,7 +407,7 @@ def ui_auto_run():
313
  return
314
  while not sim.done and sim.steps < MAX_STEPS:
315
  yield sim.step()
316
- time.sleep(0.15)
317
 
318
 
319
  def ui_take_step(z1, z2, z3):
 
21
  MAX_STEPS = 50
22
 
23
  ZONE_LABELS = ["Zone 1 (Residential)", "Zone 2 (Commercial)", "Zone 3 (Hospital)"]
24
+ ZONE_TYPES = ["Residential (low)", "Commercial (medium)", "Hospital (critical)"]
25
+ ZONE_SHORT = ["Residential", "Commercial", "Hospital"]
26
 
27
  _DARK_BG = "#1e1e1e"
28
  _DARK_AX = "#181818"
 
47
  return [0.33] * length
48
 
49
 
50
+ def _heuristic_action(demand, fault_status=None):
51
+ """
52
+ Deterministic, fault-aware heuristic:
53
+ - Allocate proportionally to demand
54
+ - Penalise zones currently in fault (halve their weight)
55
+ - Give Hospital (Zone 3) a 30 % priority boost
56
+ - Clip each allocation to [0.10, 0.60] to avoid extreme values
57
+ - Always returns a normalised 3-element action
58
+ """
59
+ if fault_status is None:
60
+ fault_status = [False, False, False]
61
+
62
+ weights = [max(float(d), 0.0) for d in demand]
63
+
64
+ # Penalise faulty zones so AI routes around them
65
+ for i, fault in enumerate(fault_status):
66
+ if fault:
67
+ weights[i] *= 0.5
68
+
69
+ # Hospital (Zone 3) always gets a boost
70
+ weights[2] *= 1.3
71
+
72
  total = sum(weights)
73
  if total <= 0:
74
+ raw = [0.20, 0.30, 0.50]
75
+ else:
76
+ raw = [w / total for w in weights]
77
+
78
+ # Clip to sensible range so no zone gets starved or flooded
79
+ clipped = [max(0.10, min(0.60, v)) for v in raw]
80
+ s = sum(clipped)
81
+ return np.array([[v / s for v in clipped]])
82
 
83
 
84
  def _fig_to_pil(fig) -> PILImage.Image:
 
111
  # FAULT annotations
112
  if fault_status is None:
113
  fault_status = [
114
+ demand_vals[i] > 0.01 and supply_vals[i] < demand_vals[i] * 0.8
115
  for i in range(3)
116
  ]
117
  for i, fault in enumerate(fault_status):
 
126
  ax.set_title("Grid Status per Zone", color="white", fontsize=12,
127
  pad=12, fontweight="bold")
128
  ax.set_xticks(x)
129
+ ax.set_xticklabels(["Zone 1\n(Residential)", "Zone 2\n(Commercial)", "Zone 3\n(Hospital)"],
130
+ color="#cccccc", fontsize=9)
131
  ax.set_ylabel("Power Units", color="#888888", fontsize=9)
132
  ax.tick_params(axis="y", colors="#888888", labelsize=8)
133
  ax.tick_params(axis="x", colors="#cccccc")
 
159
 
160
  if os.path.exists(MODEL_PATH + ".zip"):
161
  self.model = RecurrentPPO.load(MODEL_PATH, env=self.env)
162
+ print("βœ… Trained model loaded successfully.")
163
  else:
164
+ self.model = None
165
+ print("⚠️ Using heuristic AI (trained model not found)")
166
 
167
  self.ready = True
168
  self.reset()
 
188
  self.current_supply = [0.33, 0.33, 0.33]
189
  self.action_taken = [0.33, 0.33, 0.33]
190
  self.fault_status = [False, False, False]
191
+ self.ai_explanation = "System reset. Ready for AI power allocation."
192
 
193
  # ------------------------------------------------------------------
194
  def step(self, action=None, manual=False):
 
207
  )
208
  else:
209
  raw = self.obs[0] if hasattr(self.obs, '__len__') else self.obs
210
+ # Pass current fault_status so heuristic avoids broken zones
211
+ action = _heuristic_action(_extract_array_obs(raw), self.fault_status)
212
  else:
213
  total = sum(action)
214
  action = np.array([[a / total if total > 0 else 0.33 for a in action]])
 
232
  self.action_taken = (action[0].tolist()
233
  if hasattr(action[0], 'tolist') else list(action[0]))
234
 
235
+ # Blackout counting β€” use explicit key presence check (not falsy)
236
+ # to avoid double-counting and to correctly handle 0-value steps
237
+ step_blk = None
238
  for key in ("blackout_count", "blackouts", "fault_count"):
239
+ if key in ep_info:
240
+ step_blk = int(ep_info[key])
241
  break
242
+ if step_blk is not None:
243
+ self.blackouts += step_blk
244
 
245
+ # Stability: trust env value but also penalise for accumulated blackouts
246
+ env_stab = float(ep_info.get("stability_score", ep_info.get("stability", 1.0)))
247
+ penalty_stab = max(0.0, 1.0 - self.blackouts / 10.0)
248
+ self.stability = min(env_stab, penalty_stab)
249
  self.stability_history.append(self.stability)
250
  self.unmet_demand += float(ep_info.get("unmet_demand", 0.0))
251
 
 
258
  pass
259
 
260
  self.current_supply = self.action_taken[:3]
261
+ # Fault threshold raised to 0.8 β€” only flag meaningful undersupply
262
+ self.fault_status = [
263
  self.current_demand[i] > 0.01 and
264
+ self.current_supply[i] < self.current_demand[i] * 0.8
265
  for i in range(3)
266
  ]
267
+ self._update_ai_explanation()
268
+ return self.get_ui_state(f"βœ… Reward: {self.last_reward:.2f}")
269
+
270
+ # ------------------------------------------------------------------
271
+ def _update_ai_explanation(self):
272
+ z = self.action_taken[:3]
273
+ d = self.current_demand[:3]
274
+ f = self.fault_status[:3]
275
+
276
+ faulty_zones = [ZONE_SHORT[i] for i, fault in enumerate(f) if fault]
277
+ max_z = z.index(max(z))
278
+
279
+ if faulty_zones:
280
+ self.ai_explanation = (
281
+ f"AI reduced allocation to faulty zone(s) ({', '.join(faulty_zones)}) "
282
+ f"and rerouted power to maintain grid stability."
283
+ )
284
+ elif max_z == 2:
285
+ self.ai_explanation = (
286
+ f"AI prioritized Hospital (Zone 3) β€” demand at {d[2]:.0%}. "
287
+ f"Critical load protected ({z[2]:.0%} share allocated)."
288
+ )
289
+ elif max_z == 1:
290
+ self.ai_explanation = (
291
+ f"AI allocated maximum to Commercial Zone to handle demand spike "
292
+ f"({z[1]:.0%} share). Hospital share: {z[2]:.0%}."
293
+ )
294
+ else:
295
+ self.ai_explanation = (
296
+ f"AI balanced supply across zones. "
297
+ f"Residential: {z[0]:.0%} | Commercial: {z[1]:.0%} | Hospital: {z[2]:.0%}."
298
+ )
299
+
300
+ if self.blackouts > 0:
301
+ self.ai_explanation += f" ⚠️ Cumulative blackouts: {int(self.blackouts)} β€” adapting allocation."
302
 
303
  # ------------------------------------------------------------------
304
  def get_ui_state(self, status_msg="Step completed."):
305
  reward_str = f"{self.last_reward:.4f}"
306
+ if self.done:
307
+ stab_icon = "🟒" if self.stability >= 0.75 else ("🟑" if self.stability >= 0.45 else "πŸ”΄")
308
+ done_str = (
309
+ f"🏁 Episode Complete\n"
310
+ f"⚑ Final Stability: {self.stability:.2f} {stab_icon}\n"
311
+ f"🚨 Total Blackouts: {int(self.blackouts)}\n"
312
+ f"πŸ† Total Reward: {self.total_reward:.2f}"
313
+ )
314
+ else:
315
+ done_str = ""
316
+ env_desc = self._build_env_description()
317
+ chart = _make_grid_chart(self.current_demand,
318
+ self.current_supply,
319
+ self.fault_status)
320
  return reward_str, done_str, status_msg, env_desc, chart
321
 
322
  # ------------------------------------------------------------------
 
346
  f" Total unmet demand: {self.unmet_demand:.3f}",
347
  f" Total reward: {self.total_reward:.2f}",
348
  "",
349
+ f"πŸ€– AI Decision: {self.ai_explanation}",
350
+ "",
351
  "Task: Allocate power to 3 zones as fractions summing to 1.0.",
352
  "Priority: Serve Zone 3 (Hospital) first. "
353
  "Avoid overloads – they cascade into blackouts.",
354
  "Reply with exactly 3 space-separated floats. Example: 0.20 0.30 0.50",
355
  ]
356
  if self.done:
357
+ lines += [
358
+ "",
359
+ "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━",
360
+ "🏁 Episode Complete",
361
+ f" ⚑ Final Stability: {self.stability:.2f} {stab_icon}",
362
+ f" 🚨 Total Blackouts: {int(self.blackouts)}",
363
+ f" πŸ† Total Reward: {self.total_reward:.2f}",
364
+ "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━",
365
+ "Click 'Reset Env' to start a new episode.",
366
+ ]
367
  return "\n".join(lines)
368
 
369
  # ------------------------------------------------------------------
 
387
  if not sim.ready:
388
  return _ERR
389
  sim.reset()
390
+ return sim.get_ui_state("βœ… Environment reset. Click AI Step or Take Step to begin.")
391
 
392
 
393
  def ui_ai_step():
 
407
  return
408
  while not sim.done and sim.steps < MAX_STEPS:
409
  yield sim.step()
410
+ time.sleep(0.30)
411
 
412
 
413
  def ui_take_step(z1, z2, z3):