Neural ODE with Flax
This is the result of project "Reproduce Neural ODE and SDE" in HuggingFace Flax/JAX community week.
main.py
will execute training of ResNet or OdeNet for MNIST dataset.
Dependency
JAX and Flax
For JAX installation, please follow here.
or simply, type
pip install jax jaxlib
For Flax installation,
pip install flax
Tensorflow-datasets will download MNIST dataset to environment.
How to run training
For (small) ResNet training,
python main.py --model=resnet --lr=1e-4 --n_epoch=20 --batch_size=64
For Neural ODE training,
python main.py --model=odenet --lr=1e-4 --n_epoch=20 --batch_size=64
For Continuous Normalizing Flow,
python main.py --model=cnf --sample_dataset=circles
Sample datasets can be chosen as circles, moons, or scurve.
Sample Results
Bird Call generation Score SDE
These are the codes for the bird call generation score sde model.
core-sde-sampler.py
will execute the sampler. The sampler uses pretrained weight to generate bird calls. The weight can be found here
For using different sample generation parameters change the argument values. For example,
python main.py --sigma=25 --num_steps=500 --signal_to_noise_ratio=0.10 --etol=1e-5 --sample_batch_size = 128 --sample_no = 47
In order to generate the audios, these dependencies are required,
pip install librosa
pip install soundfile
In order to train the model from scratch, please generate the dataset using this link. The dataset is generated in kaggle. Therefore, during training your username and api key is required in the specified section inside the script.
python main.py --sigma=35 --n_epochs=1000 --batch_size=512 --lr=1e-3 --num_steps=500 --signal_to_noise_ratio=0.15 --etol=1e-5 --sample_batch_size = 64 --sample_no = 23
- Downloads last month
- 0