Spaces:
Sleeping
Sleeping
File size: 2,303 Bytes
48ac659 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from audiocaptioner import AudioCaptioner
from data_module import AudiostockDataset
from utils import *
def infer(input_filename):
device = get_device(0)
# connect to GCS
gcs = CheckpointManager()
# create and/or load model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=False)
prefix_dim = 512
prefix_length = 10
prefix_length_clip = 10
num_layers = 8
checkpoint = 'checkpoints/ZRIUE-BEST.pt'
model = AudioCaptioner(prefix_length, clip_length=prefix_length_clip, prefix_size=prefix_dim, num_layers=num_layers).to(device)
model.load_state_dict(gcs.get_checkpoint(checkpoint))
print(f'Loaded from {checkpoint}')
model.eval()
# read in the wav file and precompute neighbors
#dataset_path = '/graft1/datasets/kechen/audiostock-full'
dataset_path = ''
train_dataset = AudiostockDataset(
dataset_path=dataset_path,
train=False,
split='audiostock-train-240k.txt',
factor=1.0,
verbose=False,
file_list=open('audiostock-train-240k.txt', 'r').read().split()
)
print('Reading in file', input_filename)
dataset = AudiostockDataset(
dataset_path=dataset_path,
train=False,
split=None,
factor=1.0,
verbose=False,
file_list=[input_filename] # manually override file list
)
dataset.precompute_neighbors(model, candidate_set=train_dataset)
waveform = dataset.read_wav(input_filename).unsqueeze(0).to(device, dtype=torch.float32)
# predict
with torch.no_grad():
prefix_embed = model.create_prefix(waveform, 1)
tweet_tokens = torch.tensor(preproc(dataset.id2neighbor[os.path.basename(input_filename).split('.')[0]], tokenizer, stop=False), dtype=torch.int64).to(device)[:150]
tweet_embed = model.gpt.transformer.wte(tweet_tokens)
prefix_embed = torch.cat([prefix_embed, tweet_embed.unsqueeze(0)], dim=1)
candidates = generate_beam(model, tokenizer, embed=prefix_embed, beam_size=5)
generated_text = candidates[0]
generated_text = postproc(generated_text)
print('=======================================')
print(generated_text)
if __name__ == '__main__':
infer('../MusicCaptioning/sample_inputs/sisters.mp3')
|