JannikAhlers's picture
finish functionality
cb4d93d
from pathlib import Path
import random
from string import ascii_letters
from miditok import PerTok, TokSequence
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
generated_path = Path("generated")
midi_tokenizer = PerTok(params="tokenizer2.json")
_ = midi_tokenizer._create_base_vocabulary() # workaround, otherwise the preprocessing will fail
# Define which model we want, download right tokenizer
checkpoint = "JannikAhlers/groove_midi_2"
t5_tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
def generate_groove(midi_filename: str, count: int=1) -> list[str]:
midi_tokens = midi_tokenizer(midi_filename)[0]
tokens_string = " ".join(midi_tokens.tokens[:512]) # limit length to 512, because the tokenizer can't handle longer inputs
inputs = t5_tokenizer(tokens_string, return_tensors="pt").input_ids
out_filenames = []
for i in range(count):
outputs = model.generate(inputs, max_new_tokens=1000, do_sample=True, top_k=30, top_p=0.95)
generated = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
generated = generated.split(" ")
generated.pop() # delete the last event because it might be chopped off mid-token
generated_seq = TokSequence(tokens=generated)
# save file
generated_miditok = midi_tokenizer([generated_seq], programs=[(10, True)])
out_filename = f"{str.join("", random.choices(ascii_letters, k=16))}.mid"
generated_miditok.dump_midi(generated_path/out_filename)
out_filenames.append(str(generated_path/out_filename))
return out_filenames