asigalov61 commited on
Commit
6924978
1 Parent(s): 2a7e8fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -190,7 +190,7 @@ if __name__ == "__main__":
190
  model = TransformerWrapper(
191
  num_tokens=3088,
192
  max_seq_len=SEQ_LEN,
193
- attn_layers=Decoder(dim=1024, depth=32, heads=8)
194
  )
195
 
196
  model = AutoregressiveWrapper(model)
@@ -203,7 +203,7 @@ if __name__ == "__main__":
203
  print('Loading model checkpoint...')
204
 
205
  model.load_state_dict(
206
- torch.load('Allegro_Music_Transformer_Small_Trained_Model_56000_steps_0.9399_loss_0.7374_acc.pth',
207
  map_location='cpu'))
208
  print('=' * 70)
209
 
 
190
  model = TransformerWrapper(
191
  num_tokens=3088,
192
  max_seq_len=SEQ_LEN,
193
+ attn_layers=Decoder(dim=1024, depth=16, heads=8)
194
  )
195
 
196
  model = AutoregressiveWrapper(model)
 
203
  print('Loading model checkpoint...')
204
 
205
  model.load_state_dict(
206
+ torch.load('Allegro_Music_Transformer_Tiny_Trained_Model_80000_steps_0.9457_loss_0.7443_acc.pth',
207
  map_location='cpu'))
208
  print('=' * 70)
209