Spaces:
Runtime error
Runtime error
import os | |
import time | |
import torch | |
import random | |
import argparse | |
from unidecode import unidecode | |
from samplings import top_p_sampling, temperature_sampling | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
def generate_abc(args): | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
print('There are %d GPU(s) available.' % torch.cuda.device_count()) | |
print('We will use the GPU:', torch.cuda.get_device_name(0), '\n') | |
else: | |
print('No GPU available, using the CPU instead.\n') | |
device = torch.device("cpu") | |
num_tunes = args.num_tunes | |
max_length = args.max_length | |
top_p = args.top_p | |
temperature = args.temperature | |
seed = args.seed | |
print(" HYPERPARAMETERS ".center(60, "#"), '\n') | |
args = vars(args) | |
for key in args.keys(): | |
print(key+': '+str(args[key])) | |
with open('input_text.txt') as f: | |
text = unidecode(f.read()) | |
print("\n"+" INPUT TEXT ".center(60, "#")) | |
print('\n'+text+'\n') | |
tokenizer = AutoTokenizer.from_pretrained('sander-wood/text-to-music') | |
model = AutoModelForSeq2SeqLM.from_pretrained('sander-wood/text-to-music') | |
model = model.to(device) | |
input_ids = tokenizer(text, | |
return_tensors='pt', | |
truncation=True, | |
max_length=max_length)['input_ids'].to(device) | |
decoder_start_token_id = model.config.decoder_start_token_id | |
eos_token_id = model.config.eos_token_id | |
random.seed(seed) | |
tunes = "" | |
print(" OUTPUT TUNES ".center(60, "#")) | |
for n_idx in range(num_tunes): | |
print("\nX:"+str(n_idx+1)+"\n", end="") | |
tunes += "X:"+str(n_idx+1)+"\n" | |
decoder_input_ids = torch.tensor([[decoder_start_token_id]]) | |
for t_idx in range(max_length): | |
if seed!=None: | |
n_seed = random.randint(0, 1000000) | |
random.seed(n_seed) | |
else: | |
n_seed = None | |
outputs = model(input_ids=input_ids, | |
decoder_input_ids=decoder_input_ids.to(device)) | |
probs = outputs.logits[0][-1] | |
probs = torch.nn.Softmax(dim=-1)(probs).cpu().detach().numpy() | |
sampled_id = temperature_sampling(probs=top_p_sampling(probs, | |
top_p=top_p, | |
seed=n_seed, | |
return_probs=True), | |
seed=n_seed, | |
temperature=temperature) | |
decoder_input_ids = torch.cat((decoder_input_ids, torch.tensor([[sampled_id]])), 1) | |
if sampled_id!=eos_token_id: | |
sampled_token = tokenizer.decode([sampled_id]) | |
print(sampled_token, end="") | |
tunes += sampled_token | |
else: | |
tunes += '\n' | |
break | |
timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime()) | |
with open('output_tunes/'+timestamp+'.abc', 'w') as f: | |
f.write(unidecode(tunes)) | |
def get_args(parser): | |
parser.add_argument('-num_tunes', type=int, default=3, help='the number of independently computed returned tunes') | |
parser.add_argument('-max_length', type=int, default=1024, help='integer to define the maximum length in tokens of each tune') | |
parser.add_argument('-top_p', type=float, default=0.9, help='float to define the tokens that are within the sample operation of text generation') | |
parser.add_argument('-temperature', type=float, default=1., help='the temperature of the sampling operation') | |
parser.add_argument('-seed', type=int, default=None, help='seed for randomstate') | |
args = parser.parse_args() | |
return args | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
args = get_args(parser) | |
generate_abc(args) |