aheba31 commited on
Commit
d32c845
1 Parent(s): b3e63a9

add inference class

Browse files
Files changed (1) hide show
  1. inference.py +2 -118
inference.py CHANGED
@@ -1,120 +1,4 @@
1
  import torch
 
2
 
3
- def forward(self, wavs, wav_lens=None):
4
- """Runs the classification"""
5
- return self.classify_batch(wavs, wav_lens)
6
-
7
- def encode_batch(self, wavs, wav_lens=None, normalize=False):
8
- """Encodes the input audio into a single vector embedding.
9
-
10
- The waveforms should already be in the model's desired format.
11
- You can call:
12
- ``normalized = <this>.normalizer(signal, sample_rate)``
13
- to get a correctly converted signal in most cases.
14
-
15
- Arguments
16
- ---------
17
- wavs : torch.tensor
18
- Batch of waveforms [batch, time, channels] or [batch, time]
19
- depending on the model. Make sure the sample rate is fs=16000 Hz.
20
- wav_lens : torch.tensor
21
- Lengths of the waveforms relative to the longest one in the
22
- batch, tensor of shape [batch]. The longest one should have
23
- relative length 1.0 and others len(waveform) / max_length.
24
- Used for ignoring padding.
25
- normalize : bool
26
- If True, it normalizes the embeddings with the statistics
27
- contained in mean_var_norm_emb.
28
-
29
- Returns
30
- -------
31
- torch.tensor
32
- The encoded batch
33
- """
34
- # Manage single waveforms in input
35
- if len(wavs.shape) == 1:
36
- wavs = wavs.unsqueeze(0)
37
-
38
- # Assign full length if wav_lens is not assigned
39
- if wav_lens is None:
40
- wav_lens = torch.ones(wavs.shape[0], device=self.device)
41
-
42
- # Storing waveform in the specified device
43
- wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
44
- wavs = wavs.float()
45
-
46
- # Computing features and embeddings
47
- feats = self.mods.compute_features(wavs)
48
- feats = self.mods.mean_var_norm(feats, wav_lens)
49
- embeddings = self.mods.embedding_model(feats, wav_lens)
50
- if normalize:
51
- embeddings = self.hparams.mean_var_norm_emb(
52
- embeddings, torch.ones(embeddings.shape[0], device=self.device)
53
- )
54
- return embeddings
55
-
56
- def classify_batch(self, wavs, wav_lens=None):
57
- """Performs classification on the top of the encoded features.
58
-
59
- It returns the posterior probabilities, the index and, if the label
60
- encoder is specified it also the text label.
61
-
62
- Arguments
63
- ---------
64
- wavs : torch.tensor
65
- Batch of waveforms [batch, time, channels] or [batch, time]
66
- depending on the model. Make sure the sample rate is fs=16000 Hz.
67
- wav_lens : torch.tensor
68
- Lengths of the waveforms relative to the longest one in the
69
- batch, tensor of shape [batch]. The longest one should have
70
- relative length 1.0 and others len(waveform) / max_length.
71
- Used for ignoring padding.
72
-
73
- Returns
74
- -------
75
- out_prob
76
- The log posterior probabilities of each class ([batch, N_class])
77
- score:
78
- It is the value of the log-posterior for the best class ([batch,])
79
- index
80
- The indexes of the best class ([batch,])
81
- text_lab:
82
- List with the text labels corresponding to the indexes.
83
- (label encoder should be provided).
84
- """
85
- emb = self.encode_batch(wavs, wav_lens)
86
- out_prob = self.mods.classifier(emb).squeeze(1)
87
- score, index = torch.max(out_prob, dim=-1)
88
- text_lab = self.hparams.label_encoder.decode_torch(index)
89
- return out_prob, score, index, text_lab
90
-
91
-
92
- def classify_file(self, path):
93
- """Classifies the given audiofile into the given set of labels.
94
-
95
- Arguments
96
- ---------
97
- path : str
98
- Path to audio file to classify.
99
-
100
- Returns
101
- -------
102
- out_prob
103
- The log posterior probabilities of each class ([batch, N_class])
104
- score:
105
- It is the value of the log-posterior for the best class ([batch,])
106
- index
107
- The indexes of the best class ([batch,])
108
- text_lab:
109
- List with the text labels corresponding to the indexes.
110
- (label encoder should be provided).
111
- """
112
- waveform = self.load_audio(path)
113
- # Fake a batch:
114
- batch = waveform.unsqueeze(0)
115
- rel_length = torch.tensor([1.0])
116
- emb = self.encode_batch(batch, rel_length)
117
- out_prob = self.mods.classifier(emb).squeeze(1)
118
- score, index = torch.max(out_prob, dim=-1)
119
- text_lab = self.hparams.label_encoder.decode_torch(index)
120
- return out_prob, score, index, text_lab
 
1
  import torch
2
+ from speechbrain.pretrained import Pretrained
3
 
4
+ class EncoderClassifier(Pretrained):