Spaces:
Sleeping
Sleeping
Switch to Gradio.
Browse files- .gitignore +2 -0
- ARTICLE.md +20 -0
- README.md +4 -4
- app.py +78 -30
- assets/style.css +7 -0
- requirements.txt +7 -2
.gitignore
CHANGED
@@ -179,3 +179,5 @@ configs/
|
|
179 |
data/
|
180 |
notebooks/
|
181 |
output/
|
|
|
|
|
|
179 |
data/
|
180 |
notebooks/
|
181 |
output/
|
182 |
+
flagged/
|
183 |
+
*.pth
|
ARTICLE.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##### Citation π
|
2 |
+
|
3 |
+
If our work is useful for your research, please consider citing:
|
4 |
+
|
5 |
+
```bibtex
|
6 |
+
@inproceedings{xie2024citydreamer,
|
7 |
+
title = {City{D}reamer: Compositional Generative Model of Unbounded 3{D} Cities},
|
8 |
+
author = {Xie, Haozhe and
|
9 |
+
Chen, Zhaoxi and
|
10 |
+
Hong, Fangzhou and
|
11 |
+
Liu, Ziwei},
|
12 |
+
booktitle = {CVPR},
|
13 |
+
year = {2024}
|
14 |
+
}
|
15 |
+
```
|
16 |
+
|
17 |
+
##### License π
|
18 |
+
|
19 |
+
This project is licensed under [S-Lab License 1.0](https://huggingface.co/hzxie/city-dreamer/blob/main/LICENSE).
|
20 |
+
Redistribution and use for non-commercial purposes should follow this license.
|
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
title: CityDreamer
|
3 |
emoji: ποΈ
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk:
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: other
|
|
|
1 |
---
|
2 |
title: CityDreamer
|
3 |
emoji: ποΈ
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.41.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: other
|
app.py
CHANGED
@@ -4,54 +4,102 @@
|
|
4 |
# @Author: Haozhe Xie
|
5 |
# @Date: 2024-03-02 16:30:00
|
6 |
# @Last Modified by: Haozhe Xie
|
7 |
-
# @Last Modified at: 2024-03-
|
8 |
# @Email: root@haozhexie.com
|
9 |
|
10 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
13 |
|
14 |
-
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
def setup_runtime_env():
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
def get_models():
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
def get_generated_city(radius, altitude, azimuth):
|
26 |
-
|
27 |
|
28 |
|
29 |
-
def main(
|
30 |
-
|
31 |
-
page_title="CityDreamer Demo",
|
32 |
-
page_icon="ποΈ",
|
33 |
-
)
|
34 |
-
# Main
|
35 |
-
st.write("# CityDreamer Minimal Demo ποΈ")
|
36 |
with open("README.md", "r") as f:
|
37 |
markdown = f.read()
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
|
51 |
if __name__ == "__main__":
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
4 |
# @Author: Haozhe Xie
|
5 |
# @Date: 2024-03-02 16:30:00
|
6 |
# @Last Modified by: Haozhe Xie
|
7 |
+
# @Last Modified at: 2024-03-03 10:25:43
|
8 |
# @Email: root@haozhexie.com
|
9 |
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import torch
|
13 |
+
import gradio as gr
|
14 |
+
import subprocess
|
15 |
+
import urllib.request
|
16 |
+
import ssl
|
17 |
+
import sys
|
18 |
|
19 |
+
# Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed
|
20 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
21 |
|
22 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "citydreamer"))
|
23 |
+
# Import CityDreamer modules
|
24 |
+
# import citydreamer.model
|
25 |
+
# import citydreamer.inference
|
26 |
|
27 |
|
28 |
def setup_runtime_env():
|
29 |
+
subprocess.call(["pip", "freeze"])
|
30 |
+
ext_dir = os.path.join(os.path.dirname(__file__), "citydreamer", "extensions")
|
31 |
+
for e in os.listdir(ext_dir):
|
32 |
+
if not os.path.isdir(e):
|
33 |
+
continue
|
34 |
+
subprocess.call(["pip", "install", "."], workdir=os.path.join(ext_dir, e))
|
35 |
|
36 |
|
37 |
def get_models():
|
38 |
+
if not os.path.exists("CityDreamer-Fgnd.pth"):
|
39 |
+
urllib.request.urlretrieve(
|
40 |
+
"https://huggingface.co/hzxie/city-dreamer/resolve/main/CityDreamer-Fgnd.pth",
|
41 |
+
"CityDreamer-Fgnd.pth",
|
42 |
+
)
|
43 |
+
if not os.path.exists("CityDreamer-Bgnd.pth"):
|
44 |
+
urllib.request.urlretrieve(
|
45 |
+
"https://huggingface.co/hzxie/city-dreamer/resolve/main/CityDreamer-Bgnd.pth",
|
46 |
+
"CityDreamer-Bgnd.pth",
|
47 |
+
)
|
48 |
+
|
49 |
+
bgm_ckpt = torch.load("CityDreamer-Bgnd.pth")
|
50 |
+
fgm_ckpt = torch.load("CityDreamer-Fgnd.pth")
|
51 |
+
bgm = citydreamer.model.GanCraftGenerator(bgm_ckpt["cfg"])
|
52 |
+
fgm = citydreamer.model.GanCraftGenerator(fgm_ckpt["cfg"])
|
53 |
+
if torch.cuda.is_available():
|
54 |
+
fgm = torch.nn.DataParallel(fgm).cuda().eval()
|
55 |
+
bgm = torch.nn.DataParallel(bgm).cuda().eval()
|
56 |
+
|
57 |
+
return bgm, fgm
|
58 |
|
59 |
|
60 |
def get_generated_city(radius, altitude, azimuth):
|
61 |
+
print(radius, altitude, azimuth)
|
62 |
|
63 |
|
64 |
+
def main(debug):
|
65 |
+
title = "CityDreamer Demo ποΈ"
|
|
|
|
|
|
|
|
|
|
|
66 |
with open("README.md", "r") as f:
|
67 |
markdown = f.read()
|
68 |
+
desc = markdown[markdown.rfind("---") + 3:]
|
69 |
+
with open("ARTICLE.md", "r") as f:
|
70 |
+
arti = f.read()
|
71 |
+
with open("assets/style.css") as f:
|
72 |
+
css = f.read()
|
73 |
|
74 |
+
app = gr.Interface(
|
75 |
+
get_generated_city,
|
76 |
+
[
|
77 |
+
gr.Slider(
|
78 |
+
128, 512, value=320, step=5, label="Camera Radius (m)"
|
79 |
+
),
|
80 |
+
gr.Slider(
|
81 |
+
256, 512, value=384, step=5, label="Camera Altitude (m)"
|
82 |
+
),
|
83 |
+
gr.Slider(0, 360, value=180, step=5, label="Camera Azimuth (Β°)"),
|
84 |
+
],
|
85 |
+
[gr.Image(type="numpy", label="Generated City")],
|
86 |
+
title=title,
|
87 |
+
description=desc,
|
88 |
+
article=arti,
|
89 |
+
allow_flagging="never",
|
90 |
+
css=css,
|
91 |
+
)
|
92 |
+
app.queue(api_open=False)
|
93 |
+
app.launch(debug=debug)
|
94 |
|
95 |
|
96 |
if __name__ == "__main__":
|
97 |
+
logging.basicConfig(
|
98 |
+
format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO
|
99 |
+
)
|
100 |
+
logging.info("Compile CUDA extensions...")
|
101 |
+
# setup_runtime_env()
|
102 |
+
logging.info("Downloading pretrained models...")
|
103 |
+
# fgm, bgm = get_models()
|
104 |
+
logging.info("Starting the main application...")
|
105 |
+
main(os.getenv("DEBUG") == "1")
|
assets/style.css
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
display: block;
|
3 |
+
}
|
4 |
+
|
5 |
+
p img {
|
6 |
+
display: inline-block;
|
7 |
+
}
|
requirements.txt
CHANGED
@@ -1,2 +1,7 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu113
|
2 |
+
torch==1.12.0
|
3 |
+
torchvision
|
4 |
+
|
5 |
+
numpy
|
6 |
+
opencv-python
|
7 |
+
gradio
|