SceneDiffuser commited on
Commit
9fcef4d
1 Parent(s): 97fb5a2
Files changed (6) hide show
  1. .gitignore +3 -1
  2. .gitmodules +3 -0
  3. app.py +96 -152
  4. interface.py +267 -0
  5. scenediffuser +1 -0
  6. style.css +1 -0
.gitignore CHANGED
@@ -1 +1,3 @@
1
- results
 
 
1
+ __pycache__
2
+ results
3
+ src/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ [submodule "scenediffuser"]
2
+ path = scenediffuser
3
+ url = https://github.com/scenediffuser/Scene-Diffuser
app.py CHANGED
@@ -1,160 +1,104 @@
1
  import os
 
 
2
  import gradio as gr
3
- import random
4
- import pickle
5
- import numpy as np
6
- import zipfile
7
- import trimesh
8
- from PIL import Image
9
- from huggingface_hub import hf_hub_download
10
 
11
- def pose_generation(scene, count):
12
- assert isinstance(scene, str)
13
- results_path = hf_hub_download('SceneDiffuser/SceneDiffuser', 'results/pose_generation/results.pkl')
14
- with open(results_path, 'rb') as f:
15
- results = pickle.load(f)
16
-
17
- images = [Image.fromarray(results[scene][random.randint(0, 19)]) for i in range(count)]
18
- return images
19
 
20
- def pose_generation_mesh(scene, count):
21
- assert isinstance(scene, str)
22
- scene_path = f"./results/pose_generation/mesh_results/{scene}/scene_downsample.ply"
23
- if not os.path.exists(scene_path):
24
- results_path = hf_hub_download('SceneDiffuser/SceneDiffuser', 'results/pose_generation/mesh_results.zip')
25
- os.makedirs('./results/pose_generation/', exist_ok=True)
26
- with zipfile.ZipFile(results_path, 'r') as zip_ref:
27
- zip_ref.extractall('./results/pose_generation/')
28
-
29
- res = './results/pose_generation/tmp.glb'
30
- S = trimesh.Scene()
31
- S.add_geometry(trimesh.load(scene_path))
32
- for i in range(count):
33
- rid = random.randint(0, 19)
34
- S.add_geometry(trimesh.load(
35
- f"./results/pose_generation/mesh_results/{scene}/body{rid:0>3d}.ply"
36
- ))
37
- S.export(res)
38
-
39
- return res
40
 
41
- def motion_generation(scene):
42
- assert isinstance(scene, str)
43
- cnt = {
44
- 'MPH1Library': 3,
45
- 'MPH16': 6,
46
- 'N0SittingBooth': 7,
47
- 'N3OpenArea': 5
48
- }[scene]
 
 
 
 
 
 
 
 
 
49
 
50
- res = f"./results/motion_generation/results/{scene}/{random.randint(0, cnt-1)}.gif"
51
- if not os.path.exists(res):
52
- results_path = hf_hub_download('SceneDiffuser/SceneDiffuser', 'results/motion_generation/results.zip')
53
- os.makedirs('./results/motion_generation/', exist_ok=True)
54
- with zipfile.ZipFile(results_path, 'r') as zip_ref:
55
- zip_ref.extractall('./results/motion_generation/')
56
-
57
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- def grasp_generation(case_id):
60
- assert isinstance(case_id, str)
61
- res = f"./results/grasp_generation/results/{case_id}/{random.randint(0, 19)}.glb"
62
- if not os.path.exists(res):
63
- results_path = hf_hub_download('SceneDiffuser/SceneDiffuser', 'results/grasp_generation/results.zip')
64
- os.makedirs('./results/grasp_generation/', exist_ok=True)
65
- with zipfile.ZipFile(results_path, 'r') as zip_ref:
66
- zip_ref.extractall('./results/grasp_generation/')
67
-
68
- return res
69
 
70
- def path_planning(case_id):
71
- assert isinstance(case_id, str)
72
- results_path = hf_hub_download('SceneDiffuser/SceneDiffuser', 'results/path_planning/results.pkl')
73
- with open(results_path, 'rb') as f:
74
- results = pickle.load(f)
75
-
76
- case = results[case_id]
77
- steps = case['step']
78
- image = Image.fromarray(case['image'])
79
- return image, steps
80
-
81
- with gr.Blocks() as demo:
82
- gr.Markdown("# **<p align='center'>Diffusion-based Generation, Optimization, and Planning in 3D Scenes</p>**")
83
- gr.HTML(value="<img src='file/figures/teaser.png' alt='Teaser' width='710px' height='284px' style='display: block; margin: auto;'>")
84
- gr.HTML(value="<p align='center' style='font-size: 1.25em; color: #485fc7;'><a href='' target='_blank'>Paper</a> | <a href='' target='_blank'>Project Page</a> | <a href='' target='_blank'>Github</a></p>")
85
- gr.Markdown("<p align='center'><i>\"SceneDiffuser provides a unified model for solving scene-conditioned generation, optimization, and planning.\"</i></p>")
86
-
87
- ## five task
88
- ## pose generation
89
- with gr.Tab("Pose Generation"):
90
- with gr.Row():
91
- with gr.Column():
92
- input1 = [
93
- gr.Dropdown(choices=['MPH16', 'MPH1Library', 'N0SittingBooth', 'N3OpenArea'], label='Scenes'),
94
- gr.Slider(minimum=1, maximum=4, step=1, label='Count', interactive=True)
95
- ]
96
- button1 = gr.Button("Generate")
97
- with gr.Column():
98
- output1 = [
99
- gr.Gallery(label="Result").style(grid=[1], height="auto")
100
- ]
101
- button1.click(pose_generation, inputs=input1, outputs=output1)
102
-
103
- with gr.Tab("Pose Generation Mesh"):
104
- input11 = [
105
- gr.Dropdown(choices=['MPH16', 'MPH1Library', 'N0SittingBooth', 'N3OpenArea'], label='Scenes'),
106
- gr.Slider(minimum=1, maximum=4, step=1, label='Count', interactive=True)
107
- ]
108
- button11 = gr.Button("Generate")
109
- output11 = gr.Model3D(clear_color=[255, 255, 255, 255], label="Result")
110
- button11.click(pose_generation_mesh, inputs=input11, outputs=output11)
111
-
112
- ## motion generation
113
- with gr.Tab("Motion Generation"):
114
- with gr.Row():
115
- with gr.Column():
116
- input2 = [
117
- gr.Dropdown(choices=['MPH16', 'MPH1Library', 'N0SittingBooth', 'N3OpenArea'], label='Scenes')
118
- ]
119
- button2 = gr.Button("Generate")
120
- with gr.Column():
121
- output2 = gr.Image(label="Result")
122
- button2.click(motion_generation, inputs=input2, outputs=output2)
123
-
124
- ## grasp generation
125
- with gr.Tab("Grasp Generation"):
126
- with gr.Row():
127
- with gr.Column():
128
- input3 = [
129
- gr.Dropdown(choices=['contactdb+apple', 'contactdb+camera', 'contactdb+cylinder_medium', 'contactdb+door_knob', 'contactdb+rubber_duck', 'contactdb+water_bottle', 'ycb+baseball', 'ycb+pear', 'ycb+potted_meat_can', 'ycb+tomato_soup_can'], label='Objects')
130
- ]
131
- button3 = gr.Button("Run")
132
- with gr.Column():
133
- output3 = [
134
- gr.Model3D(clear_color=[255, 255, 255, 255], label="Result")
135
- ]
136
- button3.click(grasp_generation, inputs=input3, outputs=output3)
137
-
138
- ## path planning
139
- with gr.Tab("Path Planing"):
140
- with gr.Row():
141
- with gr.Column():
142
- input4 = [
143
- gr.Dropdown(choices=['scene0603_00_N0pT', 'scene0621_00_cJ4H', 'scene0634_00_48Y3', 'scene0634_00_gIRH', 'scene0637_00_YgjR', 'scene0640_00_BO94', 'scene0641_00_3K6J', 'scene0641_00_KBKx', 'scene0641_00_cb7l', 'scene0645_00_35Hy', 'scene0645_00_47D1', 'scene0645_00_XfLE', 'scene0667_00_DK4F', 'scene0667_00_o7XB', 'scene0667_00_rUMp', 'scene0672_00_U250', 'scene0673_00_Jyw8', 'scene0673_00_u1lJ', 'scene0678_00_QbNL', 'scene0678_00_RrY0', 'scene0678_00_aE1p', 'scene0678_00_hnXu', 'scene0694_00_DgAL', 'scene0694_00_etF5', 'scene0698_00_tT3Q'], label='Scenes'),
144
- ]
145
- button4 = gr.Button("Run")
146
- with gr.Column():
147
- # output4 = gr.Gallery(label="Result").style(grid=[1], height="auto")
148
- output4 = [
149
- gr.Image(label="Result"),
150
- gr.Number(label="Steps", precision=0)
151
- ]
152
- button4.click(path_planning, inputs=input4, outputs=output4)
153
-
154
- ## arm motion planning
155
- with gr.Tab("Arm Motion Planning"):
156
- gr.Markdown('Coming soon!')
157
-
158
- gr.Markdown("<p>Note: Currently, the output results are pre-sampled results. We will deploy a real-time model after we release the code.</p>")
159
-
160
- demo.launch()
1
  import os
2
+ import sys
3
+ sys.path.append('./scenediffuser/')
4
  import gradio as gr
 
 
 
 
 
 
 
5
 
6
+ import interface as IF
 
 
 
 
 
 
 
7
 
8
+ with gr.Blocks(css='style.css') as demo:
9
+ with gr.Column(elem_id="col-container"):
10
+ gr.Markdown("<p align='center' style='font-size: 1.5em;'>Diffusion-based Generation, Optimization, and Planning in 3D Scenes</p>")
11
+ gr.HTML(value="<img src='file/figures/teaser.png' alt='Teaser' width='710px' height='284px' style='display: block; margin: auto;'>")
12
+ gr.HTML(value="<p align='center' style='font-size: 1.2em; color: #485fc7;'><a href='https://arxiv.org/abs/2301.06015' target='_blank'>arXiv</a> | <a href='https://scenediffuser.github.io/' target='_blank'>Project Page</a> | <a href='https://github.com/scenediffuser/Scene-Diffuser' target='_blank'>Code</a></p>")
13
+ gr.Markdown("<p align='center'><i>\"SceneDiffuser provides a unified model for solving scene-conditioned generation, optimization, and planning.\"</i></p>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ ## five task
16
+ ## pose generation
17
+ with gr.Tab("Pose Generation"):
18
+ with gr.Row():
19
+ with gr.Column(scale=2):
20
+ selector1 = gr.Dropdown(choices=['MPH16', 'MPH1Library', 'N0SittingBooth', 'N3OpenArea'], label='Scenes', value='MPH16', interactive=True)
21
+ with gr.Row():
22
+ sample1 = gr.Slider(minimum=1, maximum=8, step=1, label='Count', interactive=True, value=1)
23
+ seed1 = gr.Slider(minimum=0, maximum=2 ** 16, step=1, label='Seed', interactive=True, value=2023)
24
+ opt1 = gr.Checkbox(label='Optimizer Guidance', interactive=True, value=True)
25
+ scale1 = gr.Slider(minimum=0.1, maximum=9.9, step=0.1, label='Scale', interactive=True, value=1.1)
26
+ button1 = gr.Button("Run")
27
+ with gr.Column(scale=3):
28
+ image1 = gr.Gallery(label="Image [Result]").style(grid=[1], height="50")
29
+ # model1 = gr.Model3D(clear_color=[255, 255, 255, 255], label="3D Model [Result]")
30
+ input1 = [selector1, sample1, seed1, opt1, scale1]
31
+ button1.click(IF.pose_generation, inputs=input1, outputs=[image1])
32
 
33
+ ## motion generation
34
+ # with gr.Tab("Motion Generation"):
35
+ # with gr.Row():
36
+ # with gr.Column(scale=2):
37
+ # selector2 = gr.Dropdown(choices=['MPH16', 'MPH1Library', 'N0SittingBooth', 'N3OpenArea'], label='Scenes', value='MPH16', interactive=True)
38
+ # with gr.Row():
39
+ # sample2 = gr.Slider(minimum=1, maximum=8, step=1, label='Count', interactive=True, value=1)
40
+ # seed2 = gr.Slider(minimum=0, maximum=2 ** 16, step=1, label='Seed', interactive=True, value=2023)
41
+ # with gr.Row():
42
+ # withstart = gr.Checkbox(label='With Start', interactive=True, value=False)
43
+ # opt2 = gr.Checkbox(label='Optimizer Guidance', interactive=True, value=True)
44
+ # scale_opt2 = gr.Slider(minimum=0.1, maximum=9.9, step=0.1, label='Scale', interactive=True, value=1.1)
45
+ # button2 = gr.Button("Run")
46
+ # with gr.Column(scale=3):
47
+ # image2 = gr.Image(label="Result")
48
+ # input2 = [selector2, sample2, seed2, withstart, opt2, scale_opt2]
49
+ # button2.click(IF.motion_generation, inputs=input2, outputs=image2)
50
+ with gr.Tab("Motion Generation"):
51
+ with gr.Row():
52
+ with gr.Column(scale=2):
53
+ input2 = [
54
+ gr.Dropdown(choices=['MPH16', 'MPH1Library', 'N0SittingBooth', 'N3OpenArea'], label='Scenes')
55
+ ]
56
+ button2 = gr.Button("Generate")
57
+ gr.HTML("<p style='font-size: 0.9em; color: #555555;'>Notes: the output results are pre-sampled results. We will deploy a real-time model for this task soon.</p>")
58
+ with gr.Column(scale=3):
59
+ output2 = gr.Image(label="Result")
60
+ button2.click(IF.motion_generation, inputs=input2, outputs=output2)
61
+
62
+ ## grasp generation
63
+ with gr.Tab("Grasp Generation"):
64
+ with gr.Row():
65
+ with gr.Column(scale=2):
66
+ input3 = [
67
+ gr.Dropdown(choices=['contactdb+apple', 'contactdb+camera', 'contactdb+cylinder_medium', 'contactdb+door_knob', 'contactdb+rubber_duck', 'contactdb+water_bottle', 'ycb+baseball', 'ycb+pear', 'ycb+potted_meat_can', 'ycb+tomato_soup_can'], label='Objects')
68
+ ]
69
+ button3 = gr.Button("Run")
70
+ gr.HTML("<p style='font-size: 0.9em; color: #555555;'>Notes: the output results are pre-sampled results. We will deploy a real-time model for this task soon.</p>")
71
+ with gr.Column(scale=3):
72
+ output3 = [
73
+ gr.Model3D(clear_color=[255, 255, 255, 255], label="Result")
74
+ ]
75
+ button3.click(IF.grasp_generation, inputs=input3, outputs=output3)
76
+
77
+ ## path planning
78
+ with gr.Tab("Path Planing"):
79
+ with gr.Row():
80
+ with gr.Column(scale=2):
81
+ selector4 = gr.Dropdown(choices=['scene0603_00', 'scene0621_00', 'scene0626_00', 'scene0634_00', 'scene0637_00', 'scene0640_00', 'scene0641_00', 'scene0645_00', 'scene0653_00', 'scene0667_00', 'scene0672_00', 'scene0673_00', 'scene0678_00', 'scene0694_00', 'scene0698_00'], label='Scenes', value='scene0621_00', interactive=True)
82
+ mode4 = gr.Radio(choices=['Sampling', 'Planning'], value='Sampling', label='Mode', interactive=True)
83
+ with gr.Row():
84
+ sample4 = gr.Slider(minimum=1, maximum=8, step=1, label='Count', interactive=True, value=1)
85
+ seed4 = gr.Slider(minimum=0, maximum=2 ** 16, step=1, label='Seed', interactive=True, value=2023)
86
+ with gr.Box():
87
+ opt4 = gr.Checkbox(label='Optimizer Guidance', interactive=True, value=True)
88
+ scale_opt4 = gr.Slider(minimum=0.02, maximum=4.98, step=0.02, label='Scale', interactive=True, value=1.0)
89
+ with gr.Box():
90
+ pla4 = gr.Checkbox(label='Planner Guidance', interactive=True, value=True)
91
+ scale_pla4 = gr.Slider(minimum=0.02, maximum=0.98, step=0.02, label='Scale', interactive=True, value=0.2)
92
+ button4 = gr.Button("Run")
93
+ with gr.Column(scale=3):
94
+ image4 = gr.Gallery(label="Image [Result]").style(grid=[1], height="50")
95
+ number4 = gr.Number(label="Steps", precision=0)
96
+ gr.HTML("<p style='font-size: 0.9em; color: #555555;'>Notes: 1. It may take a long time to do planning in <b>Planning</b> mode. 2. The <span style='color: #cc0000;'>red</span> balls represent the planning result, starting with the lightest red ball and ending with the darkest red ball. The <span style='color: #00cc00;'>green</span> ball indicates the target position.</p>")
97
+ input4 = [selector4, mode4, sample4, seed4, opt4, scale_opt4, pla4, scale_pla4]
98
+ button4.click(IF.path_planning, inputs=input4, outputs=[image4, number4])
99
 
100
+ ## arm motion planning
101
+ with gr.Tab("Arm Motion Planning"):
102
+ gr.Markdown('Coming soon!')
 
 
 
 
 
 
 
103
 
104
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
interface.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import torch
4
+ import hydra
5
+ import numpy as np
6
+ import zipfile
7
+
8
+ from typing import Any
9
+ from hydra import compose, initialize
10
+ from omegaconf import DictConfig, OmegaConf
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ from utils.misc import compute_model_dim
14
+ from datasets.base import create_dataset
15
+ from datasets.misc import collate_fn_general, collate_fn_squeeze_pcd_batch
16
+ from models.base import create_model
17
+ from models.visualizer import create_visualizer
18
+ from models.environment import create_enviroment
19
+
20
+ def pretrain_pointtrans_weight_path():
21
+ return hf_hub_download('SceneDiffuser/SceneDiffuser', 'weights/POINTTRANS_C_32768/model.pth')
22
+
23
+ def model_weight_path(task, has_observation=False):
24
+ if task == 'pose_gen':
25
+ return hf_hub_download('SceneDiffuser/SceneDiffuser', 'weights/2022-11-09_11-22-52_PoseGen_ddm4_lr1e-4_ep100/ckpts/model.pth')
26
+ elif task == 'motion_gen' and has_observation == True:
27
+ return hf_hub_download('SceneDiffuser/SceneDiffuser', 'weights//ckpts/model.pth')
28
+ elif task == 'motion_gen' and has_observation == False:
29
+ return hf_hub_download('SceneDiffuser/SceneDiffuser', 'weights//ckpts/model.pth')
30
+ elif task == 'path_planning':
31
+ return hf_hub_download('SceneDiffuser/SceneDiffuser', 'weights/2022-11-25_20-57-28_Path_ddm4_LR1e-4_E100_REL/ckpts/model.pth')
32
+ else:
33
+ raise Exception('Unexcepted task.')
34
+
35
+ def pose_motion_data_path():
36
+ zip_path = hf_hub_download('SceneDiffuser/SceneDiffuser', 'hf_data/pose_motion.zip')
37
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
38
+ zip_ref.extractall(os.path.dirname(zip_path))
39
+
40
+ rpath = os.path.join(os.path.dirname(zip_path), 'pose_motion')
41
+
42
+ return (
43
+ os.path.join(rpath, 'PROXD_temp'),
44
+ os.path.join(rpath, 'models_smplx_v1_1/models/'),
45
+ os.path.join(rpath, 'PROX'),
46
+ os.path.join(rpath, 'PROX/V02_05')
47
+ )
48
+
49
+ def path_planning_data_path():
50
+ zip_path = hf_hub_download('SceneDiffuser/SceneDiffuser', 'hf_data/path_planning.zip')
51
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
52
+ zip_ref.extractall(os.path.dirname(zip_path))
53
+
54
+ return os.path.join(os.path.dirname(zip_path), 'path_planning')
55
+
56
+ def load_ckpt(model: torch.nn.Module, path: str) -> None:
57
+ """ load ckpt for current model
58
+
59
+ Args:
60
+ model: current model
61
+ path: save path
62
+ """
63
+ assert os.path.exists(path), 'Can\'t find provided ckpt.'
64
+
65
+ saved_state_dict = torch.load(path)['model']
66
+ model_state_dict = model.state_dict()
67
+
68
+ for key in model_state_dict:
69
+ if key in saved_state_dict:
70
+ model_state_dict[key] = saved_state_dict[key]
71
+ ## model is trained with ddm
72
+ if 'module.'+key in saved_state_dict:
73
+ model_state_dict[key] = saved_state_dict['module.'+key]
74
+
75
+ model.load_state_dict(model_state_dict)
76
+
77
+ def _sampling(cfg: DictConfig, scene: str) -> Any:
78
+ ## compute modeling dimension according to task
79
+ cfg.model.d_x = compute_model_dim(cfg.task)
80
+
81
+ if cfg.gpu is not None:
82
+ device = f'cuda:{cfg.gpu}'
83
+ else:
84
+ device = 'cpu'
85
+
86
+ dataset = create_dataset(cfg.task.dataset, 'test', cfg.slurm, case_only=True, specific_scene=scene)
87
+
88
+ if cfg.model.scene_model.name == 'PointTransformer':
89
+ collate_fn = collate_fn_squeeze_pcd_batch
90
+ else:
91
+ collate_fn = collate_fn_general
92
+
93
+ dataloader = dataset.get_dataloader(
94
+ batch_size=1,
95
+ collate_fn=collate_fn,
96
+ shuffle=True,
97
+ )
98
+
99
+ ## create model and load ckpt
100
+ model = create_model(cfg, slurm=cfg.slurm, device=device)
101
+ model.to(device=device)
102
+ load_ckpt(model, path=model_weight_path(cfg.task.name, cfg.task.has_observation if 'has_observation' in cfg.task else False))
103
+
104
+ ## create visualizer and visualize
105
+ visualizer = create_visualizer(cfg.task.visualizer)
106
+ results = visualizer.visualize(model, dataloader)
107
+ return results
108
+
109
+ def _planning(cfg: DictConfig, scene: str) -> Any:
110
+ ## compute modeling dimension according to task
111
+ cfg.model.d_x = compute_model_dim(cfg.task)
112
+
113
+ if cfg.gpu is not None:
114
+ device = f'cuda:{cfg.gpu}'
115
+ else:
116
+ device = 'cpu'
117
+
118
+ dataset = create_dataset(cfg.task.dataset, 'test', cfg.slurm, case_only=True, specific_scene=scene)
119
+
120
+ if cfg.model.scene_model.name == 'PointTransformer':
121
+ collate_fn = collate_fn_squeeze_pcd_batch
122
+ else:
123
+ collate_fn = collate_fn_general
124
+
125
+ dataloader = dataset.get_dataloader(
126
+ batch_size=1,
127
+ collate_fn=collate_fn,
128
+ shuffle=True,
129
+ )
130
+
131
+ ## create model and load ckpt
132
+ model = create_model(cfg, slurm=cfg.slurm, device=device)
133
+ model.to(device=device)
134
+ load_ckpt(model, path=model_weight_path(cfg.task.name, cfg.task.has_observation if 'has_observation' in cfg.task else False))
135
+
136
+ ## create environment for planning task and run
137
+ env = create_enviroment(cfg.task.env)
138
+ results = env.run(model, dataloader)
139
+ return results
140
+
141
+
142
+ ## interface for five task
143
+ ## real-time model: pose generation, path planning
144
+ def pose_generation(scene, count, seed, opt, scale) -> Any:
145
+ scene_model_weight_path = pretrain_pointtrans_weight_path()
146
+ data_dir, smpl_dir, prox_dir, vposer_dir = pose_motion_data_path()
147
+ override_config = [
148
+ "diffuser=ddpm",
149
+ "model=unet",
150
+ f"model.scene_model.pretrained_weights={scene_model_weight_path}",
151
+ "task=pose_gen",
152
+ "task.visualizer.name=PoseGenVisualizerHF",
153
+ f"task.visualizer.ksample={count}",
154
+ f"task.dataset.data_dir={data_dir}",
155
+ f"task.dataset.smpl_dir={smpl_dir}",
156
+ f"task.dataset.prox_dir={prox_dir}",
157
+ f"task.dataset.vposer_dir={vposer_dir}",
158
+ ]
159
+
160
+ if opt == True:
161
+ override_config += [
162
+ "optimizer=pose_in_scene",
163
+ "optimizer.scale_type=div_var",
164
+ f"optimizer.scale={scale}",
165
+ "optimizer.vposer=false",
166
+ "optimizer.contact_weight=0.02",
167
+ "optimizer.collision_weight=1.0"
168
+ ]
169
+
170
+ initialize(config_path="./scenediffuser/configs", version_base=None)
171
+ config = compose(config_name="default", overrides=override_config)
172
+
173
+ random.seed(seed)
174
+ np.random.seed(seed)
175
+ torch.manual_seed(seed)
176
+ torch.cuda.manual_seed(seed)
177
+ torch.cuda.manual_seed_all(seed)
178
+
179
+ res = _sampling(config, scene)
180
+
181
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
182
+ return res
183
+
184
+ def motion_generation(scene):
185
+ assert isinstance(scene, str)
186
+ cnt = {
187
+ 'MPH1Library': 3,
188
+ 'MPH16': 6,
189
+ 'N0SittingBooth': 7,
190
+ 'N3OpenArea': 5
191
+ }[scene]
192
+
193
+ res = f"./results/motion_generation/results/{scene}/{random.randint(0, cnt-1)}.gif"
194
+ if not os.path.exists(res):
195
+ results_path = hf_hub_download('SceneDiffuser/SceneDiffuser', 'results/motion_generation/results.zip')
196
+ os.makedirs('./results/motion_generation/', exist_ok=True)
197
+ with zipfile.ZipFile(results_path, 'r') as zip_ref:
198
+ zip_ref.extractall('./results/motion_generation/')
199
+
200
+ return res
201
+
202
+ def grasp_generation(case_id):
203
+ assert isinstance(case_id, str)
204
+ res = f"./results/grasp_generation/results/{case_id}/{random.randint(0, 19)}.glb"
205
+ if not os.path.exists(res):
206
+ results_path = hf_hub_download('SceneDiffuser/SceneDiffuser', 'results/grasp_generation/results.zip')
207
+ os.makedirs('./results/grasp_generation/', exist_ok=True)
208
+ with zipfile.ZipFile(results_path, 'r') as zip_ref:
209
+ zip_ref.extractall('./results/grasp_generation/')
210
+
211
+ return res
212
+
213
+ def path_planning(scene, mode, count, seed, opt, scale_opt, pla, scale_pla):
214
+
215
+ scene_model_weight_path = pretrain_pointtrans_weight_path()
216
+ data_dir = path_planning_data_path()
217
+
218
+ override_config = [
219
+ "diffuser=ddpm",
220
+ "model=unet",
221
+ "model.use_position_embedding=true",
222
+ f"model.scene_model.pretrained_weights={scene_model_weight_path}",
223
+ "task=path_planning",
224
+ "task.visualizer.name=PathPlanningRenderingVisualizerHF",
225
+ f"task.visualizer.ksample={count}",
226
+ f"task.dataset.data_dir={data_dir}",
227
+ "task.dataset.repr_type=relative",
228
+ "task.env.name=PathPlanningEnvWrapperHF",
229
+ "task.env.inpainting_horizon=16",
230
+ "task.env.robot_top=3.0",
231
+ "task.env.env_adaption=false"
232
+ ]
233
+
234
+ if opt == True:
235
+ override_config += [
236
+ "optimizer=path_in_scene",
237
+ "optimizer.scale_type=div_var",
238
+ "optimizer.continuity=false",
239
+ f"optimizer.scale={scale_opt}",
240
+ ]
241
+ if pla == True:
242
+ override_config += [
243
+ "planner=greedy_path_planning",
244
+ f"planner.scale={scale_pla}",
245
+ "planner.scale_type=div_var",
246
+ "planner.greedy_type=all_frame_exp"
247
+ ]
248
+
249
+ initialize(config_path="./scenediffuser/configs", version_base=None)
250
+ config = compose(config_name="default", overrides=override_config)
251
+
252
+ random.seed(seed)
253
+ np.random.seed(seed)
254
+ torch.manual_seed(seed)
255
+ torch.cuda.manual_seed(seed)
256
+ torch.cuda.manual_seed_all(seed)
257
+
258
+ if mode == 'Sampling':
259
+ img = _sampling(config, scene)
260
+ res = (img, 0)
261
+ elif mode == 'Planning':
262
+ res = _planning(config, scene)
263
+ else:
264
+ res = (None, 0)
265
+
266
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
267
+ return res
scenediffuser ADDED
@@ -0,0 +1 @@
 
1
+ Subproject commit 2e6055e4aba5807f8ff81b5eaa4b171b93306067
style.css ADDED
@@ -0,0 +1 @@
 
1
+ #col-container {max-width: 1000px; margin-left: auto; margin-right: auto;}