File size: 4,016 Bytes
1e1e30f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import time
import torch
import random
import argparse
from unidecode import unidecode
from samplings import top_p_sampling, temperature_sampling
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

def generate_abc(args):

    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'

    if torch.cuda.is_available():    
        device = torch.device("cuda")
        print('There are %d GPU(s) available.' % torch.cuda.device_count())
        print('We will use the GPU:', torch.cuda.get_device_name(0), '\n')
    else:
        print('No GPU available, using the CPU instead.\n')
        device = torch.device("cpu")

    num_tunes = args.num_tunes
    max_length = args.max_length
    top_p = args.top_p
    temperature = args.temperature
    seed = args.seed
    print(" HYPERPARAMETERS ".center(60, "#"), '\n')
    args = vars(args)
    for key in args.keys():
        print(key+': '+str(args[key]))

    with open('input_text.txt') as f:
        text = unidecode(f.read())
    print("\n"+" INPUT TEXT ".center(60, "#"))
    print('\n'+text+'\n')

    tokenizer = AutoTokenizer.from_pretrained('sander-wood/text-to-music')
    model = AutoModelForSeq2SeqLM.from_pretrained('sander-wood/text-to-music')
    model = model.to(device)

    input_ids = tokenizer(text, 
                        return_tensors='pt', 
                        truncation=True, 
                        max_length=max_length)['input_ids'].to(device)
    decoder_start_token_id = model.config.decoder_start_token_id
    eos_token_id = model.config.eos_token_id
    random.seed(seed)
    tunes = ""
    print(" OUTPUT TUNES ".center(60, "#"))

    for n_idx in range(num_tunes):
        print("\nX:"+str(n_idx+1)+"\n", end="")
        tunes += "X:"+str(n_idx+1)+"\n"
        decoder_input_ids = torch.tensor([[decoder_start_token_id]])

        for t_idx in range(max_length):
            
            if seed!=None:
                n_seed = random.randint(0, 1000000)
                random.seed(n_seed)
            else:
                n_seed = None
            outputs = model(input_ids=input_ids, 
            decoder_input_ids=decoder_input_ids.to(device))
            probs = outputs.logits[0][-1]
            probs = torch.nn.Softmax(dim=-1)(probs).cpu().detach().numpy()
            sampled_id = temperature_sampling(probs=top_p_sampling(probs, 
                                                                top_p=top_p, 
                                                                seed=n_seed,
                                                                return_probs=True),
                                            seed=n_seed,
                                            temperature=temperature)
            decoder_input_ids = torch.cat((decoder_input_ids, torch.tensor([[sampled_id]])), 1)
            if sampled_id!=eos_token_id:
                sampled_token = tokenizer.decode([sampled_id])
                print(sampled_token, end="")
                tunes += sampled_token
            else:
                tunes += '\n'
                break

    timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime()) 
    with open('output_tunes/'+timestamp+'.abc', 'w') as f:
        f.write(unidecode(tunes))

def get_args(parser):

    parser.add_argument('-num_tunes', type=int, default=3, help='the number of independently computed returned tunes')
    parser.add_argument('-max_length', type=int, default=1024, help='integer to define the maximum length in tokens of each tune')
    parser.add_argument('-top_p', type=float, default=0.9, help='float to define the tokens that are within the sample operation of text generation')
    parser.add_argument('-temperature', type=float, default=1., help='the temperature of the sampling operation')
    parser.add_argument('-seed', type=int, default=None, help='seed for randomstate')
    args = parser.parse_args()

    return args
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    args = get_args(parser)
    generate_abc(args)