NeuralODE_SDE / README.md
weierstrass-walker's picture
Update README.md
076113f
|
raw
history blame
2.91 kB
# Neural ODE with Flax
This is the result of project ["Reproduce Neural ODE and SDE"][projectlink] in [HuggingFace Flax/JAX community week][comweeklink].
<code>main.py</code> will execute training of ResNet or OdeNet for MNIST dataset.
[projectlink]: https://discuss.huggingface.co/t/reproduce-neural-ode-and-neural-sde/7590
[comweeklink]: https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#projects
## Dependency
### JAX and Flax
For JAX installation, please follow [here][jaxinstalllink].
or simply, type
```bash
pip install jax jaxlib
```
For Flax installation,
```bash
pip install flax
```
[jaxinstalllink]: https://github.com/google/jax#installation
Tensorflow-datasets will download MNIST dataset to environment.
## How to run training
For (small) ResNet training,
```bash
python main.py --model=resnet --lr=1e-4 --n_epoch=20 --batch_size=64
```
For Neural ODE training,
```bash
python main.py --model=odenet --lr=1e-4 --n_epoch=20 --batch_size=64
```
For Continuous Normalizing Flow,
```bash
python main.py --model=cnf --sample_dataset=circles
```
Sample datasets can be chosen as circles, moons, or scurve.
# Sample Results
![cnf-viz](https://user-images.githubusercontent.com/72425253/126124351-44e00438-055e-4b1c-90ee-758a545dd602.gif)
![cnf-viz](https://user-images.githubusercontent.com/72425253/126124648-dcb3f8f4-396a-447c-96cf-f9304377fa48.gif)
![cnf-viz](https://user-images.githubusercontent.com/72425253/126127269-4c02ee6a-a9a3-4b9f-b380-f8669f58872b.gif)
# Bird Call generation Score SDE
These are the codes for the bird call generation score sde model.
<code>core-sde-sampler.py</code> will execute the sampler. The sampler uses pretrained weight to generate bird calls. The weight can be found [here](https://github.com/mandelbrot-walker/Birdcall-score-sde/blob/main/ckpt.flax)
For using different sample generation parameters change the argument values. For example,
```bash
python main.py --sigma=25 --num_steps=500 --signal_to_noise_ratio=0.10 --etol=1e-5 --sample_batch_size = 128 --sample_no = 47
```
In order to generate the audios, these dependencies are required,
```bash
pip install librosa
pip install soundfile
```
In order to train the model from scratch, please generate the dataset using this [link](www.kaggle.com/ibraheemmoosa/birdsong-spectogram-generation). The dataset is generated in kaggle. Therefore, during training your username and api key is required in the specified section inside the script.
```bash
python main.py --sigma=35 --n_epochs=1000 --batch_size=512 --lr=1e-3 --num_steps=500 --signal_to_noise_ratio=0.15 --etol=1e-5 --sample_batch_size = 64 --sample_no = 23
```
Generated samples can be found [here](https://github.com/mandelbrot-walker/Birdcall-score-sde/tree/main/generated_samples)
and [here](https://colab.research.google.com/drive/1AbF4aIMkSfNs-G__MXzqY7JSrz6qvLYN)