File size: 3,469 Bytes
19677a1
0d35ba8
19677a1
 
0d35ba8
0071656
 
19677a1
 
 
0071656
 
19677a1
13bc063
0d35ba8
0071656
bd597e9
5272de4
0071656
bd597e9
0071656
 
 
bd597e9
0071656
 
 
bd597e9
0071656
 
 
fd0bc4a
0071656
 
 
bd597e9
19677a1
fd209d1
 
 
 
 
 
 
 
 
 
 
 
0d35ba8
 
 
 
 
 
 
0071656
0d35ba8
0071656
19677a1
 
 
 
 
 
 
 
0d35ba8
 
 
 
19677a1
 
5272de4
 
 
 
 
0071656
 
 
 
 
 
 
 
 
 
 
13bc063
 
 
8bc9869
 
1a30119
0071656
1a30119
bd597e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import builtins
import math
import streamlit as st
import gdown

# from google_drive_downloader import GoogleDriveDownloader as gdd

from demo.src.models import load_trained_model
from demo.src.utils import render_predict_from_pose, predict_to_image

# from demo.src.config import MODEL_DIR, MODEL_NAME, FILE_ID

st.set_page_config(page_title="DietNeRF")


def select_model():
    obj_select = st.selectbox("Select a scene", ("Chair", "Lego", "Ship", "Hotdog"))
    if obj_select == "Chair":
        FILE_ID = "17dj0pQieo94TozFv-noSBkXebduij1aM"
        MODEL_DIR = "models"
        MODEL_NAME = "diet_nerf_chair"
    elif obj_select == "Lego":
        FILE_ID = "1D9I-qIVMPaxuCHfUWPWMHaoLYtAmCjwI"
        MODEL_DIR = "models"
        MODEL_NAME = "diet_nerf_lego"
    elif obj_select == "Ship":
        FILE_ID = "14ZeJ86ETQr8dtu6CFoxU-ifvniHKo_Dt"
        MODEL_DIR = "models"
        MODEL_NAME = "diet_nerf_ship"
    elif obj_select == "Hotdog":
        FILE_ID = "11vNlR4lMvV_AVFgVjZmKMrMWGVG7qhNu"
        MODEL_DIR = "models"
        MODEL_NAME = "diet_nerf_hotdog"
    return MODEL_DIR, MODEL_NAME, FILE_ID


st.title("DietNeRF")
caption = (
    "Diet-NeRF achieves SoTA few-shot learning capacity in 3D model reconstruction. "
    "Thanks to the 2D supervision by CLIP (aka semantic loss), "
    "it can render novel and challenging views with ONLY 8 training images, "
    "outperforming original NeRF!"
)
st.markdown(caption)
st.markdown("")
MODEL_DIR, MODEL_NAME, FILE_ID = select_model()


@st.cache
def download_model():
    os.makedirs(MODEL_DIR, exist_ok=True)
    _model_path = os.path.join(MODEL_DIR, MODEL_NAME)
    # gdd.download_file_from_google_drive(file_id=FILE_ID,
    #                                     dest_path=_model_path,
    #                                     unzip=True)
    url = f"https://drive.google.com/uc?id={FILE_ID}"
    gdown.download(url, _model_path, quiet=False)
    print(f"Model downloaded from google drive: {_model_path}")


@st.cache(show_spinner=False, allow_output_mutation=True)
def fetch_model():
    model, state = load_trained_model(MODEL_DIR, MODEL_NAME)
    return model, state


model_path = os.path.join(MODEL_DIR, MODEL_NAME)
if not os.path.isfile(model_path):
    download_model()

model, state = fetch_model()
pi = math.pi
st.sidebar.image("images/diet-nerf-logo.png", width=310)
st.sidebar.markdown(
    "<a href="https://github.com/codestella/putting-nerf-on-a-diet" target="_blank">GitHub</a> | <a href="https://www.notion.so/DietNeRF-Putting-NeRF-on-a-Diet-4aeddae95d054f1d91686f02bdb74745" target="_blank">Project Report</a>",
    unsafe_allow_html=True,
)
st.sidebar.header("SELECT YOUR VIEW DIRECTION")
theta = st.sidebar.slider(
    "Theta", min_value=-pi, max_value=pi, step=0.5, value=0.0, help="Rotational angle in Horizontal direction"
)
phi = st.sidebar.slider(
    "Phi", min_value=0.0, max_value=0.5 * pi, step=0.1, value=1.0, help="Rotational angle in Vertical direction"
)
radius = st.sidebar.slider(
    "Radius", min_value=2.0, max_value=6.0, step=1.0, value=3.0, help="Distance between object and the viewer"
)

st.markdown("")

with st.spinner("Rendering Image, it may take 2-3 mins. So, why don't you read our report in the meantime"):
    pred_color, _ = render_predict_from_pose(state, theta, phi, radius)
    im = predict_to_image(pred_color)
    w, _ = im.size
    new_w = int(2 * w)
    im = im.resize(size=(new_w, new_w))
    st.image(im, use_column_width=True)