juancopi81 commited on
Commit
3c6c416
1 Parent(s): 42bda77

Initial commit

Browse files
Files changed (9) hide show
  1. .gitignore +1 -0
  2. app.py +133 -0
  3. constants.py +133 -0
  4. model.py +25 -0
  5. packages.txt +4 -0
  6. pyproject.toml +6 -0
  7. requirements.txt +4 -0
  8. string_to_notes.py +137 -0
  9. utils.py +242 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ env/
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+
5
+ from utils import (
6
+ generate_song,
7
+ remove_last_instrument,
8
+ regenerate_last_instrument,
9
+ change_tempo,
10
+ )
11
+
12
+
13
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
14
+
15
+ DESCRIPTION = """
16
+
17
+ # 🎵 Multitrack Midi Generator 🎶
18
+ This interactive application uses an AI model to generate music sequences based on a chosen genre and various user inputs.
19
+
20
+ Features:
21
+ 🎼 Select the genre for the music.
22
+ 🌡️ Use the "Temperature" slider to adjust the randomness of the music generated (higher values will produce more random outputs).
23
+ ⏱️ Adjust the "Tempo" slider to change the speed of the music.
24
+ 🎹 Use the buttons to generate a new song from scratch, continue generation with the current settings, remove the last added instrument, regenerate the last added instrument with a new one, or change the tempo of the current song.
25
+ Outputs:
26
+ The app outputs the following:
27
+
28
+ 🎧 The audio of the generated song.
29
+ 📁 A MIDI file of the song.
30
+ 📊 A plot of the song's sequence.
31
+ 🎸 A list of the generated instruments.
32
+ 📝 The text sequence of the song.
33
+ Enjoy creating your own AI-generated music! 🎵
34
+ """
35
+
36
+ genres = ["ROCK", "POP", "OTHER", "R&B/SOUL", "JAZZ", "ELECTRONIC", "RANDOM"]
37
+
38
+ demo = gr.Blocks()
39
+
40
+ with demo:
41
+ gr.Markdown(DESCRIPTION)
42
+ with gr.Row():
43
+ with gr.Column():
44
+ temp = gr.Slider(
45
+ minimum=0, maximum=1, step=0.05, value=0.75, label="Temperature"
46
+ )
47
+ genre = gr.Dropdown(choices=genres, value="POP", label="Select the genre")
48
+ with gr.Row():
49
+ btn_from_scratch = gr.Button("Start from scratch")
50
+ btn_continue = gr.Button("Continue Generation")
51
+ btn_remove_last = gr.Button("Remove last instrument")
52
+ btn_regenerate_last = gr.Button("Regenerate last instrument")
53
+ with gr.Column():
54
+ with gr.Box():
55
+ audio_output = gr.Video()
56
+ midi_file = gr.File()
57
+ with gr.Row():
58
+ qpm = gr.Slider(
59
+ minimum=60, maximum=140, step=10, value=120, label="Tempo"
60
+ )
61
+ btn_qpm = gr.Button("Change Tempo")
62
+ with gr.Row():
63
+ with gr.Column():
64
+ plot_output = gr.Plot()
65
+ with gr.Column():
66
+ instruments_output = gr.Markdown("# List of generated instruments")
67
+ with gr.Row():
68
+ text_sequence = gr.Text()
69
+ empty_sequence = gr.Text(visible=False)
70
+ with gr.Row():
71
+ num_tokens = gr.Text()
72
+ btn_from_scratch.click(
73
+ fn=generate_song,
74
+ inputs=[genre, temp, empty_sequence, qpm],
75
+ outputs=[
76
+ audio_output,
77
+ midi_file,
78
+ plot_output,
79
+ instruments_output,
80
+ text_sequence,
81
+ num_tokens,
82
+ ],
83
+ )
84
+ btn_continue.click(
85
+ fn=generate_song,
86
+ inputs=[genre, temp, text_sequence, qpm],
87
+ outputs=[
88
+ audio_output,
89
+ midi_file,
90
+ plot_output,
91
+ instruments_output,
92
+ text_sequence,
93
+ num_tokens,
94
+ ],
95
+ )
96
+ btn_remove_last.click(
97
+ fn=remove_last_instrument,
98
+ inputs=[text_sequence, qpm],
99
+ outputs=[
100
+ audio_output,
101
+ midi_file,
102
+ plot_output,
103
+ instruments_output,
104
+ text_sequence,
105
+ num_tokens,
106
+ ],
107
+ )
108
+ btn_regenerate_last.click(
109
+ fn=regenerate_last_instrument,
110
+ inputs=[text_sequence, qpm],
111
+ outputs=[
112
+ audio_output,
113
+ midi_file,
114
+ plot_output,
115
+ instruments_output,
116
+ text_sequence,
117
+ num_tokens,
118
+ ],
119
+ )
120
+ btn_qpm.click(
121
+ fn=change_tempo,
122
+ inputs=[text_sequence, qpm],
123
+ outputs=[
124
+ audio_output,
125
+ midi_file,
126
+ plot_output,
127
+ instruments_output,
128
+ text_sequence,
129
+ num_tokens,
130
+ ],
131
+ )
132
+
133
+ demo.launch(debug=True)
constants.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SAMPLE_RATE = 44100
2
+
3
+
4
+ GM_INSTRUMENTS = [
5
+ "Acoustic Grand Piano",
6
+ "Bright Acoustic Piano",
7
+ "Electric Grand Piano",
8
+ "Honky-tonk Piano",
9
+ "Electric Piano 1",
10
+ "Electric Piano 2",
11
+ "Harpsichord",
12
+ "Clavi",
13
+ "Celesta",
14
+ "Glockenspiel",
15
+ "Music Box",
16
+ "Vibraphone",
17
+ "Marimba",
18
+ "Xylophone",
19
+ "Tubular Bells",
20
+ "Dulcimer",
21
+ "Drawbar Organ",
22
+ "Percussive Organ",
23
+ "Rock Organ",
24
+ "Church Organ",
25
+ "Reed Organ",
26
+ "Accordion",
27
+ "Harmonica",
28
+ "Tango Accordion",
29
+ "Acoustic Guitar (nylon)",
30
+ "Acoustic Guitar (steel)",
31
+ "Electric Guitar (jazz)",
32
+ "Electric Guitar (clean)",
33
+ "Electric Guitar (muted)",
34
+ "Overdriven Guitar",
35
+ "Distortion Guitar",
36
+ "Guitar Harmonics",
37
+ "Acoustic Bass",
38
+ "Electric Bass (finger)",
39
+ "Electric Bass (pick)",
40
+ "Fretless Bass",
41
+ "Slap Bass 1",
42
+ "Slap Bass 2",
43
+ "Synth Bass 1",
44
+ "Synth Bass 2",
45
+ "Violin",
46
+ "Viola",
47
+ "Cello",
48
+ "Contrabass",
49
+ "Tremolo Strings",
50
+ "Pizzicato Strings",
51
+ "Orchestral Harp",
52
+ "Timpani",
53
+ "String Ensemble 1",
54
+ "String Ensemble 2",
55
+ "Synth Strings 1",
56
+ "Synth Strings 2",
57
+ "Choir Aahs",
58
+ "Voice Oohs",
59
+ "Synth Choir",
60
+ "Orchestra Hit",
61
+ "Trumpet",
62
+ "Trombone",
63
+ "Tuba",
64
+ "Muted Trumpet",
65
+ "French Horn",
66
+ "Brass Section",
67
+ "Synth Brass 1",
68
+ "Synth Brass 2",
69
+ "Soprano Sax",
70
+ "Alto Sax",
71
+ "Tenor Sax",
72
+ "Baritone Sax",
73
+ "Oboe",
74
+ "English Horn",
75
+ "Bassoon",
76
+ "Clarinet",
77
+ "Piccolo",
78
+ "Flute",
79
+ "Recorder",
80
+ "Pan Flute",
81
+ "Blown Bottle",
82
+ "Shakuhachi",
83
+ "Whistle",
84
+ "Ocarina",
85
+ "Lead 1 (square)",
86
+ "Lead 2 (sawtooth)",
87
+ "Lead 3 (calliope)",
88
+ "Lead 4 (chiff)",
89
+ "Lead 5 (charang)",
90
+ "Lead 6 (voice)",
91
+ "Lead 7 (fifths)",
92
+ "Lead 8 (bass + lead)",
93
+ "Pad 1 (new age)",
94
+ "Pad 2 (warm)",
95
+ "Pad 3 (polysynth)",
96
+ "Pad 4 (choir)",
97
+ "Pad 5 (bowed)",
98
+ "Pad 6 (metallic)",
99
+ "Pad 7 (halo)",
100
+ "Pad 8 (sweep)",
101
+ "FX 1 (rain)",
102
+ "FX 2 (soundtrack)",
103
+ "FX 3 (crystal)",
104
+ "FX 4 (atmosphere)",
105
+ "FX 5 (brightness)",
106
+ "FX 6 (goblins)",
107
+ "FX 7 (echoes)",
108
+ "FX 8 (sci-fi)",
109
+ "Sitar",
110
+ "Banjo",
111
+ "Shamisen",
112
+ "Koto",
113
+ "Kalimba",
114
+ "Bagpipe",
115
+ "Fiddle",
116
+ "Shanai",
117
+ "Tinkle Bell",
118
+ "Agogo",
119
+ "Steel Drums",
120
+ "Woodblock",
121
+ "Taiko Drum",
122
+ "Melodic Tom",
123
+ "Synth Drum",
124
+ "Reverse Cymbal",
125
+ "Guitar Fret Noise",
126
+ "Breath Noise",
127
+ "Seashore",
128
+ "Bird Tweet",
129
+ "Telephone Ring",
130
+ "Helicopter",
131
+ "Applause",
132
+ "Gunshot",
133
+ ]
model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+
6
+ # Initialize the model and tokenizer variables as None
7
+ tokenizer = None
8
+ model = None
9
+
10
+
11
+ def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
12
+ """
13
+ Returns the preloaded model and tokenizer. If they haven't been loaded before, loads them.
14
+
15
+ Returns:
16
+ tuple: A tuple containing the preloaded model and tokenizer.
17
+ """
18
+ global model, tokenizer
19
+ if model is None or tokenizer is None:
20
+ # Load the tokenizer and the model
21
+ tokenizer = AutoTokenizer.from_pretrained("juancopi81/lmd_8bars_tokenizer")
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ "juancopi81/lmd-8bars-2048-epochs20_v3"
24
+ )
25
+ return model, tokenizer
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ libfluidsynth2
2
+ build-essential
3
+ libasound2-dev
4
+ libjack-dev
pyproject.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ exclude = '''
3
+ (
4
+ /env
5
+ )
6
+ '''
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ note-seq
2
+ matplotlib
3
+ transformers
4
+ pyfluidsynth
string_to_notes.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from note_seq.protobuf.music_pb2 import NoteSequence
4
+ from note_seq.constants import STANDARD_PPQ
5
+
6
+
7
+ def token_sequence_to_note_sequence(
8
+ token_sequence: str,
9
+ qpm: float = 120.0,
10
+ use_program: bool = True,
11
+ use_drums: bool = True,
12
+ instrument_mapper: Optional[dict] = None,
13
+ only_piano: bool = False,
14
+ ) -> NoteSequence:
15
+ """
16
+ Converts a sequence of tokens into a sequence of notes.
17
+
18
+ Args:
19
+ token_sequence (str): The sequence of tokens to convert.
20
+ qpm (float, optional): The quarter notes per minute. Defaults to 120.0.
21
+ use_program (bool, optional): Whether to use program. Defaults to True.
22
+ use_drums (bool, optional): Whether to use drums. Defaults to True.
23
+ instrument_mapper (Optional[dict], optional): The instrument mapper. Defaults to None.
24
+ only_piano (bool, optional): Whether to only use piano. Defaults to False.
25
+
26
+ Returns:
27
+ NoteSequence: The resulting sequence of notes.
28
+ """
29
+ if isinstance(token_sequence, str):
30
+ token_sequence = token_sequence.split()
31
+
32
+ note_sequence = empty_note_sequence(qpm)
33
+
34
+ # Compute note and bar lengths based on the provided QPM
35
+ note_length_16th = 0.25 * 60 / qpm
36
+ bar_length = 4.0 * 60 / qpm
37
+
38
+ # Render all notes.
39
+ current_program = 1
40
+ current_is_drum = False
41
+ current_instrument = 0
42
+ track_count = 0
43
+ for _, token in enumerate(token_sequence):
44
+ if token == "PIECE_START":
45
+ pass
46
+ elif token == "PIECE_END":
47
+ break
48
+ elif token == "TRACK_START":
49
+ current_bar_index = 0
50
+ track_count += 1
51
+ pass
52
+ elif token == "TRACK_END":
53
+ pass
54
+ elif token == "KEYS_START":
55
+ pass
56
+ elif token == "KEYS_END":
57
+ pass
58
+ elif token.startswith("KEY="):
59
+ pass
60
+ elif token.startswith("INST"):
61
+ instrument = token.split("=")[-1]
62
+ if instrument != "DRUMS" and use_program:
63
+ if instrument_mapper is not None:
64
+ if instrument in instrument_mapper:
65
+ instrument = instrument_mapper[instrument]
66
+ current_program = int(instrument)
67
+ current_instrument = track_count
68
+ current_is_drum = False
69
+ if instrument == "DRUMS" and use_drums:
70
+ current_instrument = 0
71
+ current_program = 0
72
+ current_is_drum = True
73
+ elif token == "BAR_START":
74
+ current_time = current_bar_index * bar_length
75
+ current_notes = {}
76
+ elif token == "BAR_END":
77
+ current_bar_index += 1
78
+ pass
79
+ elif token.startswith("NOTE_ON"):
80
+ pitch = int(token.split("=")[-1])
81
+ note = note_sequence.notes.add()
82
+ note.start_time = current_time
83
+ note.end_time = current_time + 4 * note_length_16th
84
+ note.pitch = pitch
85
+ note.instrument = current_instrument
86
+ note.program = current_program
87
+ note.velocity = 80
88
+ note.is_drum = current_is_drum
89
+ current_notes[pitch] = note
90
+ elif token.startswith("NOTE_OFF"):
91
+ pitch = int(token.split("=")[-1])
92
+ if pitch in current_notes:
93
+ note = current_notes[pitch]
94
+ note.end_time = current_time
95
+ elif token.startswith("TIME_DELTA"):
96
+ delta = float(token.split("=")[-1]) * note_length_16th
97
+ current_time += delta
98
+ elif token.startswith("DENSITY="):
99
+ pass
100
+ elif token == "[PAD]":
101
+ pass
102
+ else:
103
+ pass
104
+
105
+ # Make the instruments right.
106
+ instruments_drums = []
107
+ for note in note_sequence.notes:
108
+ pair = [note.program, note.is_drum]
109
+ if pair not in instruments_drums:
110
+ instruments_drums += [pair]
111
+ note.instrument = instruments_drums.index(pair)
112
+
113
+ if only_piano:
114
+ for note in note_sequence.notes:
115
+ if not note.is_drum:
116
+ note.instrument = 0
117
+ note.program = 0
118
+
119
+ return note_sequence
120
+
121
+
122
+ def empty_note_sequence(qpm: float = 120.0, total_time: float = 0.0) -> NoteSequence:
123
+ """
124
+ Creates an empty note sequence.
125
+
126
+ Args:
127
+ qpm (float, optional): The quarter notes per minute. Defaults to 120.0.
128
+ total_time (float, optional): The total time. Defaults to 0.0.
129
+
130
+ Returns:
131
+ NoteSequence: The empty note sequence.
132
+ """
133
+ note_sequence = NoteSequence()
134
+ note_sequence.tempos.add().qpm = qpm
135
+ note_sequence.ticks_per_quarter = STANDARD_PPQ
136
+ note_sequence.total_time = total_time
137
+ return note_sequence
utils.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import note_seq
6
+ from matplotlib.figure import Figure
7
+ from numpy import ndarray
8
+
9
+ from constants import GM_INSTRUMENTS, SAMPLE_RATE
10
+ from string_to_notes import token_sequence_to_note_sequence
11
+ from model import get_model_and_tokenizer
12
+
13
+
14
+ model, tokenizer = get_model_and_tokenizer()
15
+
16
+
17
+ def create_seed_string(genre: str = "OTHER") -> str:
18
+ """
19
+ Creates a seed string for generating a new piece.
20
+
21
+ Args:
22
+ genre (str, optional): The genre of the piece. Defaults to "OTHER".
23
+
24
+ Returns:
25
+ str: The seed string.
26
+ """
27
+ seed_string = f"PIECE_START GENRE={genre} TRACK_START"
28
+ return seed_string
29
+
30
+
31
+ def get_instruments(text_sequence: str) -> List[str]:
32
+ """
33
+ Extracts the list of instruments from a text sequence.
34
+
35
+ Args:
36
+ text_sequence (str): The text sequence.
37
+
38
+ Returns:
39
+ List[str]: The list of instruments.
40
+ """
41
+ instruments = []
42
+ parts = text_sequence.split()
43
+ for part in parts:
44
+ if part.startswith("INST="):
45
+ if part[5:] == "DRUMS":
46
+ instruments.append("Drums")
47
+ else:
48
+ index = int(part[5:])
49
+ instruments.append(GM_INSTRUMENTS[index])
50
+ return instruments
51
+
52
+
53
+ def generate_new_instrument(
54
+ seed: str, tokenizer: AutoTokenizer, model: AutoModelForCausalLM, temp: float = 0.75
55
+ ) -> str:
56
+ """
57
+ Generates a new instrument sequence from a given seed and temperature.
58
+
59
+ Args:
60
+ seed (str): The seed string for the generation.
61
+ tokenizer (PreTrainedTokenizer): The tokenizer used to encode and decode the sequences.
62
+ model (PreTrainedModel): The pretrained model used for generating the sequences.
63
+ temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.
64
+
65
+ Returns:
66
+ str: The generated instrument sequence.
67
+ """
68
+ seed_length = len(tokenizer.encode(seed))
69
+
70
+ while True:
71
+ # Encode the conditioning tokens.
72
+ input_ids = tokenizer.encode(seed, return_tensors="pt")
73
+
74
+ # Generate more tokens.
75
+ eos_token_id = tokenizer.encode("TRACK_END")[0]
76
+ generated_ids = model.generate(
77
+ input_ids,
78
+ max_new_tokens=2048,
79
+ do_sample=True,
80
+ temperature=temp,
81
+ eos_token_id=eos_token_id,
82
+ )
83
+ generated_sequence = tokenizer.decode(generated_ids[0])
84
+
85
+ # Check if the generated sequence contains "NOTE_ON" beyond the seed
86
+ new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:])
87
+ if "NOTE_ON" in new_generated_sequence:
88
+ return generated_sequence
89
+
90
+
91
+ def get_outputs_from_string(
92
+ generated_sequence: str, qpm: int = 120
93
+ ) -> Tuple[ndarray, str, Figure, str, str]:
94
+ """
95
+ Converts a generated sequence into various output formats including audio, MIDI, plot, etc.
96
+
97
+ Args:
98
+ generated_sequence (str): The generated sequence of tokens.
99
+ qpm (int, optional): The quarter notes per minute. Defaults to 120.
100
+
101
+ Returns:
102
+ Tuple[ndarray, str, Figure, str, str]: The audio waveform, MIDI file name, plot figure,
103
+ instruments string, and number of tokens string.
104
+ """
105
+ instruments = get_instruments(generated_sequence)
106
+ instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
107
+ note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)
108
+
109
+ synth = note_seq.fluidsynth
110
+ array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
111
+ int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats)
112
+ fig = note_seq.plot_sequence(note_sequence, show_figure=False)
113
+ num_tokens = str(len(generated_sequence.split()))
114
+ audio = gr.make_waveform((SAMPLE_RATE, int16_data))
115
+ note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid")
116
+ return audio, "midi_ouput.mid", fig, instruments_str, num_tokens
117
+
118
+
119
+ def remove_last_instrument(
120
+ text_sequence: str, qpm: int = 120
121
+ ) -> Tuple[ndarray, str, Figure, str, str, str]:
122
+ """
123
+ Removes the last instrument from a song string and returns the various output formats.
124
+
125
+ Args:
126
+ text_sequence (str): The song string.
127
+ qpm (int, optional): The quarter notes per minute. Defaults to 120.
128
+
129
+ Returns:
130
+ Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
131
+ instruments string, new song string, and number of tokens string.
132
+ """
133
+ # We split the song into tracks by splitting on 'TRACK_START'
134
+ tracks = text_sequence.split("TRACK_START")
135
+ # We keep all tracks except the last one
136
+ modified_tracks = tracks[:-1]
137
+ # We join the tracks back together, adding back the 'TRACK_START' that was removed by split
138
+ new_song = "TRACK_START".join(modified_tracks)
139
+
140
+ if len(tracks) == 2:
141
+ # There is only one instrument, so start from scratch
142
+ audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
143
+ text_sequence=new_song
144
+ )
145
+ elif len(tracks) == 1:
146
+ # No instrument so start from empty sequence
147
+ audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
148
+ text_sequence=""
149
+ )
150
+ else:
151
+ audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
152
+ new_song, qpm
153
+ )
154
+
155
+ return audio, midi_file, fig, instruments_str, new_song, num_tokens
156
+
157
+
158
+ def regenerate_last_instrument(
159
+ text_sequence: str, qpm: int = 120
160
+ ) -> Tuple[ndarray, str, Figure, str, str, str]:
161
+ """
162
+ Regenerates the last instrument in a song string and returns the various output formats.
163
+
164
+ Args:
165
+ text_sequence (str): The song string.
166
+ qpm (int, optional): The quarter notes per minute. Defaults to 120.
167
+
168
+ Returns:
169
+ Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
170
+ instruments string, new song string, and number of tokens string.
171
+ """
172
+ last_inst_index = text_sequence.rfind("INST=")
173
+ if last_inst_index == -1:
174
+ # No instrument so start from empty sequence
175
+ audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
176
+ text_sequence="", qpm=qpm
177
+ )
178
+ else:
179
+ # Take it from the last instrument and continue generation
180
+ next_space_index = text_sequence.find(" ", last_inst_index)
181
+ new_seed = text_sequence[:next_space_index]
182
+ audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
183
+ text_sequence=new_seed, qpm=qpm
184
+ )
185
+ return audio, midi_file, fig, instruments_str, new_song, num_tokens
186
+
187
+
188
+ def change_tempo(
189
+ text_sequence: str, qpm: int
190
+ ) -> Tuple[ndarray, str, Figure, str, str, str]:
191
+ """
192
+ Changes the tempo of a song string and returns the various output formats.
193
+
194
+ Args:
195
+ text_sequence (str): The song string.
196
+ qpm (int): The new quarter notes per minute.
197
+
198
+ Returns:
199
+ Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
200
+ instruments string, text sequence, and number of tokens string.
201
+ """
202
+ audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
203
+ text_sequence, qpm=qpm
204
+ )
205
+ return audio, midi_file, fig, instruments_str, text_sequence, num_tokens
206
+
207
+
208
+ def generate_song(
209
+ model: AutoModelForCausalLM = model,
210
+ tokenizer: AutoTokenizer = tokenizer,
211
+ genre: str = "OTHER",
212
+ temp: float = 0.75,
213
+ text_sequence: str = "",
214
+ qpm: int = 120,
215
+ ) -> Tuple[ndarray, str, Figure, str, str, str]:
216
+ """
217
+ Generates a song given a genre, temperature, initial text sequence, and tempo.
218
+
219
+ Args:
220
+ model (AutoModelForCausalLM): The pretrained model used for generating the sequences.
221
+ tokenizer (AutoTokenizer): The tokenizer used to encode and decode the sequences.
222
+ genre (str, optional): The genre of the song. Defaults to "OTHER".
223
+ temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.
224
+ text_sequence (str, optional): The initial text sequence for the song. Defaults to "".
225
+ qpm (int, optional): The quarter notes per minute. Defaults to 120.
226
+
227
+ Returns:
228
+ Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
229
+ instruments string, generated song string, and number of tokens string.
230
+ """
231
+ if text_sequence == "":
232
+ seed_string = create_seed_string(genre)
233
+ else:
234
+ seed_string = text_sequence
235
+
236
+ generated_sequence = generate_new_instrument(
237
+ seed=seed_string, tokenizer=tokenizer, model=model, temp=temp
238
+ )
239
+ audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
240
+ generated_sequence, qpm
241
+ )
242
+ return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens