xiexh20 commited on
Commit
9d94b63
1 Parent(s): be9a8e0

delele unnecessary dependency

Browse files
Files changed (4) hide show
  1. app.py +2 -0
  2. dataset/__init__.py +4 -2
  3. dataset/utils.py +47 -0
  4. requirements.txt +2 -1
app.py CHANGED
@@ -128,6 +128,8 @@ def main(cfg: ProjectConfig):
128
  # Setup model
129
  runner = DemoRunner(cfg)
130
 
 
 
131
  # Setup interface
132
  demo = gr.Blocks(title="HDM Interaction Reconstruction Demo")
133
  with demo:
 
128
  # Setup model
129
  runner = DemoRunner(cfg)
130
 
131
+ # runner = None # without model initialization, it shows one line of thumbnail
132
+
133
  # Setup interface
134
  demo = gr.Blocks(title="HDM Interaction Reconstruction Demo")
135
  with demo:
dataset/__init__.py CHANGED
@@ -15,16 +15,17 @@ from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
15
  from pytorch3d.implicitron.tools.config import expand_args_fields
16
  from pytorch3d.renderer.cameras import CamerasBase
17
  from torch.utils.data import DataLoader
 
18
 
19
  from configs.structured import CO3DConfig, DataloaderConfig, ProjectConfig, Optional
20
- from .exclude_sequence import EXCLUDE_SEQUENCE, LOW_QUALITY_SEQUENCE
21
  from .utils import DatasetMap
22
- from .r2n2_my import R2N2Sample, collate_batched_meshes
23
 
24
 
25
  def get_dataset(cfg: ProjectConfig):
26
 
27
  if cfg.dataset.type == 'co3dv2':
 
28
  dataset_cfg: CO3DConfig = cfg.dataset
29
  dataloader_cfg: DataloaderConfig = cfg.dataloader
30
 
@@ -100,6 +101,7 @@ def get_dataset(cfg: ProjectConfig):
100
  dataloader_val.batch_sampler.drop_last = False
101
  elif cfg.dataset.type == 'shapenet_r2n2':
102
  # from ..configs.structured import ShapeNetR2N2Config
 
103
  dataset_cfg: ShapeNetR2N2Config = cfg.dataset
104
  # for k in dataset_cfg:
105
  # print(k)
 
15
  from pytorch3d.implicitron.tools.config import expand_args_fields
16
  from pytorch3d.renderer.cameras import CamerasBase
17
  from torch.utils.data import DataLoader
18
+ from pytorch3d.datasets import R2N2, collate_batched_meshes
19
 
20
  from configs.structured import CO3DConfig, DataloaderConfig, ProjectConfig, Optional
 
21
  from .utils import DatasetMap
22
+
23
 
24
 
25
  def get_dataset(cfg: ProjectConfig):
26
 
27
  if cfg.dataset.type == 'co3dv2':
28
+ from .exclude_sequence import EXCLUDE_SEQUENCE, LOW_QUALITY_SEQUENCE
29
  dataset_cfg: CO3DConfig = cfg.dataset
30
  dataloader_cfg: DataloaderConfig = cfg.dataloader
31
 
 
101
  dataloader_val.batch_sampler.drop_last = False
102
  elif cfg.dataset.type == 'shapenet_r2n2':
103
  # from ..configs.structured import ShapeNetR2N2Config
104
+ from .r2n2_my import R2N2Sample
105
  dataset_cfg: ShapeNetR2N2Config = cfg.dataset
106
  # for k in dataset_cfg:
107
  # print(k)
dataset/utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Iterable, Optional
3
+
4
+ import torch
5
+ import numpy as np
6
+
7
+ def show_item(item: Dict):
8
+ for key in item.keys():
9
+ value = item[key]
10
+ if torch.is_tensor(value) and value.numel() < 5:
11
+ value_str = value
12
+ elif torch.is_tensor(value):
13
+ value_str = value.shape
14
+ elif isinstance(value, str):
15
+ value_str = ('...' + value[-52:]) if len(value) > 50 else value
16
+ elif isinstance(value, dict):
17
+ value_str = str({k: type(v) for k, v in value.items()})
18
+ else:
19
+ value_str = type(value)
20
+ print(f"{key:<30} {value_str}")
21
+
22
+
23
+ def normalize_to_zero_one(x: torch.Tensor):
24
+ return (x - x.min()) / (x.max() - x.min())
25
+
26
+
27
+ def default(x, d):
28
+ return d if x is None else x
29
+
30
+
31
+ @dataclass
32
+ class DatasetMap:
33
+ train: Optional[Iterable] = None
34
+ val: Optional[Iterable] = None
35
+ test: Optional[Iterable] = None
36
+
37
+
38
+
39
+ def create_grid_points(bound=1.0, res=128):
40
+ x_ = np.linspace(-bound, bound, res)
41
+ y_ = np.linspace(-bound, bound, res)
42
+ z_ = np.linspace(-bound, bound, res)
43
+
44
+ x, y, z = np.meshgrid(x_, y_, z_)
45
+ # print(x.shape, y.shape) # (res, res, res)
46
+ pts = np.concatenate([y.reshape(-1, 1), x.reshape(-1, 1), z.reshape(-1, 1)], axis=-1)
47
+ return pts
requirements.txt CHANGED
@@ -13,4 +13,5 @@ tqdm
13
  transformers
14
  wandb
15
  trimesh
16
- pytorch3d
 
 
13
  transformers
14
  wandb
15
  trimesh
16
+ gradio
17
+ "git+https://github.com/facebookresearch/pytorch3d.git"