zjowowen's picture
init space
079c32c
raw
history blame
1.66 kB
from dizoo.classic_control.cartpole.offline_data.collect_dqn_data_config import main_config, create_config
from ding.entry import serial_pipeline_offline
import os
import torch
from torch.utils.data import DataLoader
from ding.config import read_config, compile_config
from ding.utils.data import create_dataset
def train(args):
config = [main_config, create_config]
input_cfg = config
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = input_cfg
create_cfg.policy.type = create_cfg.policy.type + '_command'
cfg = compile_config(cfg, seed=args.seed, auto=True, create_cfg=create_cfg)
# Dataset
dataset = create_dataset(cfg)
print(dataset.__len__())
# print(dataset.__getitem__(0))
print(dataset.__getitem__(0)[0]['action'])
# episode_action = []
# for i in range(dataset.__getitem__(0).__len__()): # length of the firse collected episode
# episode_action.append(dataset.__getitem__(0)[i]['action'])
# stacked action of the first collected episode
episode_action = torch.stack(
[dataset.__getitem__(0)[i]['action'] for i in range(dataset.__getitem__(0).__len__())], axis=0
)
# dataloader = DataLoader(dataset, cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x)
# for i, train_data in enumerate(dataloader):
# print(i, train_data)
# serial_pipeline_offline(config, seed=args.seed)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--seed', '-s', type=int, default=0)
args = parser.parse_args()
train(args)