osanseviero commited on
Commit
d00a1b3
1 Parent(s): 471c873
Files changed (3) hide show
  1. README.md +78 -0
  2. pipeline.py +40 -0
  3. pytorch_model.bin +3 -0
README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - asteroid
4
+ - audio
5
+ - ConvTasNet
6
+ - audio-to-audio
7
+ datasets:
8
+ - Libri1Mix
9
+ - enh_single
10
+ license: cc-by-sa-3.0
11
+ ---
12
+
13
+ ## Asteroid model `JorisCos/ConvTasNet_Libri1Mix_enhsignle_16k`
14
+
15
+ Description:
16
+
17
+ This model was trained by Joris Cosentino using the librimix recipe in [Asteroid](https://github.com/asteroid-team/asteroid).
18
+ It was trained on the `enh_single` task of the Libri1Mix dataset.
19
+
20
+ Training config:
21
+
22
+ ```yml
23
+ data:
24
+ n_src: 1
25
+ sample_rate: 16000
26
+ segment: 3
27
+ task: enh_single
28
+ train_dir: data/wav16k/min/train-360
29
+ valid_dir: data/wav16k/min/dev
30
+ filterbank:
31
+ kernel_size: 32
32
+ n_filters: 512
33
+ stride: 16
34
+ masknet:
35
+ bn_chan: 128
36
+ hid_chan: 512
37
+ mask_act: relu
38
+ n_blocks: 8
39
+ n_repeats: 3
40
+ n_src: 1
41
+ skip_chan: 128
42
+ optim:
43
+ lr: 0.001
44
+ optimizer: adam
45
+ weight_decay: 0.0
46
+ training:
47
+ batch_size: 6
48
+ early_stop: true
49
+ epochs: 200
50
+ half_lr: true
51
+ num_workers: 4
52
+
53
+ ```
54
+
55
+
56
+ Results:
57
+
58
+ On Libri1Mix min test set :
59
+ ```yml
60
+ si_sdr: 14.743051006476085
61
+ si_sdr_imp: 11.293269700616385
62
+ sdr: 15.300522933671061
63
+ sdr_imp: 11.797860134458015
64
+ sir: Infinity
65
+ sir_imp: NaN
66
+ sar: 15.300522933671061
67
+ sar_imp: 11.797860134458015
68
+ stoi: 0.9310514162434267
69
+ stoi_imp: 0.13513159270288563
70
+ ```
71
+
72
+
73
+ License notice:
74
+
75
+ This work "ConvTasNet_Libri1Mix_enhsignle_16k" is a derivative of [LibriSpeech ASR corpus](http://www.openslr.org/12) by Vassil Panayotov,
76
+ used under [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/); of The WSJ0 Hipster Ambient Mixtures
77
+ dataset by [Whisper.ai](http://wham.whisper.ai/), used under [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) (Research only).
78
+ "ConvTasNet_Libri1Mix_enhsignle_16k" is licensed under [Attribution-ShareAlike 3.0 Unported](https://creativecommons.org/licenses/by-sa/3.0/) by Joris Cosentino
pipeline.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Union
2
+ from PIL import Image
3
+
4
+ import os
5
+ import json
6
+ import numpy as np
7
+
8
+ from fastai.learner import load_learner
9
+
10
+
11
+ class PreTrainedPipeline():
12
+ def __init__(self, path=""):
13
+ # IMPLEMENT_THIS
14
+ # Preload all the elements you are going to need at inference.
15
+ # For instance your model, processors, tokenizer that might be needed.
16
+ # This function is only called once, so do all the heavy processing I/O here"""
17
+ self.model = BaseModel.from_pretrained("")
18
+ self.sampling_rate = self.model.sample_rate
19
+
20
+ def __call__(self, inputs: np.array) -> Tuple[np.array, int, List[str]]:
21
+ """
22
+ Args:
23
+ inputs (:obj:`np.array`):
24
+ The raw waveform of audio received. By default sampled at `self.sampling_rate`.
25
+ The shape of this array is `T`, where `T` is the time axis
26
+ Return:
27
+ A :obj:`tuple` containing:
28
+ - :obj:`np.array`:
29
+ The return shape of the array must be `C'`x`T'`
30
+ - a :obj:`int`: the sampling rate as an int in Hz.
31
+ - a :obj:`List[str]`: the annotation for each out channel.
32
+ This can be the name of the instruments for audio source separation
33
+ or some annotation for speech enhancement. The length must be `C'`.
34
+ """
35
+ separated = separate.numpy_separate(self.model, inputs.reshape((1, 1, -1)))
36
+ # FIXME: how to deal with multiple sources?
37
+ out = separated[0]
38
+ n = out.shape[0]
39
+ labels = [f"label_{i}" for i in range(n)]
40
+ return separated[0], int(self.model.sample_rate), labels
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd8ddefe95a35761f8a48643a618eba908572d04d33208a8ed5451fb5a4378d0
3
+ size 20130704