Epsilon617 commited on
Commit
fc2d9fe
1 Parent(s): 283e8f1
Prediction_Head/__MACOSX/._best-layer-MERT-v1-95M DELETED
Binary file (220 Bytes)
 
Prediction_Head/__pycache__/MTGGenre_head.cpython-310.pyc DELETED
Binary file (1.67 kB)
 
Prediction_Head/best-layer-MERT-v1-95M.zip DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c8155db897d77d6896ba5d87e2af5cc335a3fb1dd300356185848982deacad4d
3
- size 17025915
 
 
 
 
Prediction_Head/best_MTGGenre.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:83b7dcffde10a0dc7ba74341ea56dabec5c5de7cad6a0483708c80f1d893514a
3
- size 1759067
 
 
 
 
app.py CHANGED
@@ -92,14 +92,14 @@ ID2CLASS = {
92
  }
93
 
94
  TASKS = ['EMO','GS', 'MTGInstrument', 'MTGGenre', 'MTGTop50', 'MTGMood', 'NSynthI', 'NSynthP', 'VocalSetS', 'VocalSetT']
95
- head_dir = '/home/chenghua/nanshen/Yizhi/MERT_Universal/Prediction_Head/best-layer-MERT-v1-95M'
96
  for task in TASKS:
97
  print('loading', task)
98
  with open(os.path.join(head_dir,f'{task}.id2class.json'), 'r') as f:
99
  ID2CLASS[task]=json.load(f)
100
  num_class = len(ID2CLASS[task].keys())
101
  CLASSIFIERS[task] = MLPProberBase(d=768, layer=MERT_BEST_LAYER_IDX[task], num_outputs=num_class)
102
- CLASSIFIERS[task].load_state_dict(torch.load(f'/home/chenghua/nanshen/Yizhi/MERT_Universal/Prediction_Head/best-layer-MERT-v1-95M/{task}.ckpt')['state_dict'])
103
  CLASSIFIERS[task].to(device)
104
 
105
  model.to(device)
 
92
  }
93
 
94
  TASKS = ['EMO','GS', 'MTGInstrument', 'MTGGenre', 'MTGTop50', 'MTGMood', 'NSynthI', 'NSynthP', 'VocalSetS', 'VocalSetT']
95
+ head_dir = './Prediction_Head/best-layer-MERT-v1-95M'
96
  for task in TASKS:
97
  print('loading', task)
98
  with open(os.path.join(head_dir,f'{task}.id2class.json'), 'r') as f:
99
  ID2CLASS[task]=json.load(f)
100
  num_class = len(ID2CLASS[task].keys())
101
  CLASSIFIERS[task] = MLPProberBase(d=768, layer=MERT_BEST_LAYER_IDX[task], num_outputs=num_class)
102
+ CLASSIFIERS[task].load_state_dict(torch.load(f'{head_dir}/{task}.ckpt')['state_dict'])
103
  CLASSIFIERS[task].to(device)
104
 
105
  model.to(device)