| import numpy as np
|
| import matplotlib.pyplot as plt
|
| import networkx as nx
|
| from matplotlib.gridspec import GridSpec
|
| from matplotlib.patches import FancyArrowPatch
|
|
|
| import os
|
| import re
|
|
|
| def setup_figure(title, rows, cols):
|
| """Initializes a new figure and grid layout with constrained_layout to avoid warnings."""
|
| fig = plt.figure(figsize=(20, 10), constrained_layout=True)
|
| fig.suptitle(title, fontsize=18, fontweight='bold')
|
| gs = GridSpec(rows, cols, figure=fig)
|
| return fig, gs
|
|
|
| def plot_agent_env_loop(ax):
|
| """MDP & Environment: Agent-Environment Interaction Loop (Flowchart)."""
|
| ax.axis('off')
|
| ax.set_title("Agent-Environment Interaction", fontsize=12, fontweight='bold')
|
|
|
| props = dict(boxstyle="round,pad=0.8", fc="ivory", ec="black", lw=1.5)
|
| ax.text(0.5, 0.8, "Agent", ha="center", va="center", bbox=props, fontsize=12)
|
| ax.text(0.5, 0.2, "Environment", ha="center", va="center", bbox=props, fontsize=12)
|
|
|
|
|
|
|
| ax.annotate("Action $A_t$", xy=(0.5, 0.35), xytext=(0.5, 0.65),
|
| arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5", lw=2))
|
|
|
| ax.annotate("State $S_{t+1}$, Reward $R_{t+1}$", xy=(0.5, 0.65), xytext=(0.5, 0.35),
|
| arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5", lw=2, color='green'))
|
|
|
| def plot_mdp_graph(ax):
|
| """MDP & Environment: Directed graph with probability-weighted arrows."""
|
| G = nx.DiGraph()
|
|
|
| G.add_edges_from([
|
| ('S0', 'S1', {'weight': 0.8}), ('S0', 'S2', {'weight': 0.2}),
|
| ('S1', 'S2', {'weight': 1.0}), ('S2', 'S0', {'weight': 0.5}), ('S2', 'S2', {'weight': 0.5})
|
| ])
|
| pos = nx.spring_layout(G, seed=42)
|
| nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=1500, node_color='lightblue')
|
| nx.draw_networkx_labels(ax=ax, G=G, pos=pos, font_weight='bold')
|
|
|
| edge_labels = {(u, v): f"P={d['weight']}" for u, v, d in G.edges(data=True)}
|
| nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrowsize=20, edge_color='gray', connectionstyle="arc3,rad=0.1")
|
| nx.draw_networkx_edge_labels(ax=ax, G=G, pos=pos, edge_labels=edge_labels, font_size=9)
|
| ax.set_title("MDP State Transition Graph", fontsize=12, fontweight='bold')
|
| ax.axis('off')
|
|
|
| def plot_reward_landscape(fig, gs):
|
| """MDP & Environment: 3D surface plot of a reward function."""
|
|
|
| try:
|
| ax = fig.add_subplot(gs[0, 1], projection='3d')
|
| except IndexError:
|
| ax = fig.add_subplot(gs[0, 0], projection='3d')
|
| X = np.linspace(-5, 5, 50)
|
| Y = np.linspace(-5, 5, 50)
|
| X, Y = np.meshgrid(X, Y)
|
| Z = np.sin(np.sqrt(X**2 + Y**2)) + (X * 0.1)
|
|
|
| surf = ax.plot_surface(X, Y, Z, cmap='viridis', edgecolor='none', alpha=0.9)
|
| ax.set_title("Reward Function Landscape", fontsize=12, fontweight='bold')
|
| ax.set_xlabel('State X')
|
| ax.set_ylabel('State Y')
|
| ax.set_zlabel('Reward R(s)')
|
|
|
| def plot_trajectory(ax):
|
| """MDP & Environment: Trajectory / Episode Sequence."""
|
| ax.set_title("Trajectory Sequence", fontsize=12, fontweight='bold')
|
| states = ['s0', 's1', 's2', 's3', 'sT']
|
| actions = ['a0', 'a1', 'a2', 'a3']
|
| rewards = ['r1', 'r2', 'r3', 'r4']
|
|
|
| for i, s in enumerate(states):
|
| ax.text(i, 0.5, s, ha='center', va='center', bbox=dict(boxstyle="circle", fc="white"))
|
| if i < len(actions):
|
| ax.annotate("", xy=(i+0.8, 0.5), xytext=(i+0.2, 0.5), arrowprops=dict(arrowstyle="->"))
|
| ax.text(i+0.5, 0.6, actions[i], ha='center', color='blue')
|
| ax.text(i+0.5, 0.4, rewards[i], ha='center', color='red')
|
|
|
| ax.set_xlim(-0.5, len(states)-0.5)
|
| ax.set_ylim(0, 1)
|
| ax.axis('off')
|
|
|
| def plot_continuous_space(ax):
|
| """MDP & Environment: Continuous State/Action Space Visualization."""
|
| np.random.seed(42)
|
| x = np.random.randn(200, 2)
|
| labels = np.linalg.norm(x, axis=1) > 1.0
|
| ax.scatter(x[labels, 0], x[labels, 1], c='coral', alpha=0.6, label='High Reward')
|
| ax.scatter(x[~labels, 0], x[~labels, 1], c='skyblue', alpha=0.6, label='Low Reward')
|
| ax.set_title("Continuous State Space (2D Projection)", fontsize=12, fontweight='bold')
|
| ax.legend(fontsize=8)
|
|
|
| def plot_discount_decay(ax):
|
| """MDP & Environment: Discount Factor (gamma) Effect."""
|
| t = np.arange(0, 20)
|
| for gamma in [0.5, 0.9, 0.99]:
|
| ax.plot(t, gamma**t, marker='o', markersize=4, label=rf"$\gamma={gamma}$")
|
| ax.set_title(r"Discount Factor $\gamma^t$ Decay", fontsize=12, fontweight='bold')
|
| ax.set_xlabel("Time steps (t)")
|
| ax.set_ylabel("Weight")
|
| ax.legend()
|
| ax.grid(True, alpha=0.3)
|
|
|
| def plot_value_heatmap(ax):
|
| """Value & Policy: State-Value Function V(s) Heatmap (Gridworld)."""
|
| grid_size = 5
|
|
|
| values = np.zeros((grid_size, grid_size))
|
| for i in range(grid_size):
|
| for j in range(grid_size):
|
| values[i, j] = -( (grid_size-1-i)**2 + (grid_size-1-j)**2 ) * 0.5
|
| values[-1, -1] = 10.0
|
|
|
| cax = ax.matshow(values, cmap='magma')
|
| for (i, j), z in np.ndenumerate(values):
|
| ax.text(j, i, f'{z:0.1f}', ha='center', va='center', color='white' if z < -5 else 'black', fontsize=9)
|
|
|
| ax.set_title("State-Value Function V(s) Heatmap", fontsize=12, fontweight='bold', pad=15)
|
| ax.set_xticks(range(grid_size))
|
| ax.set_yticks(range(grid_size))
|
|
|
| def plot_backup_diagram(ax):
|
| """Dynamic Programming: Policy Evaluation Backup Diagram."""
|
| G = nx.DiGraph()
|
| G.add_node("s", layer=0)
|
| G.add_node("a1", layer=1); G.add_node("a2", layer=1)
|
| G.add_node("s'_1", layer=2); G.add_node("s'_2", layer=2); G.add_node("s'_3", layer=2)
|
|
|
| G.add_edges_from([("s", "a1"), ("s", "a2")])
|
| G.add_edges_from([("a1", "s'_1"), ("a1", "s'_2"), ("a2", "s'_3")])
|
|
|
| pos = {
|
| "s": (0.5, 1),
|
| "a1": (0.25, 0.5), "a2": (0.75, 0.5),
|
| "s'_1": (0.1, 0), "s'_2": (0.4, 0), "s'_3": (0.75, 0)
|
| }
|
|
|
| nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, nodelist=["s", "s'_1", "s'_2", "s'_3"], node_size=800, node_color='white', edgecolors='black')
|
| nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, nodelist=["a1", "a2"], node_size=300, node_color='black')
|
| nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrows=True)
|
| nx.draw_networkx_labels(ax=ax, G=G, pos=pos, labels={"s": "s", "s'_1": "s'", "s'_2": "s'", "s'_3": "s'"}, font_size=10)
|
|
|
| ax.set_title("DP Policy Eval Backup", fontsize=12, fontweight='bold')
|
| ax.set_ylim(-0.2, 1.2)
|
| ax.axis('off')
|
|
|
| def plot_action_value_q(ax):
|
| """Value & Policy: Action-Value Function Q(s,a) (Heatmap per action stack)."""
|
| grid = np.random.rand(3, 3)
|
| ax.imshow(grid, cmap='YlGnBu')
|
| for (i, j), z in np.ndenumerate(grid):
|
| ax.text(j, i, f'{z:0.1f}', ha='center', va='center', fontsize=8)
|
| ax.set_title(r"Action-Value $Q(s, a_{up})$", fontsize=12, fontweight='bold')
|
| ax.set_xticks([]); ax.set_yticks([])
|
|
|
| def plot_policy_arrows(ax):
|
| """Value & Policy: Policy π(s) as arrow overlays on grid."""
|
| grid_size = 4
|
| ax.set_xlim(-0.5, grid_size-0.5)
|
| ax.set_ylim(-0.5, grid_size-0.5)
|
| for i in range(grid_size):
|
| for j in range(grid_size):
|
| dx, dy = np.random.choice([0, 0.3, -0.3]), np.random.choice([0, 0.3, -0.3])
|
| if dx == 0 and dy == 0: dx = 0.3
|
| ax.add_patch(FancyArrowPatch((j, i), (j+dx, i+dy), arrowstyle='->', mutation_scale=15))
|
| ax.set_title(r"Policy $\pi(s)$ Arrows", fontsize=12, fontweight='bold')
|
| ax.set_xticks(range(grid_size)); ax.set_yticks(range(grid_size)); ax.grid(True, alpha=0.2)
|
|
|
| def plot_advantage_function(ax):
|
| """Value & Policy: Advantage Function A(s,a) = Q-V."""
|
| actions = ['A1', 'A2', 'A3', 'A4']
|
| advantage = [2.1, -1.2, 0.5, -0.8]
|
| colors = ['green' if v > 0 else 'red' for v in advantage]
|
| ax.bar(actions, advantage, color=colors, alpha=0.7)
|
| ax.axhline(0, color='black', lw=1)
|
| ax.set_title(r"Advantage $A(s, a)$", fontsize=12, fontweight='bold')
|
| ax.set_ylabel("Value")
|
|
|
| def plot_policy_improvement(ax):
|
| """Dynamic Programming: Policy Improvement (Before vs After)."""
|
| ax.axis('off')
|
| ax.set_title("Policy Improvement", fontsize=12, fontweight='bold')
|
| ax.text(0.2, 0.5, r"$\pi_{old}$", fontsize=15, bbox=dict(boxstyle="round", fc="lightgrey"))
|
| ax.annotate("", xy=(0.8, 0.5), xytext=(0.3, 0.5), arrowprops=dict(arrowstyle="->", lw=2))
|
| ax.text(0.5, 0.6, "Greedy\nImprovement", ha='center', fontsize=9)
|
| ax.text(0.85, 0.5, r"$\pi_{new}$", fontsize=15, bbox=dict(boxstyle="round", fc="lightgreen"))
|
|
|
| def plot_value_iteration_backup(ax):
|
| """Dynamic Programming: Value Iteration Backup Diagram (Max over actions)."""
|
| G = nx.DiGraph()
|
| pos = {"s": (0.5, 1), "max": (0.5, 0.5), "s1": (0.2, 0), "s2": (0.5, 0), "s3": (0.8, 0)}
|
| G.add_nodes_from(pos.keys())
|
| G.add_edges_from([("s", "max"), ("max", "s1"), ("max", "s2"), ("max", "s3")])
|
|
|
| nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=500, node_color='white', edgecolors='black')
|
| nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrows=True)
|
| nx.draw_networkx_labels(ax=ax, G=G, pos=pos, labels={"s": "s", "max": "max", "s1": "s'", "s2": "s'", "s3": "s'"}, font_size=9)
|
| ax.set_title("Value Iteration Backup", fontsize=12, fontweight='bold')
|
| ax.axis('off')
|
|
|
| def plot_policy_iteration_cycle(ax):
|
| """Dynamic Programming: Policy Iteration Full Cycle Flowchart."""
|
| ax.axis('off')
|
| ax.set_title("Policy Iteration Cycle", fontsize=12, fontweight='bold')
|
| props = dict(boxstyle="round", fc="aliceblue", ec="black")
|
| ax.text(0.5, 0.8, r"Policy Evaluation" + "\n" + r"$V \leftarrow V^\pi$", ha="center", bbox=props)
|
| ax.text(0.5, 0.2, r"Policy Improvement" + "\n" + r"$\pi \leftarrow \text{greedy}(V)$", ha="center", bbox=props)
|
| ax.annotate("", xy=(0.7, 0.3), xytext=(0.7, 0.7), arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5"))
|
| ax.annotate("", xy=(0.3, 0.7), xytext=(0.3, 0.3), arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5"))
|
|
|
| def plot_mc_backup(ax):
|
| """Monte Carlo: Backup diagram (Full trajectory until terminal sT)."""
|
| ax.axis('off')
|
| ax.set_title("Monte Carlo Backup", fontsize=12, fontweight='bold')
|
| nodes = ['s', 's1', 's2', 'sT']
|
| pos = {n: (0.5, 0.9 - i*0.25) for i, n in enumerate(nodes)}
|
| for i in range(len(nodes)-1):
|
| ax.annotate("", xy=pos[nodes[i+1]], xytext=pos[nodes[i]], arrowprops=dict(arrowstyle="->", lw=1.5))
|
| ax.text(pos[nodes[i]][0]+0.05, pos[nodes[i]][1], nodes[i], va='center')
|
| ax.text(pos['sT'][0]+0.05, pos['sT'][1], 'sT', va='center', fontweight='bold')
|
| ax.annotate("Update V(s) using G", xy=(0.3, 0.9), xytext=(0.3, 0.15), arrowprops=dict(arrowstyle="->", color='red', connectionstyle="arc3,rad=0.3"))
|
|
|
| def plot_mcts(ax):
|
| """Monte Carlo: Monte Carlo Tree Search (MCTS) tree diagram."""
|
| G = nx.balanced_tree(2, 2, create_using=nx.DiGraph())
|
| pos = nx.drawing.nx_agraph.graphviz_layout(G, prog='dot') if 'pygraphviz' in globals() else nx.shell_layout(G)
|
|
|
| pos = {0:(0,0), 1:(-1,-1), 2:(1,-1), 3:(-1.5,-2), 4:(-0.5,-2), 5:(0.5,-2), 6:(1.5,-2)}
|
| nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=300, node_color='lightyellow', edgecolors='black')
|
| nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrows=True)
|
| ax.set_title("MCTS Tree", fontsize=12, fontweight='bold')
|
| ax.axis('off')
|
|
|
| def plot_importance_sampling(ax):
|
| """Monte Carlo: Importance Sampling Ratio Flow."""
|
| ax.axis('off')
|
| ax.set_title("Importance Sampling", fontsize=12, fontweight='bold')
|
| ax.text(0.5, 0.8, r"$\pi(a|s)$", bbox=dict(boxstyle="circle", fc="lightgreen"), ha='center')
|
| ax.text(0.5, 0.2, r"$b(a|s)$", bbox=dict(boxstyle="circle", fc="lightpink"), ha='center')
|
| ax.annotate(r"$\rho = \frac{\pi}{b}$", xy=(0.7, 0.5), fontsize=15)
|
| ax.annotate("", xy=(0.5, 0.35), xytext=(0.5, 0.65), arrowprops=dict(arrowstyle="<->", lw=2))
|
|
|
| def plot_td_backup(ax):
|
| """Temporal Difference: TD(0) 1-step backup."""
|
| ax.axis('off')
|
| ax.set_title("TD(0) Backup", fontsize=12, fontweight='bold')
|
| ax.text(0.5, 0.8, "s", bbox=dict(boxstyle="circle", fc="white"), ha='center')
|
| ax.text(0.5, 0.2, "s'", bbox=dict(boxstyle="circle", fc="white"), ha='center')
|
| ax.annotate(r"$R + \gamma V(s')$", xy=(0.5, 0.4), ha='center', color='blue')
|
| ax.annotate("", xy=(0.5, 0.35), xytext=(0.5, 0.65), arrowprops=dict(arrowstyle="<-", lw=2))
|
|
|
| def plot_nstep_td(ax):
|
| """Temporal Difference: n-step TD backup."""
|
| ax.axis('off')
|
| ax.set_title("n-step TD Backup", fontsize=12, fontweight='bold')
|
| for i in range(4):
|
| ax.text(0.5, 0.9-i*0.2, f"s_{i}", bbox=dict(boxstyle="circle", fc="white"), ha='center', fontsize=8)
|
| if i < 3: ax.annotate("", xy=(0.5, 0.75-i*0.2), xytext=(0.5, 0.85-i*0.2), arrowprops=dict(arrowstyle="->"))
|
| ax.annotate(r"$G_t^{(n)}$", xy=(0.7, 0.5), fontsize=12, color='red')
|
|
|
| def plot_eligibility_traces(ax):
|
| """Temporal Difference: TD(lambda) Eligibility Traces decay curve."""
|
| t = np.arange(0, 50)
|
|
|
| trace = np.zeros_like(t, dtype=float)
|
| visits = [5, 20, 35]
|
| for v in visits:
|
| trace[v:] += (0.8 ** np.arange(len(t)-v))
|
| ax.plot(t, trace, color='brown', lw=2)
|
| ax.set_title(r"Eligibility Trace $z_t(\lambda)$", fontsize=12, fontweight='bold')
|
| ax.set_xlabel("Time")
|
| ax.fill_between(t, trace, color='brown', alpha=0.1)
|
|
|
| def plot_sarsa_backup(ax):
|
| """Temporal Difference: SARSA (On-policy) backup."""
|
| ax.axis('off')
|
| ax.set_title("SARSA Backup", fontsize=12, fontweight='bold')
|
| ax.text(0.5, 0.9, "(s,a)", ha='center')
|
| ax.text(0.5, 0.1, "(s',a')", ha='center')
|
| ax.annotate("", xy=(0.5, 0.2), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="<-", lw=2, color='orange'))
|
| ax.text(0.6, 0.5, "On-policy", rotation=90)
|
|
|
| def plot_q_learning_backup(ax):
|
| """Temporal Difference: Q-Learning (Off-policy) backup."""
|
| ax.axis('off')
|
| ax.set_title("Q-Learning Backup", fontsize=12, fontweight='bold')
|
| ax.text(0.5, 0.9, "(s,a)", ha='center')
|
| ax.text(0.5, 0.1, r"$\max_{a'} Q(s',a')$", ha='center', bbox=dict(boxstyle="round", fc="lightcyan"))
|
| ax.annotate("", xy=(0.5, 0.25), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="<-", lw=2, color='blue'))
|
|
|
| def plot_double_q(ax):
|
| """Temporal Difference: Double Q-Learning / Double DQN."""
|
| ax.axis('off')
|
| ax.set_title("Double Q-Learning", fontsize=12, fontweight='bold')
|
| ax.text(0.5, 0.8, "Network A", bbox=dict(fc="lightyellow"), ha='center')
|
| ax.text(0.5, 0.2, "Network B", bbox=dict(fc="lightcyan"), ha='center')
|
| ax.annotate("Select $a^*$", xy=(0.3, 0.8), xytext=(0.5, 0.85), arrowprops=dict(arrowstyle="->"))
|
| ax.annotate("Eval $Q(s', a^*)$", xy=(0.7, 0.2), xytext=(0.5, 0.15), arrowprops=dict(arrowstyle="->"))
|
|
|
| def plot_dueling_dqn(ax):
|
| """Temporal Difference: Dueling DQN Architecture."""
|
| ax.axis('off')
|
| ax.set_title("Dueling DQN", fontsize=12, fontweight='bold')
|
| ax.text(0.1, 0.5, "Backbone", bbox=dict(fc="lightgrey"), ha='center', rotation=90)
|
| ax.text(0.5, 0.7, "V(s)", bbox=dict(fc="lightgreen"), ha='center')
|
| ax.text(0.5, 0.3, "A(s,a)", bbox=dict(fc="lightblue"), ha='center')
|
| ax.text(0.9, 0.5, "Q(s,a)", bbox=dict(boxstyle="circle", fc="orange"), ha='center')
|
| ax.annotate("", xy=(0.35, 0.7), xytext=(0.15, 0.55), arrowprops=dict(arrowstyle="->"))
|
| ax.annotate("", xy=(0.35, 0.3), xytext=(0.15, 0.45), arrowprops=dict(arrowstyle="->"))
|
| ax.annotate("", xy=(0.75, 0.55), xytext=(0.6, 0.7), arrowprops=dict(arrowstyle="->"))
|
| ax.annotate("", xy=(0.75, 0.45), xytext=(0.6, 0.3), arrowprops=dict(arrowstyle="->"))
|
|
|
| def plot_prioritized_replay(ax):
|
| """Temporal Difference: Prioritized Experience Replay (PER)."""
|
| priorities = np.random.pareto(3, 100)
|
| ax.hist(priorities, bins=20, color='teal', alpha=0.7)
|
| ax.set_title("Prioritized Replay (TD-Error)", fontsize=12, fontweight='bold')
|
| ax.set_xlabel("Priority $P_i$")
|
| ax.set_ylabel("Count")
|
|
|
| def plot_rainbow_dqn(ax):
|
| """Temporal Difference: Rainbow DQN Composite."""
|
| ax.axis('off')
|
| ax.set_title("Rainbow DQN", fontsize=12, fontweight='bold')
|
| features = ["Double", "Dueling", "PER", "Noisy", "Distributional", "n-step"]
|
| for i, f in enumerate(features):
|
| ax.text(0.5, 0.9 - i*0.15, f, ha='center', bbox=dict(boxstyle="round", fc="ghostwhite"), fontsize=8)
|
|
|
| def plot_linear_fa(ax):
|
| """Function Approximation: Linear Function Approximation."""
|
| ax.axis('off')
|
| ax.set_title("Linear Function Approx", fontsize=12, fontweight='bold')
|
| ax.text(0.5, 0.8, r"$\phi(s)$ Features", ha='center', bbox=dict(fc="white"))
|
| ax.text(0.5, 0.2, r"$w^T \phi(s)$", ha='center', bbox=dict(fc="lightgrey"))
|
| ax.annotate("", xy=(0.5, 0.35), xytext=(0.5, 0.65), arrowprops=dict(arrowstyle="->", lw=2))
|
|
|
| def plot_nn_layers(ax):
|
| """Function Approximation: Neural Network Layers diagram."""
|
| ax.axis('off')
|
| ax.set_title("NN Layers (Deep RL)", fontsize=12, fontweight='bold')
|
| layers = [4, 8, 8, 2]
|
| for i, l in enumerate(layers):
|
| for j in range(l):
|
| ax.scatter(i*0.3, j*0.1 - l*0.05, s=20, c='black')
|
| ax.set_xlim(-0.1, 1.0)
|
| ax.set_ylim(-0.5, 0.5)
|
|
|
| def plot_computation_graph(ax):
|
| """Function Approximation: Computation Graph / Backprop Flow."""
|
| ax.axis('off')
|
| ax.set_title("Computation Graph (DAG)", fontsize=12, fontweight='bold')
|
| ax.text(0.1, 0.5, "Input", bbox=dict(boxstyle="circle", fc="white"))
|
| ax.text(0.5, 0.5, "Op", bbox=dict(boxstyle="square", fc="lightgrey"))
|
| ax.text(0.9, 0.5, "Loss", bbox=dict(boxstyle="circle", fc="salmon"))
|
| ax.annotate("", xy=(0.35, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->"))
|
| ax.annotate("", xy=(0.75, 0.5), xytext=(0.6, 0.5), arrowprops=dict(arrowstyle="->"))
|
| ax.annotate("Grad", xy=(0.1, 0.3), xytext=(0.9, 0.3), arrowprops=dict(arrowstyle="->", color='red', connectionstyle="arc3,rad=0.2"))
|
|
|
| def plot_target_network(ax):
|
| """Function Approximation: Target Network concept."""
|
| ax.axis('off')
|
| ax.set_title("Target Network Updates", fontsize=12, fontweight='bold')
|
| ax.text(0.3, 0.8, r"$Q_\theta$ (Active)", bbox=dict(fc="lightgreen"))
|
| ax.text(0.7, 0.8, r"$Q_{\theta^-}$ (Target)", bbox=dict(fc="lightblue"))
|
| ax.annotate("periodic copy", xy=(0.6, 0.8), xytext=(0.4, 0.8), arrowprops=dict(arrowstyle="<-", ls='--'))
|
|
|
| def plot_ppo_clip(ax):
|
| """Policy Gradients: PPO Clipped Surrogate Objective."""
|
| epsilon = 0.2
|
| r = np.linspace(0.5, 1.5, 100)
|
| advantage = 1.0
|
| surr1 = r * advantage
|
| surr2 = np.clip(r, 1-epsilon, 1+epsilon) * advantage
|
| ax.plot(r, surr1, '--', label="r*A")
|
| ax.plot(r, np.minimum(surr1, surr2), 'r', label="min(r*A, clip*A)")
|
| ax.set_title("PPO-Clip Objective", fontsize=12, fontweight='bold')
|
| ax.legend(fontsize=8)
|
| ax.axvline(1, color='gray', linestyle=':')
|
|
|
| def plot_trpo_trust_region(ax):
|
| """Policy Gradients: TRPO Trust Region / KL Constraint."""
|
| ax.set_title("TRPO Trust Region", fontsize=12, fontweight='bold')
|
| circle = plt.Circle((0.5, 0.5), 0.3, color='blue', fill=False, label="KL Constraint")
|
| ax.add_artist(circle)
|
| ax.scatter(0.5, 0.5, c='black', label=r"$\pi_{old}$")
|
| ax.arrow(0.5, 0.5, 0.15, 0.1, head_width=0.03, color='red', label="Update")
|
| ax.set_xlim(0, 1); ax.set_ylim(0, 1)
|
| ax.axis('off')
|
|
|
| def plot_a3c_multi_worker(ax):
|
| """Actor-Critic: Asynchronous Multi-worker (A3C)."""
|
| ax.axis('off')
|
| ax.set_title("A3C Multi-worker", fontsize=12, fontweight='bold')
|
| ax.text(0.5, 0.8, "Global Parameters", bbox=dict(fc="gold"), ha='center')
|
| for i in range(3):
|
| ax.text(0.2 + i*0.3, 0.2, f"Worker {i+1}", bbox=dict(fc="lightgrey"), ha='center', fontsize=8)
|
| ax.annotate("", xy=(0.5, 0.7), xytext=(0.2 + i*0.3, 0.3), arrowprops=dict(arrowstyle="<->"))
|
|
|
| def plot_sac_arch(ax):
|
| """Actor-Critic: SAC (Entropy-regularized)."""
|
| ax.axis('off')
|
| ax.set_title("SAC Architecture", fontsize=12, fontweight='bold')
|
| ax.text(0.5, 0.7, "Actor", bbox=dict(fc="lightgreen"), ha='center')
|
| ax.text(0.5, 0.3, "Entropy Bonus", bbox=dict(fc="salmon"), ha='center')
|
| ax.text(0.1, 0.5, "State", ha='center')
|
| ax.text(0.9, 0.5, "Action", ha='center')
|
| ax.annotate("", xy=(0.4, 0.7), xytext=(0.15, 0.5), arrowprops=dict(arrowstyle="->"))
|
| ax.annotate("", xy=(0.5, 0.55), xytext=(0.5, 0.4), arrowprops=dict(arrowstyle="->"))
|
| ax.annotate("", xy=(0.85, 0.5), xytext=(0.6, 0.7), arrowprops=dict(arrowstyle="->"))
|
|
|
| def plot_softmax_exploration(ax):
|
| """Exploration: Softmax / Boltzmann probabilities."""
|
| x = np.arange(4)
|
| logits = [1, 2, 5, 3]
|
| for tau in [0.5, 1.0, 5.0]:
|
| probs = np.exp(np.array(logits)/tau)
|
| probs /= probs.sum()
|
| ax.plot(x, probs, marker='o', label=rf"$\tau={tau}$")
|
| ax.set_title("Softmax Exploration", fontsize=12, fontweight='bold')
|
| ax.legend(fontsize=8)
|
| ax.set_xticks(x)
|
|
|
| def plot_ucb_confidence(ax):
|
| """Exploration: Upper Confidence Bound (UCB)."""
|
| actions = ['A1', 'A2', 'A3']
|
| means = [0.6, 0.8, 0.5]
|
| conf = [0.3, 0.1, 0.4]
|
| ax.bar(actions, means, yerr=conf, capsize=10, color='skyblue', label='Mean Q')
|
| ax.set_title("UCB Action Values", fontsize=12, fontweight='bold')
|
| ax.set_ylim(0, 1.2)
|
|
|
| def plot_intrinsic_motivation(ax):
|
| """Exploration: Intrinsic Motivation / Curiosity."""
|
| ax.axis('off')
|
| ax.set_title("Intrinsic Motivation", fontsize=12, fontweight='bold')
|
| ax.text(0.3, 0.5, "World Model", bbox=dict(fc="lightyellow"), ha='center')
|
| ax.text(0.7, 0.5, "Prediction\nError", bbox=dict(boxstyle="circle", fc="orange"), ha='center')
|
| ax.annotate("", xy=(0.58, 0.5), xytext=(0.42, 0.5), arrowprops=dict(arrowstyle="->"))
|
| ax.text(0.85, 0.5, r"$R_{int}$", fontweight='bold')
|
|
|
| def plot_entropy_bonus(ax):
|
| """Exploration: Entropy Regularization curve."""
|
| p = np.linspace(0.01, 0.99, 50)
|
| entropy = -(p * np.log(p) + (1-p) * np.log(1-p))
|
| ax.plot(p, entropy, color='purple')
|
| ax.set_title(r"Entropy $H(\pi)$", fontsize=12, fontweight='bold')
|
| ax.set_xlabel("$P(a)$")
|
|
|
| def plot_options_framework(ax):
|
| """Hierarchical RL: Options Framework."""
|
| ax.axis('off')
|
| ax.set_title("Options Framework", fontsize=12, fontweight='bold')
|
| ax.text(0.5, 0.8, r"High-level policy" + "\n" + r"$\pi_{hi}$", bbox=dict(fc="lightblue"), ha='center')
|
| ax.text(0.2, 0.2, "Option 1", bbox=dict(fc="ivory"), ha='center')
|
| ax.text(0.8, 0.2, "Option 2", bbox=dict(fc="ivory"), ha='center')
|
| ax.annotate("", xy=(0.3, 0.3), xytext=(0.45, 0.7), arrowprops=dict(arrowstyle="->"))
|
| ax.annotate("", xy=(0.7, 0.3), xytext=(0.55, 0.7), arrowprops=dict(arrowstyle="->"))
|
|
|
| def plot_feudal_networks(ax):
|
| """Hierarchical RL: Feudal Networks / Hierarchy."""
|
| ax.axis('off')
|
| ax.set_title("Feudal Networks", fontsize=12, fontweight='bold')
|
| ax.text(0.5, 0.85, "Manager", bbox=dict(fc="plum"), ha='center')
|
| ax.text(0.5, 0.15, "Worker", bbox=dict(fc="wheat"), ha='center')
|
| ax.annotate("Goal $g_t$", xy=(0.5, 0.3), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->", lw=2))
|
|
|
| def plot_world_model(ax):
|
| """Model-Based RL: Learned Dynamics Model."""
|
| ax.axis('off')
|
| ax.set_title("World Model (Dynamics)", fontsize=12, fontweight='bold')
|
| ax.text(0.1, 0.5, "(s,a)", ha='center')
|
| ax.text(0.5, 0.5, r"$\hat{P}$", bbox=dict(boxstyle="circle", fc="lightgrey"), ha='center')
|
| ax.text(0.9, 0.7, r"$\hat{s}'$", ha='center')
|
| ax.text(0.9, 0.3, r"$\hat{r}$", ha='center')
|
| ax.annotate("", xy=(0.4, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->"))
|
| ax.annotate("", xy=(0.8, 0.65), xytext=(0.6, 0.55), arrowprops=dict(arrowstyle="->"))
|
| ax.annotate("", xy=(0.8, 0.35), xytext=(0.6, 0.45), arrowprops=dict(arrowstyle="->"))
|
|
|
| def plot_model_planning(ax):
|
| """Model-Based RL: Planning / Rollouts in imagination."""
|
| ax.axis('off')
|
| ax.set_title("Model-Based Planning", fontsize=12, fontweight='bold')
|
| ax.text(0.1, 0.5, "Real s", ha='center', fontweight='bold')
|
| for i in range(3):
|
| ax.annotate("", xy=(0.3+i*0.2, 0.5+(i%2)*0.1), xytext=(0.1+i*0.2, 0.5), arrowprops=dict(arrowstyle="->", color='gray'))
|
| ax.text(0.3+i*0.2, 0.55+(i%2)*0.1, "imagined", fontsize=7)
|
|
|
| def plot_offline_rl(ax):
|
| """Offline RL: Fixed dataset of trajectories."""
|
| ax.axis('off')
|
| ax.set_title("Offline RL Dataset", fontsize=12, fontweight='bold')
|
| ax.text(0.5, 0.5, r"Static" + "\n" + r"Dataset" + "\n" + r"$\mathcal{D}$", bbox=dict(boxstyle="round", fc="lightgrey"), ha='center')
|
| ax.annotate("No interaction", xy=(0.5, 0.9), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->", color='red'))
|
| ax.scatter([0.2, 0.8, 0.3, 0.7], [0.8, 0.8, 0.2, 0.2], marker='x', color='blue')
|
|
|
| def plot_cql_regularization(ax):
|
| """Offline RL: CQL regularization visualization."""
|
| q = np.linspace(-5, 5, 100)
|
| penalty = q**2 * 0.1
|
| ax.plot(q, penalty, 'r', label='CQL Penalty')
|
| ax.set_title("CQL Regularization", fontsize=12, fontweight='bold')
|
| ax.set_xlabel("Q-value")
|
| ax.legend(fontsize=8)
|
|
|
| def plot_multi_agent_interaction(ax):
|
| """Multi-Agent RL: Agents communicating or competing."""
|
| G = nx.complete_graph(3)
|
| pos = nx.spring_layout(G)
|
| nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=500, node_color=['red', 'blue', 'green'])
|
| nx.draw_networkx_edges(ax=ax, G=G, pos=pos, style='dashed')
|
| ax.set_title("Multi-Agent Interaction", fontsize=12, fontweight='bold')
|
| ax.axis('off')
|
|
|
| def plot_ctde(ax):
|
| """Multi-Agent RL: Centralized Training Decentralized Execution (CTDE)."""
|
| ax.axis('off')
|
| ax.set_title("CTDE Architecture", fontsize=12, fontweight='bold')
|
| ax.text(0.5, 0.8, "Centralized Critic", bbox=dict(fc="gold"), ha='center')
|
| ax.text(0.2, 0.2, "Agent 1", bbox=dict(fc="lightblue"), ha='center')
|
| ax.text(0.8, 0.2, "Agent 2", bbox=dict(fc="lightblue"), ha='center')
|
| ax.annotate("", xy=(0.5, 0.7), xytext=(0.25, 0.35), arrowprops=dict(arrowstyle="<-", color='gray'))
|
| ax.annotate("", xy=(0.5, 0.7), xytext=(0.75, 0.35), arrowprops=dict(arrowstyle="<-", color='gray'))
|
|
|
| def plot_payoff_matrix(ax):
|
| """Multi-Agent RL: Cooperative / Competitive Payoff Matrix."""
|
| matrix = np.array([[(3,3), (0,5)], [(5,0), (1,1)]])
|
| ax.axis('off')
|
| ax.set_title("Payoff Matrix (Prisoner's)", fontsize=12, fontweight='bold')
|
| for i in range(2):
|
| for j in range(2):
|
| ax.text(j, 1-i, str(matrix[i, j]), ha='center', va='center', bbox=dict(fc="white"))
|
| ax.set_xlim(-0.5, 1.5); ax.set_ylim(-0.5, 1.5)
|
|
|
| def plot_irl_reward_inference(ax):
|
| """Inverse RL: Infer reward from expert demonstrations."""
|
| ax.axis('off')
|
| ax.set_title("Inferred Reward Heatmap", fontsize=12, fontweight='bold')
|
| grid = np.zeros((5, 5))
|
| grid[2:4, 2:4] = 1.0
|
| ax.imshow(grid, cmap='hot')
|
|
|
| def plot_gail_flow(ax):
|
| """Inverse RL: GAIL (Generative Adversarial Imitation Learning)."""
|
| ax.axis('off')
|
| ax.set_title("GAIL Architecture", fontsize=12, fontweight='bold')
|
| ax.text(0.2, 0.8, "Expert Data", bbox=dict(fc="lightgrey"), ha='center')
|
| ax.text(0.2, 0.2, "Policy (Gen)", bbox=dict(fc="lightgreen"), ha='center')
|
| ax.text(0.8, 0.5, "Discriminator", bbox=dict(boxstyle="square", fc="salmon"), ha='center')
|
| ax.annotate("", xy=(0.6, 0.55), xytext=(0.35, 0.75), arrowprops=dict(arrowstyle="->"))
|
| ax.annotate("", xy=(0.6, 0.45), xytext=(0.35, 0.25), arrowprops=dict(arrowstyle="->"))
|
|
|
| def plot_meta_rl_nested_loop(ax):
|
| """Meta-RL: Outer loop (meta) + inner loop (adaptation)."""
|
| ax.axis('off')
|
| ax.set_title("Meta-RL Loops", fontsize=12, fontweight='bold')
|
| ax.add_patch(plt.Circle((0.5, 0.5), 0.4, fill=False, ls='--'))
|
| ax.add_patch(plt.Circle((0.5, 0.5), 0.2, fill=False))
|
| ax.text(0.5, 0.5, "Inner\nLoop", ha='center', fontsize=8)
|
| ax.text(0.5, 0.8, "Outer Loop", ha='center', fontsize=10)
|
|
|
| def plot_task_distribution(ax):
|
| """Meta-RL: Multiple MDPs from distribution."""
|
| ax.axis('off')
|
| ax.set_title("Task Distribution", fontsize=12, fontweight='bold')
|
| for i in range(3):
|
| ax.text(0.2 + i*0.3, 0.5, f"Task {i+1}", bbox=dict(boxstyle="round", fc="ivory"), fontsize=8)
|
| ax.annotate("sample", xy=(0.5, 0.8), xytext=(0.5, 0.6), arrowprops=dict(arrowstyle="<-"))
|
|
|
| def plot_replay_buffer(ax):
|
| """Advanced: Experience Replay Buffer (FIFO)."""
|
| ax.axis('off')
|
| ax.set_title("Experience Replay Buffer", fontsize=12, fontweight='bold')
|
| for i in range(5):
|
| ax.add_patch(plt.Rectangle((0.1+i*0.15, 0.4), 0.1, 0.2, fill=True, color='lightgrey'))
|
| ax.text(0.15+i*0.15, 0.5, f"e_{i}", ha='center')
|
| ax.annotate("In", xy=(0.05, 0.5), xytext=(-0.1, 0.5), arrowprops=dict(arrowstyle="->"), annotation_clip=False)
|
| ax.annotate("Out (Batch)", xy=(0.85, 0.5), xytext=(1.0, 0.5), arrowprops=dict(arrowstyle="<-"), annotation_clip=False)
|
|
|
| def plot_state_visitation(ax):
|
| """Advanced: State Visitation / Occupancy Measure."""
|
| data = np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], 1000)
|
| ax.hexbin(data[:, 0], data[:, 1], gridsize=15, cmap='Blues')
|
| ax.set_title("State Visitation Heatmap", fontsize=12, fontweight='bold')
|
|
|
| def plot_regret_curve(ax):
|
| """Advanced: Regret / Cumulative Regret."""
|
| t = np.arange(100)
|
| regret = np.sqrt(t) + np.random.normal(0, 0.5, 100)
|
| ax.plot(t, regret, color='red', label='Sub-linear Regret')
|
| ax.set_title("Cumulative Regret", fontsize=12, fontweight='bold')
|
| ax.set_xlabel("Time")
|
| ax.legend(fontsize=8)
|
|
|
| def plot_attention_weights(ax):
|
| """Advanced: Attention Mechanisms (Heatmap)."""
|
| weights = np.random.rand(5, 5)
|
| ax.imshow(weights, cmap='viridis')
|
| ax.set_title("Attention Weight Matrix", fontsize=12, fontweight='bold')
|
| ax.set_xticks([]); ax.set_yticks([])
|
|
|
| def plot_diffusion_policy(ax):
|
| """Advanced: Diffusion Policy denoising steps."""
|
| ax.axis('off')
|
| ax.set_title("Diffusion Policy (Denoising)", fontsize=12, fontweight='bold')
|
| for i in range(4):
|
| ax.scatter(0.1+i*0.25, 0.5, s=100/(i+1), c='black', alpha=1.0 - i*0.2)
|
| if i < 3: ax.annotate("", xy=(0.25+i*0.25, 0.5), xytext=(0.15+i*0.25, 0.5), arrowprops=dict(arrowstyle="->"))
|
| ax.text(0.5, 0.3, "Noise $\\rightarrow$ Action", ha='center', fontsize=8)
|
|
|
| def plot_gnn_rl(ax):
|
| """Advanced: Graph Neural Networks for RL."""
|
| G = nx.star_graph(4)
|
| pos = nx.spring_layout(G)
|
| nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=200, node_color='orange')
|
| nx.draw_networkx_edges(ax=ax, G=G, pos=pos)
|
| ax.set_title("GNN Message Passing", fontsize=12, fontweight='bold')
|
| ax.axis('off')
|
|
|
| def plot_latent_space(ax):
|
| """Advanced: World Model / Latent Space."""
|
| ax.axis('off')
|
| ax.set_title("Latent Space (VAE/Dreamer)", fontsize=12, fontweight='bold')
|
| ax.text(0.1, 0.5, "Image", bbox=dict(fc="lightgrey"), ha='center')
|
| ax.text(0.5, 0.5, "Latent $z$", bbox=dict(boxstyle="circle", fc="lightpink"), ha='center')
|
| ax.text(0.9, 0.5, "Reconstruction", bbox=dict(fc="lightgrey"), ha='center')
|
| ax.annotate("", xy=(0.4, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->"))
|
| ax.annotate("", xy=(0.8, 0.5), xytext=(0.6, 0.5), arrowprops=dict(arrowstyle="->"))
|
|
|
| def plot_convergence_log(ax):
|
| """Advanced: Convergence Analysis Plots (Log-scale)."""
|
| iterations = np.arange(1, 100)
|
| error = 10 / iterations**2
|
| ax.loglog(iterations, error, color='green')
|
| ax.set_title("Value Convergence (Log)", fontsize=12, fontweight='bold')
|
| ax.set_xlabel("Iterations")
|
| ax.set_ylabel("Error")
|
| ax.grid(True, which="both", ls="-", alpha=0.3)
|
|
|
| def plot_expected_sarsa_backup(ax):
|
| """Temporal Difference: Expected SARSA (Expectation over policy)."""
|
| ax.axis('off')
|
| ax.set_title("Expected SARSA Backup", fontsize=12, fontweight='bold')
|
| ax.text(0.5, 0.9, "(s,a)", ha='center')
|
| ax.text(0.5, 0.1, r"$\sum_{a'} \pi(a'|s') Q(s',a')$", ha='center', bbox=dict(boxstyle="round", fc="ivory"))
|
| ax.annotate("", xy=(0.5, 0.25), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="<-", lw=2, color='purple'))
|
|
|
| def plot_reinforce_flow(ax):
|
| """Policy Gradients: REINFORCE (Full trajectory flow)."""
|
| ax.axis('off')
|
| ax.set_title("REINFORCE Flow", fontsize=12, fontweight='bold')
|
| steps = ["s0", "a0", "r1", "s1", "...", "GT"]
|
| for i, s in enumerate(steps):
|
| ax.text(0.1 + i*0.15, 0.5, s, bbox=dict(boxstyle="circle", fc="white"))
|
| ax.annotate(r"$\nabla_\theta J \propto G_t \nabla \ln \pi$", xy=(0.5, 0.8), ha='center', fontsize=12, color='darkgreen')
|
|
|
| def plot_advantage_scaled_grad(ax):
|
| """Policy Gradients: Baseline / Advantage scaled gradient."""
|
| ax.axis('off')
|
| ax.set_title("Baseline Subtraction", fontsize=12, fontweight='bold')
|
| ax.text(0.5, 0.8, r"$(G_t - b(s))$", bbox=dict(fc="salmon"), ha='center')
|
| ax.text(0.5, 0.3, r"Scale $\nabla \ln \pi$", ha='center')
|
| ax.annotate("", xy=(0.5, 0.4), xytext=(0.5, 0.7), arrowprops=dict(arrowstyle="->"))
|
|
|
| def plot_skill_discovery(ax):
|
| """Hierarchical RL: Skill Discovery (Unsupervised clusters)."""
|
| np.random.seed(0)
|
| for i in range(3):
|
| center = np.random.randn(2) * 2
|
| pts = np.random.randn(20, 2) * 0.5 + center
|
| ax.scatter(pts[:, 0], pts[:, 1], alpha=0.6, label=f"Skill {i+1}")
|
| ax.set_title("Skill Embedding Space", fontsize=12, fontweight='bold')
|
| ax.legend(fontsize=8)
|
|
|
| def plot_imagination_rollout(ax):
|
| """Model-Based RL: Imagination-Augmented Rollouts (I2A)."""
|
| ax.axis('off')
|
| ax.set_title("Imagination Rollout (I2A)", fontsize=12, fontweight='bold')
|
| ax.text(0.1, 0.5, "Input s", ha='center')
|
| ax.add_patch(plt.Rectangle((0.3, 0.3), 0.4, 0.4, fill=True, color='lavender'))
|
| ax.text(0.5, 0.5, "Imagination\nModule", ha='center')
|
| ax.annotate("Imagined Paths", xy=(0.8, 0.5), xytext=(0.5, 0.5), arrowprops=dict(arrowstyle="->", color='gray', connectionstyle="arc3,rad=0.3"))
|
|
|
| def plot_policy_gradient_flow(ax):
|
| """Policy Gradients: Gradient flow from reward to log-prob (DAG)."""
|
| ax.axis('off')
|
| ax.set_title("Policy Gradient Flow (DAG)", fontsize=12, fontweight='bold')
|
|
|
| bbox_props = dict(boxstyle="round,pad=0.5", fc="lightgrey", ec="black", lw=1.5)
|
| ax.text(0.1, 0.8, r"Trajectory $\tau$", ha="center", va="center", bbox=bbox_props)
|
| ax.text(0.5, 0.8, r"Reward $R(\tau)$", ha="center", va="center", bbox=bbox_props)
|
| ax.text(0.1, 0.2, r"Log-Prob $\log \pi_\theta$", ha="center", va="center", bbox=bbox_props)
|
| ax.text(0.7, 0.5, r"$\nabla_\theta J(\theta)$", ha="center", va="center", bbox=dict(boxstyle="circle,pad=0.3", fc="gold", ec="black"))
|
|
|
|
|
| ax.annotate("", xy=(0.35, 0.8), xytext=(0.2, 0.8), arrowprops=dict(arrowstyle="->", lw=2))
|
| ax.annotate("", xy=(0.7, 0.65), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->", lw=2))
|
| ax.annotate("", xy=(0.6, 0.4), xytext=(0.25, 0.2), arrowprops=dict(arrowstyle="->", lw=2))
|
|
|
| def plot_actor_critic_arch(ax):
|
| """Actor-Critic: Three-network diagram (TD3 - actor + two critics)."""
|
| ax.axis('off')
|
| ax.set_title("TD3 Architecture Diagram", fontsize=12, fontweight='bold')
|
|
|
|
|
| ax.text(0.1, 0.5, r"State" + "\n" + r"$s$", ha="center", va="center", bbox=dict(boxstyle="circle,pad=0.5", fc="lightblue"))
|
|
|
|
|
| net_props = dict(boxstyle="square,pad=0.8", fc="lightgreen", ec="black")
|
| ax.text(0.5, 0.8, r"Actor $\pi_\phi$", ha="center", va="center", bbox=net_props)
|
| ax.text(0.5, 0.5, r"Critic 1 $Q_{\theta_1}$", ha="center", va="center", bbox=net_props)
|
| ax.text(0.5, 0.2, r"Critic 2 $Q_{\theta_2}$", ha="center", va="center", bbox=net_props)
|
|
|
|
|
| ax.text(0.8, 0.8, "Action $a$", ha="center", va="center", bbox=dict(boxstyle="circle,pad=0.3", fc="coral"))
|
| ax.text(0.8, 0.35, "Min Q-value", ha="center", va="center", bbox=dict(boxstyle="round,pad=0.3", fc="gold"))
|
|
|
|
|
| kwargs = dict(arrowstyle="->", lw=1.5)
|
| ax.annotate("", xy=(0.38, 0.8), xytext=(0.15, 0.55), arrowprops=kwargs)
|
| ax.annotate("", xy=(0.38, 0.5), xytext=(0.15, 0.5), arrowprops=kwargs)
|
| ax.annotate("", xy=(0.38, 0.2), xytext=(0.15, 0.45), arrowprops=kwargs)
|
| ax.annotate("", xy=(0.73, 0.8), xytext=(0.62, 0.8), arrowprops=kwargs)
|
| ax.annotate("", xy=(0.68, 0.35), xytext=(0.62, 0.5), arrowprops=kwargs)
|
| ax.annotate("", xy=(0.68, 0.35), xytext=(0.62, 0.2), arrowprops=kwargs)
|
|
|
| def plot_epsilon_decay(ax):
|
| """Exploration: ε-Greedy Strategy Decay Curve."""
|
| episodes = np.arange(0, 1000)
|
| epsilon = np.maximum(0.01, np.exp(-0.005 * episodes))
|
|
|
| ax.plot(episodes, epsilon, color='purple', lw=2)
|
| ax.set_title(r"$\epsilon$-Greedy Decay Curve", fontsize=12, fontweight='bold')
|
| ax.set_xlabel("Episodes")
|
| ax.set_ylabel(r"Probability $\epsilon$")
|
| ax.grid(True, linestyle='--', alpha=0.6)
|
| ax.fill_between(episodes, epsilon, color='purple', alpha=0.1)
|
|
|
| def plot_learning_curve(ax):
|
| """Advanced / Misc: Learning Curve with Confidence Bands."""
|
| steps = np.linspace(0, 1e6, 100)
|
|
|
| mean_return = 100 * (1 - np.exp(-5e-6 * steps)) + np.random.normal(0, 2, len(steps))
|
| std_dev = 15 * np.exp(-2e-6 * steps)
|
|
|
| ax.plot(steps, mean_return, color='blue', lw=2, label="PPO (Mean)")
|
| ax.fill_between(steps, mean_return - std_dev, mean_return + std_dev, color='blue', alpha=0.2, label="±1 Std Dev")
|
|
|
| ax.set_title("Learning Curve (Return vs Steps)", fontsize=12, fontweight='bold')
|
| ax.set_xlabel("Environment Steps")
|
| ax.set_ylabel("Average Episodic Return")
|
| ax.legend(loc="lower right")
|
| ax.grid(True, linestyle='--', alpha=0.6)
|
|
|
| def main():
|
|
|
| fig1, gs1 = setup_figure("RL: MDP & Environment", 2, 4)
|
|
|
| plot_agent_env_loop(fig1.add_subplot(gs1[0, 0]))
|
| plot_mdp_graph(fig1.add_subplot(gs1[0, 1]))
|
| plot_trajectory(fig1.add_subplot(gs1[0, 2]))
|
| plot_continuous_space(fig1.add_subplot(gs1[0, 3]))
|
| plot_reward_landscape(fig1, gs1)
|
| plot_discount_decay(fig1.add_subplot(gs1[1, 1]))
|
|
|
|
|
|
|
|
|
|
|
| fig2, gs2 = setup_figure("RL: Value, Policy & Dynamic Programming", 2, 4)
|
| plot_value_heatmap(fig2.add_subplot(gs2[0, 0]))
|
| plot_action_value_q(fig2.add_subplot(gs2[0, 1]))
|
| plot_policy_arrows(fig2.add_subplot(gs2[0, 2]))
|
| plot_advantage_function(fig2.add_subplot(gs2[0, 3]))
|
| plot_backup_diagram(fig2.add_subplot(gs2[1, 0]))
|
| plot_policy_improvement(fig2.add_subplot(gs2[1, 1]))
|
| plot_value_iteration_backup(fig2.add_subplot(gs2[1, 2]))
|
| plot_policy_iteration_cycle(fig2.add_subplot(gs2[1, 3]))
|
|
|
|
|
|
|
|
|
| fig3, gs3 = setup_figure("RL: Monte Carlo & Temporal Difference", 2, 4)
|
| plot_mc_backup(fig3.add_subplot(gs3[0, 0]))
|
| plot_mcts(fig3.add_subplot(gs3[0, 1]))
|
| plot_importance_sampling(fig3.add_subplot(gs3[0, 2]))
|
| plot_td_backup(fig3.add_subplot(gs3[0, 3]))
|
| plot_nstep_td(fig3.add_subplot(gs3[1, 0]))
|
| plot_eligibility_traces(fig3.add_subplot(gs3[1, 1]))
|
| plot_sarsa_backup(fig3.add_subplot(gs3[1, 2]))
|
| plot_q_learning_backup(fig3.add_subplot(gs3[1, 3]))
|
|
|
|
|
|
|
|
|
| fig4, gs4 = setup_figure("RL: TD Extensions & Function Approximation", 2, 4)
|
| plot_double_q(fig4.add_subplot(gs4[0, 0]))
|
| plot_dueling_dqn(fig4.add_subplot(gs4[0, 1]))
|
| plot_prioritized_replay(fig4.add_subplot(gs4[0, 2]))
|
| plot_rainbow_dqn(fig4.add_subplot(gs4[0, 3]))
|
| plot_linear_fa(fig4.add_subplot(gs4[1, 0]))
|
| plot_nn_layers(fig4.add_subplot(gs4[1, 1]))
|
| plot_computation_graph(fig4.add_subplot(gs4[1, 2]))
|
| plot_target_network(fig4.add_subplot(gs4[1, 3]))
|
|
|
|
|
|
|
|
|
| fig5, gs5 = setup_figure("RL: Policy Gradients, Actor-Critic & Exploration", 2, 4)
|
| plot_policy_gradient_flow(fig5.add_subplot(gs5[0, 0]))
|
| plot_ppo_clip(fig5.add_subplot(gs5[0, 1]))
|
| plot_trpo_trust_region(fig5.add_subplot(gs5[0, 2]))
|
| plot_actor_critic_arch(fig5.add_subplot(gs5[0, 3]))
|
| plot_a3c_multi_worker(fig5.add_subplot(gs5[1, 0]))
|
| plot_sac_arch(fig5.add_subplot(gs5[1, 1]))
|
| plot_softmax_exploration(fig5.add_subplot(gs5[1, 2]))
|
| plot_ucb_confidence(fig5.add_subplot(gs5[1, 3]))
|
|
|
|
|
|
|
|
|
| fig6, gs6 = setup_figure("RL: Hierarchical, Model-Based & Offline", 2, 4)
|
| plot_options_framework(fig6.add_subplot(gs6[0, 0]))
|
| plot_feudal_networks(fig6.add_subplot(gs6[0, 1]))
|
| plot_world_model(fig6.add_subplot(gs6[0, 2]))
|
| plot_model_planning(fig6.add_subplot(gs6[0, 3]))
|
| plot_offline_rl(fig6.add_subplot(gs6[1, 0]))
|
| plot_cql_regularization(fig6.add_subplot(gs6[1, 1]))
|
| plot_epsilon_decay(fig6.add_subplot(gs6[1, 2]))
|
| plot_intrinsic_motivation(fig6.add_subplot(gs6[1, 3]))
|
|
|
|
|
|
|
|
|
| fig7, gs7 = setup_figure("RL: Multi-Agent, IRL & Meta-RL", 2, 4)
|
| plot_multi_agent_interaction(fig7.add_subplot(gs7[0, 0]))
|
| plot_ctde(fig7.add_subplot(gs7[0, 1]))
|
| plot_payoff_matrix(fig7.add_subplot(gs7[0, 2]))
|
| plot_irl_reward_inference(fig7.add_subplot(gs7[0, 3]))
|
| plot_gail_flow(fig7.add_subplot(gs7[1, 0]))
|
| plot_meta_rl_nested_loop(fig7.add_subplot(gs7[1, 1]))
|
| plot_task_distribution(fig7.add_subplot(gs7[1, 2]))
|
|
|
|
|
|
|
|
|
| fig8, gs8 = setup_figure("RL: Advanced & Miscellaneous", 2, 4)
|
| plot_replay_buffer(fig8.add_subplot(gs8[0, 0]))
|
| plot_state_visitation(fig8.add_subplot(gs8[0, 1]))
|
| plot_regret_curve(fig8.add_subplot(gs8[0, 2]))
|
| plot_attention_weights(fig8.add_subplot(gs8[0, 3]))
|
| plot_diffusion_policy(fig8.add_subplot(gs8[1, 0]))
|
| plot_gnn_rl(fig8.add_subplot(gs8[1, 1]))
|
| plot_latent_space(fig8.add_subplot(gs8[1, 2]))
|
| plot_convergence_log(fig8.add_subplot(gs8[1, 3]))
|
|
|
|
|
| plt.show()
|
|
|
| def save_all_graphs(output_dir="graphs"):
|
| """Saves each of the 74 RL components as a separate PNG file."""
|
| if not os.path.exists(output_dir):
|
| os.makedirs(output_dir)
|
|
|
|
|
| mapping = {
|
| "Agent-Environment Interaction Loop": plot_agent_env_loop,
|
| "Markov Decision Process (MDP) Tuple": plot_mdp_graph,
|
| "State Transition Graph": plot_mdp_graph,
|
| "Trajectory / Episode Sequence": plot_trajectory,
|
| "Continuous State/Action Space Visualization": plot_continuous_space,
|
| "Reward Function / Landscape": plot_reward_landscape,
|
| "Discount Factor (gamma) Effect": plot_discount_decay,
|
| "State-Value Function V(s)": plot_value_heatmap,
|
| "Action-Value Function Q(s,a)": plot_action_value_q,
|
| "Policy pi(s) or pi(a|s)": plot_policy_arrows,
|
| "Advantage Function A(s,a)": plot_advantage_function,
|
| "Optimal Value Function V* / Q*": plot_value_heatmap,
|
| "Policy Evaluation Backup": plot_backup_diagram,
|
| "Policy Improvement": plot_policy_improvement,
|
| "Value Iteration Backup": plot_value_iteration_backup,
|
| "Policy Iteration Full Cycle": plot_policy_iteration_cycle,
|
| "Monte Carlo Backup": plot_mc_backup,
|
| "Monte Carlo Tree (MCTS)": plot_mcts,
|
| "Importance Sampling Ratio": plot_importance_sampling,
|
| "TD(0) Backup": plot_td_backup,
|
| "Bootstrapping (general)": plot_td_backup,
|
| "n-step TD Backup": plot_nstep_td,
|
| "TD(lambda) & Eligibility Traces": plot_eligibility_traces,
|
| "SARSA Update": plot_sarsa_backup,
|
| "Q-Learning Update": plot_q_learning_backup,
|
| "Expected SARSA": plot_expected_sarsa_backup,
|
| "Double Q-Learning / Double DQN": plot_double_q,
|
| "Dueling DQN Architecture": plot_dueling_dqn,
|
| "Prioritized Experience Replay": plot_prioritized_replay,
|
| "Rainbow DQN Components": plot_rainbow_dqn,
|
| "Linear Function Approximation": plot_linear_fa,
|
| "Neural Network Layers (MLP, CNN, RNN, Transformer)": plot_nn_layers,
|
| "Computation Graph / Backpropagation Flow": plot_computation_graph,
|
| "Target Network": plot_target_network,
|
| "Policy Gradient Theorem": plot_policy_gradient_flow,
|
| "REINFORCE Update": plot_reinforce_flow,
|
| "Baseline / Advantage Subtraction": plot_advantage_scaled_grad,
|
| "Trust Region (TRPO)": plot_trpo_trust_region,
|
| "Proximal Policy Optimization (PPO)": plot_ppo_clip,
|
| "Actor-Critic Architecture": plot_actor_critic_arch,
|
| "Advantage Actor-Critic (A2C/A3C)": plot_a3c_multi_worker,
|
| "Soft Actor-Critic (SAC)": plot_sac_arch,
|
| "Twin Delayed DDPG (TD3)": plot_actor_critic_arch,
|
| "epsilon-Greedy Strategy": plot_epsilon_decay,
|
| "Softmax / Boltzmann Exploration": plot_softmax_exploration,
|
| "Upper Confidence Bound (UCB)": plot_ucb_confidence,
|
| "Intrinsic Motivation / Curiosity": plot_intrinsic_motivation,
|
| "Entropy Regularization": plot_entropy_bonus,
|
| "Options Framework": plot_options_framework,
|
| "Feudal Networks / Hierarchical Actor-Critic": plot_feudal_networks,
|
| "Skill Discovery": plot_skill_discovery,
|
| "Learned Dynamics Model": plot_world_model,
|
| "Model-Based Planning": plot_model_planning,
|
| "Imagination-Augmented Agents (I2A)": plot_imagination_rollout,
|
| "Offline Dataset": plot_offline_rl,
|
| "Conservative Q-Learning (CQL)": plot_cql_regularization,
|
| "Multi-Agent Interaction Graph": plot_multi_agent_interaction,
|
| "Centralized Training Decentralized Execution (CTDE)": plot_ctde,
|
| "Cooperative / Competitive Payoff Matrix": plot_payoff_matrix,
|
| "Reward Inference": plot_irl_reward_inference,
|
| "Generative Adversarial Imitation Learning (GAIL)": plot_gail_flow,
|
| "Meta-RL Architecture": plot_meta_rl_nested_loop,
|
| "Task Distribution Visualization": plot_task_distribution,
|
| "Experience Replay Buffer": plot_replay_buffer,
|
| "State Visitation / Occupancy Measure": plot_state_visitation,
|
| "Learning Curve": plot_learning_curve,
|
| "Regret / Cumulative Regret": plot_regret_curve,
|
| "Attention Mechanisms (Transformers in RL)": plot_attention_weights,
|
| "Diffusion Policy": plot_diffusion_policy,
|
| "Graph Neural Networks for RL": plot_gnn_rl,
|
| "World Model / Latent Space": plot_latent_space,
|
| "Convergence Analysis Plots": plot_convergence_log
|
| }
|
|
|
| import sys
|
|
|
| for name, func in mapping.items():
|
|
|
| filename = re.sub(r'[^a-zA-Z0-9]', '_', name.lower()).strip('_')
|
| filename = re.sub(r'_+', '_', filename) + ".png"
|
| filepath = os.path.join(output_dir, filename)
|
|
|
| print(f"Generating: {filename} ...")
|
|
|
| plt.close('all')
|
|
|
| if func == plot_reward_landscape:
|
| fig = plt.figure(figsize=(10, 8))
|
| gs = GridSpec(1, 1, figure=fig)
|
| func(fig, gs)
|
| plt.savefig(filepath, bbox_inches='tight', dpi=100)
|
| plt.close(fig)
|
| continue
|
|
|
| fig, ax = plt.subplots(figsize=(10, 8), constrained_layout=True)
|
| func(ax)
|
| plt.savefig(filepath, bbox_inches='tight', dpi=100)
|
| plt.close(fig)
|
|
|
| print(f"\n[SUCCESS] Saved {len(mapping)} graphs to '{output_dir}/' directory.")
|
|
|
| if __name__ == "__main__":
|
| import sys
|
| if "--save" in sys.argv:
|
| save_all_graphs()
|
| else:
|
| main()
|