Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# | |
# @File: app.py | |
# @Author: Haozhe Xie | |
# @Date: 2024-03-02 16:30:00 | |
# @Last Modified by: Haozhe Xie | |
# @Last Modified at: 2024-03-03 10:39:25 | |
# @Email: root@haozhexie.com | |
import logging | |
import os | |
import torch | |
import gradio as gr | |
import subprocess | |
import urllib.request | |
import ssl | |
import sys | |
# Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed | |
ssl._create_default_https_context = ssl._create_unverified_context | |
sys.path.append(os.path.join(os.path.dirname(__file__), "citydreamer")) | |
# Import CityDreamer modules | |
# import citydreamer.model | |
# import citydreamer.inference | |
def setup_runtime_env(): | |
subprocess.call(["pip", "freeze"]) | |
ext_dir = os.path.join(os.path.dirname(__file__), "citydreamer", "extensions") | |
for e in os.listdir(ext_dir): | |
if not os.path.isdir(e): | |
continue | |
subprocess.call(["pip", "install", "."], workdir=os.path.join(ext_dir, e)) | |
def get_models(): | |
if not os.path.exists("CityDreamer-Fgnd.pth"): | |
urllib.request.urlretrieve( | |
"https://huggingface.co/hzxie/city-dreamer/resolve/main/CityDreamer-Fgnd.pth", | |
"CityDreamer-Fgnd.pth", | |
) | |
if not os.path.exists("CityDreamer-Bgnd.pth"): | |
urllib.request.urlretrieve( | |
"https://huggingface.co/hzxie/city-dreamer/resolve/main/CityDreamer-Bgnd.pth", | |
"CityDreamer-Bgnd.pth", | |
) | |
bgm_ckpt = torch.load("CityDreamer-Bgnd.pth") | |
fgm_ckpt = torch.load("CityDreamer-Fgnd.pth") | |
bgm = citydreamer.model.GanCraftGenerator(bgm_ckpt["cfg"]) | |
fgm = citydreamer.model.GanCraftGenerator(fgm_ckpt["cfg"]) | |
if torch.cuda.is_available(): | |
fgm = torch.nn.DataParallel(fgm).cuda().eval() | |
bgm = torch.nn.DataParallel(bgm).cuda().eval() | |
return bgm, fgm | |
def get_generated_city(radius, altitude, azimuth): | |
print(radius, altitude, azimuth) | |
def main(debug): | |
title = "CityDreamer Demo 🏙️" | |
with open("README.md", "r") as f: | |
markdown = f.read() | |
desc = markdown[markdown.rfind("---") + 3:] | |
with open("ARTICLE.md", "r") as f: | |
arti = f.read() | |
app = gr.Interface( | |
get_generated_city, | |
[ | |
gr.Slider( | |
128, 512, value=320, step=5, label="Camera Radius (m)" | |
), | |
gr.Slider( | |
256, 512, value=384, step=5, label="Camera Altitude (m)" | |
), | |
gr.Slider(0, 360, value=180, step=5, label="Camera Azimuth (°)"), | |
], | |
[gr.Image(type="numpy", label="Generated City")], | |
title=title, | |
description=desc, | |
article=arti, | |
allow_flagging="never", | |
) | |
app.queue(api_open=False) | |
app.launch(debug=debug) | |
if __name__ == "__main__": | |
logging.basicConfig( | |
format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO | |
) | |
logging.info("Compile CUDA extensions...") | |
# setup_runtime_env() | |
logging.info("Downloading pretrained models...") | |
# fgm, bgm = get_models() | |
logging.info("Starting the main application...") | |
main(os.getenv("DEBUG") == "1") | |