|
from pathlib import Path |
|
import random |
|
from string import ascii_letters |
|
from miditok import PerTok, TokSequence |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
generated_path = Path("generated") |
|
|
|
midi_tokenizer = PerTok(params="tokenizer2.json") |
|
_ = midi_tokenizer._create_base_vocabulary() |
|
|
|
|
|
checkpoint = "JannikAhlers/groove_midi_2" |
|
t5_tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) |
|
|
|
def generate_groove(midi_filename: str, count: int=1) -> list[str]: |
|
midi_tokens = midi_tokenizer(midi_filename)[0] |
|
tokens_string = " ".join(midi_tokens.tokens[:512]) |
|
inputs = t5_tokenizer(tokens_string, return_tensors="pt").input_ids |
|
|
|
out_filenames = [] |
|
for i in range(count): |
|
outputs = model.generate(inputs, max_new_tokens=1000, do_sample=True, top_k=30, top_p=0.95) |
|
generated = t5_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
generated = generated.split(" ") |
|
generated.pop() |
|
generated_seq = TokSequence(tokens=generated) |
|
|
|
|
|
generated_miditok = midi_tokenizer([generated_seq], programs=[(10, True)]) |
|
out_filename = f"{str.join("", random.choices(ascii_letters, k=16))}.mid" |
|
generated_miditok.dump_midi(generated_path/out_filename) |
|
out_filenames.append(str(generated_path/out_filename)) |
|
return out_filenames |