Stanislaw Szymanowicz commited on
Commit
a6b395f
1 Parent(s): 595d8bd

Load model from hub

Browse files
Files changed (2) hide show
  1. app.py +16 -13
  2. config.yaml +66 -0
app.py CHANGED
@@ -25,19 +25,9 @@ import gradio as gr
25
 
26
  import rembg
27
 
28
- def main():
29
 
30
- # ============= model loading ==========
31
- def load_model(device):
32
- experiment_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
33
- "model_file", "objaverse")
34
- # load cfg
35
- training_cfg = OmegaConf.load(os.path.join(experiment_path, ".hydra", "config.yaml"))
36
- # load model
37
- model = GaussianSplatPredictor(training_cfg)
38
- ckpt_loaded = torch.load(os.path.join(experiment_path, "model_latest.pth"), map_location=device)
39
- model.load_state_dict(ckpt_loaded["model_state_dict"])
40
- return model, training_cfg
41
 
42
  if torch.cuda.is_available():
43
  device = "cuda:0"
@@ -45,7 +35,20 @@ def main():
45
  device = "cpu"
46
  torch.cuda.set_device(device)
47
 
48
- model, model_cfg = load_model(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  model.to(device)
50
 
51
  # ============= image preprocessing =============
 
25
 
26
  import rembg
27
 
28
+ from huggingface_hub import hf_hub_download
29
 
30
+ def main():
 
 
 
 
 
 
 
 
 
 
31
 
32
  if torch.cuda.is_available():
33
  device = "cuda:0"
 
35
  device = "cpu"
36
  torch.cuda.set_device(device)
37
 
38
+ model_cfg = OmegaConf.load(
39
+ os.path.join(
40
+ os.path.dirname(os.path.abspath(__file__)),
41
+ "config.yaml"
42
+ ))
43
+
44
+ model_path = hf_hub_download(repo_id="szymanowiczs/splatter-image-multi-category-v1",
45
+ filename="model_latest.pth")
46
+
47
+
48
+ model = GaussianSplatPredictor(model_cfg)
49
+
50
+ ckpt_loaded = torch.load(model_path, map_location=device)
51
+ model.load_state_dict(ckpt_loaded["model_state_dict"])
52
  model.to(device)
53
 
54
  # ============= image preprocessing =============
config.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb:
2
+ project: gs_pred
3
+ cam_embd:
4
+ embedding: null
5
+ encode_embedding: null
6
+ dimension: 0
7
+ method: null
8
+ general:
9
+ device: 0
10
+ random_seed: 0
11
+ num_devices: 2
12
+ mixed_precision: true
13
+ data:
14
+ training_resolution: 128
15
+ fov: 49.134342641202636
16
+ subset: -1
17
+ input_images: 1
18
+ znear: 0.8
19
+ zfar: 3.2
20
+ category: objaverse
21
+ white_background: true
22
+ origin_distances: false
23
+ opt:
24
+ iterations: 50001
25
+ base_lr: 6.34584421e-05
26
+ batch_size: 16
27
+ betas:
28
+ - 0.9
29
+ - 0.999
30
+ loss: l2
31
+ imgs_per_obj: 4
32
+ ema:
33
+ use: true
34
+ update_every: 10
35
+ update_after_step: 100
36
+ beta: 0.9999
37
+ lambda_lpips: 0.33814373
38
+ start_lpips_after: 0
39
+ step_lr_at: -1
40
+ model:
41
+ max_sh_degree: 1
42
+ inverted_x: false
43
+ inverted_y: true
44
+ name: SingleUNet
45
+ opacity_scale: 1.0
46
+ opacity_bias: -2.0
47
+ scale_scale: 0.01
48
+ scale_bias: 0.02
49
+ xyz_scale: 0.1
50
+ xyz_bias: 0.0
51
+ depth_scale: 1.0
52
+ depth_bias: 0.0
53
+ network_without_offset: false
54
+ network_with_offset: true
55
+ attention_resolutions:
56
+ - 16
57
+ cross_view_attention: true
58
+ isotropic: false
59
+ base_dim: 128
60
+ num_blocks: 4
61
+ logging:
62
+ ckpt_iterations: 1000
63
+ val_log: 10000
64
+ loss_log: 10
65
+ loop_log: 10000
66
+ render_log: 10000