File size: 3,788 Bytes
c2164fe
 
 
 
 
 
3d1b36b
c2164fe
 
79df973
8b8b671
79df973
8b8b671
 
79df973
8b8b671
79df973
 
 
 
c2164fe
8b8b671
 
 
79df973
c2164fe
 
 
79df973
 
 
8b8b671
 
79df973
8b8b671
c2164fe
79df973
c2164fe
79df973
 
 
 
 
8b8b671
79df973
 
8b8b671
 
79df973
 
8b8b671
79df973
 
e60fd27
79df973
8b8b671
79df973
 
 
 
 
c2164fe
 
e60fd27
79df973
 
 
 
 
 
e60fd27
 
 
 
79df973
 
 
 
c2164fe
 
8b8b671
 
c2164fe
 
79df973
8b8b671
 
c2164fe
8b8b671
 
 
3d1b36b
 
 
 
8b8b671
 
 
 
 
 
 
 
 
c2164fe
 
 
8b8b671
 
 
79df973
f8ec29a
79df973
8b8b671
79df973
 
 
 
 
 
 
 
 
 
8b8b671
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# -*- coding: utf-8 -*-
#
# @File:   app.py
# @Author: Haozhe Xie
# @Date:   2024-03-02 16:30:00
# @Last Modified by: Haozhe Xie
# @Last Modified at: 2024-03-03 16:08:25
# @Email:  root@haozhexie.com

import gradio as gr
import logging
import numpy as np
import os
import ssl
import subprocess
import sys
import torch
import urllib.request

from PIL import Image

# Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed
ssl._create_default_https_context = ssl._create_unverified_context
# Import CityDreamer modules
sys.path.append(os.path.join(os.path.dirname(__file__), "citydreamer"))


def setup_runtime_env():
    logging.info("CUDA version is %s" % subprocess.check_output(["nvcc", "--version"]))
    logging.info("GCC version is %s" % subprocess.check_output(["g++", "--version"]))
    # Compile CUDA extensions
    ext_dir = os.path.join(os.path.dirname(__file__), "citydreamer", "extensions")
    for e in os.listdir(ext_dir):
        if not os.path.isdir(os.path.join(ext_dir, e)):
            continue

        subprocess.call(["pip", "install", "."], cwd=os.path.join(ext_dir, e))


def get_models(file_name):
    import citydreamer.model

    if not os.path.exists(file_name):
        urllib.request.urlretrieve(
            "https://huggingface.co/hzxie/city-dreamer/resolve/main/%s" % file_name,
            file_name,
        )

    ckpt = torch.load(file_name)
    model = citydreamer.model.GanCraftGenerator(ckpt["cfg"])
    if torch.cuda.is_available():
        model = torch.nn.DataParallel(model).cuda().eval()

    model.load_state_dict(ckpt["gancraft_g"], strict=False)
    return model


def get_city_layout():
    hf = np.array(Image.open("assets/NYC-HghtFld.png"))
    seg = np.array(Image.open("assets/NYC-SegMap.png").convert("P"))
    return hf, seg


def get_generated_city(radius, altitude, azimuth, map_center):
    # The import must be done after CUDA extension compilation
    import citydreamer.inference

    return citydreamer.inference.generate_city(
        get_generated_city.fgm,
        get_generated_city.bgm,
        get_generated_city.hf.copy(),
        get_generated_city.seg.copy(),
        map_center,
        map_center,
        radius,
        altitude,
        azimuth,
    )


def main(debug):
    title = "CityDreamer Demo 🏙️"
    with open("README.md", "r") as f:
        markdown = f.read()
        desc = markdown[markdown.rfind("---") + 3 :]
    with open("ARTICLE.md", "r") as f:
        arti = f.read()

    app = gr.Interface(
        get_generated_city,
        [
            gr.Slider(128, 512, value=343, step=5, label="Camera Radius (m)"),
            gr.Slider(256, 512, value=296, step=5, label="Camera Altitude (m)"),
            gr.Slider(0, 360, value=60, step=5, label="Camera Azimuth (°)"),
            gr.Slider(1440, 6752, value=3970, step=5, label="Map Center (px)"),
        ],
        [gr.Image(type="numpy", label="Generated City")],
        title=title,
        description=desc,
        article=arti,
        allow_flagging="never",
    )
    app.queue(api_open=False)
    app.launch(debug=debug)


if __name__ == "__main__":
    logging.basicConfig(
        format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO
    )
    logging.info("Compiling CUDA extensions...")
    setup_runtime_env()

    logging.info("Downloading pretrained models...")
    fgm = get_models("CityDreamer-Fgnd.pth")
    bgm = get_models("CityDreamer-Bgnd.pth")
    get_generated_city.fgm = fgm
    get_generated_city.bgm = bgm

    logging.info("Loading New York city layout to RAM...")
    hf, seg = get_city_layout()
    get_generated_city.hf = hf
    get_generated_city.seg = seg

    logging.info("Starting the main application...")
    main(os.getenv("DEBUG") == "1")