File size: 2,603 Bytes
f96995c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from pathlib import Path
import argparse
import random
import time
import os
import matplotlib.pyplot as plt
from collections import defaultdict
from tqdm import tqdm, trange
import hydra
from omegaconf import DictConfig, OmegaConf
import yaml
from datetime import datetime
import numpy as np
from PIL import Image
import warp as wp
import matplotlib.pyplot as plt
import multiprocess as mp
import torch
import torch.backends.cudnn
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
import logging

from pgnd.utils import get_root, mkdir
from modules_planning.planning_env import RobotPlanningEnv

root: Path = get_root(__file__)
logging.basicConfig(level=logging.WARNING)


def main(args):
    mp.set_start_method('spawn')

    with open(root / args.config, 'r') as f:
        config = yaml.load(f, Loader=yaml.CLoader)
    cfg = OmegaConf.create(config)

    cfg.sim.num_steps = 1000
    cfg.sim.gripper_forcing = False
    cfg.sim.uniform = True

    iteration = args.iteration
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    ckpt_path = (root / args.config).parent / 'ckpt' / f'{iteration:06d}.pt'
    seed = cfg.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    # path
    datetime_now = datetime.now().strftime('%y%m%d-%H%M%S')
    exp_root: Path = root / 'log' / cfg.train.name / 'plan' / datetime_now
    mkdir(exp_root, overwrite=cfg.overwrite, resume=cfg.resume)

    env = RobotPlanningEnv(
        cfg,
        exp_root=exp_root,
        ckpt_path=ckpt_path,
        resolution=(848, 480),
        capture_fps=30,
        record_fps=0,
        text_prompts=args.text_prompts,
        show_annotation=(not args.no_annotation),
        use_robot=True,
        bimanual=args.bimanual,
        gripper_enable=True,
        debug=True,
        construct_target=args.construct_target,
    )

    env.start()
    env.join()


if __name__ == '__main__':
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument('--config', type=str, default='log/cloth/train/hydra.yaml')
    arg_parser.add_argument('--iteration', type=str, default=100000)
    arg_parser.add_argument('--text_prompts', type=str, default='green towel.')
    arg_parser.add_argument('--seed', type=int, default=42)
    arg_parser.add_argument('--no_annotation', action='store_true')
    arg_parser.add_argument('--bimanual', action='store_true')
    arg_parser.add_argument('--construct_target', action='store_true')
    args = arg_parser.parse_args()

    with torch.no_grad():
        main(args)