MarcusSu1216 commited on
Commit
48f19ef
1 Parent(s): e2ce104

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +2 -7
train.py CHANGED
@@ -3,8 +3,6 @@ import multiprocessing
3
  import time
4
 
5
  logging.getLogger('matplotlib').setLevel(logging.WARNING)
6
- logging.getLogger('numba').setLevel(logging.WARNING)
7
-
8
  import os
9
  import json
10
  import argparse
@@ -67,15 +65,12 @@ def run(rank, n_gpus, hps):
67
  torch.manual_seed(hps.train.seed)
68
  torch.cuda.set_device(rank)
69
  collate_fn = TextAudioCollate()
70
- all_in_mem = hps.train.all_in_mem # If you have enough memory, turn on this option to avoid disk IO and speed up training.
71
- train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps, all_in_mem=all_in_mem)
72
  num_workers = 5 if multiprocessing.cpu_count() > 4 else multiprocessing.cpu_count()
73
- if all_in_mem:
74
- num_workers = 0
75
  train_loader = DataLoader(train_dataset, num_workers=num_workers, shuffle=False, pin_memory=True,
76
  batch_size=hps.train.batch_size, collate_fn=collate_fn)
77
  if rank == 0:
78
- eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps, all_in_mem=all_in_mem)
79
  eval_loader = DataLoader(eval_dataset, num_workers=1, shuffle=False,
80
  batch_size=1, pin_memory=False,
81
  drop_last=False, collate_fn=collate_fn)
 
3
  import time
4
 
5
  logging.getLogger('matplotlib').setLevel(logging.WARNING)
 
 
6
  import os
7
  import json
8
  import argparse
 
65
  torch.manual_seed(hps.train.seed)
66
  torch.cuda.set_device(rank)
67
  collate_fn = TextAudioCollate()
68
+ train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps)
 
69
  num_workers = 5 if multiprocessing.cpu_count() > 4 else multiprocessing.cpu_count()
 
 
70
  train_loader = DataLoader(train_dataset, num_workers=num_workers, shuffle=False, pin_memory=True,
71
  batch_size=hps.train.batch_size, collate_fn=collate_fn)
72
  if rank == 0:
73
+ eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps)
74
  eval_loader = DataLoader(eval_dataset, num_workers=1, shuffle=False,
75
  batch_size=1, pin_memory=False,
76
  drop_last=False, collate_fn=collate_fn)