text2music / run_inference.py
Kyio's picture
Create run_inference.py
1e1e30f
raw
history blame
4.02 kB
import os
import time
import torch
import random
import argparse
from unidecode import unidecode
from samplings import top_p_sampling, temperature_sampling
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
def generate_abc(args):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
if torch.cuda.is_available():
device = torch.device("cuda")
print('There are %d GPU(s) available.' % torch.cuda.device_count())
print('We will use the GPU:', torch.cuda.get_device_name(0), '\n')
else:
print('No GPU available, using the CPU instead.\n')
device = torch.device("cpu")
num_tunes = args.num_tunes
max_length = args.max_length
top_p = args.top_p
temperature = args.temperature
seed = args.seed
print(" HYPERPARAMETERS ".center(60, "#"), '\n')
args = vars(args)
for key in args.keys():
print(key+': '+str(args[key]))
with open('input_text.txt') as f:
text = unidecode(f.read())
print("\n"+" INPUT TEXT ".center(60, "#"))
print('\n'+text+'\n')
tokenizer = AutoTokenizer.from_pretrained('sander-wood/text-to-music')
model = AutoModelForSeq2SeqLM.from_pretrained('sander-wood/text-to-music')
model = model.to(device)
input_ids = tokenizer(text,
return_tensors='pt',
truncation=True,
max_length=max_length)['input_ids'].to(device)
decoder_start_token_id = model.config.decoder_start_token_id
eos_token_id = model.config.eos_token_id
random.seed(seed)
tunes = ""
print(" OUTPUT TUNES ".center(60, "#"))
for n_idx in range(num_tunes):
print("\nX:"+str(n_idx+1)+"\n", end="")
tunes += "X:"+str(n_idx+1)+"\n"
decoder_input_ids = torch.tensor([[decoder_start_token_id]])
for t_idx in range(max_length):
if seed!=None:
n_seed = random.randint(0, 1000000)
random.seed(n_seed)
else:
n_seed = None
outputs = model(input_ids=input_ids,
decoder_input_ids=decoder_input_ids.to(device))
probs = outputs.logits[0][-1]
probs = torch.nn.Softmax(dim=-1)(probs).cpu().detach().numpy()
sampled_id = temperature_sampling(probs=top_p_sampling(probs,
top_p=top_p,
seed=n_seed,
return_probs=True),
seed=n_seed,
temperature=temperature)
decoder_input_ids = torch.cat((decoder_input_ids, torch.tensor([[sampled_id]])), 1)
if sampled_id!=eos_token_id:
sampled_token = tokenizer.decode([sampled_id])
print(sampled_token, end="")
tunes += sampled_token
else:
tunes += '\n'
break
timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
with open('output_tunes/'+timestamp+'.abc', 'w') as f:
f.write(unidecode(tunes))
def get_args(parser):
parser.add_argument('-num_tunes', type=int, default=3, help='the number of independently computed returned tunes')
parser.add_argument('-max_length', type=int, default=1024, help='integer to define the maximum length in tokens of each tune')
parser.add_argument('-top_p', type=float, default=0.9, help='float to define the tokens that are within the sample operation of text generation')
parser.add_argument('-temperature', type=float, default=1., help='the temperature of the sampling operation')
parser.add_argument('-seed', type=int, default=None, help='seed for randomstate')
args = parser.parse_args()
return args
if __name__ == '__main__':
parser = argparse.ArgumentParser()
args = get_args(parser)
generate_abc(args)