alexlau commited on
Commit
0d35ba8
1 Parent(s): 7fb9c3e

resolve failed to read downloaded model from gdrive

Browse files
Files changed (5) hide show
  1. app.py +19 -8
  2. demo/src/config.py +2 -2
  3. demo/src/models.py +1 -1
  4. demo/src/utils.py +1 -0
  5. requirements.txt +1 -1
app.py CHANGED
@@ -1,19 +1,27 @@
1
  import os
 
2
  import math
3
  import streamlit as st
4
- from google_drive_downloader import GoogleDriveDownloader as gdd
 
5
 
6
  from demo.src.models import load_trained_model
7
  from demo.src.utils import render_predict_from_pose, predict_to_image
8
  from demo.src.config import MODEL_DIR, MODEL_NAME, FILE_ID
9
 
 
 
10
 
11
- if not os.path.isfile('models'):
12
- model_path = os.path.join(MODEL_DIR, MODEL_NAME)
13
- gdd.download_file_from_google_drive(file_id=FILE_ID,
14
- dest_path=model_path,
15
- unzip=False)
16
- print(f'model downloaded from google drive: {model_path}')
 
 
 
 
17
 
18
 
19
  @st.cache(show_spinner=False, allow_output_mutation=True)
@@ -22,9 +30,12 @@ def fetch_model():
22
  return model, state
23
 
24
 
 
 
 
 
25
  model, state = fetch_model()
26
  pi = math.pi
27
- st.set_page_config(page_title="DietNeRF Demo")
28
  st.sidebar.header('SELECT YOUR VIEW DIRECTION')
29
  theta = st.sidebar.slider("Theta", min_value=0., max_value=2.*pi,
30
  step=0.5, value=0.)
 
1
  import os
2
+ import builtins
3
  import math
4
  import streamlit as st
5
+ import gdown
6
+ #from google_drive_downloader import GoogleDriveDownloader as gdd
7
 
8
  from demo.src.models import load_trained_model
9
  from demo.src.utils import render_predict_from_pose, predict_to_image
10
  from demo.src.config import MODEL_DIR, MODEL_NAME, FILE_ID
11
 
12
+ st.set_page_config(page_title="DietNeRF Demo")
13
+
14
 
15
+ @st.cache
16
+ def download_model():
17
+ os.makedirs(MODEL_DIR, exist_ok=True)
18
+ _model_path = os.path.join(MODEL_DIR, MODEL_NAME)
19
+ # gdd.download_file_from_google_drive(file_id=FILE_ID,
20
+ # dest_path=_model_path,
21
+ # unzip=True)
22
+ url = f'https://drive.google.com/uc?id={FILE_ID}'
23
+ gdown.download(url, _model_path, quiet=False)
24
+ print(f'model downloaded from google drive: {_model_path}')
25
 
26
 
27
  @st.cache(show_spinner=False, allow_output_mutation=True)
 
30
  return model, state
31
 
32
 
33
+ model_path = os.path.join(MODEL_DIR, MODEL_NAME)
34
+ if not os.path.isfile(model_path):
35
+ download_model()
36
+
37
  model, state = fetch_model()
38
  pi = math.pi
 
39
  st.sidebar.header('SELECT YOUR VIEW DIRECTION')
40
  theta = st.sidebar.slider("Theta", min_value=0., max_value=2.*pi,
41
  step=0.5, value=0.)
demo/src/config.py CHANGED
@@ -1,6 +1,6 @@
1
  # for downloading model from google drive
2
- FILE_ID = "1iytA1n2z4go3uVCwE__vIKouTKyIDjEq"
3
- MODEL_DIR = './models'
4
  MODEL_NAME = 'trained_model'
5
 
6
 
 
1
  # for downloading model from google drive
2
+ FILE_ID = "1msgvx_jiI-Fr5BrB-otyCwRMSBcUyu-h"
3
+ MODEL_DIR = 'models'
4
  MODEL_NAME = 'trained_model'
5
 
6
 
demo/src/models.py CHANGED
@@ -5,7 +5,7 @@ from flax.training import checkpoints
5
 
6
  from jaxnerf.nerf import models
7
  from jaxnerf.nerf import utils
8
- from demo.src.config import NerfConfig
9
 
10
  rng = random.PRNGKey(0)
11
  # TODO @Alex: make image size flexible if needed
 
5
 
6
  from jaxnerf.nerf import models
7
  from jaxnerf.nerf import utils
8
+ from demo.src.config import NerfConfig, MODEL_DIR, MODEL_NAME, FILE_ID
9
 
10
  rng = random.PRNGKey(0)
11
  # TODO @Alex: make image size flexible if needed
demo/src/utils.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from functools import partial
2
  import jax
3
  from jax import random
 
1
+ import os
2
  from functools import partial
3
  import jax
4
  from jax import random
requirements.txt CHANGED
@@ -11,4 +11,4 @@ tensorflow-hub>=0.11.0
11
  transformers==4.8.2
12
  tqdm==4.61.2
13
  streamlit==0.84.1
14
- googledrivedownloader==0.4
 
11
  transformers==4.8.2
12
  tqdm==4.61.2
13
  streamlit==0.84.1
14
+ pygdrive3==0.8.1