ccolas commited on
Commit
2300540
1 Parent(s): e775f6d

integrate model download

Browse files
app.py CHANGED
@@ -18,7 +18,7 @@ from PIL import Image
18
 
19
 
20
  st.set_page_config(
21
- page_title="TastyPiano",
22
  page_icon="🎹",
23
  )
24
 
 
18
 
19
 
20
  st.set_page_config(
21
+ page_title="Test",
22
  page_icon="🎹",
23
  )
24
 
src/music2cocktailrep/training/latent_translation/setup_trained_model.py CHANGED
@@ -5,11 +5,22 @@ from src.music2cocktailrep.training.latent_translation.vae_model import get_gml_
5
  from src.music.config import TRANSLATION_VAE_CHKP_PATH
6
  from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys
7
  import os
 
 
8
 
9
-
10
  rep_keys = get_bunch_of_rep_keys()['custom']
11
 
12
  def setup_trained_model(checkpoint_path=TRANSLATION_VAE_CHKP_PATH):
 
 
 
 
 
 
 
 
 
13
  with open(checkpoint_path + 'params.json', 'r') as f:
14
  params = json.load(f)
15
 
@@ -28,7 +39,6 @@ def setup_trained_model(checkpoint_path=TRANSLATION_VAE_CHKP_PATH):
28
  def normalize_music_input(input):
29
  return (input - stats_music[0]) / stats_music[1]
30
 
31
- model_path = checkpoint_path + 'checkpoints_best_eval.save'
32
  model.load_state_dict(torch.load(model_path))
33
  model.eval()
34
 
 
5
  from src.music.config import TRANSLATION_VAE_CHKP_PATH
6
  from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys
7
  import os
8
+ from huggingface_hub import hf_hub_download
9
+ from shutil import copy
10
 
11
+ TOKEN = os.environ['secret']
12
  rep_keys = get_bunch_of_rep_keys()['custom']
13
 
14
  def setup_trained_model(checkpoint_path=TRANSLATION_VAE_CHKP_PATH):
15
+ # download translation model
16
+ repo_id = "ccolas/translation_vae"
17
+ filename = "checkpoints_best_eval_old.save"
18
+ downloaded_path = hf_hub_download(repo_id=repo_id,
19
+ filename=filename,
20
+ repo_type='model',
21
+ use_auth_token=TOKEN)
22
+ model_path = checkpoint_path + 'checkpoints_best_eval.save'
23
+ copy(downloaded_path, model_path)
24
  with open(checkpoint_path + 'params.json', 'r') as f:
25
  params = json.load(f)
26
 
 
39
  def normalize_music_input(input):
40
  return (input - stats_music[0]) / stats_music[1]
41
 
 
42
  model.load_state_dict(torch.load(model_path))
43
  model.eval()
44