File size: 2,906 Bytes
5775f48
d36098a
5775f48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0db4000
 
 
 
5966102
0db4000
 
 
2e90bae
 
5966102
23deebc
 
 
3672ac1
23deebc
3010404
23deebc
955a729
23deebc
 
 
 
 
 
 
 
 
955a729
23deebc
 
 
3da3e94
076113f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# Neural ODE with Flax
This is the result of project ["Reproduce Neural ODE and SDE"][projectlink] in [HuggingFace Flax/JAX community week][comweeklink].

<code>main.py</code> will execute training of ResNet or OdeNet for MNIST dataset.

[projectlink]: https://discuss.huggingface.co/t/reproduce-neural-ode-and-neural-sde/7590

[comweeklink]: https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#projects

## Dependency

### JAX and Flax

For JAX installation, please follow [here][jaxinstalllink].

or simply, type
```bash 
pip install jax jaxlib
```

For Flax installation,
```bash
pip install flax
```

[jaxinstalllink]: https://github.com/google/jax#installation


Tensorflow-datasets will download MNIST dataset to environment.

## How to run training

For (small) ResNet training,
```bash
python main.py --model=resnet --lr=1e-4 --n_epoch=20 --batch_size=64 
```

For Neural ODE training, 
```bash
python main.py --model=odenet --lr=1e-4 --n_epoch=20 --batch_size=64
```

For Continuous Normalizing Flow,
```bash
python main.py --model=cnf --sample_dataset=circles
```
Sample datasets can be chosen as circles, moons, or scurve.

# Sample Results

![cnf-viz](https://user-images.githubusercontent.com/72425253/126124351-44e00438-055e-4b1c-90ee-758a545dd602.gif)
![cnf-viz](https://user-images.githubusercontent.com/72425253/126124648-dcb3f8f4-396a-447c-96cf-f9304377fa48.gif)
![cnf-viz](https://user-images.githubusercontent.com/72425253/126127269-4c02ee6a-a9a3-4b9f-b380-f8669f58872b.gif)

# Bird Call generation Score SDE 

These are the codes for the bird call generation score sde model. 

<code>core-sde-sampler.py</code> will execute the sampler. The sampler uses pretrained weight to generate bird calls. The weight can be found [here](https://github.com/mandelbrot-walker/Birdcall-score-sde/blob/main/ckpt.flax)

For using different sample generation parameters change the argument values. For example,
```bash
python main.py --sigma=25 --num_steps=500 --signal_to_noise_ratio=0.10 --etol=1e-5 --sample_batch_size = 128 --sample_no = 47
``` 
In order to generate the audios, these dependencies are required,

```bash 
pip install librosa
pip install soundfile
```
In order to train the model from scratch, please generate the dataset using this [link](www.kaggle.com/ibraheemmoosa/birdsong-spectogram-generation). The dataset is generated in kaggle. Therefore, during training your username and api key is required in the specified section inside the script. 
```bash
python main.py --sigma=35 --n_epochs=1000 --batch_size=512 --lr=1e-3 --num_steps=500 --signal_to_noise_ratio=0.15 --etol=1e-5 --sample_batch_size = 64 --sample_no = 23
``` 
Generated samples can be found [here](https://github.com/mandelbrot-walker/Birdcall-score-sde/tree/main/generated_samples)
and [here](https://colab.research.google.com/drive/1AbF4aIMkSfNs-G__MXzqY7JSrz6qvLYN)