File size: 3,413 Bytes
8121fee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f57acc9
8121fee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import sys
import os
import librosa
import numpy as np
import torch
import audio_to_text.captioning.models
import audio_to_text.captioning.models.encoder
import audio_to_text.captioning.models.decoder
import audio_to_text.captioning.utils.train_util as train_util


def load_model(config, checkpoint):
    ckpt = torch.load(checkpoint, "cpu")
    encoder_cfg = config["model"]["encoder"]
    encoder = train_util.init_obj(
        audio_to_text.captioning.models.encoder,
        encoder_cfg
    )
    if "pretrained" in encoder_cfg:
        pretrained = encoder_cfg["pretrained"]
        train_util.load_pretrained_model(encoder,
                                         pretrained,
                                         sys.stdout.write)
    decoder_cfg = config["model"]["decoder"]
    if "vocab_size" not in decoder_cfg["args"]:
        decoder_cfg["args"]["vocab_size"] = len(ckpt["vocabulary"])
    decoder = train_util.init_obj(
        audio_to_text.captioning.models.decoder,
        decoder_cfg
    )
    if "word_embedding" in decoder_cfg:
        decoder.load_word_embedding(**decoder_cfg["word_embedding"])
    if "pretrained" in decoder_cfg:
        pretrained = decoder_cfg["pretrained"]
        train_util.load_pretrained_model(decoder,
                                         pretrained,
                                         sys.stdout.write)
    model = train_util.init_obj(audio_to_text.captioning.models, config["model"],
        encoder=encoder, decoder=decoder)
    train_util.load_pretrained_model(model, ckpt)
    model.eval()
    return {
        "model": model,
        "vocabulary": ckpt["vocabulary"]
    }


def decode_caption(word_ids, vocabulary):
    candidate = []
    for word_id in word_ids:
        word = vocabulary[word_id]
        if word == "<end>":
            break
        elif word == "<start>":
            continue
        candidate.append(word)
    candidate = " ".join(candidate)
    return candidate


class AudioCapModel(object):
    def __init__(self,weight_dir,device='cpu'):
        config = os.path.join(weight_dir,'config.yaml')
        self.config = train_util.parse_config_or_kwargs(config)
        checkpoint = os.path.join(weight_dir,'swa.pth')
        resumed = load_model(self.config, checkpoint)
        model = resumed["model"]
        self.vocabulary = resumed["vocabulary"]
        self.model = model.to(device)
        self.device = device

    def caption(self,audio_list):
        if isinstance(audio_list,np.ndarray):
            audio_list = [audio_list]
        elif isinstance(audio_list,str):
            audio_list = [librosa.load(audio_list,sr=32000)[0]]
        
        captions = []
        for wav in audio_list:
            inputwav = torch.as_tensor(wav).float().unsqueeze(0).to(self.device)
            wav_len = torch.LongTensor([len(wav)])
            input_dict = {
                "mode": "inference",
                "wav": inputwav,
                "wav_len": wav_len,
                "specaug": False,
                "sample_method": "beam",
            }
            print(input_dict)
            out_dict = self.model(input_dict)
            caption_batch = [decode_caption(seq, self.vocabulary) for seq in \
                out_dict["seq"].cpu().numpy()]
            captions.extend(caption_batch)
        return captions


        
    def __call__(self, audio_list):
        return self.caption(audio_list)