Tim-gubski commited on
Commit
c9ce2d9
·
verified ·
1 Parent(s): f8c0a29

changed model

Browse files
Files changed (1) hide show
  1. audio2hero.py +64 -0
audio2hero.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import librosa
3
+ from transformers import Pop2PianoForConditionalGeneration, Pop2PianoProcessor, Pop2PianoTokenizer, Pop2PianoConfig
4
+ import pretty_midi
5
+ from transformers import AutoConfig
6
+ from model_generate import generate
7
+ import torch
8
+ from post_processor import post_process
9
+ import tempfile
10
+ import shutil
11
+
12
+ def generate_midi(song_path, output_dir=None):
13
+ if output_dir is None:
14
+ output_dir = "./Outputs"
15
+
16
+ print("Loading Model...")
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ model = Pop2PianoForConditionalGeneration.from_pretrained("Tim-gubski/Audio2Hero").to(device)
19
+ model.eval()
20
+ processor = Pop2PianoProcessor.from_pretrained("sweetcocoa/pop2piano")
21
+ tokenizer = Pop2PianoTokenizer.from_pretrained("sweetcocoa/pop2piano")
22
+
23
+ print("Processing Song...")
24
+ # load an example audio file and corresponding ground truth midi file
25
+ audio, sr = librosa.load(song_path, sr=44100) # feel free to change the sr to a suitable value.
26
+ inputs = processor(audio=audio, sampling_rate=sr, return_tensors="pt")
27
+
28
+
29
+ # generate model output
30
+ print("Generating output...")
31
+ model.generation_config.output_logits = True
32
+ model.generation_config.return_dict_in_generate = True
33
+ model_output = model.generate(inputs["input_features"].to(device))
34
+
35
+ tokenizer_output = processor.batch_decode(
36
+ token_ids=model_output.sequences.cpu(),
37
+ feature_extractor_output=inputs
38
+ )
39
+
40
+ # save to temp file
41
+ temp_dir = tempfile.TemporaryDirectory()
42
+ tokenizer_output["pretty_midi_objects"][0].write(f"{temp_dir.name}/temp.mid")
43
+
44
+ print("Post Processing...")
45
+ post_process(song_path, f"{temp_dir.name}/temp.mid", output_dir)
46
+
47
+ # zip folder
48
+ song_name = song_path.split("/")[-1]
49
+ song_name = ".".join(song_name.split(".")[0:-1])
50
+ shutil.make_archive(f"{output_dir}/{song_name}", 'zip', f"{output_dir}/{song_name}")
51
+
52
+ temp_dir.cleanup()
53
+ print("Done.")
54
+
55
+ return f"{output_dir}/{song_name}.zip"
56
+
57
+
58
+ if __name__=="__main__":
59
+ args = sys.argv[1:]
60
+ song_path = args[0]
61
+ output_dir = args[1]
62
+ generate_midi(song_path, output_dir)
63
+
64
+