File size: 4,179 Bytes
19677a1
0d35ba8
19677a1
e4b8bbd
19677a1
0d35ba8
0071656
 
19677a1
 
 
0071656
 
19677a1
13bc063
0d35ba8
e4b8bbd
 
 
0071656
bd597e9
7a6388a
e4b8bbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0071656
bd597e9
19677a1
fd209d1
 
bb689c9
fd209d1
 
 
 
 
 
 
 
 
7a6388a
0d35ba8
 
 
 
 
 
0071656
0d35ba8
0071656
19677a1
 
 
 
 
 
 
 
0d35ba8
 
 
 
19677a1
 
e4b8bbd
6cbae78
 
 
 
 
 
 
 
7a6388a
6cbae78
 
 
 
5272de4
4f54252
c1f7cd5
 
 
4f54252
5272de4
 
0071656
 
 
 
 
 
 
 
 
 
 
13bc063
 
1b312d9
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import builtins
import math
import json
import streamlit as st
import gdown

# from google_drive_downloader import GoogleDriveDownloader as gdd

from demo.src.models import load_trained_model
from demo.src.utils import render_predict_from_pose, predict_to_image

# from demo.src.config import MODEL_DIR, MODEL_NAME, FILE_ID

st.set_page_config(page_title="DietNeRF")

with open("config.json") as f:
    cfg = json.loads(f.read())


def select_model():
    obj_select = st.selectbox("Select a scene", ("Mic", "Chair", "Lego", "Ship", "Hotdog"))
    # if obj_select == "Chair":
    #     FILE_ID = "17dj0pQieo94TozFv-noSBkXebduij1aM"
    #     MODEL_DIR = "models"
    #     MODEL_NAME = "diet_nerf_chair"
    # elif obj_select == "Lego":
    #     FILE_ID = "1D9I-qIVMPaxuCHfUWPWMHaoLYtAmCjwI"
    #     MODEL_DIR = "models"
    #     MODEL_NAME = "diet_nerf_lego"
    # elif obj_select == "Ship":
    #     FILE_ID = "14ZeJ86ETQr8dtu6CFoxU-ifvniHKo_Dt"
    #     MODEL_DIR = "models"
    #     MODEL_NAME = "diet_nerf_ship"
    # elif obj_select == "Hotdog":
    #     FILE_ID = "11vNlR4lMvV_AVFgVjZmKMrMWGVG7qhNu"
    #     MODEL_DIR = "models"
    #     MODEL_NAME = "diet_nerf_hotdog"
    MODEL_DIR = "models"
    MODEL_NAME = cfg[obj_select]["DIET_NERF_MODEL_NAME"]
    FILE_ID = cfg[obj_select]["FILE_ID"]
    return MODEL_DIR, MODEL_NAME, FILE_ID


st.title("DietNeRF")
caption = (
    "DietNeRF achieves SoTA few-shot learning capacity in 3D model reconstruction. "
    "Thanks to the 2D supervision by CLIP (aka semantic loss), "
    "it can render novel and challenging views with ONLY 8 training images, "
    "outperforming original NeRF!"
)
st.markdown(caption)
st.markdown("")
MODEL_DIR, MODEL_NAME, FILE_ID = select_model()


@st.cache(show_spinner=False)
def download_model():
    os.makedirs(MODEL_DIR, exist_ok=True)
    _model_path = os.path.join(MODEL_DIR, MODEL_NAME)
    # gdd.download_file_from_google_drive(file_id=FILE_ID,
    #                                     dest_path=_model_path,
    #                                     unzip=True)
    url = f"https://drive.google.com/uc?id={FILE_ID}"
    gdown.download(url, _model_path, quiet=False)
    print(f"Model downloaded from google drive: {_model_path}")


@st.cache(show_spinner=False, allow_output_mutation=True)
def fetch_model():
    model, state = load_trained_model(MODEL_DIR, MODEL_NAME)
    return model, state


model_path = os.path.join(MODEL_DIR, MODEL_NAME)
if not os.path.isfile(model_path):
    download_model()

model, state = fetch_model()
pi = math.pi

st.sidebar.markdown(
    """
<style>
.aligncenter {
    text-align: center;
}
</style>
<p class="aligncenter">
    <img src="https://user-images.githubusercontent.com/77657524/126361638-4aad58e8-4efb-4fc5-bf78-f53d03799e1e.png" width="420" height="400"/>
</p>
""",
    unsafe_allow_html=True,
)
st.sidebar.markdown(
    """
<p style='text-align: center'>
<a href="https://github.com/codestella/putting-nerf-on-a-diet" target="_blank">GitHub</a> | <a href="https://www.notion.so/DietNeRF-Putting-NeRF-on-a-Diet-4aeddae95d054f1d91686f02bdb74745" target="_blank">Project Report</a>
</p>
    """,
    unsafe_allow_html=True,
)
st.sidebar.header("SELECT YOUR VIEW DIRECTION")
theta = st.sidebar.slider(
    "Theta", min_value=-pi, max_value=pi, step=0.5, value=0.0, help="Rotational angle in Horizontal direction"
)
phi = st.sidebar.slider(
    "Phi", min_value=0.0, max_value=0.5 * pi, step=0.1, value=1.0, help="Rotational angle in Vertical direction"
)
radius = st.sidebar.slider(
    "Radius", min_value=2.0, max_value=6.0, step=1.0, value=3.0, help="Distance between object and the viewer"
)

st.markdown("")

with st.spinner("Rendering View..."):
    with st.spinner("It may take 2-3 mins. So, why don't you read our report in the meantime"):
        pred_color, _ = render_predict_from_pose(state, theta, phi, radius)
        im = predict_to_image(pred_color)
        w, _ = im.size
        new_w = int(2 * w)
        im = im.resize(size=(new_w, new_w))
        st.markdown(
            """<h4 style='text-align: center'>DietNerF</h4>""", unsafe_allow_html=True
        )
        st.image(im, use_column_width="auto")