returnzeros's picture
Upload 108 files
4d8e7a6 verified
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import math
from torch.nn.modules.batchnorm import _BatchNorm
from collections import OrderedDict
from torch.optim.lr_scheduler import LambdaLR
import torch.nn as nn
from torch.nn import functional as F
import h5py
import fnmatch
from torchvision import transforms
import pickle
from tqdm import tqdm
_UINT8_MAX_F = float(torch.iinfo(torch.uint8).max)
def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed):
# save training curves
for key in train_history[0]:
plot_path = os.path.join(ckpt_dir, f"train_val_{key}_seed_{seed}.png")
plt.figure()
train_values = [summary[key].item() for summary in train_history]
val_values = [summary[key].item() for summary in validation_history]
plt.plot(
np.linspace(0, num_epochs - 1, len(train_history)),
train_values,
label="train",
)
plt.plot(
np.linspace(0, num_epochs - 1, len(validation_history)),
val_values,
label="validation",
)
plt.tight_layout()
plt.legend()
plt.title(key)
plt.savefig(plot_path)
print(f"Saved plots to {ckpt_dir}")
def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray:
"""Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255].
Args:
input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1].
Returns:
A numpy image of layout BxHxWx3, range [0..255], uint8 dtype.
"""
if range_min == -1:
input_tensor = (input_tensor.float() + 1.0) / 2.0
ndim = input_tensor.ndim
output_image = input_tensor.clamp(0, 1).cpu().numpy()
output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,))
return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8)
def kl_divergence(mu, logvar):
batch_size = mu.size(0)
assert batch_size != 0
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
total_kld = klds.sum(1).mean(0, True)
dimension_wise_kld = klds.mean(0)
mean_kld = klds.mean(1).mean(0, True)
return total_kld, dimension_wise_kld, mean_kld
class RandomShiftsAug(nn.Module):
def __init__(self, pad_h, pad_w):
super().__init__()
self.pad_h = pad_h
self.pad_w = pad_w
print(f"RandomShiftsAug: pad_h {pad_h}, pad_w {pad_w}")
def forward(self, x):
orignal_shape = x.shape
n, h, w = x.shape[0], x.shape[-2], x.shape[-1] # n,T,M,C,H,W
x = x.view(n, -1, h, w) # n,T*M*C,H,W
padding = (
self.pad_w,
self.pad_w,
self.pad_h,
self.pad_h,
) # left, right, top, bottom padding
x = F.pad(x, padding, mode="replicate")
h_pad, w_pad = h + 2 * self.pad_h, w + 2 * self.pad_w
eps_h = 1.0 / h_pad
eps_w = 1.0 / w_pad
arange_h = torch.linspace(
-1.0 + eps_h, 1.0 - eps_h, h_pad, device=x.device, dtype=x.dtype
)[:h]
arange_w = torch.linspace(
-1.0 + eps_w, 1.0 - eps_w, w_pad, device=x.device, dtype=x.dtype
)[:w]
arange_h = arange_h.unsqueeze(1).repeat(1, w).unsqueeze(2) # h w 1
arange_w = arange_w.unsqueeze(1).repeat(1, h).unsqueeze(2) # w h 1
# print(arange_h.shape, arange_w.shape)
base_grid = torch.cat([arange_w.transpose(1, 0), arange_h], dim=2) # [H, W, 2]
base_grid = base_grid.unsqueeze(0).repeat(
n, 1, 1, 1
) # Repeat for batch [B, H, W, 2]
shift_h = torch.randint(
0, 2 * self.pad_h + 1, size=(n, 1, 1, 1), device=x.device, dtype=x.dtype
).float()
shift_w = torch.randint(
0, 2 * self.pad_w + 1, size=(n, 1, 1, 1), device=x.device, dtype=x.dtype
).float()
shift_h *= 2.0 / h_pad
shift_w *= 2.0 / w_pad
grid = base_grid + torch.cat([shift_w, shift_h], dim=3)
x = F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
return x.view(orignal_shape)
def get_norm_stats(state, action):
all_qpos_data = torch.from_numpy(np.array(state))
all_action_data = torch.from_numpy(np.array(action))
# normalize action data
action_mean = all_action_data.mean(dim=[0], keepdim=True)
action_std = all_action_data.std(dim=[0], keepdim=True)
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
action_max = torch.amax(all_action_data, dim=[0], keepdim=True)
action_min = torch.amin(all_action_data, dim=[0], keepdim=True)
# normalize qpos data
qpos_mean = all_qpos_data.mean(dim=[0], keepdim=True)
qpos_std = all_qpos_data.std(dim=[0], keepdim=True)
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
stats = {
"action_mean": action_mean.numpy().squeeze(),
"action_std": action_std.numpy().squeeze(),
"action_max": action_max.numpy().squeeze(),
"action_min": action_min.numpy().squeeze(),
"qpos_mean": qpos_mean.numpy().squeeze(),
"qpos_std": qpos_std.numpy().squeeze(),
}
return stats
class EpisodicDataset_Unified_Multiview(Dataset):
def __init__(self, data_path_list, camera_names, chunk_size,stats, img_aug=False):
super(EpisodicDataset_Unified_Multiview).__init__()
self.data_path_list = data_path_list
self.camera_names = camera_names
self.chunk_size = chunk_size
self.norm_stats = stats
self.img_aug = img_aug
self.ColorJitter = transforms.ColorJitter(
brightness=0.2,contrast=0.2,saturation=0.2,hue=0.01)
def __len__(self):
return len(self.data_path_list) * 16
def __getitem__(self, path_index):
# qpos = np.concatenate((root['/arm/jointStatePosition/masterLeft'][()], root['/arm/jointStatePosition/masterRight'][()]), axis=-1)
# actions = np.concatenate((root['/arm/jointStatePosition/puppetLeft']
path_index = path_index % len(self.data_path_list) # ensure index is within bounds
example_path = self.data_path_list[path_index]
with h5py.File(example_path, 'r') as f:
action = f['observations']['qpos'][()] # jointStatePosition/master
qpos = f['action'][()] # jointStatePosition/puppet
parent_path = os.path.dirname(example_path)
Instruction_path = os.path.join(parent_path, 'instructions')
# randomly sample instruction file
instruction_files = [f for f in os.listdir(Instruction_path) if fnmatch.fnmatch(f, '*.pt')]
instruction_file = os.path.join(Instruction_path, np.random.choice(instruction_files))
instruction = torch.load(instruction_file, weights_only=False) # num_token * 4096 tensor
# randomly sample an episode inex
episode_len = action.shape[0]
index = np.random.randint(0, episode_len)
obs_qpos = qpos[index:index + 1]
# stack images
with h5py.File(example_path, 'r') as f:
camera_list = []
for camera_name in self.camera_names:
cam_jpeg_code = f['observations']['images'][camera_name][index]
cam_image = cv2.imdecode(np.frombuffer(cam_jpeg_code, np.uint8), cv2.IMREAD_COLOR) # rgb
camera_list.append(cam_image)
obs_img = np.stack(camera_list, axis=0) # shape: (N_views, H, W, C)
original_action_shape = (self.chunk_size, *action.shape[1:])
gt_action = np.zeros(original_action_shape)
action_len = min(self.chunk_size, episode_len - index)
gt_action[:action_len] = action[
index : index + action_len
]
is_pad = np.zeros(self.chunk_size)
is_pad[action_len:] = 1
# construct observations tensor type
image_data = torch.from_numpy(obs_img).unsqueeze(0).float() # (history_steps+1, 1, H, W, 3) add num_view
image_data = image_data.permute(0, 1, 4, 2, 3) # (1, N_views, 3, H, W)
qpos_data = torch.from_numpy(obs_qpos).float()# .unsqueeze(0) # (1, 14)
action_data = torch.from_numpy(gt_action).float() # (chunk_size, 14)
is_pad = torch.from_numpy(is_pad).bool() # (chunk_size, )
instruction_data = instruction.mean(0).float() # (4096)
# normalize image and qpos
image_data = image_data / 255.0 # Normalize to [0, 1] T N C H W
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats[
"qpos_std"
]
if self.img_aug and random.random() < 0.25:
for t in range(image_data.shape[0]):
for i in range(image_data.shape[1]):
image_data[t, i] =self.ColorJitter(image_data[t, i])
return image_data, qpos_data.float(), action_data, is_pad, instruction_data
def load_data_unified(
data_dir='/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/RDT/training_data/rdt_real_multitask',
camera_names=['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
batch_size_train=32,
chunk_size=100,
img_aug=False,
fintune=False,
):
HDF5_file_path = []
for root, _, files in os.walk(data_dir, followlinks=True):
for filename in files:
if filename.endswith('.hdf5'):
HDF5_file_path.append(os.path.join(root, filename))
print(f"Loading data from {data_dir} with {len(HDF5_file_path)} episodes and batch size {batch_size_train}")
state_list = []
action_list = []
# qpos = np.concatenate((root['/arm/jointStatePosition/masterLeft'][()], root['/arm/jointStatePosition/masterRight'][()]), axis=-1)
# actions = np.concatenate((root['/arm/jointStatePosition/puppetLeft']
for p in tqdm(HDF5_file_path, desc="Data statics collection"):
with h5py.File(p, 'r') as f:
action = f['observations']['qpos'][()]
qpos = f['action'][()]
state_list.append(qpos)
action_list.append(action)
states = np.concatenate(state_list, axis=0)
actions = np.concatenate(action_list, axis=0)
if fintune:
# load stats from pretrain path 1590 episodes
pretrain_stats_path = '/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/ACT_DP_multitask/checkpoints/real_pretrain_50_2000/act_dp/dataset_stats.pkl'
with open(pretrain_stats_path, 'rb') as f:
stats = pickle.load(f)
print(f"Loaded stats from {pretrain_stats_path}")
else:
stats = get_norm_stats(states, actions)
for key, value in stats.items():
print(f"{key}: {value}")
train_dataset = EpisodicDataset_Unified_Multiview(
data_path_list=HDF5_file_path,
camera_names=camera_names,
chunk_size=chunk_size,
stats=stats,
img_aug=img_aug,
)
traind_data_loader = DataLoader(
train_dataset,
batch_size=batch_size_train,
shuffle=True,
num_workers=8,
pin_memory=True,
)
return traind_data_loader,None,None, stats
def compute_dict_mean(epoch_dicts):
result = {k: None for k in epoch_dicts[0]}
num_items = len(epoch_dicts)
for k in result:
value_sum = 0
for epoch_dict in epoch_dicts:
value_sum += epoch_dict[k]
result[k] = value_sum / num_items
return result
def detach_dict(d):
new_d = dict()
for k, v in d.items():
new_d[k] = v.detach()
return new_d
# def set_seed(seed):
# torch.manual_seed(seed)
# np.random.seed(seed)
import random
def set_seed(seed):
random.seed(seed) #
np.random.seed(seed) #
torch.manual_seed(seed) #
torch.cuda.manual_seed(seed) #
torch.cuda.manual_seed_all(seed) #
def get_cosine_schedule_with_warmup(
optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer ([~torch.optim.Optimizer]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (int):
The number of steps for the warmup phase.
num_training_steps (int):
The total number of training steps.
num_cycles (float, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (int, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(
0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_constant_schedule(optimizer, last_epoch: int = -1) -> LambdaLR:
"""
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
def normalize_data(action_data, stats, norm_type, data_type="action"):
if norm_type == "minmax":
action_max = torch.from_numpy(stats[data_type + "_max"]).float().to(action_data.device)
action_min = torch.from_numpy(stats[data_type + "_min"]).float().to(action_data.device)
action_data = (action_data - action_min) / (action_max - action_min) * 2 - 1
elif norm_type == "gaussian":
action_mean = torch.from_numpy(stats[data_type + "_mean"]).float().to(action_data.device)
action_std = torch.from_numpy(stats[data_type + "_std"]).float().to(action_data.device)
action_data = (action_data - action_mean) / action_std
return action_data
def convert_weight(obj):
newmodel = OrderedDict()
for k, v in obj.items():
if k.startswith("module."):
newmodel[k[7:]] = v
else:
newmodel[k] = v
return newmodel
if __name__ == "__main__":
train_dataloader,_,_,stats = load_data_unified(
data_dir='/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/RDT/training_data/rdt_real_multitask',
camera_names=['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
batch_size_train=32,
chunk_size=100,
img_aug=True,
)
for i, (image_data, qpos_data, action_data, is_pad, instruction_data) in enumerate(
tqdm(train_dataloader, desc="Data loading")
):
if i == 0:
print(f"Batch {i}:")
print(f"Image data shape: {image_data.shape} {image_data.max()}")
print(f"Qpos data shape: {qpos_data.shape} {qpos_data.max()}" )
print(f"Action data shape: {action_data.shape} {action_data.max()}")
print(f"Is pad shape: {is_pad.shape}")
print(f"Instruction data shape: {instruction_data.shape}")
continue