Spaces:
Sleeping
Sleeping
fix: the pretrained model weights are not loaded.
Browse files- app.py +8 -4
- citydreamer/inference.py +4 -12
app.py
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
# @Author: Haozhe Xie
|
5 |
# @Date: 2024-03-02 16:30:00
|
6 |
# @Last Modified by: Haozhe Xie
|
7 |
-
# @Last Modified at: 2024-03-03
|
8 |
# @Email: root@haozhexie.com
|
9 |
|
10 |
import gradio as gr
|
@@ -51,6 +51,7 @@ def get_models(file_name):
|
|
51 |
if torch.cuda.is_available():
|
52 |
model = torch.nn.DataParallel(model).cuda().eval()
|
53 |
|
|
|
54 |
return model
|
55 |
|
56 |
|
@@ -60,15 +61,17 @@ def get_city_layout():
|
|
60 |
return hf, seg
|
61 |
|
62 |
|
63 |
-
def get_generated_city(radius, altitude, azimuth):
|
64 |
# The import must be done after CUDA extension compilation
|
65 |
import citydreamer.inference
|
66 |
|
67 |
return citydreamer.inference.generate_city(
|
68 |
get_generated_city.fgm,
|
69 |
get_generated_city.bgm,
|
70 |
-
get_generated_city.hf,
|
71 |
-
get_generated_city.seg,
|
|
|
|
|
72 |
radius,
|
73 |
altitude,
|
74 |
azimuth,
|
@@ -89,6 +92,7 @@ def main(debug):
|
|
89 |
gr.Slider(128, 512, value=320, step=5, label="Camera Radius (m)"),
|
90 |
gr.Slider(256, 512, value=384, step=5, label="Camera Altitude (m)"),
|
91 |
gr.Slider(0, 360, value=180, step=5, label="Camera Azimuth (°)"),
|
|
|
92 |
],
|
93 |
[gr.Image(type="numpy", label="Generated City")],
|
94 |
title=title,
|
|
|
4 |
# @Author: Haozhe Xie
|
5 |
# @Date: 2024-03-02 16:30:00
|
6 |
# @Last Modified by: Haozhe Xie
|
7 |
+
# @Last Modified at: 2024-03-03 16:01:01
|
8 |
# @Email: root@haozhexie.com
|
9 |
|
10 |
import gradio as gr
|
|
|
51 |
if torch.cuda.is_available():
|
52 |
model = torch.nn.DataParallel(model).cuda().eval()
|
53 |
|
54 |
+
model.load_state_dict(ckpt["gancraft_g"], strict=False)
|
55 |
return model
|
56 |
|
57 |
|
|
|
61 |
return hf, seg
|
62 |
|
63 |
|
64 |
+
def get_generated_city(radius, altitude, azimuth, map_center):
|
65 |
# The import must be done after CUDA extension compilation
|
66 |
import citydreamer.inference
|
67 |
|
68 |
return citydreamer.inference.generate_city(
|
69 |
get_generated_city.fgm,
|
70 |
get_generated_city.bgm,
|
71 |
+
get_generated_city.hf.copy(),
|
72 |
+
get_generated_city.seg.copy(),
|
73 |
+
map_center,
|
74 |
+
map_center,
|
75 |
radius,
|
76 |
altitude,
|
77 |
azimuth,
|
|
|
92 |
gr.Slider(128, 512, value=320, step=5, label="Camera Radius (m)"),
|
93 |
gr.Slider(256, 512, value=384, step=5, label="Camera Altitude (m)"),
|
94 |
gr.Slider(0, 360, value=180, step=5, label="Camera Azimuth (°)"),
|
95 |
+
gr.Slider(1440, 6752, value=2656, step=5, label="Map Center (px)"),
|
96 |
],
|
97 |
[gr.Image(type="numpy", label="Generated City")],
|
98 |
title=title,
|
citydreamer/inference.py
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
# @Author: Haozhe Xie
|
5 |
# @Date: 2024-03-02 16:30:00
|
6 |
# @Last Modified by: Haozhe Xie
|
7 |
-
# @Last Modified at: 2024-03-03
|
8 |
# @Email: root@haozhexie.com
|
9 |
|
10 |
import copy
|
@@ -53,7 +53,7 @@ CONSTANTS = {
|
|
53 |
}
|
54 |
|
55 |
|
56 |
-
def generate_city(fgm, bgm, hf, seg, radius, altitude, azimuth):
|
57 |
cam_pos = get_orbit_camera_position(radius, altitude, azimuth)
|
58 |
seg, building_stats = get_instance_seg_map(seg)
|
59 |
# Generate latent codes
|
@@ -63,15 +63,6 @@ def generate_city(fgm, bgm, hf, seg, radius, altitude, azimuth):
|
|
63 |
bgm.module.cfg.NETWORK.GANCRAFT.STYLE_DIM,
|
64 |
bgm.output_device,
|
65 |
)
|
66 |
-
# Random choose the center of the patch
|
67 |
-
cy = (
|
68 |
-
np.random.randint(seg.shape[0] - CONSTANTS["EXTENDED_VOL_SIZE"])
|
69 |
-
+ CONSTANTS["EXTENDED_VOL_SIZE"] // 2
|
70 |
-
)
|
71 |
-
cx = (
|
72 |
-
np.random.randint(seg.shape[1] - CONSTANTS["EXTENDED_VOL_SIZE"])
|
73 |
-
+ CONSTANTS["EXTENDED_VOL_SIZE"] // 2
|
74 |
-
)
|
75 |
# Generate local image patch of the height field and seg map
|
76 |
part_hf, part_seg = get_part_hf_seg(hf, seg, cx, cy, CONSTANTS["EXTENDED_VOL_SIZE"])
|
77 |
# Generate local image patch of the height field and seg map
|
@@ -98,9 +89,10 @@ def generate_city(fgm, bgm, hf, seg, radius, altitude, azimuth):
|
|
98 |
bg_z,
|
99 |
building_zs,
|
100 |
)
|
101 |
-
|
102 |
np.uint8
|
103 |
)
|
|
|
104 |
|
105 |
|
106 |
def get_orbit_camera_position(radius, altitude, azimuth):
|
|
|
4 |
# @Author: Haozhe Xie
|
5 |
# @Date: 2024-03-02 16:30:00
|
6 |
# @Last Modified by: Haozhe Xie
|
7 |
+
# @Last Modified at: 2024-03-03 15:59:00
|
8 |
# @Email: root@haozhexie.com
|
9 |
|
10 |
import copy
|
|
|
53 |
}
|
54 |
|
55 |
|
56 |
+
def generate_city(fgm, bgm, hf, seg, cx, cy, radius, altitude, azimuth):
|
57 |
cam_pos = get_orbit_camera_position(radius, altitude, azimuth)
|
58 |
seg, building_stats = get_instance_seg_map(seg)
|
59 |
# Generate latent codes
|
|
|
63 |
bgm.module.cfg.NETWORK.GANCRAFT.STYLE_DIM,
|
64 |
bgm.output_device,
|
65 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
# Generate local image patch of the height field and seg map
|
67 |
part_hf, part_seg = get_part_hf_seg(hf, seg, cx, cy, CONSTANTS["EXTENDED_VOL_SIZE"])
|
68 |
# Generate local image patch of the height field and seg map
|
|
|
89 |
bg_z,
|
90 |
building_zs,
|
91 |
)
|
92 |
+
img = ((img.cpu().numpy().squeeze().transpose((1, 2, 0)) / 2 + 0.5) * 255).astype(
|
93 |
np.uint8
|
94 |
)
|
95 |
+
return img
|
96 |
|
97 |
|
98 |
def get_orbit_camera_position(radius, altitude, azimuth):
|