HongFangzhou commited on
Commit
7c8c180
1 Parent(s): fe189ee

update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -50
app.py CHANGED
@@ -52,7 +52,8 @@ def add_text(rgb, caption):
52
  return rgb
53
 
54
  config = "3DTopia/configs/default.yaml"
55
- local_ckpt = "3DTopia/checkpoints/3dtopia_diffusion_state_dict.ckpt"
 
56
  if os.path.exists(local_ckpt):
57
  ckpt = local_ckpt
58
  else:
@@ -62,56 +63,65 @@ configs = OmegaConf.load(config)
62
  os.makedirs("tmp", exist_ok=True)
63
  print("download finish")
64
 
65
- if ckpt.endswith(".ckpt"):
66
- model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(ckpt, map_location='cpu', strict=False, **configs.model.params)
67
- elif ckpt.endswith(".safetensors"):
68
- print("download finish")
69
- model = get_obj_from_str(configs.model["target"])(**configs.model.params)
70
- print("download finish")
71
- model_ckpt = load_file(ckpt)
72
- print("download finish")
73
- model.load_state_dict(model_ckpt)
 
 
 
 
 
 
 
 
 
74
  print("download finish")
75
- else:
76
- raise NotImplementedError
77
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
78
- model = model.to(device)
79
- print("download finish")
80
- sampler = DDIMSampler(model)
81
-
82
- img_size = configs.model.params.unet_config.params.image_size
83
- channels = configs.model.params.unet_config.params.in_channels
84
- shape = [channels, img_size, img_size * 3]
85
-
86
- pose_folder = '3DTopia/assets/sample_data/pose'
87
- poses_fname = sorted([os.path.join(pose_folder, f) for f in os.listdir(pose_folder)])
88
- batch_rays_list = []
89
- H = 128
90
- ratio = 512 // H
91
- for p in poses_fname:
92
- c2w = np.loadtxt(p).reshape(4, 4)
93
- c2w[:3, 3] *= 2.2
94
- c2w = np.array([
95
- [1, 0, 0, 0],
96
- [0, 0, -1, 0],
97
- [0, 1, 0, 0],
98
- [0, 0, 0, 1]
99
- ]) @ c2w
100
-
101
- k = np.array([
102
- [560 / ratio, 0, H * 0.5],
103
- [0, 560 / ratio, H * 0.5],
104
- [0, 0, 1]
105
- ])
106
-
107
- rays_o, rays_d = get_rays(H, H, torch.Tensor(k), torch.Tensor(c2w[:3, :4]))
108
- coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, H-1, H), indexing='ij'), -1)
109
- coords = torch.reshape(coords, [-1,2]).long()
110
- rays_o = rays_o[coords[:, 0], coords[:, 1]]
111
- rays_d = rays_d[coords[:, 0], coords[:, 1]]
112
- batch_rays = torch.stack([rays_o, rays_d], 0)
113
- batch_rays_list.append(batch_rays)
114
- batch_rays_list = torch.stack(batch_rays_list, 0)
115
 
116
  print("download finish")
117
  def marching_cube(b, text, global_info):
 
52
  return rgb
53
 
54
  config = "3DTopia/configs/default.yaml"
55
+ # local_ckpt = "3DTopia/checkpoints/3dtopia_diffusion_state_dict.ckpt"
56
+ local_ckpt = "/data/3DTopia_all/3DTopia_code/checkpoints/model.safetensors"
57
  if os.path.exists(local_ckpt):
58
  ckpt = local_ckpt
59
  else:
 
63
  os.makedirs("tmp", exist_ok=True)
64
  print("download finish")
65
 
66
+ import sys
67
+ import traceback
68
+
69
+ try:
70
+ if ckpt.endswith(".ckpt"):
71
+ model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(ckpt, map_location='cpu', strict=False, **configs.model.params)
72
+ elif ckpt.endswith(".safetensors"):
73
+ print("download finish")
74
+ model = get_obj_from_str(configs.model["target"])(**configs.model.params)
75
+ print("download finish")
76
+ model_ckpt = load_file(ckpt)
77
+ print("download finish")
78
+ model.load_state_dict(model_ckpt)
79
+ print("download finish")
80
+ else:
81
+ raise NotImplementedError
82
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
83
+ model = model.to(device)
84
  print("download finish")
85
+ sampler = DDIMSampler(model)
86
+
87
+ img_size = configs.model.params.unet_config.params.image_size
88
+ channels = configs.model.params.unet_config.params.in_channels
89
+ shape = [channels, img_size, img_size * 3]
90
+
91
+ pose_folder = '3DTopia/assets/sample_data/pose'
92
+ poses_fname = sorted([os.path.join(pose_folder, f) for f in os.listdir(pose_folder)])
93
+ batch_rays_list = []
94
+ H = 128
95
+ ratio = 512 // H
96
+ for p in poses_fname:
97
+ c2w = np.loadtxt(p).reshape(4, 4)
98
+ c2w[:3, 3] *= 2.2
99
+ c2w = np.array([
100
+ [1, 0, 0, 0],
101
+ [0, 0, -1, 0],
102
+ [0, 1, 0, 0],
103
+ [0, 0, 0, 1]
104
+ ]) @ c2w
105
+
106
+ k = np.array([
107
+ [560 / ratio, 0, H * 0.5],
108
+ [0, 560 / ratio, H * 0.5],
109
+ [0, 0, 1]
110
+ ])
111
+
112
+ rays_o, rays_d = get_rays(H, H, torch.Tensor(k), torch.Tensor(c2w[:3, :4]))
113
+ coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, H-1, H), indexing='ij'), -1)
114
+ coords = torch.reshape(coords, [-1,2]).long()
115
+ rays_o = rays_o[coords[:, 0], coords[:, 1]]
116
+ rays_d = rays_d[coords[:, 0], coords[:, 1]]
117
+ batch_rays = torch.stack([rays_o, rays_d], 0)
118
+ batch_rays_list.append(batch_rays)
119
+ batch_rays_list = torch.stack(batch_rays_list, 0)
120
+ except Exception as e:
121
+ print(e)
122
+ print(traceback.format_exc())
123
+ print(sys.exc_info()[2])
124
+
125
 
126
  print("download finish")
127
  def marching_cube(b, text, global_info):