File size: 3,060 Bytes
19677a1
0d35ba8
19677a1
 
0d35ba8
 
19677a1
 
 
cacc6ce
19677a1
0d35ba8
 
bd597e9
fd0bc4a
bd597e9
 
 
 
 
 
 
 
 
 
 
 
fd0bc4a
 
 
 
bd597e9
 
57a50dc
19677a1
0d35ba8
 
 
 
 
 
 
 
 
 
19677a1
 
 
 
 
 
 
 
0d35ba8
 
 
 
19677a1
 
1adb71b
19677a1
b26d474
19677a1
 
 
 
 
 
591edcc
 
927629c
591edcc
 
 
 
8bc9869
 
1a30119
 
 
bd597e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import builtins
import math
import streamlit as st
import gdown
#from google_drive_downloader import GoogleDriveDownloader as gdd

from demo.src.models import load_trained_model
from demo.src.utils import render_predict_from_pose, predict_to_image
#from demo.src.config import MODEL_DIR, MODEL_NAME, FILE_ID

st.set_page_config(page_title="DietNeRF Demo")

def select_model():
    obj_select = st.selectbox("Select Object to Render", ('Chair', 'Lego','Ship','hog dog'))
    if obj_select == 'Chair':
        FILE_ID = "17dj0pQieo94TozFv-noSBkXebduij1aM"
        MODEL_DIR = 'models'
        MODEL_NAME = 'diet_nerf_chair'
    elif obj_select == 'Lego':
        FILE_ID = "1D9I-qIVMPaxuCHfUWPWMHaoLYtAmCjwI"
        MODEL_DIR = 'models'
        MODEL_NAME = 'diet_nerf_lego'
    elif obj_select == 'Ship':
        FILE_ID = "14ZeJ86ETQr8dtu6CFoxU-ifvniHKo_Dt"
        MODEL_DIR = 'models'
        MODEL_NAME = 'diet_nerf_ship'
    elif obj_select == 'hog dog':
        FILE_ID = "11vNlR4lMvV_AVFgVjZmKMrMWGVG7qhNu"
        MODEL_DIR = 'models'
        MODEL_NAME = 'diet_nerf_hotdog'
    return MODEL_DIR,MODEL_NAME,FILE_ID

MODEL_DIR,MODEL_NAME,FILE_ID = select_model()

@st.cache
def download_model():
    os.makedirs(MODEL_DIR, exist_ok=True)
    _model_path = os.path.join(MODEL_DIR, MODEL_NAME)
    # gdd.download_file_from_google_drive(file_id=FILE_ID,
    #                                     dest_path=_model_path,
    #                                     unzip=True)
    url = f'https://drive.google.com/uc?id={FILE_ID}'
    gdown.download(url, _model_path, quiet=False)
    print(f'model downloaded from google drive: {_model_path}')


@st.cache(show_spinner=False, allow_output_mutation=True)
def fetch_model():
    model, state = load_trained_model(MODEL_DIR, MODEL_NAME)
    return model, state


model_path = os.path.join(MODEL_DIR, MODEL_NAME)
if not os.path.isfile(model_path):
    download_model()

model, state = fetch_model()
pi = math.pi
st.sidebar.image("images/diet-nerf.png", width=310)
st.sidebar.header('SELECT YOUR VIEW DIRECTION')
theta = st.sidebar.slider("Theta", min_value=-pi, max_value=pi,
                          step=0.5, value=0.)
phi = st.sidebar.slider("Phi", min_value=0., max_value=0.5*pi,
                        step=0.1, value=1.)
radius = st.sidebar.slider("Radius", min_value=2., max_value=6.,
                           step=1., value=3.)

caption = "Diet-NeRF achieves SoTA few-shot learning capacity in 3D model reconstruction. " \
          "Thanks to the 2D supervision by CLIP (aka semantic loss), " \
          "it can render novel and challenging views with ONLY 8 training images, " \
          "outperforming original NeRF!"
st.markdown(f""" <h4> {caption} </h4> """,
            unsafe_allow_html=True)
with st.spinner("Rendering Image (may take 2-3 mins)..."):
    pred_color, _ = render_predict_from_pose(state, theta, phi, radius)
    im = predict_to_image(pred_color)
    w, _ = im.size
    new_w = int(2*w)
    im = im.resize(size=(new_w, new_w))
    st.image(im, use_column_width=True)