Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -21,7 +21,8 @@ VECNORM_PATH = "models/vecnormalize_lstm_final.pkl"
|
|
| 21 |
MAX_STEPS = 50
|
| 22 |
|
| 23 |
ZONE_LABELS = ["Zone 1 (Residential)", "Zone 2 (Commercial)", "Zone 3 (Hospital)"]
|
| 24 |
-
ZONE_TYPES = ["
|
|
|
|
| 25 |
|
| 26 |
_DARK_BG = "#1e1e1e"
|
| 27 |
_DARK_AX = "#181818"
|
|
@@ -46,12 +47,38 @@ def _extract_array_obs(obs, length=3):
|
|
| 46 |
return [0.33] * length
|
| 47 |
|
| 48 |
|
| 49 |
-
def _heuristic_action(demand):
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
total = sum(weights)
|
| 52 |
if total <= 0:
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
def _fig_to_pil(fig) -> PILImage.Image:
|
|
@@ -84,7 +111,7 @@ def _make_grid_chart(demand_vals, supply_vals, fault_status=None) -> PILImage.Im
|
|
| 84 |
# FAULT annotations
|
| 85 |
if fault_status is None:
|
| 86 |
fault_status = [
|
| 87 |
-
demand_vals[i] > 0.01 and supply_vals[i] < demand_vals[i] * 0.
|
| 88 |
for i in range(3)
|
| 89 |
]
|
| 90 |
for i, fault in enumerate(fault_status):
|
|
@@ -99,7 +126,8 @@ def _make_grid_chart(demand_vals, supply_vals, fault_status=None) -> PILImage.Im
|
|
| 99 |
ax.set_title("Grid Status per Zone", color="white", fontsize=12,
|
| 100 |
pad=12, fontweight="bold")
|
| 101 |
ax.set_xticks(x)
|
| 102 |
-
ax.set_xticklabels(["Zone 1", "Zone 2", "Zone 3"],
|
|
|
|
| 103 |
ax.set_ylabel("Power Units", color="#888888", fontsize=9)
|
| 104 |
ax.tick_params(axis="y", colors="#888888", labelsize=8)
|
| 105 |
ax.tick_params(axis="x", colors="#cccccc")
|
|
@@ -131,8 +159,10 @@ class GridSimulator:
|
|
| 131 |
|
| 132 |
if os.path.exists(MODEL_PATH + ".zip"):
|
| 133 |
self.model = RecurrentPPO.load(MODEL_PATH, env=self.env)
|
|
|
|
| 134 |
else:
|
| 135 |
-
self.model = None
|
|
|
|
| 136 |
|
| 137 |
self.ready = True
|
| 138 |
self.reset()
|
|
@@ -158,6 +188,7 @@ class GridSimulator:
|
|
| 158 |
self.current_supply = [0.33, 0.33, 0.33]
|
| 159 |
self.action_taken = [0.33, 0.33, 0.33]
|
| 160 |
self.fault_status = [False, False, False]
|
|
|
|
| 161 |
|
| 162 |
# ------------------------------------------------------------------
|
| 163 |
def step(self, action=None, manual=False):
|
|
@@ -176,7 +207,8 @@ class GridSimulator:
|
|
| 176 |
)
|
| 177 |
else:
|
| 178 |
raw = self.obs[0] if hasattr(self.obs, '__len__') else self.obs
|
| 179 |
-
|
|
|
|
| 180 |
else:
|
| 181 |
total = sum(action)
|
| 182 |
action = np.array([[a / total if total > 0 else 0.33 for a in action]])
|
|
@@ -200,13 +232,20 @@ class GridSimulator:
|
|
| 200 |
self.action_taken = (action[0].tolist()
|
| 201 |
if hasattr(action[0], 'tolist') else list(action[0]))
|
| 202 |
|
|
|
|
|
|
|
|
|
|
| 203 |
for key in ("blackout_count", "blackouts", "fault_count"):
|
| 204 |
-
if
|
| 205 |
-
|
| 206 |
break
|
|
|
|
|
|
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
| 210 |
self.stability_history.append(self.stability)
|
| 211 |
self.unmet_demand += float(ep_info.get("unmet_demand", 0.0))
|
| 212 |
|
|
@@ -219,21 +258,65 @@ class GridSimulator:
|
|
| 219 |
pass
|
| 220 |
|
| 221 |
self.current_supply = self.action_taken[:3]
|
| 222 |
-
|
|
|
|
| 223 |
self.current_demand[i] > 0.01 and
|
| 224 |
-
self.current_supply[i] < self.current_demand[i] * 0.
|
| 225 |
for i in range(3)
|
| 226 |
]
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
# ------------------------------------------------------------------
|
| 230 |
def get_ui_state(self, status_msg="Step completed."):
|
| 231 |
reward_str = f"{self.last_reward:.4f}"
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
return reward_str, done_str, status_msg, env_desc, chart
|
| 238 |
|
| 239 |
# ------------------------------------------------------------------
|
|
@@ -263,13 +346,24 @@ class GridSimulator:
|
|
| 263 |
f" Total unmet demand: {self.unmet_demand:.3f}",
|
| 264 |
f" Total reward: {self.total_reward:.2f}",
|
| 265 |
"",
|
|
|
|
|
|
|
| 266 |
"Task: Allocate power to 3 zones as fractions summing to 1.0.",
|
| 267 |
"Priority: Serve Zone 3 (Hospital) first. "
|
| 268 |
"Avoid overloads β they cascade into blackouts.",
|
| 269 |
"Reply with exactly 3 space-separated floats. Example: 0.20 0.30 0.50",
|
| 270 |
]
|
| 271 |
if self.done:
|
| 272 |
-
lines
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
return "\n".join(lines)
|
| 274 |
|
| 275 |
# ------------------------------------------------------------------
|
|
@@ -293,7 +387,7 @@ def ui_reset():
|
|
| 293 |
if not sim.ready:
|
| 294 |
return _ERR
|
| 295 |
sim.reset()
|
| 296 |
-
return sim.get_ui_state("Environment reset.
|
| 297 |
|
| 298 |
|
| 299 |
def ui_ai_step():
|
|
@@ -313,7 +407,7 @@ def ui_auto_run():
|
|
| 313 |
return
|
| 314 |
while not sim.done and sim.steps < MAX_STEPS:
|
| 315 |
yield sim.step()
|
| 316 |
-
time.sleep(0.
|
| 317 |
|
| 318 |
|
| 319 |
def ui_take_step(z1, z2, z3):
|
|
|
|
| 21 |
MAX_STEPS = 50
|
| 22 |
|
| 23 |
ZONE_LABELS = ["Zone 1 (Residential)", "Zone 2 (Commercial)", "Zone 3 (Hospital)"]
|
| 24 |
+
ZONE_TYPES = ["Residential (low)", "Commercial (medium)", "Hospital (critical)"]
|
| 25 |
+
ZONE_SHORT = ["Residential", "Commercial", "Hospital"]
|
| 26 |
|
| 27 |
_DARK_BG = "#1e1e1e"
|
| 28 |
_DARK_AX = "#181818"
|
|
|
|
| 47 |
return [0.33] * length
|
| 48 |
|
| 49 |
|
| 50 |
+
def _heuristic_action(demand, fault_status=None):
|
| 51 |
+
"""
|
| 52 |
+
Deterministic, fault-aware heuristic:
|
| 53 |
+
- Allocate proportionally to demand
|
| 54 |
+
- Penalise zones currently in fault (halve their weight)
|
| 55 |
+
- Give Hospital (Zone 3) a 30 % priority boost
|
| 56 |
+
- Clip each allocation to [0.10, 0.60] to avoid extreme values
|
| 57 |
+
- Always returns a normalised 3-element action
|
| 58 |
+
"""
|
| 59 |
+
if fault_status is None:
|
| 60 |
+
fault_status = [False, False, False]
|
| 61 |
+
|
| 62 |
+
weights = [max(float(d), 0.0) for d in demand]
|
| 63 |
+
|
| 64 |
+
# Penalise faulty zones so AI routes around them
|
| 65 |
+
for i, fault in enumerate(fault_status):
|
| 66 |
+
if fault:
|
| 67 |
+
weights[i] *= 0.5
|
| 68 |
+
|
| 69 |
+
# Hospital (Zone 3) always gets a boost
|
| 70 |
+
weights[2] *= 1.3
|
| 71 |
+
|
| 72 |
total = sum(weights)
|
| 73 |
if total <= 0:
|
| 74 |
+
raw = [0.20, 0.30, 0.50]
|
| 75 |
+
else:
|
| 76 |
+
raw = [w / total for w in weights]
|
| 77 |
+
|
| 78 |
+
# Clip to sensible range so no zone gets starved or flooded
|
| 79 |
+
clipped = [max(0.10, min(0.60, v)) for v in raw]
|
| 80 |
+
s = sum(clipped)
|
| 81 |
+
return np.array([[v / s for v in clipped]])
|
| 82 |
|
| 83 |
|
| 84 |
def _fig_to_pil(fig) -> PILImage.Image:
|
|
|
|
| 111 |
# FAULT annotations
|
| 112 |
if fault_status is None:
|
| 113 |
fault_status = [
|
| 114 |
+
demand_vals[i] > 0.01 and supply_vals[i] < demand_vals[i] * 0.8
|
| 115 |
for i in range(3)
|
| 116 |
]
|
| 117 |
for i, fault in enumerate(fault_status):
|
|
|
|
| 126 |
ax.set_title("Grid Status per Zone", color="white", fontsize=12,
|
| 127 |
pad=12, fontweight="bold")
|
| 128 |
ax.set_xticks(x)
|
| 129 |
+
ax.set_xticklabels(["Zone 1\n(Residential)", "Zone 2\n(Commercial)", "Zone 3\n(Hospital)"],
|
| 130 |
+
color="#cccccc", fontsize=9)
|
| 131 |
ax.set_ylabel("Power Units", color="#888888", fontsize=9)
|
| 132 |
ax.tick_params(axis="y", colors="#888888", labelsize=8)
|
| 133 |
ax.tick_params(axis="x", colors="#cccccc")
|
|
|
|
| 159 |
|
| 160 |
if os.path.exists(MODEL_PATH + ".zip"):
|
| 161 |
self.model = RecurrentPPO.load(MODEL_PATH, env=self.env)
|
| 162 |
+
print("β
Trained model loaded successfully.")
|
| 163 |
else:
|
| 164 |
+
self.model = None
|
| 165 |
+
print("β οΈ Using heuristic AI (trained model not found)")
|
| 166 |
|
| 167 |
self.ready = True
|
| 168 |
self.reset()
|
|
|
|
| 188 |
self.current_supply = [0.33, 0.33, 0.33]
|
| 189 |
self.action_taken = [0.33, 0.33, 0.33]
|
| 190 |
self.fault_status = [False, False, False]
|
| 191 |
+
self.ai_explanation = "System reset. Ready for AI power allocation."
|
| 192 |
|
| 193 |
# ------------------------------------------------------------------
|
| 194 |
def step(self, action=None, manual=False):
|
|
|
|
| 207 |
)
|
| 208 |
else:
|
| 209 |
raw = self.obs[0] if hasattr(self.obs, '__len__') else self.obs
|
| 210 |
+
# Pass current fault_status so heuristic avoids broken zones
|
| 211 |
+
action = _heuristic_action(_extract_array_obs(raw), self.fault_status)
|
| 212 |
else:
|
| 213 |
total = sum(action)
|
| 214 |
action = np.array([[a / total if total > 0 else 0.33 for a in action]])
|
|
|
|
| 232 |
self.action_taken = (action[0].tolist()
|
| 233 |
if hasattr(action[0], 'tolist') else list(action[0]))
|
| 234 |
|
| 235 |
+
# Blackout counting β use explicit key presence check (not falsy)
|
| 236 |
+
# to avoid double-counting and to correctly handle 0-value steps
|
| 237 |
+
step_blk = None
|
| 238 |
for key in ("blackout_count", "blackouts", "fault_count"):
|
| 239 |
+
if key in ep_info:
|
| 240 |
+
step_blk = int(ep_info[key])
|
| 241 |
break
|
| 242 |
+
if step_blk is not None:
|
| 243 |
+
self.blackouts += step_blk
|
| 244 |
|
| 245 |
+
# Stability: trust env value but also penalise for accumulated blackouts
|
| 246 |
+
env_stab = float(ep_info.get("stability_score", ep_info.get("stability", 1.0)))
|
| 247 |
+
penalty_stab = max(0.0, 1.0 - self.blackouts / 10.0)
|
| 248 |
+
self.stability = min(env_stab, penalty_stab)
|
| 249 |
self.stability_history.append(self.stability)
|
| 250 |
self.unmet_demand += float(ep_info.get("unmet_demand", 0.0))
|
| 251 |
|
|
|
|
| 258 |
pass
|
| 259 |
|
| 260 |
self.current_supply = self.action_taken[:3]
|
| 261 |
+
# Fault threshold raised to 0.8 β only flag meaningful undersupply
|
| 262 |
+
self.fault_status = [
|
| 263 |
self.current_demand[i] > 0.01 and
|
| 264 |
+
self.current_supply[i] < self.current_demand[i] * 0.8
|
| 265 |
for i in range(3)
|
| 266 |
]
|
| 267 |
+
self._update_ai_explanation()
|
| 268 |
+
return self.get_ui_state(f"β
Reward: {self.last_reward:.2f}")
|
| 269 |
+
|
| 270 |
+
# ------------------------------------------------------------------
|
| 271 |
+
def _update_ai_explanation(self):
|
| 272 |
+
z = self.action_taken[:3]
|
| 273 |
+
d = self.current_demand[:3]
|
| 274 |
+
f = self.fault_status[:3]
|
| 275 |
+
|
| 276 |
+
faulty_zones = [ZONE_SHORT[i] for i, fault in enumerate(f) if fault]
|
| 277 |
+
max_z = z.index(max(z))
|
| 278 |
+
|
| 279 |
+
if faulty_zones:
|
| 280 |
+
self.ai_explanation = (
|
| 281 |
+
f"AI reduced allocation to faulty zone(s) ({', '.join(faulty_zones)}) "
|
| 282 |
+
f"and rerouted power to maintain grid stability."
|
| 283 |
+
)
|
| 284 |
+
elif max_z == 2:
|
| 285 |
+
self.ai_explanation = (
|
| 286 |
+
f"AI prioritized Hospital (Zone 3) β demand at {d[2]:.0%}. "
|
| 287 |
+
f"Critical load protected ({z[2]:.0%} share allocated)."
|
| 288 |
+
)
|
| 289 |
+
elif max_z == 1:
|
| 290 |
+
self.ai_explanation = (
|
| 291 |
+
f"AI allocated maximum to Commercial Zone to handle demand spike "
|
| 292 |
+
f"({z[1]:.0%} share). Hospital share: {z[2]:.0%}."
|
| 293 |
+
)
|
| 294 |
+
else:
|
| 295 |
+
self.ai_explanation = (
|
| 296 |
+
f"AI balanced supply across zones. "
|
| 297 |
+
f"Residential: {z[0]:.0%} | Commercial: {z[1]:.0%} | Hospital: {z[2]:.0%}."
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
if self.blackouts > 0:
|
| 301 |
+
self.ai_explanation += f" β οΈ Cumulative blackouts: {int(self.blackouts)} β adapting allocation."
|
| 302 |
|
| 303 |
# ------------------------------------------------------------------
|
| 304 |
def get_ui_state(self, status_msg="Step completed."):
|
| 305 |
reward_str = f"{self.last_reward:.4f}"
|
| 306 |
+
if self.done:
|
| 307 |
+
stab_icon = "π’" if self.stability >= 0.75 else ("π‘" if self.stability >= 0.45 else "π΄")
|
| 308 |
+
done_str = (
|
| 309 |
+
f"π Episode Complete\n"
|
| 310 |
+
f"β‘ Final Stability: {self.stability:.2f} {stab_icon}\n"
|
| 311 |
+
f"π¨ Total Blackouts: {int(self.blackouts)}\n"
|
| 312 |
+
f"π Total Reward: {self.total_reward:.2f}"
|
| 313 |
+
)
|
| 314 |
+
else:
|
| 315 |
+
done_str = ""
|
| 316 |
+
env_desc = self._build_env_description()
|
| 317 |
+
chart = _make_grid_chart(self.current_demand,
|
| 318 |
+
self.current_supply,
|
| 319 |
+
self.fault_status)
|
| 320 |
return reward_str, done_str, status_msg, env_desc, chart
|
| 321 |
|
| 322 |
# ------------------------------------------------------------------
|
|
|
|
| 346 |
f" Total unmet demand: {self.unmet_demand:.3f}",
|
| 347 |
f" Total reward: {self.total_reward:.2f}",
|
| 348 |
"",
|
| 349 |
+
f"π€ AI Decision: {self.ai_explanation}",
|
| 350 |
+
"",
|
| 351 |
"Task: Allocate power to 3 zones as fractions summing to 1.0.",
|
| 352 |
"Priority: Serve Zone 3 (Hospital) first. "
|
| 353 |
"Avoid overloads β they cascade into blackouts.",
|
| 354 |
"Reply with exactly 3 space-separated floats. Example: 0.20 0.30 0.50",
|
| 355 |
]
|
| 356 |
if self.done:
|
| 357 |
+
lines += [
|
| 358 |
+
"",
|
| 359 |
+
"βββββββββββββββββββββββββββββββββββββββ",
|
| 360 |
+
"π Episode Complete",
|
| 361 |
+
f" β‘ Final Stability: {self.stability:.2f} {stab_icon}",
|
| 362 |
+
f" π¨ Total Blackouts: {int(self.blackouts)}",
|
| 363 |
+
f" π Total Reward: {self.total_reward:.2f}",
|
| 364 |
+
"βββββββββββββββββββββββββββββββββββββββ",
|
| 365 |
+
"Click 'Reset Env' to start a new episode.",
|
| 366 |
+
]
|
| 367 |
return "\n".join(lines)
|
| 368 |
|
| 369 |
# ------------------------------------------------------------------
|
|
|
|
| 387 |
if not sim.ready:
|
| 388 |
return _ERR
|
| 389 |
sim.reset()
|
| 390 |
+
return sim.get_ui_state("β
Environment reset. Click AI Step or Take Step to begin.")
|
| 391 |
|
| 392 |
|
| 393 |
def ui_ai_step():
|
|
|
|
| 407 |
return
|
| 408 |
while not sim.done and sim.steps < MAX_STEPS:
|
| 409 |
yield sim.step()
|
| 410 |
+
time.sleep(0.30)
|
| 411 |
|
| 412 |
|
| 413 |
def ui_take_step(z1, z2, z3):
|