# Neural ODE with Flax This is the result of project ["Reproduce Neural ODE and SDE"][projectlink] in [HuggingFace Flax/JAX community week][comweeklink]. main.py will execute training of ResNet or OdeNet for MNIST dataset. [projectlink]: https://discuss.huggingface.co/t/reproduce-neural-ode-and-neural-sde/7590 [comweeklink]: https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#projects ## Dependency ### JAX and Flax For JAX installation, please follow [here][jaxinstalllink]. or simply, type ```bash pip install jax jaxlib ``` For Flax installation, ```bash pip install flax ``` [jaxinstalllink]: https://github.com/google/jax#installation Tensorflow-datasets will download MNIST dataset to environment. ## How to run training For (small) ResNet training, ```bash python main.py --model=resnet --lr=1e-4 --n_epoch=20 --batch_size=64 ``` For Neural ODE training, ```bash python main.py --model=odenet --lr=1e-4 --n_epoch=20 --batch_size=64 ``` For Continuous Normalizing Flow, ```bash python main.py --model=cnf --sample_dataset=circles ``` Sample datasets can be chosen as circles, moons, or scurve. # Sample Results ![cnf-viz](https://user-images.githubusercontent.com/72425253/126124351-44e00438-055e-4b1c-90ee-758a545dd602.gif) ![cnf-viz](https://user-images.githubusercontent.com/72425253/126124648-dcb3f8f4-396a-447c-96cf-f9304377fa48.gif) ![cnf-viz](https://user-images.githubusercontent.com/72425253/126127269-4c02ee6a-a9a3-4b9f-b380-f8669f58872b.gif) # Bird Call generation Score SDE These are the codes for the bird call generation score sde model. core-sde-sampler.py will execute the sampler. The sampler uses pretrained weight to generate bird calls. The weight can be found [here](https://github.com/mandelbrot-walker/Birdcall-score-sde/blob/main/ckpt.flax) For using different sample generation parameters change the argument values. For example, ```bash python main.py --sigma=25 --num_steps=500 --signal_to_noise_ratio=0.10 --etol=1e-5 --sample_batch_size = 128 --sample_no = 47 ``` In order to generate the audios, these dependencies are required, ```bash pip install librosa pip install soundfile ``` In order to train the model from scratch, please generate the dataset using this [link](www.kaggle.com/ibraheemmoosa/birdsong-spectogram-generation). The dataset is generated in kaggle. Therefore, during training your username and api key is required in the specified section inside the script. ```bash python main.py --sigma=35 --n_epochs=1000 --batch_size=512 --lr=1e-3 --num_steps=500 --signal_to_noise_ratio=0.15 --etol=1e-5 --sample_batch_size = 64 --sample_no = 23 ``` Generated samples can be found [here](https://github.com/mandelbrot-walker/Birdcall-score-sde/tree/main/generated_samples) and [here](https://colab.research.google.com/drive/1AbF4aIMkSfNs-G__MXzqY7JSrz6qvLYN)