from audiocaptioner import AudioCaptioner from data_module import AudiostockDataset from utils import * def infer(input_filename): device = get_device(0) # connect to GCS gcs = CheckpointManager() # create and/or load model tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=False) prefix_dim = 512 prefix_length = 10 prefix_length_clip = 10 num_layers = 8 checkpoint = 'checkpoints/ZRIUE-BEST.pt' model = AudioCaptioner(prefix_length, clip_length=prefix_length_clip, prefix_size=prefix_dim, num_layers=num_layers).to(device) model.load_state_dict(gcs.get_checkpoint(checkpoint)) print(f'Loaded from {checkpoint}') model.eval() # read in the wav file and precompute neighbors #dataset_path = '/graft1/datasets/kechen/audiostock-full' dataset_path = '' train_dataset = AudiostockDataset( dataset_path=dataset_path, train=False, split='audiostock-train-240k.txt', factor=1.0, verbose=False, file_list=open('audiostock-train-240k.txt', 'r').read().split() ) print('Reading in file', input_filename) dataset = AudiostockDataset( dataset_path=dataset_path, train=False, split=None, factor=1.0, verbose=False, file_list=[input_filename] # manually override file list ) dataset.precompute_neighbors(model, candidate_set=train_dataset) waveform = dataset.read_wav(input_filename).unsqueeze(0).to(device, dtype=torch.float32) # predict with torch.no_grad(): prefix_embed = model.create_prefix(waveform, 1) tweet_tokens = torch.tensor(preproc(dataset.id2neighbor[os.path.basename(input_filename).split('.')[0]], tokenizer, stop=False), dtype=torch.int64).to(device)[:150] tweet_embed = model.gpt.transformer.wte(tweet_tokens) prefix_embed = torch.cat([prefix_embed, tweet_embed.unsqueeze(0)], dim=1) candidates = generate_beam(model, tokenizer, embed=prefix_embed, beam_size=5) generated_text = candidates[0] generated_text = postproc(generated_text) print('=======================================') print(generated_text) if __name__ == '__main__': infer('../MusicCaptioning/sample_inputs/sisters.mp3')