| import torch |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import os |
| import sys |
|
|
| |
| sys.path.append("/storage/ice-shared/ae8803che/hxue/data/world_model") |
| from wm.model.diffusion.flow_matching import FlowMatchScheduler |
|
|
| def plot_sigma_curve(): |
| scheduler = FlowMatchScheduler() |
| |
| scheduler.set_timesteps(num_inference_steps=1000, training=True, shift=5.0) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| t_indices = np.arange(1000) |
| sigmas = scheduler.sigmas.numpy() |
| timesteps = scheduler.timesteps.numpy() |
| |
| |
| |
| |
| |
| plt.figure(figsize=(10, 6)) |
| |
| |
| plt.plot(t_indices, sigmas, label='Sigma (Noise Level)', color='blue') |
| plt.title("Sigma ($\sigma$) vs Step Index (0-999)\nWan Shift=5.0, num_steps=1000") |
| plt.xlabel("Index") |
| plt.ylabel("Sigma Value (0=Clean, 1=Noise)") |
| plt.grid(True, which='both', linestyle='--', alpha=0.5) |
| |
| |
| linear_sigmas = np.linspace(1.0, 0.0, 1000) |
| plt.plot(t_indices, linear_sigmas, 'r--', alpha=0.5, label='Linear (No Shift)') |
| |
| plt.legend() |
| |
| output_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_flow_matching/sigma_vs_index.png" |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| plt.savefig(output_path) |
| print(f"Plot saved to {output_path}") |
|
|
| |
| plt.figure(figsize=(10, 6)) |
| weights = scheduler.linear_timesteps_weights.numpy() |
| plt.plot(timesteps, weights, color='green', label='Training Weight') |
| plt.title("Training Weight vs Training Timestep ($t$)\nGaussian-like Weighting") |
| plt.xlabel("Training Timestep ($t \in [0, 1000]$)") |
| plt.ylabel("Weight Value") |
| plt.grid(True, which='both', linestyle='--', alpha=0.5) |
| plt.legend() |
| |
| weight_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_flow_matching/weight_vs_t.png" |
| plt.savefig(weight_path) |
| print(f"Plot saved to {weight_path}") |
|
|
| if __name__ == "__main__": |
| plot_sigma_curve() |
|
|