|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
import numpy as np |
|
import os |
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
import math |
|
from torch.nn.modules.batchnorm import _BatchNorm |
|
from collections import OrderedDict |
|
from torch.optim.lr_scheduler import LambdaLR |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
import h5py |
|
import fnmatch |
|
from torchvision import transforms |
|
import pickle |
|
from tqdm import tqdm |
|
_UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) |
|
|
|
def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed): |
|
|
|
for key in train_history[0]: |
|
plot_path = os.path.join(ckpt_dir, f"train_val_{key}_seed_{seed}.png") |
|
plt.figure() |
|
train_values = [summary[key].item() for summary in train_history] |
|
val_values = [summary[key].item() for summary in validation_history] |
|
plt.plot( |
|
np.linspace(0, num_epochs - 1, len(train_history)), |
|
train_values, |
|
label="train", |
|
) |
|
plt.plot( |
|
np.linspace(0, num_epochs - 1, len(validation_history)), |
|
val_values, |
|
label="validation", |
|
) |
|
plt.tight_layout() |
|
plt.legend() |
|
plt.title(key) |
|
plt.savefig(plot_path) |
|
print(f"Saved plots to {ckpt_dir}") |
|
|
|
|
|
def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray: |
|
"""Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255]. |
|
|
|
Args: |
|
input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1]. |
|
Returns: |
|
A numpy image of layout BxHxWx3, range [0..255], uint8 dtype. |
|
""" |
|
if range_min == -1: |
|
input_tensor = (input_tensor.float() + 1.0) / 2.0 |
|
ndim = input_tensor.ndim |
|
output_image = input_tensor.clamp(0, 1).cpu().numpy() |
|
output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,)) |
|
return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8) |
|
|
|
|
|
def kl_divergence(mu, logvar): |
|
batch_size = mu.size(0) |
|
assert batch_size != 0 |
|
if mu.data.ndimension() == 4: |
|
mu = mu.view(mu.size(0), mu.size(1)) |
|
if logvar.data.ndimension() == 4: |
|
logvar = logvar.view(logvar.size(0), logvar.size(1)) |
|
|
|
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) |
|
total_kld = klds.sum(1).mean(0, True) |
|
dimension_wise_kld = klds.mean(0) |
|
mean_kld = klds.mean(1).mean(0, True) |
|
|
|
return total_kld, dimension_wise_kld, mean_kld |
|
|
|
|
|
class RandomShiftsAug(nn.Module): |
|
def __init__(self, pad_h, pad_w): |
|
super().__init__() |
|
self.pad_h = pad_h |
|
self.pad_w = pad_w |
|
print(f"RandomShiftsAug: pad_h {pad_h}, pad_w {pad_w}") |
|
|
|
def forward(self, x): |
|
orignal_shape = x.shape |
|
n, h, w = x.shape[0], x.shape[-2], x.shape[-1] |
|
x = x.view(n, -1, h, w) |
|
padding = ( |
|
self.pad_w, |
|
self.pad_w, |
|
self.pad_h, |
|
self.pad_h, |
|
) |
|
x = F.pad(x, padding, mode="replicate") |
|
|
|
h_pad, w_pad = h + 2 * self.pad_h, w + 2 * self.pad_w |
|
eps_h = 1.0 / h_pad |
|
eps_w = 1.0 / w_pad |
|
|
|
arange_h = torch.linspace( |
|
-1.0 + eps_h, 1.0 - eps_h, h_pad, device=x.device, dtype=x.dtype |
|
)[:h] |
|
arange_w = torch.linspace( |
|
-1.0 + eps_w, 1.0 - eps_w, w_pad, device=x.device, dtype=x.dtype |
|
)[:w] |
|
|
|
arange_h = arange_h.unsqueeze(1).repeat(1, w).unsqueeze(2) |
|
arange_w = arange_w.unsqueeze(1).repeat(1, h).unsqueeze(2) |
|
|
|
|
|
base_grid = torch.cat([arange_w.transpose(1, 0), arange_h], dim=2) |
|
base_grid = base_grid.unsqueeze(0).repeat( |
|
n, 1, 1, 1 |
|
) |
|
|
|
shift_h = torch.randint( |
|
0, 2 * self.pad_h + 1, size=(n, 1, 1, 1), device=x.device, dtype=x.dtype |
|
).float() |
|
shift_w = torch.randint( |
|
0, 2 * self.pad_w + 1, size=(n, 1, 1, 1), device=x.device, dtype=x.dtype |
|
).float() |
|
shift_h *= 2.0 / h_pad |
|
shift_w *= 2.0 / w_pad |
|
|
|
grid = base_grid + torch.cat([shift_w, shift_h], dim=3) |
|
x = F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) |
|
return x.view(orignal_shape) |
|
|
|
|
|
def get_norm_stats(state, action): |
|
all_qpos_data = torch.from_numpy(np.array(state)) |
|
all_action_data = torch.from_numpy(np.array(action)) |
|
|
|
action_mean = all_action_data.mean(dim=[0], keepdim=True) |
|
action_std = all_action_data.std(dim=[0], keepdim=True) |
|
action_std = torch.clip(action_std, 1e-2, np.inf) |
|
action_max = torch.amax(all_action_data, dim=[0], keepdim=True) |
|
action_min = torch.amin(all_action_data, dim=[0], keepdim=True) |
|
|
|
|
|
qpos_mean = all_qpos_data.mean(dim=[0], keepdim=True) |
|
qpos_std = all_qpos_data.std(dim=[0], keepdim=True) |
|
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) |
|
|
|
stats = { |
|
"action_mean": action_mean.numpy().squeeze(), |
|
"action_std": action_std.numpy().squeeze(), |
|
"action_max": action_max.numpy().squeeze(), |
|
"action_min": action_min.numpy().squeeze(), |
|
"qpos_mean": qpos_mean.numpy().squeeze(), |
|
"qpos_std": qpos_std.numpy().squeeze(), |
|
} |
|
|
|
return stats |
|
|
|
class EpisodicDataset_Unified_Multiview(Dataset): |
|
def __init__(self, data_path_list, camera_names, chunk_size,stats, img_aug=False): |
|
super(EpisodicDataset_Unified_Multiview).__init__() |
|
self.data_path_list = data_path_list |
|
self.camera_names = camera_names |
|
self.chunk_size = chunk_size |
|
self.norm_stats = stats |
|
self.img_aug = img_aug |
|
self.ColorJitter = transforms.ColorJitter( |
|
brightness=0.2,contrast=0.2,saturation=0.2,hue=0.01) |
|
def __len__(self): |
|
return len(self.data_path_list) * 16 |
|
def __getitem__(self, path_index): |
|
|
|
|
|
path_index = path_index % len(self.data_path_list) |
|
example_path = self.data_path_list[path_index] |
|
with h5py.File(example_path, 'r') as f: |
|
action = f['observations']['qpos'][()] |
|
qpos = f['action'][()] |
|
|
|
parent_path = os.path.dirname(example_path) |
|
Instruction_path = os.path.join(parent_path, 'instructions') |
|
|
|
instruction_files = [f for f in os.listdir(Instruction_path) if fnmatch.fnmatch(f, '*.pt')] |
|
instruction_file = os.path.join(Instruction_path, np.random.choice(instruction_files)) |
|
instruction = torch.load(instruction_file, weights_only=False) |
|
|
|
episode_len = action.shape[0] |
|
index = np.random.randint(0, episode_len) |
|
obs_qpos = qpos[index:index + 1] |
|
|
|
|
|
with h5py.File(example_path, 'r') as f: |
|
camera_list = [] |
|
for camera_name in self.camera_names: |
|
cam_jpeg_code = f['observations']['images'][camera_name][index] |
|
cam_image = cv2.imdecode(np.frombuffer(cam_jpeg_code, np.uint8), cv2.IMREAD_COLOR) |
|
camera_list.append(cam_image) |
|
obs_img = np.stack(camera_list, axis=0) |
|
original_action_shape = (self.chunk_size, *action.shape[1:]) |
|
gt_action = np.zeros(original_action_shape) |
|
action_len = min(self.chunk_size, episode_len - index) |
|
gt_action[:action_len] = action[ |
|
index : index + action_len |
|
] |
|
is_pad = np.zeros(self.chunk_size) |
|
is_pad[action_len:] = 1 |
|
|
|
|
|
image_data = torch.from_numpy(obs_img).unsqueeze(0).float() |
|
image_data = image_data.permute(0, 1, 4, 2, 3) |
|
qpos_data = torch.from_numpy(obs_qpos).float() |
|
action_data = torch.from_numpy(gt_action).float() |
|
is_pad = torch.from_numpy(is_pad).bool() |
|
instruction_data = instruction.mean(0).float() |
|
|
|
image_data = image_data / 255.0 |
|
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats[ |
|
"qpos_std" |
|
] |
|
if self.img_aug and random.random() < 0.25: |
|
for t in range(image_data.shape[0]): |
|
for i in range(image_data.shape[1]): |
|
image_data[t, i] =self.ColorJitter(image_data[t, i]) |
|
return image_data, qpos_data.float(), action_data, is_pad, instruction_data |
|
|
|
def load_data_unified( |
|
data_dir='/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/RDT/training_data/rdt_real_multitask', |
|
camera_names=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], |
|
batch_size_train=32, |
|
chunk_size=100, |
|
img_aug=False, |
|
fintune=False, |
|
): |
|
|
|
HDF5_file_path = [] |
|
for root, _, files in os.walk(data_dir, followlinks=True): |
|
for filename in files: |
|
if filename.endswith('.hdf5'): |
|
HDF5_file_path.append(os.path.join(root, filename)) |
|
print(f"Loading data from {data_dir} with {len(HDF5_file_path)} episodes and batch size {batch_size_train}") |
|
|
|
state_list = [] |
|
action_list = [] |
|
|
|
|
|
for p in tqdm(HDF5_file_path, desc="Data statics collection"): |
|
with h5py.File(p, 'r') as f: |
|
action = f['observations']['qpos'][()] |
|
qpos = f['action'][()] |
|
state_list.append(qpos) |
|
action_list.append(action) |
|
states = np.concatenate(state_list, axis=0) |
|
actions = np.concatenate(action_list, axis=0) |
|
|
|
if fintune: |
|
|
|
pretrain_stats_path = '/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/ACT_DP_multitask/checkpoints/real_pretrain_50_2000/act_dp/dataset_stats.pkl' |
|
with open(pretrain_stats_path, 'rb') as f: |
|
stats = pickle.load(f) |
|
print(f"Loaded stats from {pretrain_stats_path}") |
|
else: |
|
stats = get_norm_stats(states, actions) |
|
|
|
for key, value in stats.items(): |
|
print(f"{key}: {value}") |
|
|
|
train_dataset = EpisodicDataset_Unified_Multiview( |
|
data_path_list=HDF5_file_path, |
|
camera_names=camera_names, |
|
chunk_size=chunk_size, |
|
stats=stats, |
|
img_aug=img_aug, |
|
) |
|
|
|
traind_data_loader = DataLoader( |
|
train_dataset, |
|
batch_size=batch_size_train, |
|
shuffle=True, |
|
num_workers=8, |
|
pin_memory=True, |
|
) |
|
|
|
return traind_data_loader,None,None, stats |
|
|
|
def compute_dict_mean(epoch_dicts): |
|
result = {k: None for k in epoch_dicts[0]} |
|
num_items = len(epoch_dicts) |
|
for k in result: |
|
value_sum = 0 |
|
for epoch_dict in epoch_dicts: |
|
value_sum += epoch_dict[k] |
|
result[k] = value_sum / num_items |
|
return result |
|
|
|
|
|
def detach_dict(d): |
|
new_d = dict() |
|
for k, v in d.items(): |
|
new_d[k] = v.detach() |
|
return new_d |
|
|
|
|
|
|
|
|
|
|
|
import random |
|
|
|
|
|
def set_seed(seed): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
def get_cosine_schedule_with_warmup( |
|
optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1 |
|
): |
|
""" |
|
Create a schedule with a learning rate that decreases following the values of the cosine function between the |
|
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the |
|
initial lr set in the optimizer. |
|
|
|
Args: |
|
optimizer ([~torch.optim.Optimizer]): |
|
The optimizer for which to schedule the learning rate. |
|
num_warmup_steps (int): |
|
The number of steps for the warmup phase. |
|
num_training_steps (int): |
|
The total number of training steps. |
|
num_cycles (float, *optional*, defaults to 0.5): |
|
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 |
|
following a half-cosine). |
|
last_epoch (int, *optional*, defaults to -1): |
|
The index of the last epoch when resuming training. |
|
|
|
Return: |
|
torch.optim.lr_scheduler.LambdaLR with the appropriate schedule. |
|
""" |
|
|
|
def lr_lambda(current_step): |
|
if current_step < num_warmup_steps: |
|
return float(current_step) / float(max(1, num_warmup_steps)) |
|
progress = float(current_step - num_warmup_steps) / float( |
|
max(1, num_training_steps - num_warmup_steps) |
|
) |
|
return max( |
|
0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) |
|
) |
|
|
|
return LambdaLR(optimizer, lr_lambda, last_epoch) |
|
|
|
|
|
def get_constant_schedule(optimizer, last_epoch: int = -1) -> LambdaLR: |
|
""" |
|
Create a schedule with a constant learning rate, using the learning rate set in optimizer. |
|
|
|
Args: |
|
optimizer ([`~torch.optim.Optimizer`]): |
|
The optimizer for which to schedule the learning rate. |
|
last_epoch (`int`, *optional*, defaults to -1): |
|
The index of the last epoch when resuming training. |
|
|
|
Return: |
|
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. |
|
""" |
|
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) |
|
|
|
|
|
def normalize_data(action_data, stats, norm_type, data_type="action"): |
|
|
|
if norm_type == "minmax": |
|
action_max = torch.from_numpy(stats[data_type + "_max"]).float().to(action_data.device) |
|
action_min = torch.from_numpy(stats[data_type + "_min"]).float().to(action_data.device) |
|
action_data = (action_data - action_min) / (action_max - action_min) * 2 - 1 |
|
elif norm_type == "gaussian": |
|
action_mean = torch.from_numpy(stats[data_type + "_mean"]).float().to(action_data.device) |
|
action_std = torch.from_numpy(stats[data_type + "_std"]).float().to(action_data.device) |
|
action_data = (action_data - action_mean) / action_std |
|
return action_data |
|
|
|
|
|
def convert_weight(obj): |
|
newmodel = OrderedDict() |
|
for k, v in obj.items(): |
|
if k.startswith("module."): |
|
newmodel[k[7:]] = v |
|
else: |
|
newmodel[k] = v |
|
return newmodel |
|
|
|
|
|
if __name__ == "__main__": |
|
train_dataloader,_,_,stats = load_data_unified( |
|
data_dir='/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/RDT/training_data/rdt_real_multitask', |
|
camera_names=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], |
|
batch_size_train=32, |
|
chunk_size=100, |
|
img_aug=True, |
|
) |
|
|
|
for i, (image_data, qpos_data, action_data, is_pad, instruction_data) in enumerate( |
|
tqdm(train_dataloader, desc="Data loading") |
|
): |
|
if i == 0: |
|
print(f"Batch {i}:") |
|
print(f"Image data shape: {image_data.shape} {image_data.max()}") |
|
print(f"Qpos data shape: {qpos_data.shape} {qpos_data.max()}" ) |
|
print(f"Action data shape: {action_data.shape} {action_data.max()}") |
|
print(f"Is pad shape: {is_pad.shape}") |
|
print(f"Instruction data shape: {instruction_data.shape}") |
|
|
|
continue |
|
|