| import torch |
| import os |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| metadata_path = "/storage/ice-shared/ae8803che/hxue/data/dataset/franka/metadata.pt" |
| if not os.path.exists(metadata_path): |
| print(f"Error: {metadata_path} not found.") |
| exit(1) |
|
|
| metadata = torch.load(metadata_path) |
| num_trajectories = len(metadata) |
|
|
| lengths = [] |
| action_dims = set() |
|
|
| |
| if isinstance(metadata, dict): |
| iterator = metadata.values() |
| else: |
| iterator = metadata |
|
|
| for info in iterator: |
| if 'num_frames' in info: |
| lengths.append(info['num_frames']) |
| elif 'actions' in info: |
| lengths.append(info['actions'].shape[0]) |
| else: |
| print(f"Keys in info: {info.keys()}") |
| break |
| action_dims.add(info['actions'].shape[-1]) |
|
|
| avg_len = sum(lengths) / len(lengths) |
| median_len = np.median(lengths) |
| action_dim = list(action_dims)[0] if len(action_dims) == 1 else str(action_dims) |
|
|
| print(f"Trajectories: {num_trajectories}") |
| print(f"Action Dim: {action_dim}") |
| print(f"Avg. Video Len: {avg_len:.1f}") |
| print(f"Median Video Len: {median_len:.1f}") |
|
|
| |
| plt.figure(figsize=(10, 6)) |
| plt.hist(lengths, bins=30, color='skyblue', edgecolor='black') |
| plt.title(f"Franka Video Length Distribution (N={num_trajectories})") |
| plt.xlabel("Number of Frames") |
| plt.ylabel("Frequency") |
| plt.grid(axis='y', alpha=0.75) |
|
|
| save_path = "/storage/ice-shared/ae8803che/hxue/data/world_model/results/stats/franka_dist.png" |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| plt.savefig(save_path) |
| print(f"Distribution plot saved to {save_path}") |
|
|