| | import torch |
| | from torch.utils.data import Dataset, DataLoader |
| | import numpy as np |
| | import os |
| | import cv2 |
| | import matplotlib.pyplot as plt |
| | import math |
| | from torch.nn.modules.batchnorm import _BatchNorm |
| | from collections import OrderedDict |
| | from torch.optim.lr_scheduler import LambdaLR |
| | import torch.nn as nn |
| | from torch.nn import functional as F |
| | import h5py |
| | import fnmatch |
| | from torchvision import transforms |
| | import pickle |
| | from tqdm import tqdm |
| | _UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) |
| |
|
| | def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed): |
| | |
| | for key in train_history[0]: |
| | plot_path = os.path.join(ckpt_dir, f"train_val_{key}_seed_{seed}.png") |
| | plt.figure() |
| | train_values = [summary[key].item() for summary in train_history] |
| | val_values = [summary[key].item() for summary in validation_history] |
| | plt.plot( |
| | np.linspace(0, num_epochs - 1, len(train_history)), |
| | train_values, |
| | label="train", |
| | ) |
| | plt.plot( |
| | np.linspace(0, num_epochs - 1, len(validation_history)), |
| | val_values, |
| | label="validation", |
| | ) |
| | plt.tight_layout() |
| | plt.legend() |
| | plt.title(key) |
| | plt.savefig(plot_path) |
| | print(f"Saved plots to {ckpt_dir}") |
| |
|
| |
|
| | def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray: |
| | """Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255]. |
| | |
| | Args: |
| | input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1]. |
| | Returns: |
| | A numpy image of layout BxHxWx3, range [0..255], uint8 dtype. |
| | """ |
| | if range_min == -1: |
| | input_tensor = (input_tensor.float() + 1.0) / 2.0 |
| | ndim = input_tensor.ndim |
| | output_image = input_tensor.clamp(0, 1).cpu().numpy() |
| | output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,)) |
| | return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8) |
| |
|
| |
|
| | def kl_divergence(mu, logvar): |
| | batch_size = mu.size(0) |
| | assert batch_size != 0 |
| | if mu.data.ndimension() == 4: |
| | mu = mu.view(mu.size(0), mu.size(1)) |
| | if logvar.data.ndimension() == 4: |
| | logvar = logvar.view(logvar.size(0), logvar.size(1)) |
| |
|
| | klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) |
| | total_kld = klds.sum(1).mean(0, True) |
| | dimension_wise_kld = klds.mean(0) |
| | mean_kld = klds.mean(1).mean(0, True) |
| |
|
| | return total_kld, dimension_wise_kld, mean_kld |
| |
|
| |
|
| | class RandomShiftsAug(nn.Module): |
| | def __init__(self, pad_h, pad_w): |
| | super().__init__() |
| | self.pad_h = pad_h |
| | self.pad_w = pad_w |
| | print(f"RandomShiftsAug: pad_h {pad_h}, pad_w {pad_w}") |
| |
|
| | def forward(self, x): |
| | orignal_shape = x.shape |
| | n, h, w = x.shape[0], x.shape[-2], x.shape[-1] |
| | x = x.view(n, -1, h, w) |
| | padding = ( |
| | self.pad_w, |
| | self.pad_w, |
| | self.pad_h, |
| | self.pad_h, |
| | ) |
| | x = F.pad(x, padding, mode="replicate") |
| |
|
| | h_pad, w_pad = h + 2 * self.pad_h, w + 2 * self.pad_w |
| | eps_h = 1.0 / h_pad |
| | eps_w = 1.0 / w_pad |
| |
|
| | arange_h = torch.linspace( |
| | -1.0 + eps_h, 1.0 - eps_h, h_pad, device=x.device, dtype=x.dtype |
| | )[:h] |
| | arange_w = torch.linspace( |
| | -1.0 + eps_w, 1.0 - eps_w, w_pad, device=x.device, dtype=x.dtype |
| | )[:w] |
| |
|
| | arange_h = arange_h.unsqueeze(1).repeat(1, w).unsqueeze(2) |
| | arange_w = arange_w.unsqueeze(1).repeat(1, h).unsqueeze(2) |
| |
|
| | |
| | base_grid = torch.cat([arange_w.transpose(1, 0), arange_h], dim=2) |
| | base_grid = base_grid.unsqueeze(0).repeat( |
| | n, 1, 1, 1 |
| | ) |
| |
|
| | shift_h = torch.randint( |
| | 0, 2 * self.pad_h + 1, size=(n, 1, 1, 1), device=x.device, dtype=x.dtype |
| | ).float() |
| | shift_w = torch.randint( |
| | 0, 2 * self.pad_w + 1, size=(n, 1, 1, 1), device=x.device, dtype=x.dtype |
| | ).float() |
| | shift_h *= 2.0 / h_pad |
| | shift_w *= 2.0 / w_pad |
| |
|
| | grid = base_grid + torch.cat([shift_w, shift_h], dim=3) |
| | x = F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) |
| | return x.view(orignal_shape) |
| |
|
| |
|
| | def get_norm_stats(state, action): |
| | all_qpos_data = torch.from_numpy(np.array(state)) |
| | all_action_data = torch.from_numpy(np.array(action)) |
| | |
| | action_mean = all_action_data.mean(dim=[0], keepdim=True) |
| | action_std = all_action_data.std(dim=[0], keepdim=True) |
| | action_std = torch.clip(action_std, 1e-2, np.inf) |
| | action_max = torch.amax(all_action_data, dim=[0], keepdim=True) |
| | action_min = torch.amin(all_action_data, dim=[0], keepdim=True) |
| |
|
| | |
| | qpos_mean = all_qpos_data.mean(dim=[0], keepdim=True) |
| | qpos_std = all_qpos_data.std(dim=[0], keepdim=True) |
| | qpos_std = torch.clip(qpos_std, 1e-2, np.inf) |
| |
|
| | stats = { |
| | "action_mean": action_mean.numpy().squeeze(), |
| | "action_std": action_std.numpy().squeeze(), |
| | "action_max": action_max.numpy().squeeze(), |
| | "action_min": action_min.numpy().squeeze(), |
| | "qpos_mean": qpos_mean.numpy().squeeze(), |
| | "qpos_std": qpos_std.numpy().squeeze(), |
| | } |
| |
|
| | return stats |
| |
|
| | class EpisodicDataset_Unified_Multiview(Dataset): |
| | def __init__(self, data_path_list, camera_names, chunk_size,stats, img_aug=False): |
| | super(EpisodicDataset_Unified_Multiview).__init__() |
| | self.data_path_list = data_path_list |
| | self.camera_names = camera_names |
| | self.chunk_size = chunk_size |
| | self.norm_stats = stats |
| | self.img_aug = img_aug |
| | self.ColorJitter = transforms.ColorJitter( |
| | brightness=0.2,contrast=0.2,saturation=0.2,hue=0.01) |
| | def __len__(self): |
| | return len(self.data_path_list) * 16 |
| | def __getitem__(self, path_index): |
| | |
| | |
| | path_index = path_index % len(self.data_path_list) |
| | example_path = self.data_path_list[path_index] |
| | with h5py.File(example_path, 'r') as f: |
| | action = f['observations']['qpos'][()] |
| | qpos = f['action'][()] |
| |
|
| | parent_path = os.path.dirname(example_path) |
| | Instruction_path = os.path.join(parent_path, 'instructions') |
| | |
| | instruction_files = [f for f in os.listdir(Instruction_path) if fnmatch.fnmatch(f, '*.pt')] |
| | instruction_file = os.path.join(Instruction_path, np.random.choice(instruction_files)) |
| | instruction = torch.load(instruction_file, weights_only=False) |
| | |
| | episode_len = action.shape[0] |
| | index = np.random.randint(0, episode_len) |
| | obs_qpos = qpos[index:index + 1] |
| | |
| | |
| | with h5py.File(example_path, 'r') as f: |
| | camera_list = [] |
| | for camera_name in self.camera_names: |
| | cam_jpeg_code = f['observations']['images'][camera_name][index] |
| | cam_image = cv2.imdecode(np.frombuffer(cam_jpeg_code, np.uint8), cv2.IMREAD_COLOR) |
| | camera_list.append(cam_image) |
| | obs_img = np.stack(camera_list, axis=0) |
| | original_action_shape = (self.chunk_size, *action.shape[1:]) |
| | gt_action = np.zeros(original_action_shape) |
| | action_len = min(self.chunk_size, episode_len - index) |
| | gt_action[:action_len] = action[ |
| | index : index + action_len |
| | ] |
| | is_pad = np.zeros(self.chunk_size) |
| | is_pad[action_len:] = 1 |
| | |
| | |
| | image_data = torch.from_numpy(obs_img).unsqueeze(0).float() |
| | image_data = image_data.permute(0, 1, 4, 2, 3) |
| | qpos_data = torch.from_numpy(obs_qpos).float() |
| | action_data = torch.from_numpy(gt_action).float() |
| | is_pad = torch.from_numpy(is_pad).bool() |
| | instruction_data = instruction.mean(0).float() |
| | |
| | image_data = image_data / 255.0 |
| | qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats[ |
| | "qpos_std" |
| | ] |
| | if self.img_aug and random.random() < 0.25: |
| | for t in range(image_data.shape[0]): |
| | for i in range(image_data.shape[1]): |
| | image_data[t, i] =self.ColorJitter(image_data[t, i]) |
| | return image_data, qpos_data.float(), action_data, is_pad, instruction_data |
| |
|
| | def load_data_unified( |
| | data_dir='/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/RDT/training_data/rdt_real_multitask', |
| | camera_names=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], |
| | batch_size_train=32, |
| | chunk_size=100, |
| | img_aug=False, |
| | fintune=False, |
| | ): |
| | |
| | HDF5_file_path = [] |
| | for root, _, files in os.walk(data_dir, followlinks=True): |
| | for filename in files: |
| | if filename.endswith('.hdf5'): |
| | HDF5_file_path.append(os.path.join(root, filename)) |
| | print(f"Loading data from {data_dir} with {len(HDF5_file_path)} episodes and batch size {batch_size_train}") |
| | |
| | state_list = [] |
| | action_list = [] |
| | |
| | |
| | for p in tqdm(HDF5_file_path, desc="Data statics collection"): |
| | with h5py.File(p, 'r') as f: |
| | action = f['observations']['qpos'][()] |
| | qpos = f['action'][()] |
| | state_list.append(qpos) |
| | action_list.append(action) |
| | states = np.concatenate(state_list, axis=0) |
| | actions = np.concatenate(action_list, axis=0) |
| | |
| | if fintune: |
| | |
| | pretrain_stats_path = '/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/ACT_DP_multitask/checkpoints/real_pretrain_50_2000/act_dp/dataset_stats.pkl' |
| | with open(pretrain_stats_path, 'rb') as f: |
| | stats = pickle.load(f) |
| | print(f"Loaded stats from {pretrain_stats_path}") |
| | else: |
| | stats = get_norm_stats(states, actions) |
| | |
| | for key, value in stats.items(): |
| | print(f"{key}: {value}") |
| | |
| | train_dataset = EpisodicDataset_Unified_Multiview( |
| | data_path_list=HDF5_file_path, |
| | camera_names=camera_names, |
| | chunk_size=chunk_size, |
| | stats=stats, |
| | img_aug=img_aug, |
| | ) |
| | |
| | traind_data_loader = DataLoader( |
| | train_dataset, |
| | batch_size=batch_size_train, |
| | shuffle=True, |
| | num_workers=8, |
| | pin_memory=True, |
| | ) |
| | |
| | return traind_data_loader,None,None, stats |
| |
|
| | def compute_dict_mean(epoch_dicts): |
| | result = {k: None for k in epoch_dicts[0]} |
| | num_items = len(epoch_dicts) |
| | for k in result: |
| | value_sum = 0 |
| | for epoch_dict in epoch_dicts: |
| | value_sum += epoch_dict[k] |
| | result[k] = value_sum / num_items |
| | return result |
| |
|
| |
|
| | def detach_dict(d): |
| | new_d = dict() |
| | for k, v in d.items(): |
| | new_d[k] = v.detach() |
| | return new_d |
| |
|
| |
|
| | |
| | |
| | |
| | import random |
| |
|
| |
|
| | def set_seed(seed): |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| | def get_cosine_schedule_with_warmup( |
| | optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1 |
| | ): |
| | """ |
| | Create a schedule with a learning rate that decreases following the values of the cosine function between the |
| | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the |
| | initial lr set in the optimizer. |
| | |
| | Args: |
| | optimizer ([~torch.optim.Optimizer]): |
| | The optimizer for which to schedule the learning rate. |
| | num_warmup_steps (int): |
| | The number of steps for the warmup phase. |
| | num_training_steps (int): |
| | The total number of training steps. |
| | num_cycles (float, *optional*, defaults to 0.5): |
| | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 |
| | following a half-cosine). |
| | last_epoch (int, *optional*, defaults to -1): |
| | The index of the last epoch when resuming training. |
| | |
| | Return: |
| | torch.optim.lr_scheduler.LambdaLR with the appropriate schedule. |
| | """ |
| |
|
| | def lr_lambda(current_step): |
| | if current_step < num_warmup_steps: |
| | return float(current_step) / float(max(1, num_warmup_steps)) |
| | progress = float(current_step - num_warmup_steps) / float( |
| | max(1, num_training_steps - num_warmup_steps) |
| | ) |
| | return max( |
| | 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) |
| | ) |
| |
|
| | return LambdaLR(optimizer, lr_lambda, last_epoch) |
| |
|
| |
|
| | def get_constant_schedule(optimizer, last_epoch: int = -1) -> LambdaLR: |
| | """ |
| | Create a schedule with a constant learning rate, using the learning rate set in optimizer. |
| | |
| | Args: |
| | optimizer ([`~torch.optim.Optimizer`]): |
| | The optimizer for which to schedule the learning rate. |
| | last_epoch (`int`, *optional*, defaults to -1): |
| | The index of the last epoch when resuming training. |
| | |
| | Return: |
| | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. |
| | """ |
| | return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) |
| |
|
| |
|
| | def normalize_data(action_data, stats, norm_type, data_type="action"): |
| |
|
| | if norm_type == "minmax": |
| | action_max = torch.from_numpy(stats[data_type + "_max"]).float().to(action_data.device) |
| | action_min = torch.from_numpy(stats[data_type + "_min"]).float().to(action_data.device) |
| | action_data = (action_data - action_min) / (action_max - action_min) * 2 - 1 |
| | elif norm_type == "gaussian": |
| | action_mean = torch.from_numpy(stats[data_type + "_mean"]).float().to(action_data.device) |
| | action_std = torch.from_numpy(stats[data_type + "_std"]).float().to(action_data.device) |
| | action_data = (action_data - action_mean) / action_std |
| | return action_data |
| |
|
| |
|
| | def convert_weight(obj): |
| | newmodel = OrderedDict() |
| | for k, v in obj.items(): |
| | if k.startswith("module."): |
| | newmodel[k[7:]] = v |
| | else: |
| | newmodel[k] = v |
| | return newmodel |
| |
|
| |
|
| | if __name__ == "__main__": |
| | train_dataloader,_,_,stats = load_data_unified( |
| | data_dir='/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/RDT/training_data/rdt_real_multitask', |
| | camera_names=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], |
| | batch_size_train=32, |
| | chunk_size=100, |
| | img_aug=True, |
| | ) |
| | |
| | for i, (image_data, qpos_data, action_data, is_pad, instruction_data) in enumerate( |
| | tqdm(train_dataloader, desc="Data loading") |
| | ): |
| | if i == 0: |
| | print(f"Batch {i}:") |
| | print(f"Image data shape: {image_data.shape} {image_data.max()}") |
| | print(f"Qpos data shape: {qpos_data.shape} {qpos_data.max()}" ) |
| | print(f"Action data shape: {action_data.shape} {action_data.max()}") |
| | print(f"Is pad shape: {is_pad.shape}") |
| | print(f"Instruction data shape: {instruction_data.shape}") |
| | |
| | continue |
| | |