Simon Stolarczyk commited on
Commit
c90d42e
1 Parent(s): 2dcf4b6

Use hub for model.

Browse files
Files changed (2) hide show
  1. .ipynb_checkpoints/app-checkpoint.py +10 -2
  2. app.py +10 -2
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -22,15 +22,23 @@ print(os.getcwd())
22
  data_dir = Path('.')
23
  data = load_data(data_dir, 'data.pkl')
24
 
 
 
 
 
 
 
25
  # Default config options
26
  config = default_config()
27
  config['encode_position'] = True
28
 
 
 
29
  # Load our fine-tuned model
30
  learner = music_model_learner(
31
  data,
32
- config=config.copy(),
33
- pretrained_path='model.pth'
34
  )
35
 
36
 
 
22
  data_dir = Path('.')
23
  data = load_data(data_dir, 'data.pkl')
24
 
25
+ from huggingface_hub import hf_hub_download
26
+
27
+ model_cache_path = hf_hub_download(repo_id="psistolar/musicautobot-fine1", filename="model.pth")
28
+
29
+
30
+
31
  # Default config options
32
  config = default_config()
33
  config['encode_position'] = True
34
 
35
+
36
+
37
  # Load our fine-tuned model
38
  learner = music_model_learner(
39
  data,
40
+ config=config.copy(),
41
+ pretrained_path=model_cache_path
42
  )
43
 
44
 
app.py CHANGED
@@ -22,15 +22,23 @@ print(os.getcwd())
22
  data_dir = Path('.')
23
  data = load_data(data_dir, 'data.pkl')
24
 
 
 
 
 
 
 
25
  # Default config options
26
  config = default_config()
27
  config['encode_position'] = True
28
 
 
 
29
  # Load our fine-tuned model
30
  learner = music_model_learner(
31
  data,
32
- config=config.copy(),
33
- pretrained_path='model.pth'
34
  )
35
 
36
 
 
22
  data_dir = Path('.')
23
  data = load_data(data_dir, 'data.pkl')
24
 
25
+ from huggingface_hub import hf_hub_download
26
+
27
+ model_cache_path = hf_hub_download(repo_id="psistolar/musicautobot-fine1", filename="model.pth")
28
+
29
+
30
+
31
  # Default config options
32
  config = default_config()
33
  config['encode_position'] = True
34
 
35
+
36
+
37
  # Load our fine-tuned model
38
  learner = music_model_learner(
39
  data,
40
+ config=config.copy(),
41
+ pretrained_path=model_cache_path
42
  )
43
 
44