DietNerf-Demo / app.py
osanseviero's picture
osanseviero HF staff
Update app.py
39dcc54
raw history blame
No virus
6.32 kB
import json
import math
import random
import os
import gdown
import streamlit as st
from demo.src.models import load_trained_model
from demo.src.utils import predict_to_image, render_predict_from_pose
st.set_page_config(page_title="DietNeRF")
with open("config.json") as f:
cfg = json.loads(f.read())
MODEL_DIR = "models"
SCENES_LIST = ["Mic", "Chair", "Lego", "Drums", "Ship", "Hotdog"]
# random_index = random.randint(0, len(SCENES_LIST) - 1)
def select_model(obj_select):
DIET_NERF_MODEL_NAME = cfg[obj_select]["DIET_NERF_MODEL_NAME"]
DIET_NERF_FILE_ID = cfg[obj_select]["DIET_NERF_FILE_ID"]
NERF_MODEL_NAME = cfg[obj_select]["NERF_MODEL_NAME"]
NERF_FILE_ID = cfg[obj_select]["NERF_FILE_ID"]
return DIET_NERF_MODEL_NAME, DIET_NERF_FILE_ID, NERF_MODEL_NAME, NERF_FILE_ID
pi = math.pi
st.title("DietNeRF")
st.sidebar.markdown(
"""
<style>
.aligncenter {
text-align: center;
}
</style>
<p class="aligncenter">
<img src="https://user-images.githubusercontent.com/77657524/126361638-4aad58e8-4efb-4fc5-bf78-f53d03799e1e.png" width="420" height="400"/>
</p>
""",
unsafe_allow_html=True,
)
st.sidebar.markdown(
"""
<p style='text-align: center'>
<a href="https://github.com/codestella/putting-nerf-on-a-diet" target="_blank">GitHub</a> | <a href="https://www.notion.so/DietNeRF-Putting-NeRF-on-a-Diet-4aeddae95d054f1d91686f02bdb74745" target="_blank">Project Report</a>
</p>
""",
unsafe_allow_html=True,
)
st.sidebar.header("SELECT YOUR VIEW DIRECTION")
theta = st.sidebar.slider(
"Rotation (Left to Right)",
min_value=-pi,
max_value=pi,
step=0.5,
value=0.0,
help="Rotational angle in Vertical direction (Theta)",
)
phi = st.sidebar.slider(
"Rotation (Bottom to Top)",
min_value=0.0,
max_value=0.5 * pi,
step=0.1,
value=1.0,
help="Rotational angle in Horizontal direction (Phi)",
)
radius = st.sidebar.slider(
"Distance (Close to Far)",
min_value=2.0,
max_value=6.0,
step=1.0,
value=3.0,
help="Distance between object and the viewer (Radius)",
)
caption = (
"`DietNeRF` achieves state-of-the-art few-shot learning capacity in 3D model reconstruction. "
"Thanks to the 2D supervision by `CLIP (aka. Semantic Consisteny Loss)`, "
"it can render novel and challenging views with `ONLY 8 training images`, "
"**outperforming** original [NeRF](https://www.matthewtancik.com/nerf)!"
)
st.markdown(caption)
st.markdown(
"> πŸ“’ **NOTE**: To get a detailed comparison of differences in results between `DietNeRF` and `NeRF`, you can take a look at the "
"[Experimental Results](https://www.notion.so/DietNeRF-Putting-NeRF-on-a-Diet-4aeddae95d054f1d91686f02bdb74745#0f6bc8f1008d4765b9b4635999626d4b) "
"section in our project report."
)
obj_select = st.selectbox("Select a Scene", SCENES_LIST, index=0)
DIET_NERF_MODEL_NAME, DIET_NERF_FILE_ID, NERF_MODEL_NAME, NERF_FILE_ID = select_model(obj_select)
@st.cache(show_spinner=False)
def download_diet_nerf_model():
os.makedirs(MODEL_DIR, exist_ok=True)
diet_nerf_model_path = os.path.join(MODEL_DIR, DIET_NERF_MODEL_NAME)
url = f"https://drive.google.com/uc?id={DIET_NERF_FILE_ID}"
gdown.download(url, diet_nerf_model_path, quiet=False)
print(f"Model downloaded from google drive: {diet_nerf_model_path}")
# @st.cache(show_spinner=False)
# def download_nerf_model():
# nerf_model_path = os.path.join(MODEL_DIR, NERF_MODEL_NAME)
# url = f"https://drive.google.com/uc?id={NERF_FILE_ID}"
# gdown.download(url, nerf_model_path, quiet=False)
# print(f"Model downloaded from google drive: {nerf_model_path}")
@st.cache(show_spinner=False, allow_output_mutation=True)
def fetch_diet_nerf_model():
model, state = load_trained_model(MODEL_DIR, DIET_NERF_MODEL_NAME)
return model, state
# @st.cache(show_spinner=False, allow_output_mutation=True)
# def fetch_nerf_model():
# model, state = load_trained_model(MODEL_DIR, NERF_MODEL_NAME)
# return model, state
diet_nerf_model_path = os.path.join(MODEL_DIR, DIET_NERF_MODEL_NAME)
if not os.path.isfile(diet_nerf_model_path):
download_diet_nerf_model()
# nerf_model_path = os.path.join(MODEL_DIR, NERF_MODEL_NAME)
# if not os.path.isfile(nerf_model_path):
# download_nerf_model()
diet_nerf_model, diet_nerf_state = fetch_diet_nerf_model()
# nerf_model, nerf_state = fetch_nerf_model()
st.markdown("")
with st.spinner("Rendering view..."):
with st.spinner(
":information_source: **INFO**: It may take around 30-50 seconds to render the view. "
"In the meantime, why don't you take a look at our "
"[project report](https://www.notion.so/DietNeRF-Putting-NeRF-on-a-Diet-4aeddae95d054f1d91686f02bdb74745), "
"if you haven't already :slightly_smiling_face:"
):
dn_pred_color, _ = render_predict_from_pose(diet_nerf_state, theta, phi, radius)
dn_im = predict_to_image(dn_pred_color)
# dn_w, _ = dn_im.size
# dn_new_w = int(2 * dn_w)
# dn_im = dn_im.resize(size=(dn_new_w, dn_new_w))
# n_pred_color, _ = render_predict_from_pose(nerf_state, theta, phi, radius)
# n_im = predict_to_image(n_pred_color)
# n_w, _ = n_im.size
# n_new_w = int(2 * n_w)
# n_im = n_im.resize(size=(n_new_w, n_new_w))
# diet_nerf_col, nerf_col = st.beta_columns([1, 1])
st.markdown(
"> πŸ“’ **NOTE**: The rendered view does not fully reflect the true quality of the view generated by the model "
"because it has been downsampled to speedup the process."
)
st.markdown(f"""<h4 style='text-align: center'>Rendered view for {obj_select}</h4>""", unsafe_allow_html=True)
st.image(dn_im, use_column_width=True)
# nerf_col.markdown("""<h4 style='text-align: center'>NeRF</h4>""", unsafe_allow_html=True)
# nerf_col.image(n_im, use_column_width=True)
# st.markdown(
# "> πŸ“’ NOTE: The views may look similar to you but see the "
# "[Experimental Results](https://www.notion.so/DietNeRF-Putting-NeRF-on-a-Diet-4aeddae95d054f1d91686f02bdb74745#0f6bc8f1008d4765b9b4635999626d4b) "
# "section in our report to get a detailed comparison of differences between `DietNeRF` and `NeRF`."
# )