amanmibra commited on
Commit
3e88903
·
1 Parent(s): 44430bc

Remove test

Browse files
Files changed (1) hide show
  1. train.py +1 -3
train.py CHANGED
@@ -117,9 +117,7 @@ if __name__ == "__main__":
117
  )
118
 
119
  train_dataset = VoiceDataset(TRAIN_FILE, mel_spectrogram, device)
120
- test_dataset = VoiceDataset(TEST_FILE, mel_spectrogram, device)
121
  train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
122
- test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
123
 
124
  # construct model
125
  model = CNNetwork().to(device)
@@ -131,7 +129,7 @@ if __name__ == "__main__":
131
 
132
 
133
  # train model
134
- train(model, train_dataloader, loss_fn, optimizer, device, EPOCHS, test_dataloader=test_dataloader)
135
 
136
  # save model
137
  now = datetime.now()
 
117
  )
118
 
119
  train_dataset = VoiceDataset(TRAIN_FILE, mel_spectrogram, device)
 
120
  train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
 
121
 
122
  # construct model
123
  model = CNNetwork().to(device)
 
129
 
130
 
131
  # train model
132
+ train(model, train_dataloader, loss_fn, optimizer, device, EPOCHS)
133
 
134
  # save model
135
  now = datetime.now()