| import torch | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| def plot_length_distribution(dataset_name, metadata_path, output_path): | |
| print(f"Loading metadata from {metadata_path}...") | |
| metadata = torch.load(metadata_path, weights_only=False) | |
| lengths = [item['length'] for item in metadata] | |
| plt.figure(figsize=(10, 6)) | |
| plt.hist(lengths, bins=50, color='skyblue', edgecolor='black', alpha=0.7) | |
| mean_len = np.mean(lengths) | |
| median_len = np.median(lengths) | |
| plt.axvline(mean_len, color='red', linestyle='dashed', linewidth=1, label=f'Mean: {mean_len:.1f}') | |
| plt.axvline(median_len, color='green', linestyle='dashed', linewidth=1, label=f'Median: {median_len:.1f}') | |
| plt.title(f'Video Length Distribution - {dataset_name}') | |
| plt.xlabel('Number of Frames') | |
| plt.ylabel('Frequency') | |
| plt.grid(axis='y', alpha=0.3) | |
| plt.legend() | |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
| plt.savefig(output_path) | |
| print(f"Plot saved to {output_path}") | |
| print(f"Stats for {dataset_name}:") | |
| print(f" Total Trajectories: {len(lengths)}") | |
| print(f" Min Length: {np.min(lengths)}") | |
| print(f" Max Length: {np.max(lengths)}") | |
| print(f" Mean Length: {mean_len:.1f}") | |
| print(f" Median Length: {median_len:.1f}") | |
| if __name__ == "__main__": | |
| datasets = { | |
| "language_table": "/storage/ice-shared/ae8803che/hxue/data/dataset/language_table/metadata_lite.pt", | |
| "rt1": "/storage/ice-shared/ae8803che/hxue/data/dataset/rt1/metadata_lite.pt", | |
| "recon": "/storage/ice-shared/ae8803che/hxue/data/dataset/recon_processed/metadata_lite.pt", | |
| "dreamer4": "/storage/ice-shared/ae8803che/hxue/data/dataset/dreamer4_processed/metadata_lite.pt" | |
| } | |
| for name, meta_path in datasets.items(): | |
| out_path = f"/storage/ice-shared/ae8803che/hxue/data/world_model/results/stats/{name}_dist.png" | |
| plot_length_distribution(name, meta_path, out_path) | |