File size: 7,542 Bytes
052c3ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c6da25
 
052c3ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7ad4c6
 
052c3ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# -*- coding: utf-8 -*-

# Copyright 2024 Wen-Chin Huang
#  MIT License (https://opensource.org/licenses/MIT)

# SSLMOS model
# modified from: https://github.com/nii-yamagishilab/mos-finetune-ssl/blob/main/mos_fairseq.py (written by Erica Cooper)

import torch
import torch.nn as nn
from .modules import Projection


class SSLMOS(torch.nn.Module):
    def __init__(
        self,
        # dummy, for signature need
        model_input: str,
        # model related
        ssl_module: str,
        s3prl_name: str,
        ssl_model_output_dim: int,
        ssl_model_layer_idx: int,
        # mean net related
        mean_net_dnn_dim: int = 64,
        mean_net_output_type: str = "scalar",
        mean_net_output_dim: int = 5,
        mean_net_output_step: float = 0.25,
        mean_net_range_clipping: bool = True,
        # listener related
        use_listener_modeling: bool = False,
        num_listeners: int = None,
        listener_emb_dim: int = None,
        use_mean_listener: bool = True,
        # decoder related
        decoder_type: str = "ffn",
        decoder_dnn_dim: int = 64,
        output_type: str = "scalar",
        range_clipping: bool = True,
        # dummy
        num_domains: int = None,
    ):
        super().__init__()  # this is needed! or else there will be an error.
        self.use_mean_listener = use_mean_listener
        self.output_type = output_type

        # define listener embedding
        self.use_listener_modeling = use_listener_modeling

        # define ssl model
        if ssl_module == "s3prl":
            from s3prl.nn import S3PRLUpstream

            if s3prl_name in S3PRLUpstream.available_names():
                self.ssl_model = S3PRLUpstream(s3prl_name)
            self.ssl_model_layer_idx = ssl_model_layer_idx
        else:
            raise NotImplementedError

        # default uses ffn type mean net
        self.mean_net_dnn = Projection(
            ssl_model_output_dim,
            mean_net_dnn_dim,
            nn.ReLU,
            mean_net_output_type,
            mean_net_output_dim,
            mean_net_output_step,
            mean_net_range_clipping,
        )

        # listener modeling related
        self.use_listener_modeling = use_listener_modeling
        if use_listener_modeling:
            self.num_listeners = num_listeners
            self.listener_embeddings = nn.Embedding(
                num_embeddings=num_listeners, embedding_dim=listener_emb_dim
            )
            # define decoder
            self.decoder_type = decoder_type
            if decoder_type == "ffn":
                decoder_dnn_input_dim = ssl_model_output_dim + listener_emb_dim
            else:
                raise NotImplementedError
            # there is always dnn
            self.decoder_dnn = Projection(
                decoder_dnn_input_dim,
                decoder_dnn_dim,
                self.activation,
                output_type,
                range_clipping,
            )

    def get_num_params(self):
        return sum(p.numel() for n, p in self.named_parameters())

    def forward(self, inputs):
        """Calculate forward propagation.
        Args:
            waveform has shape (batch, time)
            waveform_lengths has shape (batch)
            listener_ids has shape (batch)
        """
        waveform = inputs["waveform"]
        waveform_lengths = inputs["waveform_lengths"]

        batch, time = waveform.shape

        # get listener embedding
        if self.use_listener_modeling:
            listener_ids = inputs["listener_idxs"]
            # NOTE(unlight): not tested yet
            listener_embs = self.listener_embeddings(listener_ids)  # (batch, emb_dim)
            listener_embs = torch.stack(
                [listener_embs for i in range(time)], dim=1
            )  # (batch, time, feat_dim)

        # ssl model forward
        all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model(
            waveform, waveform_lengths
        )
        encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx]
        encoder_outputs_lens = all_encoder_outputs_lens[self.ssl_model_layer_idx]

        # inject listener embedding
        if self.use_listener_modeling:
            # NOTE(unlight): not tested yet
            encoder_outputs = encoder_outputs.view(
                (batch, time, -1)
            )  # (batch, time, feat_dim)
            decoder_inputs = torch.cat(
                [encoder_outputs, listener_embs], dim=-1
            )  # concat along feature dimension
        else:
            decoder_inputs = encoder_outputs

        # masked mean pooling
        # masks = make_non_pad_mask(encoder_outputs_lens)
        # masks = masks.unsqueeze(-1).to(decoder_inputs.device) # [B, max_time, 1]
        # decoder_inputs = torch.sum(decoder_inputs * masks, dim=1) / encoder_outputs_lens.unsqueeze(-1)

        # mean net
        mean_net_outputs = self.mean_net_dnn(
            decoder_inputs
        )  # [batch, time, 1 (scalar) / 5 (categorical)]

        # decoder
        if self.use_listener_modeling:
            if self.decoder_type == "rnn":
                decoder_outputs, (h, c) = self.decoder_rnn(decoder_inputs)
            else:
                decoder_outputs = decoder_inputs
            decoder_outputs = self.decoder_dnn(
                decoder_outputs
            )  # [batch, time, 1 (scalar) / 5 (categorical)]

        # set outputs
        # return lengths for masked loss calculation
        ret = {
            "waveform_lengths": waveform_lengths,
            "frame_lengths": encoder_outputs_lens,
        }

        # define scores
        ret["mean_scores"] = mean_net_outputs
        ret["ld_scores"] = decoder_outputs if self.use_listener_modeling else None

        return ret

    def mean_net_inference(self, inputs):
        waveform = inputs["waveform"]
        waveform_lengths = inputs["waveform_lengths"]

        # ssl model forward
        all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model(
            waveform, waveform_lengths
        )
        encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx]

        # mean net
        decoder_inputs = encoder_outputs
        mean_net_outputs = self.mean_net_dnn(
            decoder_inputs, inference=True
        )  # [batch, time, 1 (scalar) / 5 (categorical)]
        mean_net_outputs = mean_net_outputs.squeeze(-1)
        scores = torch.mean(mean_net_outputs, dim=1) # [batch]

        return {
            "ssl_embeddings": encoder_outputs,
            "scores": scores
        }

    def mean_net_inference_p1(self, waveform, waveform_lengths):
        # ssl model forward
        all_encoder_outputs, _ = self.ssl_model(waveform, waveform_lengths)
        encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx]
        return encoder_outputs

    def mean_net_inference_p2(self, encoder_outputs):
        # mean net
        mean_net_outputs = self.mean_net_dnn(
            encoder_outputs
        )  # [batch, time, 1 (scalar) / 5 (categorical)]
        mean_net_outputs = mean_net_outputs.squeeze(-1)
        scores = torch.mean(mean_net_outputs, dim=1)

        return scores

    def get_ssl_embeddings(self, inputs):
        waveform = inputs["waveform"]
        waveform_lengths = inputs["waveform_lengths"]

        all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model(
            waveform, waveform_lengths
        )
        encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx]
        return encoder_outputs