teticio commited on
Commit
65fa65c
1 Parent(s): e97d748

update for sagemaker

Browse files
Files changed (2) hide show
  1. README.md +15 -12
  2. src/train_unconditional.py +5 -5
README.md CHANGED
@@ -1,19 +1,22 @@
1
  # audio-diffusion
2
  ```bash
 
 
 
3
  python src/audio_to_images.py \
4
- --resolution=256 \
5
- --input_dir=path-to-audio-files \
6
- --output_dir=data
7
  ```
8
  ```bash
9
  accelerate launch src/train_unconditional.py \
10
- --dataset_name="data" \
11
- --resolution=256 \
12
- --output_dir="ddpm-ema-audio-256" \
13
- --train_batch_size=16 \
14
- --num_epochs=100 \
15
- --gradient_accumulation_steps=1 \
16
- --learning_rate=1e-4 \
17
- --lr_warmup_steps=500 \
18
- --mixed_precision=no
19
  ```
 
1
  # audio-diffusion
2
  ```bash
3
+ accelerate config
4
+ ```
5
+ ```bash
6
  python src/audio_to_images.py \
7
+ --resolution 256 \
8
+ --input_dir path-to-audio-files \
9
+ --output_dir data-256
10
  ```
11
  ```bash
12
  accelerate launch src/train_unconditional.py \
13
+ --dataset_name data-256 \
14
+ --resolution 256 \
15
+ --output_dir ddpm-ema-audio-256 \
16
+ --train_batch_size 16 \
17
+ --num_epochs 100 \
18
+ --gradient_accumulation_steps 1 \
19
+ --learning_rate 1e-4 \
20
+ --lr_warmup_steps 500 \
21
+ --mixed_precision no
22
  ```
src/train_unconditional.py CHANGED
@@ -253,7 +253,7 @@ if __name__ == "__main__":
253
  help="A folder containing the training data.",
254
  )
255
  parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
256
- parser.add_argument("--overwrite_output_dir", action="store_true")
257
  parser.add_argument("--cache_dir", type=str, default=None)
258
  parser.add_argument("--resolution", type=int, default=64)
259
  parser.add_argument("--train_batch_size", type=int, default=16)
@@ -269,15 +269,15 @@ if __name__ == "__main__":
269
  parser.add_argument("--adam_beta2", type=float, default=0.999)
270
  parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
271
  parser.add_argument("--adam_epsilon", type=float, default=1e-08)
272
- parser.add_argument("--use_ema", action="store_true", default=True)
273
  parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
274
  parser.add_argument("--ema_power", type=float, default=3 / 4)
275
  parser.add_argument("--ema_max_decay", type=float, default=0.9999)
276
- parser.add_argument("--push_to_hub", action="store_true")
277
- parser.add_argument("--use_auth_token", action="store_true")
278
  parser.add_argument("--hub_token", type=str, default=None)
279
  parser.add_argument("--hub_model_id", type=str, default=None)
280
- parser.add_argument("--hub_private_repo", action="store_true")
281
  parser.add_argument("--logging_dir", type=str, default="logs")
282
  parser.add_argument(
283
  "--mixed_precision",
 
253
  help="A folder containing the training data.",
254
  )
255
  parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
256
+ parser.add_argument("--overwrite_output_dir", type=bool, default=False)
257
  parser.add_argument("--cache_dir", type=str, default=None)
258
  parser.add_argument("--resolution", type=int, default=64)
259
  parser.add_argument("--train_batch_size", type=int, default=16)
 
269
  parser.add_argument("--adam_beta2", type=float, default=0.999)
270
  parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
271
  parser.add_argument("--adam_epsilon", type=float, default=1e-08)
272
+ parser.add_argument("--use_ema", type=bool, default=True)
273
  parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
274
  parser.add_argument("--ema_power", type=float, default=3 / 4)
275
  parser.add_argument("--ema_max_decay", type=float, default=0.9999)
276
+ parser.add_argument("--push_to_hub", type=bool, default=False)
277
+ parser.add_argument("--use_auth_token", type=bool, default=False)
278
  parser.add_argument("--hub_token", type=str, default=None)
279
  parser.add_argument("--hub_model_id", type=str, default=None)
280
+ parser.add_argument("--hub_private_repo", type=bool, default=False)
281
  parser.add_argument("--logging_dir", type=str, default="logs")
282
  parser.add_argument(
283
  "--mixed_precision",