Spaces:
Runtime error
Runtime error
Create run_inference.py
Browse files- run_inference.py +100 -0
run_inference.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import argparse
|
6 |
+
from unidecode import unidecode
|
7 |
+
from samplings import top_p_sampling, temperature_sampling
|
8 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
9 |
+
|
10 |
+
def generate_abc(args):
|
11 |
+
|
12 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
13 |
+
|
14 |
+
if torch.cuda.is_available():
|
15 |
+
device = torch.device("cuda")
|
16 |
+
print('There are %d GPU(s) available.' % torch.cuda.device_count())
|
17 |
+
print('We will use the GPU:', torch.cuda.get_device_name(0), '\n')
|
18 |
+
else:
|
19 |
+
print('No GPU available, using the CPU instead.\n')
|
20 |
+
device = torch.device("cpu")
|
21 |
+
|
22 |
+
num_tunes = args.num_tunes
|
23 |
+
max_length = args.max_length
|
24 |
+
top_p = args.top_p
|
25 |
+
temperature = args.temperature
|
26 |
+
seed = args.seed
|
27 |
+
print(" HYPERPARAMETERS ".center(60, "#"), '\n')
|
28 |
+
args = vars(args)
|
29 |
+
for key in args.keys():
|
30 |
+
print(key+': '+str(args[key]))
|
31 |
+
|
32 |
+
with open('input_text.txt') as f:
|
33 |
+
text = unidecode(f.read())
|
34 |
+
print("\n"+" INPUT TEXT ".center(60, "#"))
|
35 |
+
print('\n'+text+'\n')
|
36 |
+
|
37 |
+
tokenizer = AutoTokenizer.from_pretrained('sander-wood/text-to-music')
|
38 |
+
model = AutoModelForSeq2SeqLM.from_pretrained('sander-wood/text-to-music')
|
39 |
+
model = model.to(device)
|
40 |
+
|
41 |
+
input_ids = tokenizer(text,
|
42 |
+
return_tensors='pt',
|
43 |
+
truncation=True,
|
44 |
+
max_length=max_length)['input_ids'].to(device)
|
45 |
+
decoder_start_token_id = model.config.decoder_start_token_id
|
46 |
+
eos_token_id = model.config.eos_token_id
|
47 |
+
random.seed(seed)
|
48 |
+
tunes = ""
|
49 |
+
print(" OUTPUT TUNES ".center(60, "#"))
|
50 |
+
|
51 |
+
for n_idx in range(num_tunes):
|
52 |
+
print("\nX:"+str(n_idx+1)+"\n", end="")
|
53 |
+
tunes += "X:"+str(n_idx+1)+"\n"
|
54 |
+
decoder_input_ids = torch.tensor([[decoder_start_token_id]])
|
55 |
+
|
56 |
+
for t_idx in range(max_length):
|
57 |
+
|
58 |
+
if seed!=None:
|
59 |
+
n_seed = random.randint(0, 1000000)
|
60 |
+
random.seed(n_seed)
|
61 |
+
else:
|
62 |
+
n_seed = None
|
63 |
+
outputs = model(input_ids=input_ids,
|
64 |
+
decoder_input_ids=decoder_input_ids.to(device))
|
65 |
+
probs = outputs.logits[0][-1]
|
66 |
+
probs = torch.nn.Softmax(dim=-1)(probs).cpu().detach().numpy()
|
67 |
+
sampled_id = temperature_sampling(probs=top_p_sampling(probs,
|
68 |
+
top_p=top_p,
|
69 |
+
seed=n_seed,
|
70 |
+
return_probs=True),
|
71 |
+
seed=n_seed,
|
72 |
+
temperature=temperature)
|
73 |
+
decoder_input_ids = torch.cat((decoder_input_ids, torch.tensor([[sampled_id]])), 1)
|
74 |
+
if sampled_id!=eos_token_id:
|
75 |
+
sampled_token = tokenizer.decode([sampled_id])
|
76 |
+
print(sampled_token, end="")
|
77 |
+
tunes += sampled_token
|
78 |
+
else:
|
79 |
+
tunes += '\n'
|
80 |
+
break
|
81 |
+
|
82 |
+
timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
|
83 |
+
with open('output_tunes/'+timestamp+'.abc', 'w') as f:
|
84 |
+
f.write(unidecode(tunes))
|
85 |
+
|
86 |
+
def get_args(parser):
|
87 |
+
|
88 |
+
parser.add_argument('-num_tunes', type=int, default=3, help='the number of independently computed returned tunes')
|
89 |
+
parser.add_argument('-max_length', type=int, default=1024, help='integer to define the maximum length in tokens of each tune')
|
90 |
+
parser.add_argument('-top_p', type=float, default=0.9, help='float to define the tokens that are within the sample operation of text generation')
|
91 |
+
parser.add_argument('-temperature', type=float, default=1., help='the temperature of the sampling operation')
|
92 |
+
parser.add_argument('-seed', type=int, default=None, help='seed for randomstate')
|
93 |
+
args = parser.parse_args()
|
94 |
+
|
95 |
+
return args
|
96 |
+
|
97 |
+
if __name__ == '__main__':
|
98 |
+
parser = argparse.ArgumentParser()
|
99 |
+
args = get_args(parser)
|
100 |
+
generate_abc(args)
|