diff --git "a/checkpoint/core.py" "b/checkpoint/core.py" new file mode 100644--- /dev/null +++ "b/checkpoint/core.py" @@ -0,0 +1,2478 @@ +import numpy as np +import matplotlib.pyplot as plt +import networkx as nx +from matplotlib.gridspec import GridSpec +from matplotlib.patches import FancyArrowPatch +from scipy.stats import norm + +import os +import re + +def setup_figure(title, rows, cols): + """Initializes a new figure and grid layout with constrained_layout to avoid warnings.""" + fig = plt.figure(figsize=(20, 10), constrained_layout=True) + fig.suptitle(title, fontsize=18, fontweight='bold') + gs = GridSpec(rows, cols, figure=fig) + return fig, gs + +def plot_agent_env_loop(ax): + """MDP & Environment: Agent-Environment Interaction Loop (Flowchart).""" + ax.axis('off') + ax.set_title("Agent-Environment Interaction", fontsize=12, fontweight='bold') + + props = dict(boxstyle="round,pad=0.8", fc="ivory", ec="black", lw=1.5) + ax.text(0.5, 0.8, "Agent", ha="center", va="center", bbox=props, fontsize=12) + ax.text(0.5, 0.2, "Environment", ha="center", va="center", bbox=props, fontsize=12) + + # Arrows + # Agent to Env: Action + ax.annotate("Action $A_t$", xy=(0.5, 0.35), xytext=(0.5, 0.65), + arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5", lw=2)) + # Env to Agent: State & Reward + ax.annotate("State $S_{t+1}$, Reward $R_{t+1}$", xy=(0.5, 0.65), xytext=(0.5, 0.35), + arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5", lw=2, color='green')) + +def plot_mdp_graph(ax): + """MDP & Environment: Directed graph with probability-weighted arrows.""" + G = nx.DiGraph() + # Corrected syntax: using a dictionary for edge attributes + G.add_edges_from([ + ('S0', 'S1', {'weight': 0.8}), ('S0', 'S2', {'weight': 0.2}), + ('S1', 'S2', {'weight': 1.0}), ('S2', 'S0', {'weight': 0.5}), ('S2', 'S2', {'weight': 0.5}) + ]) + pos = nx.spring_layout(G, seed=42) + nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=1500, node_color='lightblue') + nx.draw_networkx_labels(ax=ax, G=G, pos=pos, font_weight='bold') + + edge_labels = {(u, v): f"P={d['weight']}" for u, v, d in G.edges(data=True)} + nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrowsize=20, edge_color='gray', connectionstyle="arc3,rad=0.1") + nx.draw_networkx_edge_labels(ax=ax, G=G, pos=pos, edge_labels=edge_labels, font_size=9) + ax.set_title("MDP State Transition Graph", fontsize=12, fontweight='bold') + ax.axis('off') + +def plot_reward_landscape(fig, gs): + """MDP & Environment: 3D surface plot of a reward function.""" + # Use the first available slot in gs (handled flexibly for dashboard vs save) + try: + ax = fig.add_subplot(gs[0, 1], projection='3d') + except IndexError: + ax = fig.add_subplot(gs[0, 0], projection='3d') + X = np.linspace(-5, 5, 50) + Y = np.linspace(-5, 5, 50) + X, Y = np.meshgrid(X, Y) + Z = np.sin(np.sqrt(X**2 + Y**2)) + (X * 0.1) # Simulated reward landscape + + surf = ax.plot_surface(X, Y, Z, cmap='viridis', edgecolor='none', alpha=0.9) + ax.set_title("Reward Function Landscape", fontsize=12, fontweight='bold') + ax.set_xlabel('State X') + ax.set_ylabel('State Y') + ax.set_zlabel('Reward R(s)') + +def plot_trajectory(ax): + """MDP & Environment: Trajectory / Episode Sequence.""" + ax.set_title("Trajectory Sequence", fontsize=12, fontweight='bold') + states = ['s0', 's1', 's2', 's3', 'sT'] + actions = ['a0', 'a1', 'a2', 'a3'] + rewards = ['r1', 'r2', 'r3', 'r4'] + + for i, s in enumerate(states): + ax.text(i, 0.5, s, ha='center', va='center', bbox=dict(boxstyle="circle", fc="white")) + if i < len(actions): + ax.annotate("", xy=(i+0.8, 0.5), xytext=(i+0.2, 0.5), arrowprops=dict(arrowstyle="->")) + ax.text(i+0.5, 0.6, actions[i], ha='center', color='blue') + ax.text(i+0.5, 0.4, rewards[i], ha='center', color='red') + + ax.set_xlim(-0.5, len(states)-0.5) + ax.set_ylim(0, 1) + ax.axis('off') + +def plot_continuous_space(ax): + """MDP & Environment: Continuous State/Action Space Visualization.""" + np.random.seed(42) + x = np.random.randn(200, 2) + labels = np.linalg.norm(x, axis=1) > 1.0 + ax.scatter(x[labels, 0], x[labels, 1], c='coral', alpha=0.6, label='High Reward') + ax.scatter(x[~labels, 0], x[~labels, 1], c='skyblue', alpha=0.6, label='Low Reward') + ax.set_title("Continuous State Space (2D Projection)", fontsize=12, fontweight='bold') + ax.legend(fontsize=8) + +def plot_discount_decay(ax): + """MDP & Environment: Discount Factor (gamma) Effect.""" + t = np.arange(0, 20) + for gamma in [0.5, 0.9, 0.99]: + ax.plot(t, gamma**t, marker='o', markersize=4, label=rf"$\gamma={gamma}$") + ax.set_title(r"Discount Factor $\gamma^t$ Decay", fontsize=12, fontweight='bold') + ax.set_xlabel("Time steps (t)") + ax.set_ylabel("Weight") + ax.legend() + ax.grid(True, alpha=0.3) + +def plot_value_heatmap(ax): + """Value & Policy: State-Value Function V(s) Heatmap (Gridworld).""" + grid_size = 5 + # Simulate a value landscape where the top right is the goal + values = np.zeros((grid_size, grid_size)) + for i in range(grid_size): + for j in range(grid_size): + values[i, j] = -( (grid_size-1-i)**2 + (grid_size-1-j)**2 ) * 0.5 + values[-1, -1] = 10.0 # Goal state + + cax = ax.matshow(values, cmap='magma') + for (i, j), z in np.ndenumerate(values): + ax.text(j, i, f'{z:0.1f}', ha='center', va='center', color='white' if z < -5 else 'black', fontsize=9) + + ax.set_title("State-Value Function V(s) Heatmap", fontsize=12, fontweight='bold', pad=15) + ax.set_xticks(range(grid_size)) + ax.set_yticks(range(grid_size)) + +def plot_backup_diagram(ax): + """Dynamic Programming: Policy Evaluation Backup Diagram.""" + G = nx.DiGraph() + G.add_node("s", layer=0) + G.add_node("a1", layer=1); G.add_node("a2", layer=1) + G.add_node("s'_1", layer=2); G.add_node("s'_2", layer=2); G.add_node("s'_3", layer=2) + + G.add_edges_from([("s", "a1"), ("s", "a2")]) + G.add_edges_from([("a1", "s'_1"), ("a1", "s'_2"), ("a2", "s'_3")]) + + pos = { + "s": (0.5, 1), + "a1": (0.25, 0.5), "a2": (0.75, 0.5), + "s'_1": (0.1, 0), "s'_2": (0.4, 0), "s'_3": (0.75, 0) + } + + nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, nodelist=["s", "s'_1", "s'_2", "s'_3"], node_size=800, node_color='white', edgecolors='black') + nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, nodelist=["a1", "a2"], node_size=300, node_color='black') # Action nodes are solid black dots + nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrows=True) + nx.draw_networkx_labels(ax=ax, G=G, pos=pos, labels={"s": "s", "s'_1": "s'", "s'_2": "s'", "s'_3": "s'"}, font_size=10) + + ax.set_title("DP Policy Eval Backup", fontsize=12, fontweight='bold') + ax.set_ylim(-0.2, 1.2) + ax.axis('off') + +def plot_action_value_q(ax): + """Value & Policy: Action-Value Function Q(s,a) (Heatmap per action stack).""" + grid = np.random.rand(3, 3) + ax.imshow(grid, cmap='YlGnBu') + for (i, j), z in np.ndenumerate(grid): + ax.text(j, i, f'{z:0.1f}', ha='center', va='center', fontsize=8) + ax.set_title(r"Action-Value $Q(s, a_{up})$", fontsize=12, fontweight='bold') + ax.set_xticks([]); ax.set_yticks([]) + +def plot_policy_arrows(ax): + """Value & Policy: Policy π(s) as arrow overlays on grid.""" + grid_size = 4 + ax.set_xlim(-0.5, grid_size-0.5) + ax.set_ylim(-0.5, grid_size-0.5) + for i in range(grid_size): + for j in range(grid_size): + dx, dy = np.random.choice([0, 0.3, -0.3]), np.random.choice([0, 0.3, -0.3]) + if dx == 0 and dy == 0: dx = 0.3 + ax.add_patch(FancyArrowPatch((j, i), (j+dx, i+dy), arrowstyle='->', mutation_scale=15)) + ax.set_title(r"Policy $\pi(s)$ Arrows", fontsize=12, fontweight='bold') + ax.set_xticks(range(grid_size)); ax.set_yticks(range(grid_size)); ax.grid(True, alpha=0.2) + +def plot_advantage_function(ax): + """Value & Policy: Advantage Function A(s,a) = Q-V.""" + actions = ['A1', 'A2', 'A3', 'A4'] + advantage = [2.1, -1.2, 0.5, -0.8] + colors = ['green' if v > 0 else 'red' for v in advantage] + ax.bar(actions, advantage, color=colors, alpha=0.7) + ax.axhline(0, color='black', lw=1) + ax.set_title(r"Advantage $A(s, a)$", fontsize=12, fontweight='bold') + ax.set_ylabel("Value") + +def plot_policy_improvement(ax): + """Dynamic Programming: Policy Improvement (Before vs After).""" + ax.axis('off') + ax.set_title("Policy Improvement", fontsize=12, fontweight='bold') + ax.text(0.2, 0.5, r"$\pi_{old}$", fontsize=15, bbox=dict(boxstyle="round", fc="lightgrey")) + ax.annotate("", xy=(0.8, 0.5), xytext=(0.3, 0.5), arrowprops=dict(arrowstyle="->", lw=2)) + ax.text(0.5, 0.6, "Greedy\nImprovement", ha='center', fontsize=9) + ax.text(0.85, 0.5, r"$\pi_{new}$", fontsize=15, bbox=dict(boxstyle="round", fc="lightgreen")) + +def plot_value_iteration_backup(ax): + """Dynamic Programming: Value Iteration Backup Diagram (Max over actions).""" + G = nx.DiGraph() + pos = {"s": (0.5, 1), "max": (0.5, 0.5), "s1": (0.2, 0), "s2": (0.5, 0), "s3": (0.8, 0)} + G.add_nodes_from(pos.keys()) + G.add_edges_from([("s", "max"), ("max", "s1"), ("max", "s2"), ("max", "s3")]) + + nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=500, node_color='white', edgecolors='black') + nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrows=True) + nx.draw_networkx_labels(ax=ax, G=G, pos=pos, labels={"s": "s", "max": "max", "s1": "s'", "s2": "s'", "s3": "s'"}, font_size=9) + ax.set_title("Value Iteration Backup", fontsize=12, fontweight='bold') + ax.axis('off') + +def plot_policy_iteration_cycle(ax): + """Dynamic Programming: Policy Iteration Full Cycle Flowchart.""" + ax.axis('off') + ax.set_title("Policy Iteration Cycle", fontsize=12, fontweight='bold') + props = dict(boxstyle="round", fc="aliceblue", ec="black") + ax.text(0.5, 0.8, r"Policy Evaluation" + "\n" + r"$V \leftarrow V^\pi$", ha="center", bbox=props) + ax.text(0.5, 0.2, r"Policy Improvement" + "\n" + r"$\pi \leftarrow \text{greedy}(V)$", ha="center", bbox=props) + ax.annotate("", xy=(0.7, 0.3), xytext=(0.7, 0.7), arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5")) + ax.annotate("", xy=(0.3, 0.7), xytext=(0.3, 0.3), arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5")) + +def plot_mc_backup(ax): + """Monte Carlo: Backup diagram (Full trajectory until terminal sT).""" + ax.axis('off') + ax.set_title("Monte Carlo Backup", fontsize=12, fontweight='bold') + nodes = ['s', 's1', 's2', 'sT'] + pos = {n: (0.5, 0.9 - i*0.25) for i, n in enumerate(nodes)} + for i in range(len(nodes)-1): + ax.annotate("", xy=pos[nodes[i+1]], xytext=pos[nodes[i]], arrowprops=dict(arrowstyle="->", lw=1.5)) + ax.text(pos[nodes[i]][0]+0.05, pos[nodes[i]][1], nodes[i], va='center') + ax.text(pos['sT'][0]+0.05, pos['sT'][1], 'sT', va='center', fontweight='bold') + ax.annotate("Update V(s) using G", xy=(0.3, 0.9), xytext=(0.3, 0.15), arrowprops=dict(arrowstyle="->", color='red', connectionstyle="arc3,rad=0.3")) + +def plot_mcts(ax): + """Monte Carlo: Monte Carlo Tree Search (MCTS) tree diagram.""" + G = nx.balanced_tree(2, 2, create_using=nx.DiGraph()) + pos = nx.drawing.nx_agraph.graphviz_layout(G, prog='dot') if 'pygraphviz' in globals() else nx.shell_layout(G) + # Simple tree fallback + pos = {0:(0,0), 1:(-1,-1), 2:(1,-1), 3:(-1.5,-2), 4:(-0.5,-2), 5:(0.5,-2), 6:(1.5,-2)} + nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=300, node_color='lightyellow', edgecolors='black') + nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrows=True) + ax.set_title("MCTS Tree", fontsize=12, fontweight='bold') + ax.axis('off') + +def plot_importance_sampling(ax): + """Monte Carlo: Importance Sampling Ratio Flow.""" + ax.axis('off') + ax.set_title("Importance Sampling", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, r"$\pi(a|s)$", bbox=dict(boxstyle="circle", fc="lightgreen"), ha='center') + ax.text(0.5, 0.2, r"$b(a|s)$", bbox=dict(boxstyle="circle", fc="lightpink"), ha='center') + ax.annotate(r"$\rho = \frac{\pi}{b}$", xy=(0.7, 0.5), fontsize=15) + ax.annotate("", xy=(0.5, 0.35), xytext=(0.5, 0.65), arrowprops=dict(arrowstyle="<->", lw=2)) + +def plot_td_backup(ax): + """Temporal Difference: TD(0) 1-step backup.""" + ax.axis('off') + ax.set_title("TD(0) Backup", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "s", bbox=dict(boxstyle="circle", fc="white"), ha='center') + ax.text(0.5, 0.2, "s'", bbox=dict(boxstyle="circle", fc="white"), ha='center') + ax.annotate(r"$R + \gamma V(s')$", xy=(0.5, 0.4), ha='center', color='blue') + ax.annotate("", xy=(0.5, 0.35), xytext=(0.5, 0.65), arrowprops=dict(arrowstyle="<-", lw=2)) + +def plot_nstep_td(ax): + """Temporal Difference: n-step TD backup.""" + ax.axis('off') + ax.set_title("n-step TD Backup", fontsize=12, fontweight='bold') + for i in range(4): + ax.text(0.5, 0.9-i*0.2, f"s_{i}", bbox=dict(boxstyle="circle", fc="white"), ha='center', fontsize=8) + if i < 3: ax.annotate("", xy=(0.5, 0.75-i*0.2), xytext=(0.5, 0.85-i*0.2), arrowprops=dict(arrowstyle="->")) + ax.annotate(r"$G_t^{(n)}$", xy=(0.7, 0.5), fontsize=12, color='red') + +def plot_eligibility_traces(ax): + """Temporal Difference: TD(lambda) Eligibility Traces decay curve.""" + t = np.arange(0, 50) + # Simulate multiple highlights (visits) + trace = np.zeros_like(t, dtype=float) + visits = [5, 20, 35] + for v in visits: + trace[v:] += (0.8 ** np.arange(len(t)-v)) + ax.plot(t, trace, color='brown', lw=2) + ax.set_title(r"Eligibility Trace $z_t(\lambda)$", fontsize=12, fontweight='bold') + ax.set_xlabel("Time") + ax.fill_between(t, trace, color='brown', alpha=0.1) + +def plot_sarsa_backup(ax): + """Temporal Difference: SARSA (On-policy) backup.""" + ax.axis('off') + ax.set_title("SARSA Backup", fontsize=12, fontweight='bold') + ax.text(0.5, 0.9, "(s,a)", ha='center') + ax.text(0.5, 0.1, "(s',a')", ha='center') + ax.annotate("", xy=(0.5, 0.2), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="<-", lw=2, color='orange')) + ax.text(0.6, 0.5, "On-policy", rotation=90) + +def plot_q_learning_backup(ax): + """Temporal Difference: Q-Learning (Off-policy) backup.""" + ax.axis('off') + ax.set_title("Q-Learning Backup", fontsize=12, fontweight='bold') + ax.text(0.5, 0.9, "(s,a)", ha='center') + ax.text(0.5, 0.1, r"$\max_{a'} Q(s',a')$", ha='center', bbox=dict(boxstyle="round", fc="lightcyan")) + ax.annotate("", xy=(0.5, 0.25), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="<-", lw=2, color='blue')) + +def plot_double_q(ax): + """Temporal Difference: Double Q-Learning / Double DQN.""" + ax.axis('off') + ax.set_title("Double Q-Learning", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "Network A", bbox=dict(fc="lightyellow"), ha='center') + ax.text(0.5, 0.2, "Network B", bbox=dict(fc="lightcyan"), ha='center') + ax.annotate("Select $a^*$", xy=(0.3, 0.8), xytext=(0.5, 0.85), arrowprops=dict(arrowstyle="->")) + ax.annotate("Eval $Q(s', a^*)$", xy=(0.7, 0.2), xytext=(0.5, 0.15), arrowprops=dict(arrowstyle="->")) + +def plot_dueling_dqn(ax): + """Temporal Difference: Dueling DQN Architecture.""" + ax.axis('off') + ax.set_title("Dueling DQN", fontsize=12, fontweight='bold') + ax.text(0.1, 0.5, "Backbone", bbox=dict(fc="lightgrey"), ha='center', rotation=90) + ax.text(0.5, 0.7, "V(s)", bbox=dict(fc="lightgreen"), ha='center') + ax.text(0.5, 0.3, "A(s,a)", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.9, 0.5, "Q(s,a)", bbox=dict(boxstyle="circle", fc="orange"), ha='center') + ax.annotate("", xy=(0.35, 0.7), xytext=(0.15, 0.55), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.35, 0.3), xytext=(0.15, 0.45), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.75, 0.55), xytext=(0.6, 0.7), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.75, 0.45), xytext=(0.6, 0.3), arrowprops=dict(arrowstyle="->")) + +def plot_prioritized_replay(ax): + """Temporal Difference: Prioritized Experience Replay (PER).""" + priorities = np.random.pareto(3, 100) + ax.hist(priorities, bins=20, color='teal', alpha=0.7) + ax.set_title("Prioritized Replay (TD-Error)", fontsize=12, fontweight='bold') + ax.set_xlabel("Priority $P_i$") + ax.set_ylabel("Count") + +def plot_rainbow_dqn(ax): + """Temporal Difference: Rainbow DQN Composite.""" + ax.axis('off') + ax.set_title("Rainbow DQN", fontsize=12, fontweight='bold') + features = ["Double", "Dueling", "PER", "Noisy", "Distributional", "n-step"] + for i, f in enumerate(features): + ax.text(0.5, 0.9 - i*0.15, f, ha='center', bbox=dict(boxstyle="round", fc="ghostwhite"), fontsize=8) + +def plot_linear_fa(ax): + """Function Approximation: Linear Function Approximation.""" + ax.axis('off') + ax.set_title("Linear Function Approx", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, r"$\phi(s)$ Features", ha='center', bbox=dict(fc="white")) + ax.text(0.5, 0.2, r"$w^T \phi(s)$", ha='center', bbox=dict(fc="lightgrey")) + ax.annotate("", xy=(0.5, 0.35), xytext=(0.5, 0.65), arrowprops=dict(arrowstyle="->", lw=2)) + +def plot_nn_layers(ax): + """Function Approximation: Neural Network Layers diagram.""" + ax.axis('off') + ax.set_title("NN Layers (Deep RL)", fontsize=12, fontweight='bold') + layers = [4, 8, 8, 2] + for i, l in enumerate(layers): + for j in range(l): + ax.scatter(i*0.3, j*0.1 - l*0.05, s=20, c='black') + ax.set_xlim(-0.1, 1.0) + ax.set_ylim(-0.5, 0.5) + +def plot_computation_graph(ax): + """Function Approximation: Computation Graph / Backprop Flow.""" + ax.axis('off') + ax.set_title("Computation Graph (DAG)", fontsize=12, fontweight='bold') + ax.text(0.1, 0.5, "Input", bbox=dict(boxstyle="circle", fc="white")) + ax.text(0.5, 0.5, "Op", bbox=dict(boxstyle="square", fc="lightgrey")) + ax.text(0.9, 0.5, "Loss", bbox=dict(boxstyle="circle", fc="salmon")) + ax.annotate("", xy=(0.35, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.75, 0.5), xytext=(0.6, 0.5), arrowprops=dict(arrowstyle="->")) + ax.annotate("Grad", xy=(0.1, 0.3), xytext=(0.9, 0.3), arrowprops=dict(arrowstyle="->", color='red', connectionstyle="arc3,rad=0.2")) + +def plot_target_network(ax): + """Function Approximation: Target Network concept.""" + ax.axis('off') + ax.set_title("Target Network Updates", fontsize=12, fontweight='bold') + ax.text(0.3, 0.8, r"$Q_\theta$ (Active)", bbox=dict(fc="lightgreen")) + ax.text(0.7, 0.8, r"$Q_{\theta^-}$ (Target)", bbox=dict(fc="lightblue")) + ax.annotate("periodic copy", xy=(0.6, 0.8), xytext=(0.4, 0.8), arrowprops=dict(arrowstyle="<-", ls='--')) + +def plot_ppo_clip(ax): + """Policy Gradients: PPO Clipped Surrogate Objective.""" + epsilon = 0.2 + r = np.linspace(0.5, 1.5, 100) + advantage = 1.0 + surr1 = r * advantage + surr2 = np.clip(r, 1-epsilon, 1+epsilon) * advantage + ax.plot(r, surr1, '--', label="r*A") + ax.plot(r, np.minimum(surr1, surr2), 'r', label="min(r*A, clip*A)") + ax.set_title("PPO-Clip Objective", fontsize=12, fontweight='bold') + ax.legend(fontsize=8) + ax.axvline(1, color='gray', linestyle=':') + +def plot_trpo_trust_region(ax): + """Policy Gradients: TRPO Trust Region / KL Constraint.""" + ax.set_title("TRPO Trust Region", fontsize=12, fontweight='bold') + circle = plt.Circle((0.5, 0.5), 0.3, color='blue', fill=False, label="KL Constraint") + ax.add_artist(circle) + ax.scatter(0.5, 0.5, c='black', label=r"$\pi_{old}$") + ax.arrow(0.5, 0.5, 0.15, 0.1, head_width=0.03, color='red', label="Update") + ax.set_xlim(0, 1); ax.set_ylim(0, 1) + ax.axis('off') + +def plot_a3c_multi_worker(ax): + """Actor-Critic: Asynchronous Multi-worker (A3C).""" + ax.axis('off') + ax.set_title("A3C Multi-worker", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "Global Parameters", bbox=dict(fc="gold"), ha='center') + for i in range(3): + ax.text(0.2 + i*0.3, 0.2, f"Worker {i+1}", bbox=dict(fc="lightgrey"), ha='center', fontsize=8) + ax.annotate("", xy=(0.5, 0.7), xytext=(0.2 + i*0.3, 0.3), arrowprops=dict(arrowstyle="<->")) + +def plot_sac_arch(ax): + """Actor-Critic: SAC (Entropy-regularized).""" + ax.axis('off') + ax.set_title("SAC Architecture", fontsize=12, fontweight='bold') + ax.text(0.5, 0.7, "Actor", bbox=dict(fc="lightgreen"), ha='center') + ax.text(0.5, 0.3, "Entropy Bonus", bbox=dict(fc="salmon"), ha='center') + ax.text(0.1, 0.5, "State", ha='center') + ax.text(0.9, 0.5, "Action", ha='center') + ax.annotate("", xy=(0.4, 0.7), xytext=(0.15, 0.5), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.5, 0.55), xytext=(0.5, 0.4), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.85, 0.5), xytext=(0.6, 0.7), arrowprops=dict(arrowstyle="->")) + +def plot_softmax_exploration(ax): + """Exploration: Softmax / Boltzmann probabilities.""" + x = np.arange(4) + logits = [1, 2, 5, 3] + for tau in [0.5, 1.0, 5.0]: + probs = np.exp(np.array(logits)/tau) + probs /= probs.sum() + ax.plot(x, probs, marker='o', label=rf"$\tau={tau}$") + ax.set_title("Softmax Exploration", fontsize=12, fontweight='bold') + ax.legend(fontsize=8) + ax.set_xticks(x) + +def plot_ucb_confidence(ax): + """Exploration: Upper Confidence Bound (UCB).""" + actions = ['A1', 'A2', 'A3'] + means = [0.6, 0.8, 0.5] + conf = [0.3, 0.1, 0.4] + ax.bar(actions, means, yerr=conf, capsize=10, color='skyblue', label='Mean Q') + ax.set_title("UCB Action Values", fontsize=12, fontweight='bold') + ax.set_ylim(0, 1.2) + +def plot_intrinsic_motivation(ax): + """Exploration: Intrinsic Motivation / Curiosity.""" + ax.axis('off') + ax.set_title("Intrinsic Motivation", fontsize=12, fontweight='bold') + ax.text(0.3, 0.5, "World Model", bbox=dict(fc="lightyellow"), ha='center') + ax.text(0.7, 0.5, "Prediction\nError", bbox=dict(boxstyle="circle", fc="orange"), ha='center') + ax.annotate("", xy=(0.58, 0.5), xytext=(0.42, 0.5), arrowprops=dict(arrowstyle="->")) + ax.text(0.85, 0.5, r"$R_{int}$", fontweight='bold') + +def plot_entropy_bonus(ax): + """Exploration: Entropy Regularization curve.""" + p = np.linspace(0.01, 0.99, 50) + entropy = -(p * np.log(p) + (1-p) * np.log(1-p)) + ax.plot(p, entropy, color='purple') + ax.set_title(r"Entropy $H(\pi)$", fontsize=12, fontweight='bold') + ax.set_xlabel("$P(a)$") + +def plot_options_framework(ax): + """Hierarchical RL: Options Framework.""" + ax.axis('off') + ax.set_title("Options Framework", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, r"High-level policy" + "\n" + r"$\pi_{hi}$", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.2, 0.2, "Option 1", bbox=dict(fc="ivory"), ha='center') + ax.text(0.8, 0.2, "Option 2", bbox=dict(fc="ivory"), ha='center') + ax.annotate("", xy=(0.3, 0.3), xytext=(0.45, 0.7), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.7, 0.3), xytext=(0.55, 0.7), arrowprops=dict(arrowstyle="->")) + +def plot_feudal_networks(ax): + """Hierarchical RL: Feudal Networks / Hierarchy.""" + ax.axis('off') + ax.set_title("Feudal Networks", fontsize=12, fontweight='bold') + ax.text(0.5, 0.85, "Manager", bbox=dict(fc="plum"), ha='center') + ax.text(0.5, 0.15, "Worker", bbox=dict(fc="wheat"), ha='center') + ax.annotate("Goal $g_t$", xy=(0.5, 0.3), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->", lw=2)) + +def plot_world_model(ax): + """Model-Based RL: Learned Dynamics Model.""" + ax.axis('off') + ax.set_title("World Model (Dynamics)", fontsize=12, fontweight='bold') + ax.text(0.1, 0.5, "(s,a)", ha='center') + ax.text(0.5, 0.5, r"$\hat{P}$", bbox=dict(boxstyle="circle", fc="lightgrey"), ha='center') + ax.text(0.9, 0.7, r"$\hat{s}'$", ha='center') + ax.text(0.9, 0.3, r"$\hat{r}$", ha='center') + ax.annotate("", xy=(0.4, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.8, 0.65), xytext=(0.6, 0.55), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.8, 0.35), xytext=(0.6, 0.45), arrowprops=dict(arrowstyle="->")) + +def plot_model_planning(ax): + """Model-Based RL: Planning / Rollouts in imagination.""" + ax.axis('off') + ax.set_title("Model-Based Planning", fontsize=12, fontweight='bold') + ax.text(0.1, 0.5, "Real s", ha='center', fontweight='bold') + for i in range(3): + ax.annotate("", xy=(0.3+i*0.2, 0.5+(i%2)*0.1), xytext=(0.1+i*0.2, 0.5), arrowprops=dict(arrowstyle="->", color='gray')) + ax.text(0.3+i*0.2, 0.55+(i%2)*0.1, "imagined", fontsize=7) + +def plot_offline_rl(ax): + """Offline RL: Fixed dataset of trajectories.""" + ax.axis('off') + ax.set_title("Offline RL Dataset", fontsize=12, fontweight='bold') + ax.text(0.5, 0.5, r"Static" + "\n" + r"Dataset" + "\n" + r"$\mathcal{D}$", bbox=dict(boxstyle="round", fc="lightgrey"), ha='center') + ax.annotate("No interaction", xy=(0.5, 0.9), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->", color='red')) + ax.scatter([0.2, 0.8, 0.3, 0.7], [0.8, 0.8, 0.2, 0.2], marker='x', color='blue') + +def plot_cql_regularization(ax): + """Offline RL: CQL regularization visualization.""" + q = np.linspace(-5, 5, 100) + penalty = q**2 * 0.1 + ax.plot(q, penalty, 'r', label='CQL Penalty') + ax.set_title("CQL Regularization", fontsize=12, fontweight='bold') + ax.set_xlabel("Q-value") + ax.legend(fontsize=8) + +def plot_multi_agent_interaction(ax): + """Multi-Agent RL: Agents communicating or competing.""" + G = nx.complete_graph(3) + pos = nx.spring_layout(G) + nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=500, node_color=['red', 'blue', 'green']) + nx.draw_networkx_edges(ax=ax, G=G, pos=pos, style='dashed') + ax.set_title("Multi-Agent Interaction", fontsize=12, fontweight='bold') + ax.axis('off') + +def plot_ctde(ax): + """Multi-Agent RL: Centralized Training Decentralized Execution (CTDE).""" + ax.axis('off') + ax.set_title("CTDE Architecture", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "Centralized Critic", bbox=dict(fc="gold"), ha='center') + ax.text(0.2, 0.2, "Agent 1", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.8, 0.2, "Agent 2", bbox=dict(fc="lightblue"), ha='center') + ax.annotate("", xy=(0.5, 0.7), xytext=(0.25, 0.35), arrowprops=dict(arrowstyle="<-", color='gray')) + ax.annotate("", xy=(0.5, 0.7), xytext=(0.75, 0.35), arrowprops=dict(arrowstyle="<-", color='gray')) + +def plot_payoff_matrix(ax): + """Multi-Agent RL: Cooperative / Competitive Payoff Matrix.""" + matrix = np.array([[(3,3), (0,5)], [(5,0), (1,1)]]) + ax.axis('off') + ax.set_title("Payoff Matrix (Prisoner's)", fontsize=12, fontweight='bold') + for i in range(2): + for j in range(2): + ax.text(j, 1-i, str(matrix[i, j]), ha='center', va='center', bbox=dict(fc="white")) + ax.set_xlim(-0.5, 1.5); ax.set_ylim(-0.5, 1.5) + +def plot_irl_reward_inference(ax): + """Inverse RL: Infer reward from expert demonstrations.""" + ax.axis('off') + ax.set_title("Inferred Reward Heatmap", fontsize=12, fontweight='bold') + grid = np.zeros((5, 5)) + grid[2:4, 2:4] = 1.0 # Expert path + ax.imshow(grid, cmap='hot') + +def plot_gail_flow(ax): + """Inverse RL: GAIL (Generative Adversarial Imitation Learning).""" + ax.axis('off') + ax.set_title("GAIL Architecture", fontsize=12, fontweight='bold') + ax.text(0.2, 0.8, "Expert Data", bbox=dict(fc="lightgrey"), ha='center') + ax.text(0.2, 0.2, "Policy (Gen)", bbox=dict(fc="lightgreen"), ha='center') + ax.text(0.8, 0.5, "Discriminator", bbox=dict(boxstyle="square", fc="salmon"), ha='center') + ax.annotate("", xy=(0.6, 0.55), xytext=(0.35, 0.75), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.6, 0.45), xytext=(0.35, 0.25), arrowprops=dict(arrowstyle="->")) + +def plot_meta_rl_nested_loop(ax): + """Meta-RL: Outer loop (meta) + inner loop (adaptation).""" + ax.axis('off') + ax.set_title("Meta-RL Loops", fontsize=12, fontweight='bold') + ax.add_patch(plt.Circle((0.5, 0.5), 0.4, fill=False, ls='--')) + ax.add_patch(plt.Circle((0.5, 0.5), 0.2, fill=False)) + ax.text(0.5, 0.5, "Inner\nLoop", ha='center', fontsize=8) + ax.text(0.5, 0.8, "Outer Loop", ha='center', fontsize=10) + +def plot_task_distribution(ax): + """Meta-RL: Multiple MDPs from distribution.""" + ax.axis('off') + ax.set_title("Task Distribution", fontsize=12, fontweight='bold') + for i in range(3): + ax.text(0.2 + i*0.3, 0.5, f"Task {i+1}", bbox=dict(boxstyle="round", fc="ivory"), fontsize=8) + ax.annotate("sample", xy=(0.5, 0.8), xytext=(0.5, 0.6), arrowprops=dict(arrowstyle="<-")) + +def plot_replay_buffer(ax): + """Advanced: Experience Replay Buffer (FIFO).""" + ax.axis('off') + ax.set_title("Experience Replay Buffer", fontsize=12, fontweight='bold') + for i in range(5): + ax.add_patch(plt.Rectangle((0.1+i*0.15, 0.4), 0.1, 0.2, fill=True, color='lightgrey')) + ax.text(0.15+i*0.15, 0.5, f"e_{i}", ha='center') + ax.annotate("In", xy=(0.05, 0.5), xytext=(-0.1, 0.5), arrowprops=dict(arrowstyle="->"), annotation_clip=False) + ax.annotate("Out (Batch)", xy=(0.85, 0.5), xytext=(1.0, 0.5), arrowprops=dict(arrowstyle="<-"), annotation_clip=False) + +def plot_state_visitation(ax): + """Advanced: State Visitation / Occupancy Measure.""" + data = np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], 1000) + ax.hexbin(data[:, 0], data[:, 1], gridsize=15, cmap='Blues') + ax.set_title("State Visitation Heatmap", fontsize=12, fontweight='bold') + +def plot_regret_curve(ax): + """Advanced: Regret / Cumulative Regret.""" + t = np.arange(100) + regret = np.sqrt(t) + np.random.normal(0, 0.5, 100) + ax.plot(t, regret, color='red', label='Sub-linear Regret') + ax.set_title("Cumulative Regret", fontsize=12, fontweight='bold') + ax.set_xlabel("Time") + ax.legend(fontsize=8) + +def plot_attention_weights(ax): + """Advanced: Attention Mechanisms (Heatmap).""" + weights = np.random.rand(5, 5) + ax.imshow(weights, cmap='viridis') + ax.set_title("Attention Weight Matrix", fontsize=12, fontweight='bold') + ax.set_xticks([]); ax.set_yticks([]) + +def plot_diffusion_policy(ax): + """Advanced: Diffusion Policy denoising steps.""" + ax.axis('off') + ax.set_title("Diffusion Policy (Denoising)", fontsize=12, fontweight='bold') + for i in range(4): + ax.scatter(0.1+i*0.25, 0.5, s=100/(i+1), c='black', alpha=1.0 - i*0.2) + if i < 3: ax.annotate("", xy=(0.25+i*0.25, 0.5), xytext=(0.15+i*0.25, 0.5), arrowprops=dict(arrowstyle="->")) + ax.text(0.5, 0.3, "Noise $\\rightarrow$ Action", ha='center', fontsize=8) + +def plot_gnn_rl(ax): + """Advanced: Graph Neural Networks for RL.""" + G = nx.star_graph(4) + pos = nx.spring_layout(G) + nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=200, node_color='orange') + nx.draw_networkx_edges(ax=ax, G=G, pos=pos) + ax.set_title("GNN Message Passing", fontsize=12, fontweight='bold') + ax.axis('off') + +def plot_latent_space(ax): + """Advanced: World Model / Latent Space.""" + ax.axis('off') + ax.set_title("Latent Space (VAE/Dreamer)", fontsize=12, fontweight='bold') + ax.text(0.1, 0.5, "Image", bbox=dict(fc="lightgrey"), ha='center') + ax.text(0.5, 0.5, "Latent $z$", bbox=dict(boxstyle="circle", fc="lightpink"), ha='center') + ax.text(0.9, 0.5, "Reconstruction", bbox=dict(fc="lightgrey"), ha='center') + ax.annotate("", xy=(0.4, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.8, 0.5), xytext=(0.6, 0.5), arrowprops=dict(arrowstyle="->")) + +def plot_convergence_log(ax): + """Advanced: Convergence Analysis Plots (Log-scale).""" + iterations = np.arange(1, 100) + error = 10 / iterations**2 + ax.loglog(iterations, error, color='green') + ax.set_title("Value Convergence (Log)", fontsize=12, fontweight='bold') + ax.set_xlabel("Iterations") + ax.set_ylabel("Error") + ax.grid(True, which="both", ls="-", alpha=0.3) + +def plot_expected_sarsa_backup(ax): + """Temporal Difference: Expected SARSA (Expectation over policy).""" + ax.axis('off') + ax.set_title("Expected SARSA Backup", fontsize=12, fontweight='bold') + ax.text(0.5, 0.9, "(s,a)", ha='center') + ax.text(0.5, 0.1, r"$\sum_{a'} \pi(a'|s') Q(s',a')$", ha='center', bbox=dict(boxstyle="round", fc="ivory")) + ax.annotate("", xy=(0.5, 0.25), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="<-", lw=2, color='purple')) + +def plot_reinforce_flow(ax): + """Policy Gradients: REINFORCE (Full trajectory flow).""" + ax.axis('off') + ax.set_title("REINFORCE Flow", fontsize=12, fontweight='bold') + steps = ["s0", "a0", "r1", "s1", "...", "GT"] + for i, s in enumerate(steps): + ax.text(0.1 + i*0.15, 0.5, s, bbox=dict(boxstyle="circle", fc="white")) + ax.annotate(r"$\nabla_\theta J \propto G_t \nabla \ln \pi$", xy=(0.5, 0.8), ha='center', fontsize=12, color='darkgreen') + +def plot_advantage_scaled_grad(ax): + """Policy Gradients: Baseline / Advantage scaled gradient.""" + ax.axis('off') + ax.set_title("Baseline Subtraction", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, r"$(G_t - b(s))$", bbox=dict(fc="salmon"), ha='center') + ax.text(0.5, 0.3, r"Scale $\nabla \ln \pi$", ha='center') + ax.annotate("", xy=(0.5, 0.4), xytext=(0.5, 0.7), arrowprops=dict(arrowstyle="->")) + +def plot_skill_discovery(ax): + """Hierarchical RL: Skill Discovery (Unsupervised clusters).""" + np.random.seed(0) + for i in range(3): + center = np.random.randn(2) * 2 + pts = np.random.randn(20, 2) * 0.5 + center + ax.scatter(pts[:, 0], pts[:, 1], alpha=0.6, label=f"Skill {i+1}") + ax.set_title("Skill Embedding Space", fontsize=12, fontweight='bold') + ax.legend(fontsize=8) + +def plot_imagination_rollout(ax): + """Model-Based RL: Imagination-Augmented Rollouts (I2A).""" + ax.axis('off') + ax.set_title("Imagination Rollout (I2A)", fontsize=12, fontweight='bold') + ax.text(0.1, 0.5, "Input s", ha='center') + ax.add_patch(plt.Rectangle((0.3, 0.3), 0.4, 0.4, fill=True, color='lavender')) + ax.text(0.5, 0.5, "Imagination\nModule", ha='center') + ax.annotate("Imagined Paths", xy=(0.8, 0.5), xytext=(0.5, 0.5), arrowprops=dict(arrowstyle="->", color='gray', connectionstyle="arc3,rad=0.3")) + +def plot_policy_gradient_flow(ax): + """Policy Gradients: Gradient flow from reward to log-prob (DAG).""" + ax.axis('off') + ax.set_title("Policy Gradient Flow (DAG)", fontsize=12, fontweight='bold') + + bbox_props = dict(boxstyle="round,pad=0.5", fc="lightgrey", ec="black", lw=1.5) + ax.text(0.1, 0.8, r"Trajectory $\tau$", ha="center", va="center", bbox=bbox_props) + ax.text(0.5, 0.8, r"Reward $R(\tau)$", ha="center", va="center", bbox=bbox_props) + ax.text(0.1, 0.2, r"Log-Prob $\log \pi_\theta$", ha="center", va="center", bbox=bbox_props) + ax.text(0.7, 0.5, r"$\nabla_\theta J(\theta)$", ha="center", va="center", bbox=dict(boxstyle="circle,pad=0.3", fc="gold", ec="black")) + + # Draw arrows + ax.annotate("", xy=(0.35, 0.8), xytext=(0.2, 0.8), arrowprops=dict(arrowstyle="->", lw=2)) + ax.annotate("", xy=(0.7, 0.65), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->", lw=2)) + ax.annotate("", xy=(0.6, 0.4), xytext=(0.25, 0.2), arrowprops=dict(arrowstyle="->", lw=2)) + +def plot_rl_as_inference_pgm(ax): + """PGM: RL as Inference (Control as Inference).""" + ax.axis('off') + ax.set_title("RL as Inference (PGM)", fontsize=12, fontweight='bold') + nodes = { + 's_t': (0.1, 0.8), 'a_t': (0.1, 0.4), 's_tp1': (0.5, 0.8), + 'r_t': (0.5, 0.4), 'O_t': (0.8, 0.4) + } + for name, pos in nodes.items(): + color = 'white' if 'O' not in name else 'lightcoral' + ax.text(pos[0], pos[1], name, bbox=dict(boxstyle="circle", fc=color), ha='center') + + # Dependencies + arrows = [('s_t', 's_tp1'), ('a_t', 's_tp1'), ('s_t', 'a_t'), ('a_t', 'r_t'), ('r_t', 'O_t')] + for start, end in arrows: + ax.annotate("", xy=nodes[end], xytext=nodes[start], arrowprops=dict(arrowstyle="->")) + +def plot_rl_taxonomy_tree(ax): + """Taxonomy: RL Algorithm Classification Tree.""" + ax.axis('off') + ax.set_title("RL Algorithm Taxonomy", fontsize=12, fontweight='bold') + ax.text(0.5, 0.9, "Reinforcement Learning", bbox=dict(fc="lightgrey"), ha='center') + ax.text(0.25, 0.6, "Model-Free", bbox=dict(fc="ivory"), ha='center') + ax.text(0.75, 0.6, "Model-Based", bbox=dict(fc="ivory"), ha='center') + ax.text(0.1, 0.3, "Policy Opt", fontsize=8, ha='center') + ax.text(0.4, 0.3, "Value-Based", fontsize=8, ha='center') + for x in [0.25, 0.75]: ax.annotate("", xy=(x, 0.65), xytext=(0.5, 0.85), arrowprops=dict(arrowstyle="->")) + for x in [0.1, 0.4]: ax.annotate("", xy=(x, 0.35), xytext=(0.25, 0.55), arrowprops=dict(arrowstyle="->")) + +def plot_distributional_rl_atoms(ax): + """Distributional RL: C51 return probability atoms.""" + returns = np.linspace(-10, 10, 51) + probs = np.exp(-(returns - 2)**2 / 4) + np.exp(-(returns + 4)**2 / 2) + probs /= probs.sum() + ax.bar(returns, probs, width=0.3, color='steelblue', alpha=0.8) + ax.set_title("Distributional RL (Atoms)", fontsize=12, fontweight='bold') + ax.set_xlabel("Return $Z$") + ax.set_ylabel("Probability") + +def plot_her_goal_relabeling(ax): + """HER: Hindsight Experience Replay goal relabeling.""" + ax.axis('off') + ax.set_title("HER Goal Relabeling", fontsize=12, fontweight='bold') + path = np.array([[0.1, 0.2], [0.3, 0.4], [0.6, 0.5], [0.8, 0.7]]) + ax.plot(path[:, 0], path[:, 1], 'k--', alpha=0.3) + ax.scatter(path[:, 0], path[:, 1], c='black', s=20) + ax.text(0.9, 0.9, "True Goal G", color='red', fontweight='bold', ha='center') + ax.text(0.8, 0.6, "Relabeled G'", color='blue', fontweight='bold', ha='center') + ax.annotate("", xy=(0.8, 0.7), xytext=(0.8, 0.63), arrowprops=dict(arrowstyle="->", color='blue')) + +def plot_dyna_q_flow(ax): + """Dyna-Q: Real interaction + Model-based planning flow.""" + ax.axis('off') + ax.set_title("Dyna-Q Architecture", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "Agent Policy", bbox=dict(fc="white"), ha='center') + ax.text(0.2, 0.5, "Real World", bbox=dict(fc="lightgreen"), ha='center') + ax.text(0.8, 0.5, "Model", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.5, 0.2, "Value Function / Q", bbox=dict(fc="gold"), ha='center') + # Loop + ax.annotate("Direct RL", xy=(0.35, 0.25), xytext=(0.2, 0.45), arrowprops=dict(arrowstyle="->")) + ax.annotate("Planning", xy=(0.65, 0.25), xytext=(0.8, 0.45), arrowprops=dict(arrowstyle="->")) + +def plot_noisy_nets_parameters(ax): + """Noisy Nets: Parameter noise distribution σ for weights.""" + x = np.linspace(-3, 3, 100) + y = np.exp(-x**2 / 2) # Base weight (constant) + ax.plot(x, y, color='black', label=r"$\mu$ (Mean)") + ax.fill_between(x, y-0.2, y+0.2, color='gray', alpha=0.3, label=r"$\sigma \cdot \epsilon$ (Noise)") + ax.set_title("Noisy Nets Parameter Noise", fontsize=12, fontweight='bold') + ax.legend(fontsize=8) + +def plot_icm_curiosity(ax): + """Exploration: Intrinsic Curiosity Module (ICM).""" + ax.axis('off') + ax.set_title("ICM: Inverse & Forward Models", fontsize=12, fontweight='bold') + ax.text(0.1, 0.5, "s_t, s_t+1", ha='center') + ax.text(0.5, 0.8, "Inverse Model", bbox=dict(fc="ivory"), ha='center') + ax.text(0.5, 0.2, "Forward Model", bbox=dict(fc="ivory"), ha='center') + ax.text(0.9, 0.5, "Intrinsic Reward", ha='center', color='red') + ax.annotate("", xy=(0.35, 0.75), xytext=(0.2, 0.55), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.35, 0.25), xytext=(0.2, 0.45), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.8, 0.5), xytext=(0.65, 0.3), arrowprops=dict(arrowstyle="->")) + +def plot_v_trace_impala(ax): + """IMPALA: V-trace asynchronous importance sampling.""" + ax.axis('off') + ax.set_title("V-trace (IMPALA)", fontsize=12, fontweight='bold') + for i in range(4): + h = 0.5 + 0.3*np.sin(i) + ax.bar(0.2+i*0.2, h, width=0.1, color='teal') + ax.text(0.2+i*0.2, h+0.05, rf"$\rho_{i}$", ha='center', fontsize=8) + ax.axhline(0.5, ls='--', color='red', label="Clipped $\\rho$") + ax.set_ylim(0, 1.2) + +def plot_qmix_mixing_net(ax): + """Multi-Agent RL: QMIX Mixing Network.""" + ax.axis('off') + ax.set_title("QMIX Architecture", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "Mixing Network", bbox=dict(boxstyle="round,pad=1", fc="gold"), ha='center') + for i in range(3): + ax.text(0.2+i*0.3, 0.4, f"Agent {i+1} Q", bbox=dict(fc="grey"), ha='center', fontsize=7) + ax.annotate("", xy=(0.5, 0.65), xytext=(0.2+i*0.3, 0.45), arrowprops=dict(arrowstyle="->")) + ax.text(0.5, 0.1, "Global State s", ha='center') + ax.annotate("hypernets", xy=(0.5, 0.68), xytext=(0.5, 0.2), arrowprops=dict(arrowstyle="->", ls=':')) + +def plot_saliency_heatmaps(ax): + """Interpretability: Attention/Saliency Heatmap on input.""" + # Dummy "state" (e.g. Breakout screen) + img = np.zeros((20, 20)) + img[15, 8:12] = 1.0 # Paddle + img[5:7, 5:15] = 0.5 # Bricks + heatmap = np.random.rand(20, 20) * 0.5 + heatmap[14:17, 7:13] = 1.0 # High attention on paddle + ax.imshow(img, cmap='gray') + ax.imshow(heatmap, cmap='hot', alpha=0.5) + ax.set_title("Action Saliency Heatmap", fontsize=12, fontweight='bold') + ax.axis('off') + +def plot_action_selection_noise(ax): + """Exploration: OU-noise vs Gaussian Noise paths.""" + t = np.arange(100) + gaussian = np.random.normal(0, 0.1, 100) + ou = np.zeros(100) + for i in range(1, 100): + ou[i] = ou[i-1] * 0.9 + np.random.normal(0, 0.1) + ax.plot(t, gaussian, label="Gaussian", alpha=0.5) + ax.plot(t, ou, label="Ornstein-Uhlenbeck", color='red') + ax.set_title("Action Selection Noise", fontsize=12, fontweight='bold') + ax.legend(fontsize=8) + +def plot_tsne_state_embeddings(ax): + """Interpretability: t-SNE / UMAP State Clusters.""" + np.random.seed(42) + for i in range(3): + center = np.random.randn(2) * 5 + pts = np.random.randn(30, 2) + center + ax.scatter(pts[:, 0], pts[:, 1], alpha=0.6, label=f"Cluster {i+1}") + ax.set_title("t-SNE State Embeddings", fontsize=12, fontweight='bold') + ax.legend(fontsize=8) + +def plot_loss_landscape(fig, gs): + """Optimization: Loss Landscape / Surface.""" + ax = fig.add_subplot(gs[0, 0], projection='3d') + x = np.linspace(-2, 2, 30) + y = np.linspace(-2, 2, 30) + X, Y = np.meshgrid(x, y) + Z = X**2 + Y**2 + 0.5*np.sin(5*X) # Non-convex surface + ax.plot_surface(X, Y, Z, cmap='terrain', alpha=0.8) + ax.set_title("Policy Loss Landscape", fontsize=12, fontweight='bold') + +def plot_success_rate_curve(ax): + """Evaluation: Success Rate over training.""" + steps = np.linspace(0, 1e6, 100) + success = 1.0 / (1.0 + np.exp(-1e-5 * (steps - 4e5))) # S-curve + ax.plot(steps, success, color='darkgreen', lw=2) + ax.set_title("Success Rate vs Steps", fontsize=12, fontweight='bold') + ax.set_ylim(-0.05, 1.05) + ax.grid(True, alpha=0.3) + +def plot_hyperparameter_sensitivity(ax): + """Analysis: Hyperparameter Sensitivity Heatmap.""" + lr = [1e-5, 1e-4, 1e-3] + batches = [32, 64, 128] + data = np.array([[60, 85, 40], [75, 95, 80], [30, 50, 45]]) + im = ax.imshow(data, cmap='RdYlGn') + ax.set_xticks(range(3)); ax.set_xticklabels(batches) + ax.set_yticks(range(3)); ax.set_yticklabels(lr) + ax.set_xlabel("Batch Size"); ax.set_ylabel("Learning Rate") + ax.set_title("Hyperparam Sensitivity", fontsize=12, fontweight='bold') + for (i, j), z in np.ndenumerate(data): + ax.text(j, i, f'{z}%', ha='center', va='center') + +def plot_action_persistence(ax): + """Dynamics: Action Persistence (Frame Skipping).""" + ax.axis('off') + ax.set_title("Action Persistence (k=4)", fontsize=12, fontweight='bold') + for i in range(2): + ax.add_patch(plt.Rectangle((0.1, 0.6-i*0.4), 0.8, 0.2, fill=False)) + ax.text(0.5, 0.7-i*0.4, f"Action A_{i}", ha='center') + for j in range(4): + ax.add_patch(plt.Rectangle((0.1+j*0.2, 0.6-i*0.4), 0.2, 0.2, fill=True, alpha=0.2)) + ax.text(0.5, 0.45, "Repeat Action for k frames", ha='center', color='blue', fontsize=8) + +def plot_muzero_search_tree(ax): + """Model-Based: MuZero Search Tree with dynamics.""" + ax.axis('off') + ax.set_title("MuZero Search Tree", fontsize=12, fontweight='bold') + ax.text(0.5, 0.9, "Node $s$", bbox=dict(boxstyle="circle", fc="white"), ha='center') + ax.text(0.3, 0.5, "Dyn $g$", bbox=dict(fc="lavender"), ha='center') + ax.text(0.3, 0.1, "Pred $f$", bbox=dict(fc="ivory"), ha='center') + ax.annotate("", xy=(0.3, 0.6), xytext=(0.5, 0.85), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.3, 0.2), xytext=(0.3, 0.4), arrowprops=dict(arrowstyle="->")) + +def plot_policy_distillation(ax): + """Deep RL: Policy Distillation (Teacher-Student).""" + ax.axis('off') + ax.set_title("Policy Distillation", fontsize=12, fontweight='bold') + ax.text(0.2, 0.5, r"Teacher $\pi_T$", bbox=dict(fc="gold"), ha='center') + ax.text(0.8, 0.5, r"Student $\pi_S$", bbox=dict(fc="lightgrey"), ha='center') + ax.annotate("KL-Divergence Loss", xy=(0.7, 0.5), xytext=(0.3, 0.5), arrowprops=dict(arrowstyle="->", lw=2, color='red')) + +def plot_decision_transformer_tokens(ax): + """Transformers: Token Sequence (DT/TT).""" + ax.axis('off') + ax.set_title("Decision Transformer Tokens", fontsize=12, fontweight='bold') + tokens = [r"$\hat{R}_t$", "$s_t$", "$a_t$", r"$\hat{R}_{t+1}$", "$s_{t+1}$"] + for i, t in enumerate(tokens): + ax.text(0.1+i*0.2, 0.5, t, bbox=dict(boxstyle="round", fc="white")) + ax.annotate("causal attention", xy=(0.5, 0.7), xytext=(0.5, 0.6), annotation_clip=False) + +def plot_performance_profiles_rliable(ax): + """Evaluation: Success Probability Profiles (rliable).""" + x = np.linspace(0, 1, 100) + y1 = x**2 + y2 = np.sqrt(x) + ax.plot(x, y1, label="Algo A") + ax.plot(x, y2, label="Algo B") + ax.set_title("Performance Profiles", fontsize=12, fontweight='bold') + ax.set_xlabel("Normalized Score") + ax.set_ylabel("Probability of higher score") + ax.legend(fontsize=8) + +def plot_safety_shielding(ax): + """Safety RL: Action Shielding / Constraints.""" + ax.axis('off') + ax.set_title("Safety Shielding", fontsize=12, fontweight='bold') + ax.add_patch(plt.Circle((0.5, 0.5), 0.4, fill=True, color='red', alpha=0.1)) + ax.text(0.5, 0.5, "Forbidden\nRegion", ha='center', color='red') + ax.annotate("Shielded Action", xy=(0.2, 0.2), xytext=(0.4, 0.4), arrowprops=dict(arrowstyle="->", color='green', lw=2)) + +def plot_automated_curriculum(ax): + """Training: Automated Curriculum Difficulty.""" + t = np.arange(100) + difficulty = 1.0 / (1.0 + np.exp(-0.05 * (t - 50))) + performance = 0.8 / (1.0 + np.exp(-0.05 * (t - 40))) + ax.plot(t, difficulty, label="Task Difficulty", color='black') + ax.plot(t, performance, '--', label="Agent Performance", color='blue') + ax.set_title("Automated Curriculum", fontsize=12, fontweight='bold') + ax.legend(fontsize=8) + +def plot_domain_randomization(ax): + """Sim-to-Real: Domain Randomization parameter distribution.""" + params = np.random.normal(1.0, 0.3, 1000) + ax.hist(params, bins=30, color='orange', alpha=0.6) + ax.set_title("Domain Randomization ($P(\\mu)$)", fontsize=12, fontweight='bold') + ax.set_xlabel("Friction / Mass Parameter") + +def plot_rlhf_flow(ax): + """Alignment: RL with Human Feedback (RLHF).""" + ax.axis('off') + ax.set_title("RLHF Flow Diagram", fontsize=12, fontweight='bold') + ax.text(0.1, 0.8, "Human Pref", bbox=dict(fc="salmon"), ha='center') + ax.text(0.5, 0.8, "Reward Model", bbox=dict(fc="gold"), ha='center') + ax.text(0.9, 0.8, "Fine-tuned Policy", bbox=dict(fc="lightgreen"), ha='center') + ax.annotate("", xy=(0.4, 0.8), xytext=(0.2, 0.8), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.8, 0.8), xytext=(0.6, 0.8), arrowprops=dict(arrowstyle="->")) + ax.annotate("PPO Update", xy=(0.5, 0.5), xytext=(0.9, 0.7), arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=0.3")) + +def plot_successor_representations(ax): + """Neuro-inspired RL: Successor Representation (SR) Matrix M.""" + M = np.zeros((10, 10)) + for i in range(10): + for j in range(10): + M[i, j] = 0.9**abs(i-j) # Decaying future occupancy + ax.imshow(M, cmap='viridis') + ax.set_title("Successor Representation $M$", fontsize=12, fontweight='bold') + ax.set_xlabel("State $j$") + ax.set_ylabel("State $i$") + +def plot_maxent_irl_trajectories(ax): + """IRL: MaxEnt IRL (Log-probability of trajectories).""" + ax.axis('off') + ax.set_title("MaxEnt IRL Distribution", fontsize=12, fontweight='bold') + for i in range(5): + alpha = 0.1 + i*0.2 + ax.plot([0, 1], [0.5, 0.5+0.1*i], color='blue', alpha=alpha) + ax.plot([0, 1], [0.5, 0.5-0.1*i], color='blue', alpha=alpha) + ax.text(0.5, 0.8, r"$P(\tau) \propto \exp(R(\tau))$", ha='center', fontsize=12) + +def plot_information_bottleneck(ax): + """Theory: Information Bottleneck in RL.""" + ax.axis('off') + ax.set_title("Information Bottleneck", fontsize=12, fontweight='bold') + ax.text(0.1, 0.5, "S", bbox=dict(boxstyle="circle", fc="white"), ha='center') + ax.text(0.5, 0.5, "Z", bbox=dict(boxstyle="circle", fc="gold"), ha='center') + ax.text(0.9, 0.5, "A", bbox=dict(boxstyle="circle", fc="white"), ha='center') + ax.annotate("Compress", xy=(0.4, 0.5), xytext=(0.15, 0.5), arrowprops=dict(arrowstyle="->")) + ax.annotate("Extract", xy=(0.85, 0.5), xytext=(0.6, 0.5), arrowprops=dict(arrowstyle="->")) + ax.text(0.5, 0.2, r"$\min I(S;Z)$ s.t. $I(Z;A) \geq I_c$", ha='center', fontsize=8) + +def plot_es_population_distribution(ax): + """Evolutionary Strategies: ES Population Distribution.""" + np.random.seed(0) + mu = [0, 0] + points = np.random.randn(50, 2) * 0.5 + mu + ax.scatter(points[:, 0], points[:, 1], color='blue', alpha=0.4, label="Population") + ax.scatter(mu[0], mu[1], color='red', marker='x', label=r"$\mu$") + ax.annotate("Gradient Estimate", xy=(1.0, 1.0), xytext=(0, 0), arrowprops=dict(arrowstyle="->", color='red')) + ax.set_title("ES Population Update", fontsize=12, fontweight='bold') + ax.legend(fontsize=8) + +def plot_cbf_safe_set(ax): + """Safety RL: Control Barrier Function (CBF) Safe Set.""" + ax.axis('off') + ax.set_title("CBF Safe Set Boundary", fontsize=12, fontweight='bold') + ax.add_patch(plt.Circle((0.5, 0.5), 0.35, fill=False, color='black', lw=2)) + ax.text(0.5, 0.5, r"Safe Set $h(s) \geq 0$", ha='center') + ax.text(0.5, 0.1, "Unsafe $h(s) < 0$", ha='center', color='red') + ax.annotate("", xy=(0.8, 0.8), xytext=(0.6, 0.6), arrowprops=dict(arrowstyle="->", color='blue')) + ax.text(0.75, 0.65, r"$\nabla h$", color='blue') + +def plot_count_based_exploration(ax): + """Exploration: Count-based Heatmap N(s).""" + grid = np.random.poisson(2, (10, 10)) + grid[0, 0] = 50; grid[9, 9] = 1 + im = ax.imshow(grid, cmap='hot') + ax.set_title("Visit Counts $N(s)$", fontsize=12, fontweight='bold') + plt.colorbar(im, ax=ax, label="Visits") + +def plot_thompson_sampling(ax): + """Exploration: Thompson Sampling Posterior Distribution.""" + x = np.linspace(0, 1, 100) + import scipy.stats as stats + y1 = stats.beta.pdf(x, 2, 5) + y2 = stats.beta.pdf(x, 10, 4) + ax.plot(x, y1, label="Action 1 (Uncertain)") + ax.plot(x, y2, label="Action 2 (Certain)") + ax.fill_between(x, y1, alpha=0.2) + ax.fill_between(x, y2, alpha=0.2) + ax.set_title("Thompson Sampling Posteriors", fontsize=12, fontweight='bold') + ax.legend(fontsize=8) + +def plot_adversarial_rl_interaction(ax): + """Multi-Agent: Adversarial RL (Protaganist vs Antagonist).""" + ax.axis('off') + ax.set_title("Adversarial RL Interaction", fontsize=12, fontweight='bold') + ax.text(0.2, 0.5, "Protaganist", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.8, 0.5, "Antagonist", bbox=dict(fc="salmon"), ha='center') + ax.annotate("Force Distortion", xy=(0.35, 0.5), xytext=(0.65, 0.5), arrowprops=dict(arrowstyle="->", color='red')) + ax.annotate("Policy Update", xy=(0.5, 0.8), xytext=(0.5, 0.6), arrowprops=dict(arrowstyle="<-", connectionstyle="arc3,rad=-0.3")) + +def plot_hierarchical_subgoals(ax): + """Hierarchical RL: Subgoal Trajectory Waypoints.""" + ax.set_title("Subgoal Trajectory", fontsize=12, fontweight='bold') + ax.plot([0, 1], [0, 1], 'k--', alpha=0.3) + ax.scatter([0, 0.3, 0.7, 1], [0, 0.4, 0.6, 1], c=['black', 'red', 'red', 'gold'], s=100) + ax.text(0.3, 0.45, "Subgoal 1", color='red', fontsize=8) + ax.text(0.7, 0.65, "Subgoal 2", color='red', fontsize=8) + ax.text(1, 1.1, "Final Goal", color='gold', fontweight='bold', ha='center') + +def plot_offline_distribution_shift(ax): + """Offline RL: Distribution Shift (Shift between D and pi).""" + x = np.linspace(-5, 5, 200) + d = np.exp(-(x+1)**2 / 2) + pi = np.exp(-(x-2)**2 / 1.5) + ax.plot(x, d, label=r"Offline Dataset $\mathcal{D}$", color='grey') + ax.plot(x, pi, label=r"Learned Policy $\pi$", color='blue') + ax.fill_between(x, 0, d, color='grey', alpha=0.1) + ax.fill_between(x, 0, pi, color='blue', alpha=0.1) + ax.set_title("Action Distribution Shift", fontsize=12, fontweight='bold') + ax.legend(fontsize=8) + +def plot_rnd_curiosity(ax): + """Exploration: Random Network Distillation (RND).""" + ax.axis('off') + ax.set_title("RND: Predictor vs Target", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "State $s$", bbox=dict(boxstyle="circle", fc="white"), ha='center') + ax.text(0.3, 0.5, "Fixed Target Net", bbox=dict(fc="lightgrey"), ha='center', fontsize=8) + ax.text(0.7, 0.5, "Predictor Net", bbox=dict(fc="ivory"), ha='center', fontsize=8) + ax.text(0.5, 0.2, "MSE Error = Intrinsic Reward", ha='center', color='red', fontsize=9) + ax.annotate("", xy=(0.3, 0.6), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.7, 0.6), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->")) + +def plot_bcq_offline_constraint(ax): + """Offline RL: Batch-Constrained Q-learning (BCQ).""" + ax.axis('off') + ax.set_title("BCQ: Action Constraint", fontsize=12, fontweight='bold') + ax.add_patch(plt.Circle((0.5, 0.5), 0.35, fill=True, color='blue', alpha=0.1)) + ax.text(0.5, 0.5, "Dataset Action\nDistribution", ha='center', color='blue') + ax.annotate("Constrained Action", xy=(0.4, 0.45), xytext=(0.2, 0.2), arrowprops=dict(arrowstyle="->", lw=2)) + ax.text(0.5, 0.1, r"$\max Q(s, a)$ s.t. $a \in \mathcal{D}$", ha='center', fontsize=9) + +def plot_pbt_evolution(ax): + """Training: Population-Based Training (PBT).""" + ax.axis('off') + ax.set_title("Population-Based Training", fontsize=12, fontweight='bold') + for i in range(3): + ax.plot([0.1, 0.9], [0.8-i*0.3, 0.8-i*0.3], 'grey', alpha=0.3) + ax.text(0.1, 0.8-i*0.3, f"Agent {i+1}", ha='right') + ax.scatter([0.2, 0.5, 0.8], [0.8-i*0.3, 0.8-i*0.3, 0.8-i*0.3], color='blue') + ax.annotate("Exploit & Perturb", xy=(0.5, 0.2), xytext=(0.5, 0.5), arrowprops=dict(arrowstyle="->", color='red')) + +def plot_recurrent_state_flow(ax): + """Deep RL: Recurrent State Flow (DRQN/R2D2).""" + ax.axis('off') + ax.set_title("Recurrent $h_t$ Flow", fontsize=12, fontweight='bold') + for i in range(3): + ax.text(0.2+i*0.3, 0.5, f"Cell {i}", bbox=dict(fc="ivory"), ha='center') + if i < 2: + ax.annotate("", xy=(0.35+i*0.3, 0.5), xytext=(0.25+i*0.3, 0.5), arrowprops=dict(arrowstyle="->", color='blue')) + ax.text(0.3+i*0.3, 0.55, rf"$h_{i}$", color='blue', fontsize=8) + +def plot_belief_state_pomdp(ax): + """Theory: Belief State in POMDPs.""" + x = np.linspace(0, 1, 100) + y = np.exp(-(x-0.3)**2 / 0.02) + 0.3*np.exp(-(x-0.8)**2 / 0.01) + ax.plot(x, y, color='purple') + ax.fill_between(x, y, alpha=0.2, color='purple') + ax.set_title(r"Belief State $b(s)$", fontsize=12, fontweight='bold') + ax.set_xlabel("State Space") + ax.set_ylabel("Probability") + +def plot_pareto_front_morl(ax): + """Multi-Objective RL: Pareto Front.""" + np.random.seed(42) + x = np.random.rand(50) + y = np.random.rand(50) + ax.scatter(x, y, alpha=0.3, color='grey') + # Pareto front + px = np.sort(x)[-10:] + py = np.sort(y)[-10:][::-1] + ax.plot(px, py, 'r-o', label="Pareto Front") + ax.set_title("Multi-Objective Pareto Front", fontsize=12, fontweight='bold') + ax.set_xlabel("Reward A") + ax.set_ylabel("Reward B") + ax.legend(fontsize=8) + +def plot_differential_value_average_reward(ax): + """Theory: Differential Value (Average Reward RL).""" + t = np.arange(100) + v = np.sin(0.2*t) + 0.05*t # Increasing with oscillation + rho = 0.05 # average gain + ax.plot(t, v, label="Value $V(s_t)$") + ax.plot(t, rho*t, '--', label=r"Gain $\rho \cdot t$", color='red') + ax.set_title("Differential Value $v(s)$", fontsize=12, fontweight='bold') + ax.legend(fontsize=8) + +def plot_distributed_rl_cluster(ax): + """Infrastructure: Distributed RL Cluster (Ray/RLLib).""" + ax.axis('off') + ax.set_title("Distributed RL Cluster", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "Learner / GPU", bbox=dict(boxstyle="round", fc="gold"), ha='center') + ax.text(0.5, 0.5, "Replay Buffer", bbox=dict(fc="lightgrey"), ha='center') + for i in range(3): + ax.text(0.2+i*0.3, 0.2, f"Worker {i+1}", bbox=dict(fc="ivory"), ha='center', fontsize=8) + ax.annotate("", xy=(0.5, 0.45), xytext=(0.2+i*0.3, 0.25), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.5, 0.75), xytext=(0.5, 0.55), arrowprops=dict(arrowstyle="->")) + +def plot_neuroevolution_topology(ax): + """Evolutionary RL: Topology Evolution (NEAT).""" + ax.axis('off') + ax.set_title("Neuroevolution Topology", fontsize=12, fontweight='bold') + nodes = [(0.2, 0.5), (0.5, 0.8), (0.5, 0.2), (0.8, 0.5)] + for p in nodes: ax.text(p[0], p[1], "", bbox=dict(boxstyle="circle", fc="white")) + # Edges + ax.annotate("", xy=nodes[1], xytext=nodes[0], arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=nodes[2], xytext=nodes[0], arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=nodes[3], xytext=nodes[1], arrowprops=dict(arrowstyle="->")) + # Mutation + ax.text(0.5, 0.5, "New Node", bbox=dict(boxstyle="circle", fc="yellow"), ha='center', fontsize=7) + ax.annotate("", xy=(0.5, 0.5), xytext=nodes[0], arrowprops=dict(arrowstyle="->", color='red', ls='--')) + +def plot_ewc_elastic_weights(ax): + """Continual RL: Elastic Weight Consolidation (EWC).""" + ax.axis('off') + ax.set_title("EWC Elastic Constraint", fontsize=12, fontweight='bold') + ax.add_patch(plt.Circle((0.3, 0.5), 0.2, color='blue', alpha=0.2, label="Task A")) + ax.add_patch(plt.Circle((0.7, 0.5), 0.2, color='red', alpha=0.2, label="Task B")) + ax.annotate("", xy=(0.5, 0.5), xytext=(0.3, 0.5), arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=0.3")) + ax.text(0.5, 0.7, "Spring Constraint", color='darkgreen', ha='center', fontsize=9) + +def plot_successor_features(ax): + """Theory: Successor Features (SF).""" + ax.axis('off') + ax.set_title(r"Successor Features $\psi$", fontsize=12, fontweight='bold') + ax.text(0.2, 0.5, r"Features $\phi(s)$", bbox=dict(fc="ivory"), ha='center') + ax.text(0.8, 0.5, r"SF $\psi(s)$", bbox=dict(fc="gold"), ha='center') + ax.annotate(r"$\sum \gamma^t \phi(s_t)$", xy=(0.7, 0.5), xytext=(0.3, 0.5), arrowprops=dict(arrowstyle="->", lw=2)) + +def plot_adversarial_state_noise(ax): + r"""Safety: Adversarial State Noise ($s + \delta$).""" + ax.axis('off') + ax.set_title("Adversarial Perturbation", fontsize=12, fontweight='bold') + ax.text(0.2, 0.5, "State $s$", bbox=dict(fc="lightgreen"), ha='center') + ax.text(0.5, 0.5, "+", fontsize=20, ha='center') + ax.text(0.8, 0.5, r"Noise $\delta$", bbox=dict(fc="salmon"), ha='center') + ax.annotate("Target: Wrong Action!", xy=(0.5, 0.2), xytext=(0.5, 0.4), arrowprops=dict(arrowstyle="->", color='red')) + +def plot_behavioral_cloning_il(ax): + """Imitation: Behavioral Cloning (BC).""" + ax.axis('off') + ax.set_title("Behavioral Cloning Flow", fontsize=12, fontweight='bold') + ax.text(0.1, 0.5, "Expert Data\n$(s^*, a^*)$", bbox=dict(fc="gold"), ha='center', fontsize=8) + ax.text(0.5, 0.5, "Supervised\nLearning", bbox=dict(fc="ivory"), ha='center', fontsize=8) + ax.text(0.9, 0.5, r"Clone Policy\n$\pi_{BC}$", bbox=dict(fc="lightgrey"), ha='center', fontsize=8) + ax.annotate("", xy=(0.35, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.75, 0.5), xytext=(0.6, 0.5), arrowprops=dict(arrowstyle="->")) + +def plot_relational_graph_state(ax): + """Relational RL: Graph-based State Representation.""" + ax.axis('off') + ax.set_title("Relational Graph State", fontsize=12, fontweight='bold') + pos = {1: (0.3, 0.7), 2: (0.7, 0.7), 3: (0.5, 0.3)} + for k, p in pos.items(): + ax.text(p[0], p[1], f"Obj {k}", bbox=dict(boxstyle="round", fc="lightblue"), ha='center') + edges = [(1, 2), (2, 3), (3, 1)] + for u, v in edges: + ax.annotate("relation", xy=pos[v], xytext=pos[u], arrowprops=dict(arrowstyle="-", color='grey', ls=':'), ha='center') + +def plot_quantum_rl_circuit(ax): + """Quantum RL: Parameterized Quantum Circuit (PQC) Policy.""" + ax.axis('off') + ax.set_title("Quantum Policy (PQC)", fontsize=12, fontweight='bold') + ax.plot([0.1, 0.9], [0.7, 0.7], 'k', lw=1) + ax.plot([0.1, 0.9], [0.3, 0.3], 'k', lw=1) + ax.text(0.2, 0.7, r"$|0\rangle$", ha='right') + ax.text(0.2, 0.3, r"$|0\rangle$", ha='right') + # Gates + ax.text(0.4, 0.7, r"$R_y(\theta)$", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.6, 0.5, "CNOT", bbox=dict(fc="gold"), ha='center') + ax.plot([0.6, 0.6], [0.3, 0.7], 'k-o') + ax.text(0.8, 0.7, r"$\mathcal{M}$", bbox=dict(boxstyle="square", fc="lightgrey"), ha='center') + +def plot_symbolic_expression_tree(ax): + """Symbolic RL: Policy as a Mathematical Expression Tree.""" + ax.axis('off') + ax.set_title("Symbolic Policy Tree", fontsize=12, fontweight='bold') + nodes = {0:(0.5, 0.8, "+"), 1:(0.3, 0.5, "*"), 2:(0.7, 0.5, "exp"), 3:(0.2, 0.2, "s"), 4:(0.4, 0.2, "2.5"), 5:(0.7, 0.2, "s")} + edges = [(0,1), (0,2), (1,3), (1,4), (2,5)] + for k, (x, y, t) in nodes.items(): + ax.text(x, y, t, bbox=dict(boxstyle="circle", fc="ivory"), ha='center') + for u, v in edges: + ax.annotate("", xy=nodes[v][:2], xytext=nodes[u][:2], arrowprops=dict(arrowstyle="-")) + +def plot_differentiable_physics_gradient(ax): + """Control: Differentiable Physics Gradient Flow.""" + ax.axis('off') + ax.set_title("Diff-Physics Gradient", fontsize=12, fontweight='bold') + ax.text(0.1, 0.5, "Policy", bbox=dict(fc="ivory"), ha='center') + ax.text(0.5, 0.5, "Diff-Sim\nDynamics", bbox=dict(fc="gold", boxstyle="round"), ha='center') + ax.text(0.9, 0.5, "Loss", bbox=dict(fc="salmon"), ha='center') + # Forward + ax.annotate("", xy=(0.35, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.75, 0.5), xytext=(0.65, 0.5), arrowprops=dict(arrowstyle="->")) + # Backward + ax.annotate("$\nabla$ gradient", xy=(0.15, 0.4), xytext=(0.85, 0.4), arrowprops=dict(arrowstyle="->", color='red', connectionstyle="arc3,rad=-0.2")) + +def plot_marl_communication_channel(ax): + """MARL: Communication Channel (CommNet/DIAL).""" + ax.axis('off') + ax.set_title("Multi-Agent Comm Channel", fontsize=12, fontweight='bold') + ax.text(0.2, 0.8, "Agent A", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.8, 0.8, "Agent B", bbox=dict(fc="lightgreen"), ha='center') + ax.text(0.5, 0.2, "Task Goal", bbox=dict(fc="lightgrey"), ha='center') + # Message + ax.annotate("Message $m_{A \to B}$", xy=(0.7, 0.8), xytext=(0.3, 0.8), arrowprops=dict(arrowstyle="->", ls="--", color='purple')) + ax.annotate("", xy=(0.2, 0.45), xytext=(0.2, 0.7), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.8, 0.45), xytext=(0.8, 0.7), arrowprops=dict(arrowstyle="->")) + +def plot_lagrangian_multiplier_landscape(ax): + """Safety: Lagrangian Constraint Optimization.""" + x = np.linspace(-2, 2, 100); y = np.linspace(-2, 2, 100) + X, Y = np.meshgrid(x, y); Z = X**2 + Y**2 + ax.contour(X, Y, Z, levels=10, alpha=0.3) + ax.axvline(x=0.5, color='red', ls='--', label=r"Constraint $g(s) \leq 0$") + ax.scatter([1.0], [1.0], color='blue', label="Unconstrained Min") + ax.scatter([0.5], [0.0], color='green', label="Constrained Min") + ax.set_title("Lagrangian Constrained Opt", fontsize=12, fontweight='bold') + ax.legend(fontsize=7, loc='upper left') + +def plot_maxq_task_hierarchy(ax): + """HRL: MAXQ Recursive Task Decomposition.""" + ax.axis('off') + ax.set_title("MAXQ Task Hierarchy", fontsize=12, fontweight='bold') + # Levels + ax.text(0.5, 0.9, "Root Task", bbox=dict(fc="gold"), ha='center') + ax.text(0.3, 0.6, "GetFuel", bbox=dict(fc="ivory"), ha='center') + ax.text(0.7, 0.6, "DeliverCargo", bbox=dict(fc="ivory"), ha='center') + ax.text(0.3, 0.3, "Navigate", bbox=dict(fc="lightgrey"), ha='center', fontsize=8) + ax.text(0.7, 0.3, "Unload", bbox=dict(fc="lightgrey"), ha='center', fontsize=8) + # Recursion + ax.annotate("", xy=(0.3, 0.65), xytext=(0.45, 0.85), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.7, 0.65), xytext=(0.55, 0.85), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.3, 0.35), xytext=(0.3, 0.55), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.7, 0.35), xytext=(0.7, 0.55), arrowprops=dict(arrowstyle="->")) + +def plot_react_cycle_thinking(ax): + """Agentic LLM: ReAct Loop (Thought-Action-Observation).""" + ax.axis('off') + ax.set_title(r"ReAct Cycle: $T \to A \to O$", fontsize=12, fontweight='bold') + steps = ["Thought", "Action", "Observation"] + colors = ["ivory", "lightblue", "lightgreen"] + for i, s in enumerate(steps): + angle = 2 * np.pi * i / 3 + x, y = 0.5 + 0.3*np.cos(angle), 0.5 + 0.3*np.sin(angle) + ax.text(x, y, s, bbox=dict(boxstyle="round", fc=colors[i]), ha='center') + # Loop arrows + ax.annotate("", xy=(0.2, 0.5), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=0.3")) + ax.annotate("", xy=(0.5, 0.2), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=0.3")) + ax.annotate("", xy=(0.8, 0.5), xytext=(0.5, 0.2), arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=0.3")) + +def plot_synaptic_plasticity_rl(ax): + """Bio-inspired: Synaptic Plasticity (Hebbian RL/STDP).""" + ax.axis('off') + ax.set_title("Synaptic Plasticity RL", fontsize=12, fontweight='bold') + ax.text(0.3, 0.5, "Pre-neuron", bbox=dict(boxstyle="circle", fc="white"), ha='center') + ax.text(0.7, 0.5, "Post-neuron", bbox=dict(boxstyle="circle", fc="white"), ha='center') + ax.plot([0.35, 0.65], [0.5, 0.5], 'k', lw=4, label="Synapse $w$") + ax.text(0.5, 0.6, r"$\Delta w \propto \delta \cdot x_{pre} \cdot x_{post}$", color='red', ha='center', fontsize=10) + ax.annotate(r"TD Error $\delta$", xy=(0.5, 0.5), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="->", color='red')) + +def plot_guided_policy_search_gps(ax): + """Control: Guided Policy Search (GPS).""" + ax.axis('off') + ax.set_title("Guided Policy Search (GPS)", fontsize=12, fontweight='bold') + ax.plot([0.1, 0.9], [0.7, 0.8], 'b', label=r"Optimal Trajectory $\tau^*$") + ax.plot([0.1, 0.9], [0.6, 0.6], 'r--', label=r"Current Policy $\pi_\theta$") + ax.annotate("Minimize KL", xy=(0.5, 0.6), xytext=(0.5, 0.72), arrowprops=dict(arrowstyle="<->")) + ax.legend(fontsize=8, loc='lower right') + +def plot_sim2real_jitter_latency(ax): + """Robotics: Sim-to-Real Jitter & Latency Analysis.""" + t = np.linspace(0, 10, 100) + ideal = np.sin(t) + jitter = ideal + 0.2*np.random.randn(100) + ax.plot(t, ideal, 'g', alpha=0.5, label="Simulator (Ideal)") + ax.step(t + 0.3, jitter, 'r', label="Real Robot (Latency+Jitter)") + ax.set_title("Sim-to-Real Temporal Mismatch", fontsize=12, fontweight='bold') + ax.set_xlabel("Time (s)") + ax.legend(fontsize=8) + +def plot_ddpg_deterministic_gradient(ax): + """Deterministic Policy Gradient (DDPG).""" + ax.axis('off') + ax.set_title("DDPG Gradient Flow", fontsize=12, fontweight='bold') + ax.text(0.2, 0.5, r"$\pi_\theta(s)$", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.8, 0.5, r"$Q_w(s, a)$", bbox=dict(fc="gold"), ha='center') + ax.annotate(r"$\nabla_\theta J \approx \nabla_a Q(s,a)|_{a=\pi(s)} \nabla_\theta \pi_\theta(s)$", xy=(0.5, 0.2), xytext=(0.5, 0.4), arrowprops=dict(arrowstyle="->", color='red'), ha='center', fontsize=9) + ax.annotate("action", xy=(0.7, 0.5), xytext=(0.3, 0.5), arrowprops=dict(arrowstyle="->")) + +def plot_dreamer_latent_rollout(ax): + """Model-Based RL: Dreamer Latent imagination.""" + ax.axis('off') + ax.set_title("Dreamer Latent imagination", fontsize=12, fontweight='bold') + for i in range(3): + ax.text(0.2 + i*0.3, 0.5, f"$z_{i}$", bbox=dict(boxstyle="circle", fc="lightgreen"), ha='center') + if i < 2: + ax.annotate("", xy=(0.35 + i*0.3, 0.5), xytext=(0.25 + i*0.3, 0.5), arrowprops=dict(arrowstyle="->")) + ax.text(0.3 + i*0.3, 0.7, r"$\hat{a}$", ha='center') + ax.text(0.5, 0.2, r"Policy $\pi(z)$ learned in latent space", fontsize=9, ha='center') + +def plot_unreal_auxiliary_tasks(ax): + """Deep RL: UNREAL Architecture (Auxiliary Tasks).""" + ax.axis('off') + ax.set_title("UNREAL Auxiliary Tasks", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "Base Agent (A3C)", bbox=dict(fc="ivory"), ha='center') + tasks = ["Pixel Control", "Value Replay", "Reward Prediction"] + for i, t in enumerate(tasks): + ax.text(0.2 + i*0.3, 0.4, t, bbox=dict(fc="orange", alpha=0.3), ha='center', fontsize=8) + ax.annotate("", xy=(0.2+i*0.3, 0.5), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->", ls=':')) + ax.text(0.5, 0.1, "Shared Representation Learning", fontweight='bold', ha='center', fontsize=9) + +def plot_iql_expectile_loss(ax): + """Offline RL: Implicit Q-Learning (IQL) Expectile.""" + x = np.linspace(-2, 2, 100) + tau = 0.8 + loss = np.where(x > 0, tau * x**2, (1-tau) * x**2) + ax.plot(x, loss, color='purple', lw=2) + ax.set_title(r"IQL Expectile Loss $L_\tau$", fontsize=12, fontweight='bold') + ax.axvline(0, color='black', alpha=0.3) + ax.text(1, 1, r"$\tau=0.8$", color='purple') + +def plot_prioritized_sweeping(ax): + """Model-Based: Prioritized Sweeping.""" + ax.axis('off') + ax.set_title("Prioritized Sweeping", fontsize=12, fontweight='bold') + ax.text(0.2, 0.8, "State $s$", bbox=dict(fc="white"), ha='center') + ax.text(0.8, 0.2, "Priority Queue", bbox=dict(boxstyle="sawtooth", fc="gold"), ha='center') + ax.annotate(r"TD Error $|\delta|$", xy=(0.7, 0.3), xytext=(0.3, 0.7), arrowprops=dict(arrowstyle="->", color='red')) + ax.text(0.5, 0.5, "Update most affected states first", rotation=-35, fontsize=8) + +def plot_dagger_expert_loop(ax): + """Imitation: DAgger (Dataset Aggregation).""" + ax.axis('off') + ax.set_title("DAgger Expert Loop", fontsize=12, fontweight='bold') + ax.text(0.2, 0.7, r"Learner $\pi_\theta$", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.8, 0.7, r"Expert $\pi^*$", bbox=dict(fc="gold"), ha='center') + ax.text(0.5, 0.3, r"Dataset $\mathcal{D}$", bbox=dict(boxstyle="round", fc="ivory"), ha='center') + ax.annotate("Collect", xy=(0.5, 0.4), xytext=(0.2, 0.6), arrowprops=dict(arrowstyle="->")) + ax.annotate("Relabel", xy=(0.8, 0.6), xytext=(0.5, 0.4), arrowprops=dict(arrowstyle="<-")) + ax.annotate("Train", xy=(0.25, 0.65), xytext=(0.4, 0.35), arrowprops=dict(arrowstyle="->", color='blue')) + +def plot_spr_self_prediction(ax): + """Deep RL: Self-Predictive Representations (SPR).""" + ax.axis('off') + ax.set_title("SPR: Self-Prediction", fontsize=12, fontweight='bold') + ax.text(0.2, 0.5, "Encoder", bbox=dict(fc="lightgrey"), ha='center') + ax.text(0.8, 0.7, "Target Latent", bbox=dict(fc="gold", alpha=0.3), ha='center') + ax.text(0.8, 0.3, "Predicted Latent", bbox=dict(fc="lightblue"), ha='center') + ax.annotate("", xy=(0.7, 0.7), xytext=(0.3, 0.55), arrowprops=dict(arrowstyle="->", ls='--')) + ax.annotate("", xy=(0.7, 0.3), xytext=(0.3, 0.45), arrowprops=dict(arrowstyle="->")) + ax.text(0.9, 0.5, "Consistency Loss", rotation=90, color='red', fontsize=8) + +def plot_joint_action_space(ax): + """MARL: Joint Action Space $A_1 \times A_2$.""" + ax.set_title(r"Joint Action Space $A_1 \times A_2$", fontsize=12, fontweight='bold') + for x in range(3): + for y in range(3): + ax.scatter(x, y, color='blue', alpha=0.5) + ax.text(x, y+0.1, f"($a^k_{x}, a^j_{y}$)", fontsize=7, ha='center') + ax.set_xlabel("Agent 1 Actions") + ax.set_ylabel("Agent 2 Actions") + ax.set_xticks([0,1,2]); ax.set_yticks([0,1,2]) + +def plot_dec_pomdp_graph(ax): + """MARL: Dec-POMDP Formal Model.""" + ax.axis('off') + ax.set_title("Dec-POMDP Model", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "Global State $s$", bbox=dict(fc="ivory"), ha='center') + ax.text(0.2, 0.4, "Obs $o_1$", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.8, 0.4, "Obs $o_2$", bbox=dict(fc="lightgreen"), ha='center') + ax.text(0.5, 0.1, "Joint Reward $r$", bbox=dict(fc="gold"), ha='center') + ax.annotate("", xy=(0.2, 0.5), xytext=(0.45, 0.75), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.8, 0.5), xytext=(0.55, 0.75), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.45, 0.15), xytext=(0.2, 0.35), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.55, 0.15), xytext=(0.8, 0.35), arrowprops=dict(arrowstyle="->")) + +def plot_bisimulation_metric(ax): + """Theory: State Bisimulation Metric.""" + ax.axis('off') + ax.set_title("Bisimulation Metric", fontsize=12, fontweight='bold') + ax.text(0.3, 0.6, "$s_1$", bbox=dict(boxstyle="circle", fc="white"), ha='center') + ax.text(0.7, 0.6, "$s_2$", bbox=dict(boxstyle="circle", fc="white"), ha='center') + ax.annotate("$d(s_1, s_2)$", xy=(0.65, 0.6), xytext=(0.35, 0.6), arrowprops=dict(arrowstyle="<->", color='purple')) + ax.text(0.5, 0.2, "States are equivalent if rewards and\ntransitions to equivalent states match", ha='center', fontsize=8) + +def plot_reward_shaping_phi(ax): + """Theory: Potential-Based Reward Shaping.""" + ax.axis('off') + ax.set_title("Potential-Based Reward Shaping", fontsize=12, fontweight='bold') + ax.text(0.2, 0.5, "$s$", bbox=dict(fc="ivory"), ha='center') + ax.text(0.8, 0.5, "$s'$", bbox=dict(fc="ivory"), ha='center') + ax.annotate("", xy=(0.7, 0.5), xytext=(0.3, 0.5), arrowprops=dict(arrowstyle="->")) + ax.text(0.5, 0.7, r"$\gamma \Phi(s') - \Phi(s)$", color='blue', ha='center') + ax.text(0.5, 0.3, "Added to environmental reward $r$", fontsize=8, ha='center') + +def plot_transfer_rl_source_target(ax): + """Training: Transfer RL (Source to Target).""" + ax.axis('off') + ax.set_title("Transfer RL: Source to Target", fontsize=12, fontweight='bold') + ax.text(0.3, 0.7, r"Source Task $\mathcal{T}_A$", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.7, 0.3, r"Target Task $\mathcal{T}_B$", bbox=dict(fc="lightgreen"), ha='center') + ax.annotate("Knowledge Transfer\n(Weights/Expert Data)", xy=(0.6, 0.4), xytext=(0.4, 0.6), arrowprops=dict(arrowstyle="->", lw=2, color='orange'), ha='center') + +def plot_multi_task_backbone(ax): + """Deep RL: Multi-Task Architecture.""" + ax.axis('off') + ax.set_title("Multi-Task Backbone Arch", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "State Input", bbox=dict(fc="lightgrey"), ha='center') + ax.text(0.5, 0.5, "Shared Backbone", bbox=dict(fc="cornflowerblue"), ha='center') + ax.text(0.2, 0.2, "Task 1 Head", bbox=dict(fc="orange", alpha=0.5), ha='center') + ax.text(0.8, 0.2, "Task N Head", bbox=dict(fc="orange", alpha=0.5), ha='center') + ax.annotate("", xy=(0.5, 0.6), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="<-")) + ax.annotate("", xy=(0.25, 0.3), xytext=(0.45, 0.45), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.75, 0.3), xytext=(0.55, 0.45), arrowprops=dict(arrowstyle="->")) + +def plot_contextual_bandit_pipeline(ax): + """Bandits: Contextual Bandit Pipeline.""" + ax.axis('off') + ax.set_title("Contextual Bandit Pipeline", fontsize=12, fontweight='bold') + ax.text(0.1, 0.5, r"Context $x$", bbox=dict(fc="ivory"), ha='center') + ax.text(0.5, 0.5, r"Policy $\pi(a|x)$", bbox=dict(fc="lightgreen"), ha='center') + ax.text(0.9, 0.5, r"Reward $r$", bbox=dict(fc="gold"), ha='center') + ax.annotate("", xy=(0.4, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.8, 0.5), xytext=(0.6, 0.5), arrowprops=dict(arrowstyle="->")) + +def plot_regret_bounds_theoretical(ax): + """Theory: Regret Upper/Lower Bounds.""" + t = np.linspace(1, 100, 100) + ax.plot(t, np.sqrt(t), label=r"Upper Bound $O(\sqrt{T})$", color='red') + ax.plot(t, np.log(t), label=r"Optimal Regret $O(\log T)$", color='blue') + ax.set_title("Theoretical Regret Bounds", fontsize=12, fontweight='bold') + ax.set_xlabel("Time $T$") + ax.set_ylabel("Cumulative Regret") + ax.legend() + +def plot_soft_q_heatmap(ax): + """Value-based: Soft Q-Learning Heatmap.""" + data = np.random.randn(10, 10) + soft_q = np.exp(data) / np.sum(np.exp(data)) + im = ax.imshow(soft_q, cmap='hot') + plt.colorbar(im, ax=ax) + ax.set_title("Soft Q Boltzmann Probabilities", fontsize=12, fontweight='bold') + +def plot_ad_rl_pipeline(ax): + """Robotics: Autonomous Driving RL Pipeline.""" + ax.axis('off') + ax.set_title("Autonomous Driving RL Pipeline", fontsize=12, fontweight='bold') + modules = ["Sensors", "Perception (CNN)", "RL Policy", "Actuators"] + for i, m in enumerate(modules): + ax.text(0.25 + (i%2)*0.5, 0.7 - (i//2)*0.5, m, bbox=dict(fc="ivory"), ha='center') + ax.annotate("", xy=(0.7, 0.7), xytext=(0.3, 0.7), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.75, 0.35), xytext=(0.75, 0.6), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.3, 0.2), xytext=(0.7, 0.2), arrowprops=dict(arrowstyle="<-")) + +def plot_action_grad_comparison(ax): + """Policy: Stochastic vs Deterministic Gradients.""" + ax.axis('off') + ax.set_title("Action Gradient Types", fontsize=12, fontweight='bold') + ax.text(0.5, 0.7, r"Stochastic: $\nabla \log \pi(a|s) Q(s,a)$", color='blue', ha='center') + ax.text(0.5, 0.3, r"Deterministic: $\nabla_a Q(s,a) \nabla \pi(s)$", color='red', ha='center') + ax.text(0.5, 0.5, "vs", fontweight='bold', ha='center') + +def plot_irl_feature_matching(ax): + """IRL: Feature Expectation Matching.""" + ax.axis('off') + ax.set_title("IRL: Feature Expectation Matching", fontsize=12, fontweight='bold') + ax.text(0.2, 0.5, r"Expert $\mu(\pi^*)$", bbox=dict(fc="gold"), ha='center') + ax.text(0.8, 0.5, r"Learner $\mu(\pi)$", bbox=dict(fc="lightblue"), ha='center') + ax.annotate(r"$||\mu(\pi^*) - \mu(\pi)||_2 \leq \epsilon$", xy=(0.5, 0.2), ha='center', color='red') + ax.annotate("", xy=(0.65, 0.5), xytext=(0.35, 0.5), arrowprops=dict(arrowstyle="<->", ls='--')) + +def plot_apprenticeship_learning_loop(ax): + """Imitation: Apprenticeship Learning Loop.""" + ax.axis('off') + ax.set_title("Apprenticeship Learning Loop", fontsize=12, fontweight='bold') + nodes = ["Expert Demos", "Reward Learning", "Agent Policy", "Environment"] + for i, n in enumerate(nodes): + ax.text(0.5, 0.9 - i*0.25, n, bbox=dict(fc="ivory"), ha='center') + if i < 3: ax.annotate("", xy=(0.5, 0.7 - i*0.25), xytext=(0.5, 0.8 - i*0.25), arrowprops=dict(arrowstyle="->")) + ax.annotate("feedback", xy=(0.3, 0.9), xytext=(0.3, 0.15), arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5")) + +def plot_active_inference_loop(ax): + """Theoretical: Active Inference / Free Energy Loop.""" + ax.axis('off') + ax.set_title("Active Inference Loop", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "Internal Model (Generative)", bbox=dict(fc="cornflowerblue", alpha=0.3), ha='center') + ax.text(0.5, 0.2, "External Environment", bbox=dict(fc="lightgrey"), ha='center') + ax.annotate("Action (Active Charge)", xy=(0.8, 0.25), xytext=(0.8, 0.75), arrowprops=dict(arrowstyle="<-", color='red')) + ax.annotate("Perception (Surprise Min)", xy=(0.2, 0.75), xytext=(0.2, 0.25), arrowprops=dict(arrowstyle="<-", color='blue')) + ax.text(0.5, 0.5, r"$\min F = D_{KL}(q||p)$", ha='center', fontweight='bold') + +def plot_bellman_residual_landscape(ax): + """Theory: Bellman Residual Landscape.""" + X, Y = np.meshgrid(np.linspace(-2, 2, 20), np.linspace(-2, 2, 20)) + Z = (X**2 + Y**2) + 0.5 * np.sin(3*X) # Non-convex loss + ax.contourf(X, Y, Z, cmap='magma') + ax.set_title("Bellman Residual Landscape", fontsize=12, fontweight='bold') + +def plot_plan_to_explore_map(ax): + """MBRL: Plan-to-Explore Uncertainty Map.""" + data = np.random.rand(10, 10) + im = ax.imshow(data, cmap='YlOrRd') + ax.set_title("Plan-to-Explore Uncertainty", fontsize=12, fontweight='bold') + ax.text(2, 2, "Explored", color='black', fontsize=8) + ax.text(7, 7, "Unknown", color='red', fontweight='bold', fontsize=8) + +def plot_robust_rl_uncertainty_set(ax): + """Safety: Robust RL Uncertainty Set.""" + ax.axis('off') + ax.set_title("Robust RL Uncertainty Set", fontsize=12, fontweight='bold') + circle = plt.Circle((0.5, 0.5), 0.3, color='blue', alpha=0.1) + ax.add_patch(circle) + ax.text(0.5, 0.5, r"$\mathcal{P}$", fontsize=20, ha='center') + ax.text(0.5, 0.1, r"$\min_\pi \max_{P \in \mathcal{P}} \mathbb{E}[R]$", ha='center', fontsize=12) + ax.annotate("Nominal Model", xy=(0.5, 0.5), xytext=(0.2, 0.8), arrowprops=dict(arrowstyle="->")) + +def plot_hpo_bayesian_opt_cycle(ax): + """Training: HPO Bayesian Optimization Cycle.""" + ax.axis('off') + ax.set_title("HPO Bayesian Opt Cycle", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "Surrogate Model (GP)", bbox=dict(fc="ivory"), ha='center') + ax.text(0.5, 0.2, "RL Objective Function", bbox=dict(fc="ivory"), ha='center') + ax.annotate("Select Hyperparams", xy=(0.7, 0.3), xytext=(0.7, 0.7), arrowprops=dict(arrowstyle="<-")) + ax.annotate("Update Model", xy=(0.3, 0.7), xytext=(0.3, 0.3), arrowprops=dict(arrowstyle="<-")) + +def plot_slate_rl_reco_pipeline(ax): + """Applied: Slate RL / Recommendation Pipeline.""" + ax.axis('off') + ax.set_title("Slate RL Recommendation", fontsize=12, fontweight='bold') + ax.text(0.1, 0.5, "User State", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.5, 0.5, "Slate Policy", bbox=dict(fc="gold"), ha='center') + ax.text(0.9, 0.5, "Action (Items)", bbox=dict(fc="lightgreen"), ha='center') + ax.annotate("", xy=(0.4, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.8, 0.5), xytext=(0.6, 0.5), arrowprops=dict(arrowstyle="->")) + ax.text(0.5, 0.2, "Combinatorial Action Space", fontsize=8, ha='center') + +def plot_game_theory_fictitious_play(ax): + """Multi-Agent: Fictitious Play Interaction.""" + ax.axis('off') + ax.set_title("Fictitious Play Interaction", fontsize=12, fontweight='bold') + ax.text(0.2, 0.7, "Agent A (Best Response)", bbox=dict(fc="white"), ha='center') + ax.text(0.8, 0.7, "Agent B (Best Response)", bbox=dict(fc="white"), ha='center') + ax.text(0.5, 0.3, r"Empirical Frequency $\hat{\pi}$", bbox=dict(fc="ivory"), ha='center') + ax.annotate("", xy=(0.45, 0.4), xytext=(0.25, 0.6), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.55, 0.4), xytext=(0.75, 0.6), arrowprops=dict(arrowstyle="->")) + +def plot_universal_rl_framework(ax): + """Conceptual: Universal RL Framework Diagram.""" + ax.axis('off') + ax.set_title("Universal RL Framework", fontsize=12, fontweight='bold') + rect = plt.Rectangle((0.15, 0.15), 0.7, 0.7, fill=False, ls='--') + ax.add_patch(rect) + ax.text(0.5, 0.5, "RL Agent\n(Algorithm + Model + Exp)", ha='center', fontweight='bold') + ax.text(0.5, 0.9, "Problem Context", ha='center', color='grey') + ax.text(0.5, 0.1, "Reward / Evaluation", ha='center', color='grey') + +def plot_offline_density_ratio(ax): + """Offline RL: Density Ratio Estimation $w(s,a)$.""" + x = np.linspace(-3, 3, 100) + pi_e = norm.pdf(x, 0, 1) + pi_b = norm.pdf(x, 1, 1.5) + ax.plot(x, pi_e, label=r"Policy $\pi_e$") + ax.plot(x, pi_b, label=r"Behavior $\pi_b$", ls='--') + ax.fill_between(x, pi_e / (pi_b + 1e-5), alpha=0.1, label="Ratio $w$") + ax.set_title(r"Offline Density Ratio $w(s,a)$", fontsize=12, fontweight='bold') + ax.legend(fontsize=8) + +def plot_continual_task_interference(ax): + """Continual RL: Task Interference Heatmap.""" + data = np.eye(5) + 0.1 * np.random.randn(5, 5) + data[1,0] = -0.5 # Interference + im = ax.imshow(data, cmap='coolwarm', vmin=-1, vmax=1) + plt.colorbar(im, ax=ax) + ax.set_title("Continual Task Interference", fontsize=12, fontweight='bold') + ax.set_xlabel("Previously Learned Tasks"); ax.set_ylabel("Current Task") + +def plot_lyapunov_safe_set(ax): + """Safety: Lyapunov Stability Set.""" + ax.set_title("Lyapunov Safe Set", fontsize=12, fontweight='bold') + theta = np.linspace(0, 2*np.pi, 100) + r = 1 + 0.2 * np.sin(4*theta) + ax.fill(r * np.cos(theta), r * np.sin(theta), color='green', alpha=0.1, label="Invariant Set") + ax.plot(r * np.cos(theta), r * np.sin(theta), color='green') + ax.quiver(0.5, 0.5, -0.4, -0.4, color='red', scale=5, label="Energy Decrease") + ax.legend(fontsize=8); ax.set_xlim(-1.5, 1.5); ax.set_ylim(-1.5, 1.5) + +def plot_molecular_rl_atoms(ax): + """Applied: Molecular RL (Atoms).""" + ax.set_title("Molecular RL (Atom State)", fontsize=12, fontweight='bold') + for _ in range(5): + pos = np.random.rand(2) + circle = plt.Circle(pos, 0.05, color='blue', alpha=0.7) + ax.add_patch(circle) + ax.set_xlim(0, 1); ax.set_ylim(0, 1); ax.axis('off') + ax.text(0.5, -0.05, "States = Atomic Coordinates", ha='center', fontsize=8) + +def plot_moe_multi_task_arch(ax): + """Architecture: MoE for Multi-task.""" + ax.axis('off') + ax.set_title("MoE Multi-task Architecture", fontsize=12, fontweight='bold') + ax.text(0.5, 0.9, "Gating Network", bbox=dict(fc="orange"), ha='center') + for i in range(3): + ax.text(0.2 + i*0.3, 0.5, f"Expert {i+1}", bbox=dict(fc="ivory"), ha='center') + ax.annotate("", xy=(0.2 + i*0.3, 0.6), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="->")) + ax.text(0.5, 0.2, "Joint Output", bbox=dict(fc="lightgrey"), ha='center') + +def plot_cma_es_distribution(ax): + """Direct Policy Search: CMA-ES Distribution.""" + x = np.random.randn(200, 2) + ax.scatter(x[:,0], x[:,1], alpha=0.3, color='grey') + circle = plt.Circle((0, 0), 1.5, fill=False, color='red', lw=2, label="Sample Ellipsoid") + ax.add_patch(circle) + ax.set_title("CMA-ES Policy Search", fontsize=12, fontweight='bold') + ax.legend(fontsize=8) + +def plot_elo_rating_preference(ax): + """Alignment: Elo Rating Preference Plot.""" + x = np.linspace(0, 10, 10) + y = 1000 + 100 * np.log(x + 1) + 20 * np.random.randn(10) + ax.step(x, y, color='purple', where='post') + ax.set_title("Policy Elo Rating vs Experience", fontsize=12, fontweight='bold') + ax.set_xlabel("Relative Training Time"); ax.set_ylabel("Elo Rating") + +def plot_shap_lime_attribution(ax): + """Explainable RL: SHAP/LIME Attribution.""" + ax.set_title("Action Attribution (SHAP)", fontsize=12, fontweight='bold') + feats = ["Dist to Goal", "Velocity", "Agent Pitch", "Sensor 4"] + vals = [0.6, -0.3, 0.1, 0.05] + colors = ['green' if v > 0 else 'red' for v in vals] + ax.barh(feats, vals, color=colors) + ax.set_xlabel("Contribution to Action probability") + +def plot_pearl_context_encoder(ax): + """Meta-RL: Context Encoder (PEARL).""" + ax.axis('off') + ax.set_title("PEARL Context Encoder", fontsize=12, fontweight='bold') + ax.text(0.2, 0.5, "Experience batch\n(s, a, r, s')", bbox=dict(fc="ivory"), ha='center', fontsize=8) + ax.text(0.5, 0.5, r"Encoder $q_\phi(z|...)$", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.8, 0.5, "Latent Task $z$", bbox=dict(boxstyle="circle", fc="lightgreen"), ha='center') + ax.annotate("", xy=(0.4, 0.5), xytext=(0.3, 0.5), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.7, 0.5), xytext=(0.6, 0.5), arrowprops=dict(arrowstyle="->")) + +def plot_healthcare_rl_pipeline(ax): + """Applied: Healthcare / Medical Therapy.""" + ax.axis('off') + ax.set_title("Medical RL Therapy Pipeline", fontsize=12, fontweight='bold') + blocks = ["Patient History (EHR)", "State Estimator", "Policy (Action = Dose)", "Clinical Outcome"] + for i, b in enumerate(blocks): + ax.text(0.5, 0.9 - i*0.25, b, bbox=dict(fc="pink", alpha=0.3), ha='center') + if i < 3: ax.annotate("", xy=(0.5, 0.7 - i*0.25), xytext=(0.5, 0.8 - i*0.25), arrowprops=dict(arrowstyle="->")) + +def plot_supply_chain_rl(ax): + """Applied: Supply Chain / Inventory RL.""" + ax.axis('off') + ax.set_title("Supply Chain RL Pipeline", fontsize=12, fontweight='bold') + G = nx.DiGraph() + nodes = ["Factory", "Warehouse", "Retailer", "Customer"] + for i, n in enumerate(nodes): + ax.text(0.1 + i*0.27, 0.5, n, bbox=dict(boxstyle="round", fc="ivory"), ha='center') + for i in range(3): + ax.annotate("", xy=(0.28 + i*0.27, 0.5), xytext=(0.2 + i*0.27, 0.5), arrowprops=dict(arrowstyle="->")) + ax.text(0.5, 0.2, "State = Stock Levels, Action = Orders", ha='center', fontsize=8) + +def plot_sysid_safe_loop(ax): + """Robotics: Sim-to-Real SysID Loop.""" + ax.axis('off') + ax.set_title("Sim-to-Real SysID Loop", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "Physical System", bbox=dict(fc="lightgreen"), ha='center') + ax.text(0.5, 0.5, "System ID Estimator", bbox=dict(fc="orange", alpha=0.5), ha='center') + ax.text(0.5, 0.2, "Simulation Model", bbox=dict(fc="lightblue"), ha='center') + ax.annotate("Observables", xy=(0.4, 0.6), xytext=(0.4, 0.75), arrowprops=dict(arrowstyle="<-")) + ax.annotate("Update Parameters", xy=(0.6, 0.3), xytext=(0.6, 0.45), arrowprops=dict(arrowstyle="<-")) + +def plot_transformer_world_model(ax): + """Architecture: Transformer World Model.""" + ax.axis('off') + ax.set_title("Transformer World Model", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "Sequence of $(s, a, r)$", bbox=dict(fc="ivory"), ha='center') + ax.text(0.5, 0.5, "Self-Attention Layers", bbox=dict(fc="purple", alpha=0.3), ha='center') + ax.text(0.5, 0.2, "Predicted $s_{t+1}, r_{t+1}$", bbox=dict(fc="lightgreen"), ha='center') + ax.annotate("", xy=(0.5, 0.6), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.5, 0.3), xytext=(0.5, 0.45), arrowprops=dict(arrowstyle="->")) + +def plot_network_rl(ax): + """Applied: RL for Networking.""" + ax.axis('off') + ax.set_title("Network Traffic RL", fontsize=12, fontweight='bold') + G = nx.Graph() + G.add_edges_from([(0,1), (1,2), (2,3), (3,0)]) + pos = nx.spring_layout(G) + nx.draw(G, pos, ax=ax, node_color='lightblue', with_labels=False) + ax.annotate("RL Router", xy=(pos[1][0], pos[1][1]), xytext=(pos[1][0], pos[1][1]+0.2), arrowprops=dict(arrowstyle="->")) + +def plot_rlhf_ppo_ref(ax): + """Training: RLHF PPO with Reference Policy.""" + ax.axis('off') + ax.set_title("RLHF: PPO with Reference Policy", fontsize=12, fontweight='bold') + ax.text(0.3, 0.8, r"Active Policy $\pi_\theta$", bbox=dict(fc="ivory"), ha='center') + ax.text(0.7, 0.8, r"Ref Policy $\pi_{ref}$", bbox=dict(fc="lightgrey"), ha='center') + ax.text(0.5, 0.5, "KL Penalty", bbox=dict(boxstyle="sawtooth", fc="red", alpha=0.3), ha='center') + ax.text(0.5, 0.2, "Reward Model $r(s,a)$", bbox=dict(fc="gold"), ha='center') + ax.annotate("", xy=(0.4, 0.6), xytext=(0.3, 0.75), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.6, 0.6), xytext=(0.7, 0.75), arrowprops=dict(arrowstyle="->")) + ax.annotate("Total Reward", xy=(0.5, 0.4), xytext=(0.5, 0.3), arrowprops=dict(arrowstyle="<-")) + +def plot_psro_meta_game(ax): + """Multi-Agent: PSRO Meta-Game Tree.""" + ax.axis('off') + ax.set_title("PSRO Meta-Game Update", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "Meta-Game Matrix", bbox=dict(fc="ivory"), ha='center') + ax.text(0.2, 0.5, "Best Response", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.8, 0.5, "Nash Equilibrium", bbox=dict(fc="lightgreen"), ha='center') + ax.text(0.5, 0.2, "Add Oracle Policy", bbox=dict(fc="gold"), ha='center', fontweight='bold') + ax.annotate("", xy=(0.3, 0.6), xytext=(0.45, 0.75), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.7, 0.6), xytext=(0.55, 0.75), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.5, 0.3), xytext=(0.3, 0.45), arrowprops=dict(arrowstyle="->")) + +def plot_dial_comm_channel(ax): + """Multi-Agent: DIAL Comm Channel.""" + ax.axis('off') + ax.set_title("DIAL: Differentiable Comm", fontsize=12, fontweight='bold') + ax.text(0.2, 0.5, "Agent 1", bbox=dict(boxstyle="circle", fc="lightblue"), ha='center') + ax.text(0.8, 0.5, "Agent 2", bbox=dict(boxstyle="circle", fc="lightblue"), ha='center') + ax.annotate("Message $m$ (Differentiable)", xy=(0.7, 0.52), xytext=(0.3, 0.52), arrowprops=dict(arrowstyle="->", lw=2, color='orange')) + ax.annotate("Gradient $\\nabla m$", xy=(0.3, 0.48), xytext=(0.7, 0.48), arrowprops=dict(arrowstyle="->", lw=1, color='blue', ls='--')) + +def plot_fqi_batch_loop(ax): + """Batch RL: Fitted Q-Iteration (FQI).""" + ax.axis('off') + ax.set_title("Fitted Q-Iteration Loop", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, r"Dataset $\mathcal{D}$", bbox=dict(boxstyle="round", fc="ivory"), ha='center') + ax.text(0.5, 0.5, "Supervised Regressor", bbox=dict(fc="orange", alpha=0.3), ha='center') + ax.text(0.5, 0.2, "Updated $Q_{k+1}$", bbox=dict(fc="lightgreen"), ha='center') + ax.annotate("", xy=(0.5, 0.6), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.5, 0.3), xytext=(0.5, 0.45), arrowprops=dict(arrowstyle="->")) + ax.annotate("Bootstrap", xy=(0.8, 0.3), xytext=(0.8, 0.7), arrowprops=dict(arrowstyle="<-", connectionstyle="arc3,rad=-0.5")) + +def plot_cmdp_feasible_set(ax): + """Safety RL: CMDP Feasible Set.""" + ax.set_title("CMDP Feasible Region", fontsize=12, fontweight='bold') + circle = plt.Circle((0, 0), 1, alpha=0.2, color='green', label="Constrained Feasible Set") + ax.add_patch(circle) + ax.axhline(0.7, color='red', ls='--', label=r"Constraint $J \leq C$") + ax.text(0, -0.3, r"Optimized Policy $\pi^*$", color='blue', fontweight='bold', ha='center') + ax.set_xlim(-1.5, 1.5); ax.set_ylim(-1.5, 1.5) + ax.legend(fontsize=8) + +def plot_mpc_vs_rl_horizon(ax): + """Control: MPC vs RL Comparison.""" + ax.axis('off') + ax.set_title("MPC vs RL Planning", fontsize=12, fontweight='bold') + ax.text(0.25, 0.8, "MPC", fontweight='bold') + ax.text(0.75, 0.8, "RL", fontweight='bold') + ax.text(0.25, 0.5, "Receding Horizon\nPlanning at every step", ha='center', fontsize=8) + ax.text(0.75, 0.5, "Direct Mapping from\nState to Action (Policy)", ha='center', fontsize=8) + ax.text(0.5, 0.2, "Convergent when Model is Exact", color='grey', ha='center', fontsize=7) + +def plot_l2o_meta_pipeline(ax): + """AutoML: Learning to Optimize (L2O).""" + ax.axis('off') + ax.set_title("Learning to Optimize (L2O)", fontsize=12, fontweight='bold') + ax.text(0.5, 0.7, "Optimizer (RL Policy)", bbox=dict(fc="cornflowerblue"), ha='center') + ax.text(0.5, 0.3, "Optimizee (Deep Net)", bbox=dict(fc="lightgrey"), ha='center') + ax.annotate(r"Step $\Delta w$", xy=(0.5, 0.4), xytext=(0.5, 0.6), arrowprops=dict(arrowstyle="->")) + ax.annotate(r"Gradient $\nabla L$", xy=(0.2, 0.6), xytext=(0.2, 0.4), arrowprops=dict(arrowstyle="->", color='red')) + +def plot_chip_placement_rl(ax): + """Applied: RL for Chip Placement.""" + ax.set_title("RL for Chip Placement", fontsize=12, fontweight='bold') + ax.grid(True, ls='--', alpha=0.3) + for _ in range(8): + pos = np.random.rand(2) + rect = plt.Rectangle(pos, 0.1, 0.1, facecolor='lightblue', edgecolor='blue', alpha=0.7) + ax.add_patch(rect) + ax.set_xlim(0, 1); ax.set_ylim(0, 1) + ax.text(0.5, -0.15, "Optimizing Macro Placement on Silicon", ha='center', fontsize=8) + +def plot_compiler_mlgo(ax): + """Applied: RL for Compiler Optimization (MLGO).""" + ax.axis('off') + ax.set_title("MLGO: Compiler RL", fontsize=12, fontweight='bold') + G = nx.DiGraph() + G.add_edges_from([(0,1), (0,2), (1,3), (2,3)]) + pos = {0: (0.5, 0.9), 1: (0.3, 0.6), 2: (0.7, 0.6), 3: (0.5, 0.3)} + nx.draw(G, pos, ax=ax, node_color='lightgreen', with_labels=False) + ax.text(0.5, 0.1, "Control Flow Graph (CFG) + Inline Policy", ha='center', fontsize=8) + +def plot_theorem_proving_rl(ax): + """Applied: RL for Theorem Proving.""" + ax.axis('off') + ax.set_title("RL for Theorem Proving", fontsize=12, fontweight='bold') + ax.text(0.5, 0.9, "Target Theorem", bbox=dict(fc="ivory"), ha='center') + ax.text(0.3, 0.5, "Proof Step $a$", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.7, 0.5, "Heuristic $V(s)$", bbox=dict(fc="gold"), ha='center') + ax.text(0.5, 0.2, "Verified Proof Tree", ha='center', fontsize=8) + ax.annotate("", xy=(0.35, 0.6), xytext=(0.45, 0.8), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.65, 0.6), xytext=(0.55, 0.8), arrowprops=dict(arrowstyle="->")) + +def plot_diffusion_ql_loop(ax): + """Modern: Diffusion-QL Offline RL.""" + ax.axis('off') + ax.set_title("Diffusion-QL Training", fontsize=12, fontweight='bold') + ax.text(0.2, 0.5, r"Noise $\epsilon$", ha='center') + ax.text(0.5, 0.5, r"Denoising MLP\n$\pi_\theta(a|s, k)$", bbox=dict(fc="lightgrey"), ha='center') + ax.text(0.8, 0.5, "Action $a$", ha='center') + ax.annotate("", xy=(0.35, 0.5), xytext=(0.25, 0.5), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.65, 0.5), xytext=(1.0, 0.5), arrowprops=dict(arrowstyle="<-")) + ax.text(0.5, 0.2, "Policy as a Reverse Diffusion Process", fontsize=8, ha='center') + +def plot_fairness_rl_pareto(ax): + """Principles: Fairness-aware RL Pareto.""" + ax.set_title("Fairness-Reward Pareto Frontier", fontsize=12, fontweight='bold') + x = np.linspace(0.1, 1, 100) + y = 1 - x**2 + ax.plot(x, y, color='purple', lw=3, label="Pareto Frontier") + ax.fill_between(x, 0, y, color='purple', alpha=0.1) + ax.set_xlabel("Reward $R$"); ax.set_ylabel("Fairness Metric $F$") + ax.legend(fontsize=8) + +def plot_dp_rl_noise(ax): + """Principles: Differentially Private RL.""" + ax.axis('off') + ax.set_title("Differentially Private RL", fontsize=12, fontweight='bold') + ax.text(0.3, 0.5, r"Algorithm $\mathcal{A}$", bbox=dict(fc="ivory"), ha='center') + ax.text(0.5, 0.5, r"$\mathcal{N}(0, \sigma^2 \mathbb{I})$", bbox=dict(fc="red", alpha=0.3), ha='center') + ax.text(0.7, 0.5, r"Privacy Budget $\epsilon, \delta$", bbox=dict(fc="lightgrey"), ha='center') + ax.annotate("", xy=(0.4, 0.5), xytext=(0.3, 0.5), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.6, 0.5), xytext=(0.45, 0.5), arrowprops=dict(arrowstyle="->")) + +def plot_smart_agriculture_rl(ax): + """Applied: Smart Agriculture RL.""" + ax.axis('off') + ax.set_title("Smart Agriculture RL", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "Soil/Weather Sensors", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.5, 0.5, "Irrigation Policy", bbox=dict(fc="gold"), ha='center') + ax.text(0.5, 0.2, "Yield Optimization", bbox=dict(fc="lightgreen"), ha='center') + ax.annotate("", xy=(0.5, 0.6), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.5, 0.3), xytext=(0.5, 0.45), arrowprops=dict(arrowstyle="->")) + +def plot_climate_rl_grid(ax): + """Applied: Climate Science RL.""" + ax.set_title("Climate Mitigation RL (Grid)", fontsize=12, fontweight='bold') + data = np.random.randn(10, 10) + im = ax.imshow(data, cmap='coolwarm') + ax.set_xlabel("Longitude"); ax.set_ylabel("Latitude") + ax.text(5, 5, "Carbon Sequestration\nControl Map", ha='center', color='white', fontweight='bold', fontsize=8) + +def plot_ai_education_tracing(ax): + """Applied: Intelligent Tutoring Systems RL.""" + ax.axis('off') + ax.set_title("AI Education (Knowledge Tracing)", fontsize=12, fontweight='bold') + nodes = ["Concept 1", "Concept 2", "Student State $S_t$", "Next Problem $a_t$"] + for i, n in enumerate(nodes): + ax.text(0.2 + (i%2)*0.6, 0.7 - (i//2)*0.4, n, bbox=dict(fc="pink", alpha=0.3), ha='center') + ax.annotate("", xy=(0.6, 0.5), xytext=(0.4, 0.5), arrowprops=dict(arrowstyle="->")) + +def plot_decision_sde_flow(ax): + """Modern: Decision SDEs.""" + ax.set_title(r"Decision SDE Flow $dX_t = f(X_t, u_t)dt + gdW_t$", fontsize=10, fontweight='bold') + t = np.linspace(0, 1, 100) + for _ in range(5): + path = np.cumsum(np.random.normal(0, 0.1, size=100)) + ax.plot(t, path + 0.5*t, alpha=0.5) + ax.set_xlabel("Continuous Time $t$") + +def plot_diff_physics_brax(ax): + """Control: Differentiable Physics (Brax).""" + ax.axis('off') + ax.set_title(r"Differentiable physics $\nabla_{u} \mathcal{L}$", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "Physics Engine (Jacobian)", bbox=dict(fc="orange", alpha=0.1), ha='center') + ax.text(0.5, 0.5, "Simulator Layer", bbox=dict(fc="lightgrey"), ha='center') + ax.text(0.5, 0.2, "Policy Update", bbox=dict(fc="blue", alpha=0.1), ha='center') + ax.annotate("", xy=(0.5, 0.4), xytext=(0.5, 0.6), arrowprops=dict(arrowstyle="<-", color='red', label="Grads")) + +def plot_beamforming_rl(ax): + """Applied: RL for Beamforming.""" + ax.axis('off') + ax.set_title("Wireless Beamforming RL", fontsize=12, fontweight='bold') + ax.add_patch(plt.Circle((0.2, 0.5), 0.05, color='black')) + theta = np.linspace(-np.pi/4, np.pi/4, 100) + r = np.cos(4*theta) + ax.plot(0.2 + r*np.cos(theta), 0.5 + r*np.sin(theta), color='orange', label="Main Lobe") + ax.text(0.8, 0.5, "User Device", bbox=dict(boxstyle="round", fc="lightgrey"), ha='center') + +def plot_quantum_error_correction_rl(ax): + """Applied: Quantum Error Correction RL.""" + ax.axis('off') + ax.set_title("Quantum Error Correction RL", fontsize=12, fontweight='bold') + ax.text(0.1, 0.5, "Syndrome $S$", bbox=dict(fc="ivory"), ha='center') + ax.text(0.5, 0.5, "Decoder Agent", bbox=dict(boxstyle="round4", fc="purple", alpha=0.2), ha='center') + ax.text(0.9, 0.5, "Recovery $P$", bbox=dict(fc="gold"), ha='center') + ax.annotate("", xy=(0.4, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.8, 0.5), xytext=(0.6, 0.5), arrowprops=dict(arrowstyle="->")) + +def plot_mean_field_rl(ax): + """Multi-Agent: Mean Field RL.""" + ax.axis('off') + ax.set_title("Mean Field RL Interaction", fontsize=12, fontweight='bold') + x = np.random.randn(50) + ax.text(0.2, 0.5, "Single Agent $i$", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.8, 0.5, r"Mean State $\overline{s}$", bbox=dict(fc="white"), ha='center', fontweight='bold') + ax.annotate("", xy=(0.7, 0.5), xytext=(0.3, 0.5), arrowprops=dict(arrowstyle="<->")) + ax.text(0.5, 0.2, r"Population Limit $N \rightarrow \infty$", ha='center', fontsize=8) + +def plot_goal_gan_hrl(ax): + """HRL: Goal-GAN Pipeline.""" + ax.axis('off') + ax.set_title("Goal-GAN Curriculum", fontsize=12, fontweight='bold') + ax.text(0.2, 0.7, "Goal Generator\n(GAN Ref)", bbox=dict(fc="gold"), ha='center') + ax.text(0.8, 0.7, "RL Policy\n(Worker)", bbox=dict(fc="lightblue"), ha='center') + ax.text(0.5, 0.3, "Goal Label (Success/Fail)", bbox=dict(fc="ivory"), ha='center') + ax.annotate("Set Goal $g$", xy=(0.7, 0.7), xytext=(0.3, 0.7), arrowprops=dict(arrowstyle="->")) + ax.annotate("Train GAN", xy=(0.3, 0.4), xytext=(0.5, 0.35), arrowprops=dict(arrowstyle="->")) + +def plot_jepa_arch(ax): + """Modern: JEPA (Joint Embedding Predictive Architecture).""" + ax.axis('off') + ax.set_title("JEPA: Predictive Architecture", fontsize=12, fontweight='bold') + ax.text(0.2, 0.2, "Context $x$", bbox=dict(fc="lightgrey"), ha='center') + ax.text(0.8, 0.2, "Target $y$", bbox=dict(fc="lightgrey"), ha='center') + ax.text(0.2, 0.6, "Encoder $E_x$", bbox=dict(fc="cornflowerblue"), ha='center') + ax.text(0.8, 0.6, "Encoder $E_y$", bbox=dict(fc="cornflowerblue"), ha='center') + ax.text(0.5, 0.8, "Predictor $P$", bbox=dict(fc="orange", alpha=0.3), ha='center') + for i in [0.2, 0.8]: + ax.annotate("", xy=(i, 0.5), xytext=(i, 0.3), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.4, 0.75), xytext=(0.25, 0.65), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.6, 0.75), xytext=(0.75, 0.65), arrowprops=dict(arrowstyle="->")) + +def plot_cql_penalty_surface(ax): + """Offline RL: CQL Value Penalty.""" + X, Y = np.meshgrid(np.linspace(-3, 3, 20), np.linspace(-3, 3, 20)) + Z = (X**2 + Y**2) - 2 * np.exp(- (X**2 + Y**2)) # CQL lower bound + ax.contourf(X, Y, Z, cmap='viridis') + ax.set_title("CQL Value Penalty Landscape", fontsize=12, fontweight='bold') + +def plot_cyber_attack_defense(ax): + """Applied: Cybersecurity RL Game.""" + ax.axis('off') + ax.set_title("Cybersecurity Attack-Defense RL", fontsize=12, fontweight='bold') + ax.text(0.2, 0.7, "Attacker Agent", bbox=dict(fc="red", alpha=0.2), ha='center', fontweight='bold') + ax.text(0.8, 0.7, "Defender Agent", bbox=dict(fc="blue", alpha=0.2), ha='center', fontweight='bold') + ax.text(0.5, 0.3, "Network Infrastructure", bbox=dict(fc="grey", alpha=0.3), ha='center') + ax.annotate("Intrusion", xy=(0.4, 0.4), xytext=(0.2, 0.6), arrowprops=dict(arrowstyle="->", color='red')) + ax.annotate("Mitigation", xy=(0.6, 0.4), xytext=(0.8, 0.6), arrowprops=dict(arrowstyle="->", color='blue')) + + +def plot_smart_grid_rl(ax): + """Applied: Smart Grid Supply/Demand.""" + ax.axis('off') + ax.set_title("Smart Grid RL Management", fontsize=12, fontweight='bold') + ax.text(0.2, 0.8, "Renewables", ha='center') + ax.text(0.8, 0.8, "Consumers", ha='center') + ax.text(0.5, 0.5, "RL Dispatcher", bbox=dict(fc="gold"), ha='center') + ax.text(0.5, 0.2, "Energy Storage", bbox=dict(fc="lightgrey"), ha='center') + ax.annotate("", xy=(0.4, 0.55), xytext=(0.25, 0.75), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.75, 0.75), xytext=(0.6, 0.6), arrowprops=dict(arrowstyle="<-")) + ax.annotate("", xy=(0.5, 0.3), xytext=(0.5, 0.45), arrowprops=dict(arrowstyle="->")) + +def plot_quantum_tomography_rl(ax): + """Applied: Quantum State Tomography.""" + ax.axis('off') + ax.set_title("Quantum state Tomography RL", fontsize=12, fontweight='bold') + ax.text(0.5, 0.8, "Quantum State $\\rho$", bbox=dict(boxstyle="circle", fc="purple", alpha=0.2), ha='center') + ax.text(0.5, 0.5, "Measurement $M$", ha='center') + ax.text(0.5, 0.2, "RL Estimator", bbox=dict(fc="lightblue"), ha='center') + ax.annotate("", xy=(0.5, 0.6), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->")) + ax.annotate("", xy=(0.5, 0.3), xytext=(0.5, 0.45), arrowprops=dict(arrowstyle="->")) + +def plot_absolute_encyclopedia_map(ax): + """Conceptual: Absolute Universal Encyclopedia Map.""" + ax.axis('off') + ax.set_title("Absolute Universal RL Pillar Map", fontsize=14, fontweight='bold', color='darkblue') + categories = ["Foundational", "Model-Free", "Model-Based", "Advanced Paradigms", "Analysis/Safety", "Applied Pipelines"] + for i, c in enumerate(categories): + angle = 2 * np.pi * i / 6 + ax.text(0.5 + 0.35*np.cos(angle), 0.5 + 0.35*np.sin(angle), c, bbox=dict(fc="ivory", lw=2), ha='center', fontsize=9) + ax.text(0.5, 0.5, "Reinforcement\nLearning\nGraphical\nLibrary", ha='center', fontweight='bold', fontsize=12) + for i in range(6): + angle = 2 * np.pi * i / 6 + ax.annotate("", xy=(0.5 + 0.25*np.cos(angle), 0.5 + 0.25*np.sin(angle)), xytext=(0.5, 0.5), arrowprops=dict(arrowstyle="->", alpha=0.3)) + +def plot_actor_critic_arch(ax): + """Actor-Critic: Three-network diagram (TD3 - actor + two critics).""" + ax.axis('off') + ax.set_title("TD3 Architecture Diagram", fontsize=12, fontweight='bold') + + # State input + ax.text(0.1, 0.5, r"State" + "\n" + r"$s$", ha="center", va="center", bbox=dict(boxstyle="circle,pad=0.5", fc="lightblue")) + + # Networks + net_props = dict(boxstyle="square,pad=0.8", fc="lightgreen", ec="black") + ax.text(0.5, 0.8, r"Actor $\pi_\phi$", ha="center", va="center", bbox=net_props) + ax.text(0.5, 0.5, r"Critic 1 $Q_{\theta_1}$", ha="center", va="center", bbox=net_props) + ax.text(0.5, 0.2, r"Critic 2 $Q_{\theta_2}$", ha="center", va="center", bbox=net_props) + + # Outputs + ax.text(0.8, 0.8, "Action $a$", ha="center", va="center", bbox=dict(boxstyle="circle,pad=0.3", fc="coral")) + ax.text(0.8, 0.35, "Min Q-value", ha="center", va="center", bbox=dict(boxstyle="round,pad=0.3", fc="gold")) + + # Connections + kwargs = dict(arrowstyle="->", lw=1.5) + ax.annotate("", xy=(0.38, 0.8), xytext=(0.15, 0.55), arrowprops=kwargs) # S -> Actor + ax.annotate("", xy=(0.38, 0.5), xytext=(0.15, 0.5), arrowprops=kwargs) # S -> C1 + ax.annotate("", xy=(0.38, 0.2), xytext=(0.15, 0.45), arrowprops=kwargs) # S -> C2 + ax.annotate("", xy=(0.73, 0.8), xytext=(0.62, 0.8), arrowprops=kwargs) # Actor -> Action + ax.annotate("", xy=(0.68, 0.35), xytext=(0.62, 0.5), arrowprops=kwargs) # C1 -> Min + ax.annotate("", xy=(0.68, 0.35), xytext=(0.62, 0.2), arrowprops=kwargs) # C2 -> Min + +def plot_epsilon_decay(ax): + """Exploration: ε-Greedy Strategy Decay Curve.""" + episodes = np.arange(0, 1000) + epsilon = np.maximum(0.01, np.exp(-0.005 * episodes)) # Exponential decay + + ax.plot(episodes, epsilon, color='purple', lw=2) + ax.set_title(r"$\epsilon$-Greedy Decay Curve", fontsize=12, fontweight='bold') + ax.set_xlabel("Episodes") + ax.set_ylabel(r"Probability $\epsilon$") + ax.grid(True, linestyle='--', alpha=0.6) + ax.fill_between(episodes, epsilon, color='purple', alpha=0.1) + +def plot_learning_curve(ax): + """Advanced / Misc: Learning Curve with Confidence Bands.""" + steps = np.linspace(0, 1e6, 100) + # Simulate a learning curve converging to a maximum + mean_return = 100 * (1 - np.exp(-5e-6 * steps)) + np.random.normal(0, 2, len(steps)) + std_dev = 15 * np.exp(-2e-6 * steps) # Variance decreases as policy stabilizes + + ax.plot(steps, mean_return, color='blue', lw=2, label="PPO (Mean)") + ax.fill_between(steps, mean_return - std_dev, mean_return + std_dev, color='blue', alpha=0.2, label="±1 Std Dev") + + ax.set_title("Learning Curve (Return vs Steps)", fontsize=12, fontweight='bold') + ax.set_xlabel("Environment Steps") + ax.set_ylabel("Average Episodic Return") + ax.legend(loc="lower right") + ax.grid(True, linestyle='--', alpha=0.6) + +def main(): + # Figure 1: MDP & Environment (7 plots) + fig1, gs1 = setup_figure("RL: MDP & Environment", 2, 4) + + plot_agent_env_loop(fig1.add_subplot(gs1[0, 0])) + plot_mdp_graph(fig1.add_subplot(gs1[0, 1])) + plot_trajectory(fig1.add_subplot(gs1[0, 2])) + plot_continuous_space(fig1.add_subplot(gs1[0, 3])) + plot_reward_landscape(fig1, gs1) # projection='3d' handled inside + plot_discount_decay(fig1.add_subplot(gs1[1, 1])) + # row 5 (State Transition Graph) is basically plot_mdp_graph + + # Layout handled by constrained_layout=True + + # Figure 2: Value, Policy & Dynamic Programming + fig2, gs2 = setup_figure("RL: Value, Policy & Dynamic Programming", 2, 4) + plot_value_heatmap(fig2.add_subplot(gs2[0, 0])) + plot_action_value_q(fig2.add_subplot(gs2[0, 1])) + plot_policy_arrows(fig2.add_subplot(gs2[0, 2])) + plot_advantage_function(fig2.add_subplot(gs2[0, 3])) + plot_backup_diagram(fig2.add_subplot(gs2[1, 0])) # Policy Eval + plot_policy_improvement(fig2.add_subplot(gs2[1, 1])) + plot_value_iteration_backup(fig2.add_subplot(gs2[1, 2])) + plot_policy_iteration_cycle(fig2.add_subplot(gs2[1, 3])) + + # Layout handled by constrained_layout=True + + # Figure 3: Monte Carlo & Temporal Difference + fig3, gs3 = setup_figure("RL: Monte Carlo & Temporal Difference", 2, 4) + plot_mc_backup(fig3.add_subplot(gs3[0, 0])) + plot_mcts(fig3.add_subplot(gs3[0, 1])) + plot_importance_sampling(fig3.add_subplot(gs3[0, 2])) + plot_td_backup(fig3.add_subplot(gs3[0, 3])) + plot_nstep_td(fig3.add_subplot(gs3[1, 0])) + plot_eligibility_traces(fig3.add_subplot(gs3[1, 1])) + plot_sarsa_backup(fig3.add_subplot(gs3[1, 2])) + plot_q_learning_backup(fig3.add_subplot(gs3[1, 3])) + + # Layout handled by constrained_layout=True + + # Figure 4: TD Extensions & Function Approximation + fig4, gs4 = setup_figure("RL: TD Extensions & Function Approximation", 2, 4) + plot_double_q(fig4.add_subplot(gs4[0, 0])) + plot_dueling_dqn(fig4.add_subplot(gs4[0, 1])) + plot_prioritized_replay(fig4.add_subplot(gs4[0, 2])) + plot_rainbow_dqn(fig4.add_subplot(gs4[0, 3])) + plot_linear_fa(fig4.add_subplot(gs4[1, 0])) + plot_nn_layers(fig4.add_subplot(gs4[1, 1])) + plot_computation_graph(fig4.add_subplot(gs4[1, 2])) + plot_target_network(fig4.add_subplot(gs4[1, 3])) + + # Layout handled by constrained_layout=True + + # Figure 5: Policy Gradients, Actor-Critic & Exploration + fig5, gs5 = setup_figure("RL: Policy Gradients, Actor-Critic & Exploration", 2, 4) + plot_policy_gradient_flow(fig5.add_subplot(gs5[0, 0])) + plot_ppo_clip(fig5.add_subplot(gs5[0, 1])) + plot_trpo_trust_region(fig5.add_subplot(gs5[0, 2])) + plot_actor_critic_arch(fig5.add_subplot(gs5[0, 3])) + plot_a3c_multi_worker(fig5.add_subplot(gs5[1, 0])) + plot_sac_arch(fig5.add_subplot(gs5[1, 1])) + plot_softmax_exploration(fig5.add_subplot(gs5[1, 2])) + plot_ucb_confidence(fig5.add_subplot(gs5[1, 3])) + + # Layout handled by constrained_layout=True + + # Figure 6: Hierarchical, Model-Based & Offline RL + fig6, gs6 = setup_figure("RL: Hierarchical, Model-Based & Offline", 2, 4) + plot_options_framework(fig6.add_subplot(gs6[0, 0])) + plot_feudal_networks(fig6.add_subplot(gs6[0, 1])) + plot_world_model(fig6.add_subplot(gs6[0, 2])) + plot_model_planning(fig6.add_subplot(gs6[0, 3])) + plot_offline_rl(fig6.add_subplot(gs6[1, 0])) + plot_cql_regularization(fig6.add_subplot(gs6[1, 1])) + plot_epsilon_decay(fig6.add_subplot(gs6[1, 2])) # placeholder/spacer + plot_intrinsic_motivation(fig6.add_subplot(gs6[1, 3])) + + # Layout handled by constrained_layout=True + + # Figure 7: Multi-Agent, IRL & Meta-RL + fig7, gs7 = setup_figure("RL: Multi-Agent, IRL & Meta-RL", 2, 4) + plot_multi_agent_interaction(fig7.add_subplot(gs7[0, 0])) + plot_ctde(fig7.add_subplot(gs7[0, 1])) + plot_payoff_matrix(fig7.add_subplot(gs7[0, 2])) + plot_irl_reward_inference(fig7.add_subplot(gs7[0, 3])) + plot_gail_flow(fig7.add_subplot(gs7[1, 0])) + plot_meta_rl_nested_loop(fig7.add_subplot(gs7[1, 1])) + plot_task_distribution(fig7.add_subplot(gs7[1, 2])) + + # Layout handled by constrained_layout=True + + # Figure 8: Advanced / Miscellaneous Topics + fig8, gs8 = setup_figure("RL: Advanced & Miscellaneous", 2, 4) + plot_replay_buffer(fig8.add_subplot(gs8[0, 0])) + plot_state_visitation(fig8.add_subplot(gs8[0, 1])) + plot_regret_curve(fig8.add_subplot(gs8[0, 2])) + plot_attention_weights(fig8.add_subplot(gs8[0, 3])) + plot_diffusion_policy(fig8.add_subplot(gs8[1, 0])) + plot_gnn_rl(fig8.add_subplot(gs8[1, 1])) + plot_latent_space(fig8.add_subplot(gs8[1, 2])) + plot_convergence_log(fig8.add_subplot(gs8[1, 3])) + + # Figure 9: Specialized & Modern RL (Advanced Gallery) + fig9, gs9 = setup_figure("RL: Specialized & Modern (Absolute Completeness)", 3, 4) + # Row 1 + plot_rl_taxonomy_tree(fig9.add_subplot(gs9[0, 0])) + plot_rl_as_inference_pgm(fig9.add_subplot(gs9[0, 1])) + plot_distributional_rl_atoms(fig9.add_subplot(gs9[0, 2])) + plot_her_goal_relabeling(fig9.add_subplot(gs9[0, 3])) + # Row 2 + plot_dyna_q_flow(fig9.add_subplot(gs9[1, 0])) + plot_noisy_nets_parameters(fig9.add_subplot(gs9[1, 1])) + plot_icm_curiosity(fig9.add_subplot(gs9[1, 2])) + plot_v_trace_impala(fig9.add_subplot(gs9[1, 3])) + # Row 3 + plot_qmix_mixing_net(fig9.add_subplot(gs9[2, 0])) + plot_saliency_heatmaps(fig9.add_subplot(gs9[2, 1])) + plot_tsne_state_embeddings(fig9.add_subplot(gs9[2, 2])) + plot_action_selection_noise(fig9.add_subplot(gs9[2, 3])) + + # Figure 10: Evaluation, Safety & Alignment + fig10, gs10 = setup_figure("RL: Evaluation, Safety & Alignment", 2, 4) + plot_success_rate_curve(fig10.add_subplot(gs10[0, 0])) + plot_performance_profiles_rliable(fig10.add_subplot(gs10[0, 1])) + plot_hyperparameter_sensitivity(fig10.add_subplot(gs10[0, 2])) + plot_action_persistence(fig10.add_subplot(gs10[0, 3])) + plot_safety_shielding(fig10.add_subplot(gs10[1, 0])) + plot_automated_curriculum(fig10.add_subplot(gs10[1, 1])) + plot_domain_randomization(fig10.add_subplot(gs10[1, 2])) + plot_rlhf_flow(fig10.add_subplot(gs10[1, 3])) + + # Figure 11: Transformer & Specific MB Architecture + fig11, gs11 = setup_figure("RL: Transformers & Specific MB Architecture", 1, 3) + plot_decision_transformer_tokens(fig11.add_subplot(gs11[0, 0])) + plot_muzero_search_tree(fig11.add_subplot(gs11[0, 1])) + plot_policy_distillation(fig11.add_subplot(gs11[0, 2])) + + # Special Handle for Loss Landscape in Dashboard if needed (but it's 3D) + # We skip it in the main dashboard or add it to a single 3D fig + fig_loss = plt.figure(figsize=(10, 8)) + gs_loss = GridSpec(1, 1, figure=fig_loss) + plot_loss_landscape(fig_loss, gs_loss) + + plt.show() + +def save_all_graphs(output_dir="graphs"): + """Saves each of the 74 RL components as a separate PNG file.""" + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Component-to-Function Mapping (Total 74 entries as per e.md rows) + mapping = { + "Agent-Environment Interaction Loop": plot_agent_env_loop, + "Markov Decision Process (MDP) Tuple": plot_mdp_graph, + "State Transition Graph": plot_mdp_graph, + "Trajectory / Episode Sequence": plot_trajectory, + "Continuous State/Action Space Visualization": plot_continuous_space, + "Reward Function / Landscape": plot_reward_landscape, + "Discount Factor (gamma) Effect": plot_discount_decay, + "State-Value Function V(s)": plot_value_heatmap, + "Action-Value Function Q(s,a)": plot_action_value_q, + "Policy pi(s) or pi(a|s)": plot_policy_arrows, + "Advantage Function A(s,a)": plot_advantage_function, + "Optimal Value Function V* / Q*": plot_value_heatmap, + "Policy Evaluation Backup": plot_backup_diagram, + "Policy Improvement": plot_policy_improvement, + "Value Iteration Backup": plot_value_iteration_backup, + "Policy Iteration Full Cycle": plot_policy_iteration_cycle, + "Monte Carlo Backup": plot_mc_backup, + "Monte Carlo Tree (MCTS)": plot_mcts, + "Importance Sampling Ratio": plot_importance_sampling, + "TD(0) Backup": plot_td_backup, + "Bootstrapping (general)": plot_td_backup, + "n-step TD Backup": plot_nstep_td, + "TD(lambda) & Eligibility Traces": plot_eligibility_traces, + "SARSA Update": plot_sarsa_backup, + "Q-Learning Update": plot_q_learning_backup, + "Expected SARSA": plot_expected_sarsa_backup, + "Double Q-Learning / Double DQN": plot_double_q, + "Dueling DQN Architecture": plot_dueling_dqn, + "Prioritized Experience Replay": plot_prioritized_replay, + "Rainbow DQN Components": plot_rainbow_dqn, + "Linear Function Approximation": plot_linear_fa, + "Neural Network Layers (MLP, CNN, RNN, Transformer)": plot_nn_layers, + "Computation Graph / Backpropagation Flow": plot_computation_graph, + "Target Network": plot_target_network, + "Policy Gradient Theorem": plot_policy_gradient_flow, + "REINFORCE Update": plot_reinforce_flow, + "Baseline / Advantage Subtraction": plot_advantage_scaled_grad, + "Trust Region (TRPO)": plot_trpo_trust_region, + "Proximal Policy Optimization (PPO)": plot_ppo_clip, + "Actor-Critic Architecture": plot_actor_critic_arch, + "Advantage Actor-Critic (A2C/A3C)": plot_a3c_multi_worker, + "Soft Actor-Critic (SAC)": plot_sac_arch, + "Twin Delayed DDPG (TD3)": plot_actor_critic_arch, + "epsilon-Greedy Strategy": plot_epsilon_decay, + "Softmax / Boltzmann Exploration": plot_softmax_exploration, + "Upper Confidence Bound (UCB)": plot_ucb_confidence, + "Intrinsic Motivation / Curiosity": plot_intrinsic_motivation, + "Entropy Regularization": plot_entropy_bonus, + "Options Framework": plot_options_framework, + "Feudal Networks / Hierarchical Actor-Critic": plot_feudal_networks, + "Skill Discovery": plot_skill_discovery, + "Learned Dynamics Model": plot_world_model, + "Model-Based Planning": plot_model_planning, + "Imagination-Augmented Agents (I2A)": plot_imagination_rollout, + "Offline Dataset": plot_offline_rl, + "Conservative Q-Learning (CQL)": plot_cql_regularization, + "Multi-Agent Interaction Graph": plot_multi_agent_interaction, + "Centralized Training Decentralized Execution (CTDE)": plot_ctde, + "Cooperative / Competitive Payoff Matrix": plot_payoff_matrix, + "Reward Inference": plot_irl_reward_inference, + "Generative Adversarial Imitation Learning (GAIL)": plot_gail_flow, + "Meta-RL Architecture": plot_meta_rl_nested_loop, + "Task Distribution Visualization": plot_task_distribution, + "Experience Replay Buffer": plot_replay_buffer, + "State Visitation / Occupancy Measure": plot_state_visitation, + "Learning Curve": plot_learning_curve, + "Regret / Cumulative Regret": plot_regret_curve, + "Attention Mechanisms (Transformers in RL)": plot_attention_weights, + "Diffusion Policy": plot_diffusion_policy, + "Graph Neural Networks for RL": plot_gnn_rl, + "World Model / Latent Space": plot_latent_space, + "Convergence Analysis Plots": plot_convergence_log, + "RL Algorithm Taxonomy": plot_rl_taxonomy_tree, + "Probabilistic Graphical Model (RL as Inference)": plot_rl_as_inference_pgm, + "Distributional RL (C51 / Categorical)": plot_distributional_rl_atoms, + "Hindsight Experience Replay (HER)": plot_her_goal_relabeling, + "Dyna-Q Architecture": plot_dyna_q_flow, + "Noisy Networks (Parameter Noise)": plot_noisy_nets_parameters, + "Intrinsic Curiosity Module (ICM)": plot_icm_curiosity, + "V-trace (IMPALA)": plot_v_trace_impala, + "QMIX Mixing Network": plot_qmix_mixing_net, + "Saliency Maps / Attention on State": plot_saliency_heatmaps, + "Action Selection Noise (OU vs Gaussian)": plot_action_selection_noise, + "t-SNE / UMAP State Embeddings": plot_tsne_state_embeddings, + "Loss Landscape Visualization": plot_loss_landscape, + "Success Rate vs Steps": plot_success_rate_curve, + "Hyperparameter Sensitivity Heatmap": plot_hyperparameter_sensitivity, + "Action Persistence (Frame Skipping)": plot_action_persistence, + "MuZero Dynamics Search Tree": plot_muzero_search_tree, + "Policy Distillation": plot_policy_distillation, + "Decision Transformer Token Sequence": plot_decision_transformer_tokens, + "Performance Profiles (rliable)": plot_performance_profiles_rliable, + "Safety Shielding / Barrier Functions": plot_safety_shielding, + "Automated Curriculum Learning": plot_automated_curriculum, + "Domain Randomization": plot_domain_randomization, + "RL with Human Feedback (RLHF)": plot_rlhf_flow, + "Successor Representations (SR)": plot_successor_representations, + "Maximum Entropy IRL": plot_maxent_irl_trajectories, + "Information Bottleneck": plot_information_bottleneck, + "Evolutionary Strategies Population": plot_es_population_distribution, + "Control Barrier Functions (CBF)": plot_cbf_safe_set, + "Count-based Exploration Heatmap": plot_count_based_exploration, + "Thompson Sampling Posteriors": plot_thompson_sampling, + "Adversarial RL Interaction": plot_adversarial_rl_interaction, + "Hierarchical Subgoal Trajectory": plot_hierarchical_subgoals, + "Offline Action Distribution Shift": plot_offline_distribution_shift, + "Random Network Distillation (RND)": plot_rnd_curiosity, + "Batch-Constrained Q-learning (BCQ)": plot_bcq_offline_constraint, + "Population-Based Training (PBT)": plot_pbt_evolution, + "Recurrent State Flow (DRQN/R2D2)": plot_recurrent_state_flow, + "Belief State in POMDPs": plot_belief_state_pomdp, + "Multi-Objective Pareto Front": plot_pareto_front_morl, + "Differential Value (Average Reward RL)": plot_differential_value_average_reward, + "Distributed RL Cluster (Ray/RLLib)": plot_distributed_rl_cluster, + "Neuroevolution Topology Evolution": plot_neuroevolution_topology, + "Elastic Weight Consolidation (EWC)": plot_ewc_elastic_weights, + "Successor Features (SF)": plot_successor_features, + "Adversarial State Noise (Perception)": plot_adversarial_state_noise, + "Behavioral Cloning (Imitation)": plot_behavioral_cloning_il, + "Relational Graph State Representation": plot_relational_graph_state, + "Quantum RL Circuit (PQC)": plot_quantum_rl_circuit, + "Symbolic Policy Tree": plot_symbolic_expression_tree, + "Differentiable Physics Gradient Flow": plot_differentiable_physics_gradient, + "MARL Communication Channel": plot_marl_communication_channel, + "Lagrangian Constraint Landscape": plot_lagrangian_multiplier_landscape, + "MAXQ Task Hierarchy": plot_maxq_task_hierarchy, + "ReAct Agentic Cycle": plot_react_cycle_thinking, + "Synaptic Plasticity RL": plot_synaptic_plasticity_rl, + "Guided Policy Search (GPS)": plot_guided_policy_search_gps, + "Sim-to-Real Jitter & Latency": plot_sim2real_jitter_latency, + "Deterministic Policy Gradient (DDPG) Flow": plot_ddpg_deterministic_gradient, + "Dreamer Latent Imagination": plot_dreamer_latent_rollout, + "UNREAL Auxiliary Tasks": plot_unreal_auxiliary_tasks, + "Implicit Q-Learning (IQL) Expectile": plot_iql_expectile_loss, + "Prioritized Sweeping": plot_prioritized_sweeping, + "DAgger Expert Loop": plot_dagger_expert_loop, + "Self-Predictive Representations (SPR)": plot_spr_self_prediction, + "Joint Action Space": plot_joint_action_space, + "Dec-POMDP Formal Model": plot_dec_pomdp_graph, + "Bisimulation Metric": plot_bisimulation_metric, + "Potential-Based Reward Shaping": plot_reward_shaping_phi, + "Transfer RL: Source to Target": plot_transfer_rl_source_target, + "Multi-Task Backbone Arch": plot_multi_task_backbone, + "Contextual Bandit Pipeline": plot_contextual_bandit_pipeline, + "Theoretical Regret Bounds": plot_regret_bounds_theoretical, + "Soft Q Boltzmann Probabilities": plot_soft_q_heatmap, + "Autonomous Driving RL Pipeline": plot_ad_rl_pipeline, + "Policy action gradient comparison": plot_action_grad_comparison, + "IRL: Feature Expectation Matching": plot_irl_feature_matching, + "Apprenticeship Learning Loop": plot_apprenticeship_learning_loop, + "Active Inference Loop": plot_active_inference_loop, + "Bellman Residual Landscape": plot_bellman_residual_landscape, + "Plan-to-Explore Uncertainty Map": plot_plan_to_explore_map, + "Robust RL Uncertainty Set": plot_robust_rl_uncertainty_set, + "HPO Bayesian Opt Cycle": plot_hpo_bayesian_opt_cycle, + "Slate RL Recommendation": plot_slate_rl_reco_pipeline, + "Fictitious Play Interaction": plot_game_theory_fictitious_play, + "Universal RL Framework Diagram": plot_universal_rl_framework, + "Offline Density Ratio Estimator": plot_offline_density_ratio, + "Continual Task Interference Heatmap": plot_continual_task_interference, + "Lyapunov Stability Safe Set": plot_lyapunov_safe_set, + "Molecular RL (Atom Coordinates)": plot_molecular_rl_atoms, + "MoE Multi-task Architecture": plot_moe_multi_task_arch, + "CMA-ES Policy Search": plot_cma_es_distribution, + "Elo Rating Preference Plot": plot_elo_rating_preference, + "Explainable RL (SHAP Attribution)": plot_shap_lime_attribution, + "PEARL Context Encoder": plot_pearl_context_encoder, + "Medical RL Therapy Pipeline": plot_healthcare_rl_pipeline, + "Supply Chain RL Pipeline": plot_supply_chain_rl, + "Sim-to-Real SysID Loop": plot_sysid_safe_loop, + "Transformer World Model": plot_transformer_world_model, + "Network Traffic RL": plot_network_rl, + "RLHF: PPO with Reference Policy": plot_rlhf_ppo_ref, + "PSRO Meta-Game Update": plot_psro_meta_game, + "DIAL: Differentiable Comm": plot_dial_comm_channel, + "Fitted Q-Iteration Loop": plot_fqi_batch_loop, + "CMDP Feasible Region": plot_cmdp_feasible_set, + "MPC vs RL Planning": plot_mpc_vs_rl_horizon, + "Learning to Optimize (L2O)": plot_l2o_meta_pipeline, + "Smart Grid RL Management": plot_smart_grid_rl, + "Quantum State Tomography RL": plot_quantum_tomography_rl, + "Absolute Universal RL Pillar Map": plot_absolute_encyclopedia_map, + "RL for Chip Placement": plot_chip_placement_rl, + "RL Compiler Optimization (MLGO)": plot_compiler_mlgo, + "RL for Theorem Proving": plot_theorem_proving_rl, + "Diffusion-QL Offline RL": plot_diffusion_ql_loop, + "Fairness-reward Pareto Frontier": plot_fairness_rl_pareto, + "Differentially Private RL": plot_dp_rl_noise, + "Smart Agriculture RL": plot_smart_agriculture_rl, + "Climate Mitigation RL (Grid)": plot_climate_rl_grid, + "AI Education (Knowledge Tracing)": plot_ai_education_tracing, + "Decision SDE Flow": plot_decision_sde_flow, + "Differentiable physics (Brax)": plot_diff_physics_brax, + "Wireless Beamforming RL": plot_beamforming_rl, + "Quantum Error Correction RL": plot_quantum_error_correction_rl, + "Mean Field RL Interaction": plot_mean_field_rl, + "Goal-GAN Curriculum": plot_goal_gan_hrl, + "JEPA: Predictive Architecture": plot_jepa_arch, + "CQL Value Penalty Landscape": plot_cql_penalty_surface, + "Cybersecurity Attack-Defense RL": plot_cyber_attack_defense + } + + import sys + + for name, func in mapping.items(): + # Sanitize filename + filename = re.sub(r'[^a-zA-Z0-9]', '_', name.lower()).strip('_') + filename = re.sub(r'_+', '_', filename) + ".png" + filepath = os.path.join(output_dir, filename) + + print(f"Generating: {filename} ...") + + plt.close('all') + + if func in [plot_reward_landscape, plot_loss_landscape]: + fig = plt.figure(figsize=(10, 8)) + gs = GridSpec(1, 1, figure=fig) + func(fig, gs) + plt.savefig(filepath, bbox_inches='tight', dpi=100) + plt.close(fig) + continue + + fig, ax = plt.subplots(figsize=(10, 8), constrained_layout=True) + func(ax) + plt.savefig(filepath, bbox_inches='tight', dpi=100) + plt.close(fig) + + print(f"\n[SUCCESS] Saved {len(mapping)} graphs to '{output_dir}/' directory.") + +if __name__ == "__main__": + import sys + if "--save" in sys.argv: + save_all_graphs() + else: + main() \ No newline at end of file