Spaces:
Running
Running
Epsilon617
commited on
Commit
•
fc2d9fe
1
Parent(s):
283e8f1
fix path
Browse files
Prediction_Head/__MACOSX/._best-layer-MERT-v1-95M
DELETED
Binary file (220 Bytes)
|
|
Prediction_Head/__pycache__/MTGGenre_head.cpython-310.pyc
DELETED
Binary file (1.67 kB)
|
|
Prediction_Head/best-layer-MERT-v1-95M.zip
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:c8155db897d77d6896ba5d87e2af5cc335a3fb1dd300356185848982deacad4d
|
3 |
-
size 17025915
|
|
|
|
|
|
|
|
Prediction_Head/best_MTGGenre.ckpt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:83b7dcffde10a0dc7ba74341ea56dabec5c5de7cad6a0483708c80f1d893514a
|
3 |
-
size 1759067
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -92,14 +92,14 @@ ID2CLASS = {
|
|
92 |
}
|
93 |
|
94 |
TASKS = ['EMO','GS', 'MTGInstrument', 'MTGGenre', 'MTGTop50', 'MTGMood', 'NSynthI', 'NSynthP', 'VocalSetS', 'VocalSetT']
|
95 |
-
head_dir = '
|
96 |
for task in TASKS:
|
97 |
print('loading', task)
|
98 |
with open(os.path.join(head_dir,f'{task}.id2class.json'), 'r') as f:
|
99 |
ID2CLASS[task]=json.load(f)
|
100 |
num_class = len(ID2CLASS[task].keys())
|
101 |
CLASSIFIERS[task] = MLPProberBase(d=768, layer=MERT_BEST_LAYER_IDX[task], num_outputs=num_class)
|
102 |
-
CLASSIFIERS[task].load_state_dict(torch.load(f'/
|
103 |
CLASSIFIERS[task].to(device)
|
104 |
|
105 |
model.to(device)
|
|
|
92 |
}
|
93 |
|
94 |
TASKS = ['EMO','GS', 'MTGInstrument', 'MTGGenre', 'MTGTop50', 'MTGMood', 'NSynthI', 'NSynthP', 'VocalSetS', 'VocalSetT']
|
95 |
+
head_dir = './Prediction_Head/best-layer-MERT-v1-95M'
|
96 |
for task in TASKS:
|
97 |
print('loading', task)
|
98 |
with open(os.path.join(head_dir,f'{task}.id2class.json'), 'r') as f:
|
99 |
ID2CLASS[task]=json.load(f)
|
100 |
num_class = len(ID2CLASS[task].keys())
|
101 |
CLASSIFIERS[task] = MLPProberBase(d=768, layer=MERT_BEST_LAYER_IDX[task], num_outputs=num_class)
|
102 |
+
CLASSIFIERS[task].load_state_dict(torch.load(f'{head_dir}/{task}.ckpt')['state_dict'])
|
103 |
CLASSIFIERS[task].to(device)
|
104 |
|
105 |
model.to(device)
|