ccolas commited on
Commit
3558258
1 Parent(s): 73d8ad8

Update src/music2cocktailrep/training/latent_translation/setup_trained_model.py

Browse files
src/music2cocktailrep/training/latent_translation/setup_trained_model.py CHANGED
@@ -10,19 +10,19 @@ from shutil import copy
10
  import hashlib
11
 
12
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
- # TOKEN = os.environ['token']
14
  rep_keys = get_bunch_of_rep_keys()['custom']
15
 
16
  def setup_trained_model(checkpoint_path=TRANSLATION_VAE_CHKP_PATH):
17
  # download translation model
18
- # repo_id = "ccolas/translation_vae"
19
- # filename = "checkpoints_best_eval_old.save"
20
- # downloaded_path = hf_hub_download(repo_id=repo_id,
21
- # filename=filename,
22
- # repo_type='model',
23
- # use_auth_token=TOKEN)
24
  model_path = checkpoint_path + 'checkpoints_best_eval.save'
25
- # copy(downloaded_path, model_path)
26
  with open(checkpoint_path + 'params.json', 'r') as f:
27
  params = json.load(f)
28
 
 
10
  import hashlib
11
 
12
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+ TOKEN = os.environ['token']
14
  rep_keys = get_bunch_of_rep_keys()['custom']
15
 
16
  def setup_trained_model(checkpoint_path=TRANSLATION_VAE_CHKP_PATH):
17
  # download translation model
18
+ repo_id = "ccolas/translation_vae"
19
+ filename = "checkpoints_best_eval_old.save"
20
+ downloaded_path = hf_hub_download(repo_id=repo_id,
21
+ filename=filename,
22
+ repo_type='model',
23
+ use_auth_token=TOKEN)
24
  model_path = checkpoint_path + 'checkpoints_best_eval.save'
25
+ copy(downloaded_path, model_path)
26
  with open(checkpoint_path + 'params.json', 'r') as f:
27
  params = json.load(f)
28