Spaces:
Runtime error
Runtime error
Remove test
Browse files
train.py
CHANGED
@@ -117,9 +117,7 @@ if __name__ == "__main__":
|
|
117 |
)
|
118 |
|
119 |
train_dataset = VoiceDataset(TRAIN_FILE, mel_spectrogram, device)
|
120 |
-
test_dataset = VoiceDataset(TEST_FILE, mel_spectrogram, device)
|
121 |
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
122 |
-
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
123 |
|
124 |
# construct model
|
125 |
model = CNNetwork().to(device)
|
@@ -131,7 +129,7 @@ if __name__ == "__main__":
|
|
131 |
|
132 |
|
133 |
# train model
|
134 |
-
train(model, train_dataloader, loss_fn, optimizer, device, EPOCHS
|
135 |
|
136 |
# save model
|
137 |
now = datetime.now()
|
|
|
117 |
)
|
118 |
|
119 |
train_dataset = VoiceDataset(TRAIN_FILE, mel_spectrogram, device)
|
|
|
120 |
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
|
121 |
|
122 |
# construct model
|
123 |
model = CNNetwork().to(device)
|
|
|
129 |
|
130 |
|
131 |
# train model
|
132 |
+
train(model, train_dataloader, loss_fn, optimizer, device, EPOCHS)
|
133 |
|
134 |
# save model
|
135 |
now = datetime.now()
|