Spaces:
Build error
Build error
resolve failed to read downloaded model from gdrive
Browse files- app.py +19 -8
- demo/src/config.py +2 -2
- demo/src/models.py +1 -1
- demo/src/utils.py +1 -0
- requirements.txt +1 -1
app.py
CHANGED
@@ -1,19 +1,27 @@
|
|
1 |
import os
|
|
|
2 |
import math
|
3 |
import streamlit as st
|
4 |
-
|
|
|
5 |
|
6 |
from demo.src.models import load_trained_model
|
7 |
from demo.src.utils import render_predict_from_pose, predict_to_image
|
8 |
from demo.src.config import MODEL_DIR, MODEL_NAME, FILE_ID
|
9 |
|
|
|
|
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
@st.cache(show_spinner=False, allow_output_mutation=True)
|
@@ -22,9 +30,12 @@ def fetch_model():
|
|
22 |
return model, state
|
23 |
|
24 |
|
|
|
|
|
|
|
|
|
25 |
model, state = fetch_model()
|
26 |
pi = math.pi
|
27 |
-
st.set_page_config(page_title="DietNeRF Demo")
|
28 |
st.sidebar.header('SELECT YOUR VIEW DIRECTION')
|
29 |
theta = st.sidebar.slider("Theta", min_value=0., max_value=2.*pi,
|
30 |
step=0.5, value=0.)
|
|
|
1 |
import os
|
2 |
+
import builtins
|
3 |
import math
|
4 |
import streamlit as st
|
5 |
+
import gdown
|
6 |
+
#from google_drive_downloader import GoogleDriveDownloader as gdd
|
7 |
|
8 |
from demo.src.models import load_trained_model
|
9 |
from demo.src.utils import render_predict_from_pose, predict_to_image
|
10 |
from demo.src.config import MODEL_DIR, MODEL_NAME, FILE_ID
|
11 |
|
12 |
+
st.set_page_config(page_title="DietNeRF Demo")
|
13 |
+
|
14 |
|
15 |
+
@st.cache
|
16 |
+
def download_model():
|
17 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
18 |
+
_model_path = os.path.join(MODEL_DIR, MODEL_NAME)
|
19 |
+
# gdd.download_file_from_google_drive(file_id=FILE_ID,
|
20 |
+
# dest_path=_model_path,
|
21 |
+
# unzip=True)
|
22 |
+
url = f'https://drive.google.com/uc?id={FILE_ID}'
|
23 |
+
gdown.download(url, _model_path, quiet=False)
|
24 |
+
print(f'model downloaded from google drive: {_model_path}')
|
25 |
|
26 |
|
27 |
@st.cache(show_spinner=False, allow_output_mutation=True)
|
|
|
30 |
return model, state
|
31 |
|
32 |
|
33 |
+
model_path = os.path.join(MODEL_DIR, MODEL_NAME)
|
34 |
+
if not os.path.isfile(model_path):
|
35 |
+
download_model()
|
36 |
+
|
37 |
model, state = fetch_model()
|
38 |
pi = math.pi
|
|
|
39 |
st.sidebar.header('SELECT YOUR VIEW DIRECTION')
|
40 |
theta = st.sidebar.slider("Theta", min_value=0., max_value=2.*pi,
|
41 |
step=0.5, value=0.)
|
demo/src/config.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
# for downloading model from google drive
|
2 |
-
FILE_ID = "
|
3 |
-
MODEL_DIR = '
|
4 |
MODEL_NAME = 'trained_model'
|
5 |
|
6 |
|
|
|
1 |
# for downloading model from google drive
|
2 |
+
FILE_ID = "1msgvx_jiI-Fr5BrB-otyCwRMSBcUyu-h"
|
3 |
+
MODEL_DIR = 'models'
|
4 |
MODEL_NAME = 'trained_model'
|
5 |
|
6 |
|
demo/src/models.py
CHANGED
@@ -5,7 +5,7 @@ from flax.training import checkpoints
|
|
5 |
|
6 |
from jaxnerf.nerf import models
|
7 |
from jaxnerf.nerf import utils
|
8 |
-
from demo.src.config import NerfConfig
|
9 |
|
10 |
rng = random.PRNGKey(0)
|
11 |
# TODO @Alex: make image size flexible if needed
|
|
|
5 |
|
6 |
from jaxnerf.nerf import models
|
7 |
from jaxnerf.nerf import utils
|
8 |
+
from demo.src.config import NerfConfig, MODEL_DIR, MODEL_NAME, FILE_ID
|
9 |
|
10 |
rng = random.PRNGKey(0)
|
11 |
# TODO @Alex: make image size flexible if needed
|
demo/src/utils.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from functools import partial
|
2 |
import jax
|
3 |
from jax import random
|
|
|
1 |
+
import os
|
2 |
from functools import partial
|
3 |
import jax
|
4 |
from jax import random
|
requirements.txt
CHANGED
@@ -11,4 +11,4 @@ tensorflow-hub>=0.11.0
|
|
11 |
transformers==4.8.2
|
12 |
tqdm==4.61.2
|
13 |
streamlit==0.84.1
|
14 |
-
|
|
|
11 |
transformers==4.8.2
|
12 |
tqdm==4.61.2
|
13 |
streamlit==0.84.1
|
14 |
+
pygdrive3==0.8.1
|