File size: 2,303 Bytes
48ac659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from audiocaptioner import AudioCaptioner
from data_module import AudiostockDataset
from utils import *

def infer(input_filename):
    device = get_device(0)
    # connect to GCS
    gcs = CheckpointManager()
    # create and/or load model
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=False)
    prefix_dim = 512
    prefix_length = 10
    prefix_length_clip = 10
    num_layers = 8
    checkpoint = 'checkpoints/ZRIUE-BEST.pt'
    model = AudioCaptioner(prefix_length, clip_length=prefix_length_clip, prefix_size=prefix_dim, num_layers=num_layers).to(device)
    model.load_state_dict(gcs.get_checkpoint(checkpoint))
    print(f'Loaded from {checkpoint}')
    model.eval()
    # read in the wav file and precompute neighbors
    #dataset_path = '/graft1/datasets/kechen/audiostock-full'
    dataset_path = ''
    train_dataset = AudiostockDataset(
            dataset_path=dataset_path,
            train=False,
            split='audiostock-train-240k.txt',
            factor=1.0,
            verbose=False,
            file_list=open('audiostock-train-240k.txt', 'r').read().split()
    )
    print('Reading in file', input_filename)
    dataset = AudiostockDataset(
        dataset_path=dataset_path,
        train=False,
        split=None,
        factor=1.0,
        verbose=False,
        file_list=[input_filename] # manually override file list
    )
    dataset.precompute_neighbors(model, candidate_set=train_dataset)
    waveform = dataset.read_wav(input_filename).unsqueeze(0).to(device, dtype=torch.float32)
    # predict
    with torch.no_grad():
        prefix_embed = model.create_prefix(waveform, 1)
        tweet_tokens = torch.tensor(preproc(dataset.id2neighbor[os.path.basename(input_filename).split('.')[0]], tokenizer, stop=False), dtype=torch.int64).to(device)[:150]
        tweet_embed = model.gpt.transformer.wte(tweet_tokens)
        prefix_embed = torch.cat([prefix_embed, tweet_embed.unsqueeze(0)], dim=1)
        candidates = generate_beam(model, tokenizer, embed=prefix_embed, beam_size=5)
        generated_text = candidates[0]
        generated_text = postproc(generated_text)
    print('=======================================')
    print(generated_text)

if __name__ == '__main__':
    infer('../MusicCaptioning/sample_inputs/sisters.mp3')