DietNerf-Demo / app.py
hassiahk's picture
Fix Chair model link
074dc08
raw history blame
No virus
3.39 kB
import os
import builtins
import math
import json
import streamlit as st
import gdown
from demo.src.models import load_trained_model
from demo.src.utils import render_predict_from_pose, predict_to_image
st.set_page_config(page_title="DietNeRF")
with open("config.json") as f:
cfg = json.loads(f.read())
MODEL_DIR = "models"
def select_model():
obj_select = st.selectbox("Select a Scene", ("Mic", "Chair", "Lego", "Ship", "Hotdog"))
DIET_NERF_MODEL_NAME = cfg[obj_select]["DIET_NERF_MODEL_NAME"]
DIET_NERF_FILE_ID = cfg[obj_select]["DIET_NERF_FILE_ID"]
return DIET_NERF_MODEL_NAME, DIET_NERF_FILE_ID
st.title("DietNeRF")
caption = (
"DietNeRF achieves SoTA few-shot learning capacity in 3D model reconstruction. "
"Thanks to the 2D supervision by CLIP (aka. _Semantic Consisteny Loss_), "
"it can render novel and challenging views with ONLY 8 training images, "
"outperforming original NeRF!"
)
st.markdown(caption)
st.markdown("")
DIET_NERF_MODEL_NAME, DIET_NERF_FILE_ID = select_model()
@st.cache(show_spinner=False)
def download_model():
os.makedirs(MODEL_DIR, exist_ok=True)
_model_path = os.path.join(MODEL_DIR, DIET_NERF_MODEL_NAME)
url = f"https://drive.google.com/uc?id={DIET_NERF_FILE_ID}"
gdown.download(url, _model_path, quiet=False)
print(f"Model downloaded from google drive: {_model_path}")
@st.cache(show_spinner=False, allow_output_mutation=True)
def fetch_model():
model, state = load_trained_model(MODEL_DIR, DIET_NERF_MODEL_NAME)
return model, state
model_path = os.path.join(MODEL_DIR, DIET_NERF_MODEL_NAME)
if not os.path.isfile(model_path):
download_model()
model, state = fetch_model()
pi = math.pi
st.sidebar.markdown(
"""
<style>
.aligncenter {
text-align: center;
}
</style>
<p class="aligncenter">
<img src="https://user-images.githubusercontent.com/77657524/126361638-4aad58e8-4efb-4fc5-bf78-f53d03799e1e.png" width="420" height="400"/>
</p>
""",
unsafe_allow_html=True,
)
st.sidebar.markdown(
"""
<p style='text-align: center'>
<a href="https://github.com/codestella/putting-nerf-on-a-diet" target="_blank">GitHub</a> | <a href="https://www.notion.so/DietNeRF-Putting-NeRF-on-a-Diet-4aeddae95d054f1d91686f02bdb74745" target="_blank">Project Report</a>
</p>
""",
unsafe_allow_html=True,
)
st.sidebar.header("SELECT YOUR VIEW DIRECTION")
theta = st.sidebar.slider(
"Theta", min_value=-pi, max_value=pi, step=0.5, value=0.0, help="Rotational angle in Horizontal direction"
)
phi = st.sidebar.slider(
"Phi", min_value=0.0, max_value=0.5 * pi, step=0.1, value=1.0, help="Rotational angle in Vertical direction"
)
radius = st.sidebar.slider(
"Radius", min_value=2.0, max_value=6.0, step=1.0, value=3.0, help="Distance between object and the viewer"
)
st.markdown("")
with st.spinner("Rendering View..."):
with st.spinner("It may take 2-3 mins. So, why don't you read our report in the meantime"):
pred_color, _ = render_predict_from_pose(state, theta, phi, radius)
im = predict_to_image(pred_color)
w, _ = im.size
new_w = int(2 * w)
im = im.resize(size=(new_w, new_w))
# diet_nerf_col = st.beta_columns(1)
# diet_nerf_col.markdown(
# """<h4 style='text-align: center'>DietNerF</h4>""", unsafe_allow_html=True
# )
st.image(im, use_column_width="auto")