hzxie commited on
Commit
e60fd27
1 Parent(s): eada252

fix: the pretrained model weights are not loaded.

Browse files
Files changed (2) hide show
  1. app.py +8 -4
  2. citydreamer/inference.py +4 -12
app.py CHANGED
@@ -4,7 +4,7 @@
4
  # @Author: Haozhe Xie
5
  # @Date: 2024-03-02 16:30:00
6
  # @Last Modified by: Haozhe Xie
7
- # @Last Modified at: 2024-03-03 12:21:20
8
  # @Email: root@haozhexie.com
9
 
10
  import gradio as gr
@@ -51,6 +51,7 @@ def get_models(file_name):
51
  if torch.cuda.is_available():
52
  model = torch.nn.DataParallel(model).cuda().eval()
53
 
 
54
  return model
55
 
56
 
@@ -60,15 +61,17 @@ def get_city_layout():
60
  return hf, seg
61
 
62
 
63
- def get_generated_city(radius, altitude, azimuth):
64
  # The import must be done after CUDA extension compilation
65
  import citydreamer.inference
66
 
67
  return citydreamer.inference.generate_city(
68
  get_generated_city.fgm,
69
  get_generated_city.bgm,
70
- get_generated_city.hf,
71
- get_generated_city.seg,
 
 
72
  radius,
73
  altitude,
74
  azimuth,
@@ -89,6 +92,7 @@ def main(debug):
89
  gr.Slider(128, 512, value=320, step=5, label="Camera Radius (m)"),
90
  gr.Slider(256, 512, value=384, step=5, label="Camera Altitude (m)"),
91
  gr.Slider(0, 360, value=180, step=5, label="Camera Azimuth (°)"),
 
92
  ],
93
  [gr.Image(type="numpy", label="Generated City")],
94
  title=title,
 
4
  # @Author: Haozhe Xie
5
  # @Date: 2024-03-02 16:30:00
6
  # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2024-03-03 16:01:01
8
  # @Email: root@haozhexie.com
9
 
10
  import gradio as gr
 
51
  if torch.cuda.is_available():
52
  model = torch.nn.DataParallel(model).cuda().eval()
53
 
54
+ model.load_state_dict(ckpt["gancraft_g"], strict=False)
55
  return model
56
 
57
 
 
61
  return hf, seg
62
 
63
 
64
+ def get_generated_city(radius, altitude, azimuth, map_center):
65
  # The import must be done after CUDA extension compilation
66
  import citydreamer.inference
67
 
68
  return citydreamer.inference.generate_city(
69
  get_generated_city.fgm,
70
  get_generated_city.bgm,
71
+ get_generated_city.hf.copy(),
72
+ get_generated_city.seg.copy(),
73
+ map_center,
74
+ map_center,
75
  radius,
76
  altitude,
77
  azimuth,
 
92
  gr.Slider(128, 512, value=320, step=5, label="Camera Radius (m)"),
93
  gr.Slider(256, 512, value=384, step=5, label="Camera Altitude (m)"),
94
  gr.Slider(0, 360, value=180, step=5, label="Camera Azimuth (°)"),
95
+ gr.Slider(1440, 6752, value=2656, step=5, label="Map Center (px)"),
96
  ],
97
  [gr.Image(type="numpy", label="Generated City")],
98
  title=title,
citydreamer/inference.py CHANGED
@@ -4,7 +4,7 @@
4
  # @Author: Haozhe Xie
5
  # @Date: 2024-03-02 16:30:00
6
  # @Last Modified by: Haozhe Xie
7
- # @Last Modified at: 2024-03-03 12:10:18
8
  # @Email: root@haozhexie.com
9
 
10
  import copy
@@ -53,7 +53,7 @@ CONSTANTS = {
53
  }
54
 
55
 
56
- def generate_city(fgm, bgm, hf, seg, radius, altitude, azimuth):
57
  cam_pos = get_orbit_camera_position(radius, altitude, azimuth)
58
  seg, building_stats = get_instance_seg_map(seg)
59
  # Generate latent codes
@@ -63,15 +63,6 @@ def generate_city(fgm, bgm, hf, seg, radius, altitude, azimuth):
63
  bgm.module.cfg.NETWORK.GANCRAFT.STYLE_DIM,
64
  bgm.output_device,
65
  )
66
- # Random choose the center of the patch
67
- cy = (
68
- np.random.randint(seg.shape[0] - CONSTANTS["EXTENDED_VOL_SIZE"])
69
- + CONSTANTS["EXTENDED_VOL_SIZE"] // 2
70
- )
71
- cx = (
72
- np.random.randint(seg.shape[1] - CONSTANTS["EXTENDED_VOL_SIZE"])
73
- + CONSTANTS["EXTENDED_VOL_SIZE"] // 2
74
- )
75
  # Generate local image patch of the height field and seg map
76
  part_hf, part_seg = get_part_hf_seg(hf, seg, cx, cy, CONSTANTS["EXTENDED_VOL_SIZE"])
77
  # Generate local image patch of the height field and seg map
@@ -98,9 +89,10 @@ def generate_city(fgm, bgm, hf, seg, radius, altitude, azimuth):
98
  bg_z,
99
  building_zs,
100
  )
101
- return ((img.cpu().numpy().squeeze().transpose((1, 2, 0)) / 2 + 0.5) * 255).astype(
102
  np.uint8
103
  )
 
104
 
105
 
106
  def get_orbit_camera_position(radius, altitude, azimuth):
 
4
  # @Author: Haozhe Xie
5
  # @Date: 2024-03-02 16:30:00
6
  # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2024-03-03 15:59:00
8
  # @Email: root@haozhexie.com
9
 
10
  import copy
 
53
  }
54
 
55
 
56
+ def generate_city(fgm, bgm, hf, seg, cx, cy, radius, altitude, azimuth):
57
  cam_pos = get_orbit_camera_position(radius, altitude, azimuth)
58
  seg, building_stats = get_instance_seg_map(seg)
59
  # Generate latent codes
 
63
  bgm.module.cfg.NETWORK.GANCRAFT.STYLE_DIM,
64
  bgm.output_device,
65
  )
 
 
 
 
 
 
 
 
 
66
  # Generate local image patch of the height field and seg map
67
  part_hf, part_seg = get_part_hf_seg(hf, seg, cx, cy, CONSTANTS["EXTENDED_VOL_SIZE"])
68
  # Generate local image patch of the height field and seg map
 
89
  bg_z,
90
  building_zs,
91
  )
92
+ img = ((img.cpu().numpy().squeeze().transpose((1, 2, 0)) / 2 + 0.5) * 255).astype(
93
  np.uint8
94
  )
95
+ return img
96
 
97
 
98
  def get_orbit_camera_position(radius, altitude, azimuth):