File size: 1,645 Bytes
cb4d93d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from pathlib import Path
import random
from string import ascii_letters
from miditok import PerTok, TokSequence
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

generated_path = Path("generated")

midi_tokenizer = PerTok(params="tokenizer2.json")
_ = midi_tokenizer._create_base_vocabulary()  # workaround, otherwise the preprocessing will fail

# Define which model we want, download right tokenizer
checkpoint = "JannikAhlers/groove_midi_2"
t5_tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

def generate_groove(midi_filename: str, count: int=1) -> list[str]:
    midi_tokens = midi_tokenizer(midi_filename)[0]
    tokens_string = " ".join(midi_tokens.tokens[:512])  # limit length to 512, because the tokenizer can't handle longer inputs
    inputs = t5_tokenizer(tokens_string, return_tensors="pt").input_ids

    out_filenames = []
    for i in range(count):
        outputs = model.generate(inputs, max_new_tokens=1000, do_sample=True, top_k=30, top_p=0.95)
        generated = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated = generated.split(" ")
        generated.pop()  # delete the last event because it might be chopped off mid-token
        generated_seq = TokSequence(tokens=generated)

        # save file
        generated_miditok = midi_tokenizer([generated_seq], programs=[(10, True)])
        out_filename = f"{str.join("", random.choices(ascii_letters, k=16))}.mid"
        generated_miditok.dump_midi(generated_path/out_filename)
        out_filenames.append(str(generated_path/out_filename))
    return out_filenames