File size: 3,928 Bytes
089d567
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128


import torch
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
from speechbrain.utils.distributed import run_on_main

show_results_every = 100  # plots results every N iterations
run_opts = {
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

class PipelineSLUTask(sb.pretrained.interfaces.Pretrained):
    HPARAMS_NEEDED = [
            "slu_enc",
            "output_emb",
            "dec",
            "seq_lin",
            "env_corrupt",
            "tokenizer",
    ]
    MODULES_NEEDED = [
            "slu_enc",
            "output_emb",
            "dec",
            "seq_lin",
            "env_corrupt",
    ]

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        pass

    def encode_file(self, path):

        tokens_bos = torch.tensor([[0]]).to(self.device)
        tokens = torch.tensor([], dtype=torch.int64).to(self.device)

        waveform = self.load_audio(path)
        wavs = waveform.unsqueeze(0)
        wav_lens = torch.tensor([1.0])
        # Fake a batch:
        # batch = waveform.unsqueeze(0)
        rel_length = torch.tensor([1.0])
        with torch.no_grad():
            rel_lens = rel_length.to(self.device)
            # ASR encoder forward pass
            ASR_encoder_out = self.hparams.asr_model.encode_batch(
                wavs.detach(), wav_lens
            )

            # SLU forward pass
            encoder_out = self.hparams.slu_enc(ASR_encoder_out)
            e_in = self.hparams.output_emb(tokens_bos)
            # print(e_in.shape)
            # print(encoder_out.shape)
            # print(wav_lens.shape)
            h, _ = self.hparams.dec(e_in, encoder_out, wav_lens)

            # Output layer for seq2seq log-probabilities
            logits = self.hparams.seq_lin(h)
            p_seq = self.hparams.log_softmax(logits)

            # Compute outputs
            # if (
            #     stage == sb.Stage.TRAIN
            #     and self.batch_count % show_results_every != 0
            # ):
            #     return p_seq, wav_lens
            # else:
            p_tokens, scores = self.hparams.beam_searcher(encoder_out, wav_lens)
            return p_seq, wav_lens, p_tokens

        # return ASR_encoder_out

    def decode(self, p_seq, wav_lens, predicted_tokens):
        tokens_eos = torch.tensor([[0]]).to(self.device)
        tokens_eos_lens = torch.tensor([0]).to(self.device)

        # Decode token terms to words
        predicted_semantics = [
            self.hparams.tokenizer.decode_ids(utt_seq).split(" ")
            for utt_seq in predicted_tokens
        ]
        return predicted_semantics


from typing import Dict, List, Any

class EndpointHandler():
    def __init__(self, path=""):
        hparams_file = f"{path}/better_tokenizer/1986/hyperparams.yaml"
        overrides = {}
        with open(hparams_file) as fin:
            hparams = load_hyperpyyaml(fin, overrides)
            
        run_opts = {
            "device": "cuda" if torch.cuda.is_available() else "cpu"
        }
        
        # We download and pretrain the tokenizer
        run_on_main(hparams["pretrainer"].collect_files)
        hparams["pretrainer"].load_collected(device=run_opts["device"])
        
        self.pipeline = PipelineSLUTask(
            modules=hparams['modules'],
            hparams=hparams,
            run_opts=run_opts
        )

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # pseudo
        # self.model(input)
        data = data.get("inputs", data)
        print(data)
        ps, wl, pt = self.pipeline.encode_file(data)
        print(ps)
        print(wl)
        print(pt)
        return self.pipeline.decode(ps, wl, pt)