import torch from torch.utils.data import Dataset, DataLoader import numpy as np import os import cv2 import matplotlib.pyplot as plt import math from torch.nn.modules.batchnorm import _BatchNorm from collections import OrderedDict from torch.optim.lr_scheduler import LambdaLR import torch.nn as nn from torch.nn import functional as F import h5py import fnmatch from torchvision import transforms import pickle from tqdm import tqdm _UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed): # save training curves for key in train_history[0]: plot_path = os.path.join(ckpt_dir, f"train_val_{key}_seed_{seed}.png") plt.figure() train_values = [summary[key].item() for summary in train_history] val_values = [summary[key].item() for summary in validation_history] plt.plot( np.linspace(0, num_epochs - 1, len(train_history)), train_values, label="train", ) plt.plot( np.linspace(0, num_epochs - 1, len(validation_history)), val_values, label="validation", ) plt.tight_layout() plt.legend() plt.title(key) plt.savefig(plot_path) print(f"Saved plots to {ckpt_dir}") def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray: """Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255]. Args: input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1]. Returns: A numpy image of layout BxHxWx3, range [0..255], uint8 dtype. """ if range_min == -1: input_tensor = (input_tensor.float() + 1.0) / 2.0 ndim = input_tensor.ndim output_image = input_tensor.clamp(0, 1).cpu().numpy() output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,)) return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8) def kl_divergence(mu, logvar): batch_size = mu.size(0) assert batch_size != 0 if mu.data.ndimension() == 4: mu = mu.view(mu.size(0), mu.size(1)) if logvar.data.ndimension() == 4: logvar = logvar.view(logvar.size(0), logvar.size(1)) klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) total_kld = klds.sum(1).mean(0, True) dimension_wise_kld = klds.mean(0) mean_kld = klds.mean(1).mean(0, True) return total_kld, dimension_wise_kld, mean_kld class RandomShiftsAug(nn.Module): def __init__(self, pad_h, pad_w): super().__init__() self.pad_h = pad_h self.pad_w = pad_w print(f"RandomShiftsAug: pad_h {pad_h}, pad_w {pad_w}") def forward(self, x): orignal_shape = x.shape n, h, w = x.shape[0], x.shape[-2], x.shape[-1] # n,T,M,C,H,W x = x.view(n, -1, h, w) # n,T*M*C,H,W padding = ( self.pad_w, self.pad_w, self.pad_h, self.pad_h, ) # left, right, top, bottom padding x = F.pad(x, padding, mode="replicate") h_pad, w_pad = h + 2 * self.pad_h, w + 2 * self.pad_w eps_h = 1.0 / h_pad eps_w = 1.0 / w_pad arange_h = torch.linspace( -1.0 + eps_h, 1.0 - eps_h, h_pad, device=x.device, dtype=x.dtype )[:h] arange_w = torch.linspace( -1.0 + eps_w, 1.0 - eps_w, w_pad, device=x.device, dtype=x.dtype )[:w] arange_h = arange_h.unsqueeze(1).repeat(1, w).unsqueeze(2) # h w 1 arange_w = arange_w.unsqueeze(1).repeat(1, h).unsqueeze(2) # w h 1 # print(arange_h.shape, arange_w.shape) base_grid = torch.cat([arange_w.transpose(1, 0), arange_h], dim=2) # [H, W, 2] base_grid = base_grid.unsqueeze(0).repeat( n, 1, 1, 1 ) # Repeat for batch [B, H, W, 2] shift_h = torch.randint( 0, 2 * self.pad_h + 1, size=(n, 1, 1, 1), device=x.device, dtype=x.dtype ).float() shift_w = torch.randint( 0, 2 * self.pad_w + 1, size=(n, 1, 1, 1), device=x.device, dtype=x.dtype ).float() shift_h *= 2.0 / h_pad shift_w *= 2.0 / w_pad grid = base_grid + torch.cat([shift_w, shift_h], dim=3) x = F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) return x.view(orignal_shape) def get_norm_stats(state, action): all_qpos_data = torch.from_numpy(np.array(state)) all_action_data = torch.from_numpy(np.array(action)) # normalize action data action_mean = all_action_data.mean(dim=[0], keepdim=True) action_std = all_action_data.std(dim=[0], keepdim=True) action_std = torch.clip(action_std, 1e-2, np.inf) # clipping action_max = torch.amax(all_action_data, dim=[0], keepdim=True) action_min = torch.amin(all_action_data, dim=[0], keepdim=True) # normalize qpos data qpos_mean = all_qpos_data.mean(dim=[0], keepdim=True) qpos_std = all_qpos_data.std(dim=[0], keepdim=True) qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping stats = { "action_mean": action_mean.numpy().squeeze(), "action_std": action_std.numpy().squeeze(), "action_max": action_max.numpy().squeeze(), "action_min": action_min.numpy().squeeze(), "qpos_mean": qpos_mean.numpy().squeeze(), "qpos_std": qpos_std.numpy().squeeze(), } return stats class EpisodicDataset_Unified_Multiview(Dataset): def __init__(self, data_path_list, camera_names, chunk_size,stats, img_aug=False): super(EpisodicDataset_Unified_Multiview).__init__() self.data_path_list = data_path_list self.camera_names = camera_names self.chunk_size = chunk_size self.norm_stats = stats self.img_aug = img_aug self.ColorJitter = transforms.ColorJitter( brightness=0.2,contrast=0.2,saturation=0.2,hue=0.01) def __len__(self): return len(self.data_path_list) * 16 def __getitem__(self, path_index): # qpos = np.concatenate((root['/arm/jointStatePosition/masterLeft'][()], root['/arm/jointStatePosition/masterRight'][()]), axis=-1) # actions = np.concatenate((root['/arm/jointStatePosition/puppetLeft'] path_index = path_index % len(self.data_path_list) # ensure index is within bounds example_path = self.data_path_list[path_index] with h5py.File(example_path, 'r') as f: action = f['observations']['qpos'][()] # jointStatePosition/master qpos = f['action'][()] # jointStatePosition/puppet parent_path = os.path.dirname(example_path) Instruction_path = os.path.join(parent_path, 'instructions') # randomly sample instruction file instruction_files = [f for f in os.listdir(Instruction_path) if fnmatch.fnmatch(f, '*.pt')] instruction_file = os.path.join(Instruction_path, np.random.choice(instruction_files)) instruction = torch.load(instruction_file, weights_only=False) # num_token * 4096 tensor # randomly sample an episode inex episode_len = action.shape[0] index = np.random.randint(0, episode_len) obs_qpos = qpos[index:index + 1] # stack images with h5py.File(example_path, 'r') as f: camera_list = [] for camera_name in self.camera_names: cam_jpeg_code = f['observations']['images'][camera_name][index] cam_image = cv2.imdecode(np.frombuffer(cam_jpeg_code, np.uint8), cv2.IMREAD_COLOR) # rgb camera_list.append(cam_image) obs_img = np.stack(camera_list, axis=0) # shape: (N_views, H, W, C) original_action_shape = (self.chunk_size, *action.shape[1:]) gt_action = np.zeros(original_action_shape) action_len = min(self.chunk_size, episode_len - index) gt_action[:action_len] = action[ index : index + action_len ] is_pad = np.zeros(self.chunk_size) is_pad[action_len:] = 1 # construct observations tensor type image_data = torch.from_numpy(obs_img).unsqueeze(0).float() # (history_steps+1, 1, H, W, 3) add num_view image_data = image_data.permute(0, 1, 4, 2, 3) # (1, N_views, 3, H, W) qpos_data = torch.from_numpy(obs_qpos).float()# .unsqueeze(0) # (1, 14) action_data = torch.from_numpy(gt_action).float() # (chunk_size, 14) is_pad = torch.from_numpy(is_pad).bool() # (chunk_size, ) instruction_data = instruction.mean(0).float() # (4096) # normalize image and qpos image_data = image_data / 255.0 # Normalize to [0, 1] T N C H W qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats[ "qpos_std" ] if self.img_aug and random.random() < 0.25: for t in range(image_data.shape[0]): for i in range(image_data.shape[1]): image_data[t, i] =self.ColorJitter(image_data[t, i]) return image_data, qpos_data.float(), action_data, is_pad, instruction_data def load_data_unified( data_dir='/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/RDT/training_data/rdt_real_multitask', camera_names=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], batch_size_train=32, chunk_size=100, img_aug=False, fintune=False, ): HDF5_file_path = [] for root, _, files in os.walk(data_dir, followlinks=True): for filename in files: if filename.endswith('.hdf5'): HDF5_file_path.append(os.path.join(root, filename)) print(f"Loading data from {data_dir} with {len(HDF5_file_path)} episodes and batch size {batch_size_train}") state_list = [] action_list = [] # qpos = np.concatenate((root['/arm/jointStatePosition/masterLeft'][()], root['/arm/jointStatePosition/masterRight'][()]), axis=-1) # actions = np.concatenate((root['/arm/jointStatePosition/puppetLeft'] for p in tqdm(HDF5_file_path, desc="Data statics collection"): with h5py.File(p, 'r') as f: action = f['observations']['qpos'][()] qpos = f['action'][()] state_list.append(qpos) action_list.append(action) states = np.concatenate(state_list, axis=0) actions = np.concatenate(action_list, axis=0) if fintune: # load stats from pretrain path 1590 episodes pretrain_stats_path = '/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/ACT_DP_multitask/checkpoints/real_pretrain_50_2000/act_dp/dataset_stats.pkl' with open(pretrain_stats_path, 'rb') as f: stats = pickle.load(f) print(f"Loaded stats from {pretrain_stats_path}") else: stats = get_norm_stats(states, actions) for key, value in stats.items(): print(f"{key}: {value}") train_dataset = EpisodicDataset_Unified_Multiview( data_path_list=HDF5_file_path, camera_names=camera_names, chunk_size=chunk_size, stats=stats, img_aug=img_aug, ) traind_data_loader = DataLoader( train_dataset, batch_size=batch_size_train, shuffle=True, num_workers=8, pin_memory=True, ) return traind_data_loader,None,None, stats def compute_dict_mean(epoch_dicts): result = {k: None for k in epoch_dicts[0]} num_items = len(epoch_dicts) for k in result: value_sum = 0 for epoch_dict in epoch_dicts: value_sum += epoch_dict[k] result[k] = value_sum / num_items return result def detach_dict(d): new_d = dict() for k, v in d.items(): new_d[k] = v.detach() return new_d # def set_seed(seed): # torch.manual_seed(seed) # np.random.seed(seed) import random def set_seed(seed): random.seed(seed) # np.random.seed(seed) # torch.manual_seed(seed) # torch.cuda.manual_seed(seed) # torch.cuda.manual_seed_all(seed) # def get_cosine_schedule_with_warmup( optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1 ): """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([~torch.optim.Optimizer]): The optimizer for which to schedule the learning rate. num_warmup_steps (int): The number of steps for the warmup phase. num_training_steps (int): The total number of training steps. num_cycles (float, *optional*, defaults to 0.5): The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 following a half-cosine). last_epoch (int, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: torch.optim.lr_scheduler.LambdaLR with the appropriate schedule. """ def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float( max(1, num_training_steps - num_warmup_steps) ) return max( 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) ) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_constant_schedule(optimizer, last_epoch: int = -1) -> LambdaLR: """ Create a schedule with a constant learning rate, using the learning rate set in optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) def normalize_data(action_data, stats, norm_type, data_type="action"): if norm_type == "minmax": action_max = torch.from_numpy(stats[data_type + "_max"]).float().to(action_data.device) action_min = torch.from_numpy(stats[data_type + "_min"]).float().to(action_data.device) action_data = (action_data - action_min) / (action_max - action_min) * 2 - 1 elif norm_type == "gaussian": action_mean = torch.from_numpy(stats[data_type + "_mean"]).float().to(action_data.device) action_std = torch.from_numpy(stats[data_type + "_std"]).float().to(action_data.device) action_data = (action_data - action_mean) / action_std return action_data def convert_weight(obj): newmodel = OrderedDict() for k, v in obj.items(): if k.startswith("module."): newmodel[k[7:]] = v else: newmodel[k] = v return newmodel if __name__ == "__main__": train_dataloader,_,_,stats = load_data_unified( data_dir='/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/RDT/training_data/rdt_real_multitask', camera_names=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], batch_size_train=32, chunk_size=100, img_aug=True, ) for i, (image_data, qpos_data, action_data, is_pad, instruction_data) in enumerate( tqdm(train_dataloader, desc="Data loading") ): if i == 0: print(f"Batch {i}:") print(f"Image data shape: {image_data.shape} {image_data.max()}") print(f"Qpos data shape: {qpos_data.shape} {qpos_data.max()}" ) print(f"Action data shape: {action_data.shape} {action_data.max()}") print(f"Is pad shape: {is_pad.shape}") print(f"Instruction data shape: {instruction_data.shape}") continue