|
|
import os |
|
|
import spaces |
|
|
import gradio as gr |
|
|
import torch |
|
|
from dataclasses import asdict |
|
|
|
|
|
from transformers import AutoModelForCausalLM |
|
|
from anticipation.sample import generate |
|
|
from anticipation.convert import events_to_midi, midi_to_events |
|
|
from anticipation.tokenize import extract_instruments |
|
|
from anticipation import ops |
|
|
from mido import MidiFile |
|
|
|
|
|
from pyharp.core import ModelCard, build_endpoint |
|
|
from pyharp.labels import LabelList |
|
|
|
|
|
|
|
|
|
|
|
SMALL_MODEL = "stanford-crfm/music-small-800k" |
|
|
MEDIUM_MODEL = "stanford-crfm/music-medium-800k" |
|
|
LARGE_MODEL = "stanford-crfm/music-large-800k" |
|
|
|
|
|
|
|
|
|
|
|
model_card = ModelCard( |
|
|
name="Anticipatory Music Transformer", |
|
|
description=( |
|
|
"Generate musical accompaniment for your existing vamp using the Anticipatory Music Transformer. " |
|
|
"Input: a MIDI file with a short accompaniment (vamp) followed by a melody line. " |
|
|
"Output: a new MIDI file with extended accompaniment matching the melody continuation. " |
|
|
"Use the controls to choose model size and how much of the song is used as context." |
|
|
), |
|
|
author="John Thickstun, David Hall, Chris Donahue, Percy Liang", |
|
|
tags=["midi", "generation", "accompaniment"] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
_model_cache = {} |
|
|
|
|
|
def load_amt_model(model_choice: str): |
|
|
"""Loads and caches the AMT model inside the worker process (same behavior as old app).""" |
|
|
if model_choice in _model_cache: |
|
|
return _model_cache[model_choice] |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
if model_choice == LARGE_MODEL: |
|
|
print(f"Loading {LARGE_MODEL} (low_cpu_mem_usage, fp16 on CUDA if available)...") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
LARGE_MODEL, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
low_cpu_mem_usage=True |
|
|
).to(device) |
|
|
else: |
|
|
print(f"Loading {model_choice} ...") |
|
|
model = AutoModelForCausalLM.from_pretrained(model_choice).to(device) |
|
|
|
|
|
_model_cache[model_choice] = model |
|
|
return model |
|
|
|
|
|
def find_melody_program(mid, debug=False): |
|
|
track_stats = [] |
|
|
for i, track in enumerate(mid.tracks): |
|
|
pitches, times = [], [] |
|
|
current_time = 0 |
|
|
for msg in track: |
|
|
current_time += getattr(msg, "time", 0) |
|
|
if msg.type == "note_on" and msg.velocity > 0: |
|
|
pitches.append(msg.note) |
|
|
times.append(current_time) |
|
|
if pitches: |
|
|
mean_pitch = sum(pitches) / len(pitches) |
|
|
span = (max(times) - min(times)) or 1 |
|
|
density = len(pitches) / span |
|
|
polyphony = len(set(pitches)) / len(pitches) |
|
|
track_stats.append((i, mean_pitch, len(pitches), density, polyphony)) |
|
|
|
|
|
if not track_stats: |
|
|
if debug: |
|
|
print("No notes detected in any track.") |
|
|
return 0 |
|
|
|
|
|
melody_idx = sorted(track_stats, key=lambda x: (-x[1], -x[3]))[0][0] |
|
|
|
|
|
return melody_idx |
|
|
|
|
|
|
|
|
def get_program_number(mid, track_index): |
|
|
for msg in mid.tracks[track_index]: |
|
|
if msg.type == "program_change": |
|
|
return msg.program |
|
|
return None |
|
|
|
|
|
|
|
|
def auto_extract_melody(mid, debug=False): |
|
|
events = midi_to_events(mid) |
|
|
melody_track = find_melody_program(mid, debug=debug) |
|
|
melody_program = get_program_number(mid, melody_track) |
|
|
|
|
|
if debug: |
|
|
print(f"Melody Track: {melody_track} | Program: {melody_program}") |
|
|
|
|
|
if melody_program is not None: |
|
|
all_events_copy = events.copy() |
|
|
events, melody = extract_instruments(events, [melody_program]) |
|
|
|
|
|
for e in all_events_copy: |
|
|
if hasattr(e, "program") and e.program == melody_program: |
|
|
events.append(e) |
|
|
else: |
|
|
if debug: |
|
|
print("No program number found; using all events as melody.") |
|
|
melody = events |
|
|
|
|
|
return events, melody |
|
|
|
|
|
@spaces.GPU |
|
|
|
|
|
def generate_accompaniment(midi_path: str, model_choice: str, history_length: float): |
|
|
""" |
|
|
Generates accompaniment for the entire MIDI input, conditioned on user-selected history length. |
|
|
FIX: parse MIDI with mido.MidiFile before midi_to_events to avoid 'str' .time error. |
|
|
""" |
|
|
model = load_amt_model(model_choice) |
|
|
|
|
|
|
|
|
mid = MidiFile(midi_path) |
|
|
print(f"Loaded MIDI file: type {mid.type} ({'single track' if mid.type == 0 else 'multi-track'})") |
|
|
|
|
|
|
|
|
all_events, melody = auto_extract_melody(mid, debug=True) |
|
|
if len(melody) == 0: |
|
|
print("No melody detected; using all events") |
|
|
melody = all_events |
|
|
|
|
|
|
|
|
history = ops.clip(all_events, 0, history_length, clip_duration=False) |
|
|
start_time = ops.max_time(history, seconds=True) |
|
|
|
|
|
mid_time = mid.length or 0 |
|
|
ops_time = ops.max_time(all_events, seconds=True) |
|
|
total_time = round(max(mid_time, ops_time)) |
|
|
|
|
|
accompaniment = generate( |
|
|
model, |
|
|
start_time=history_length, |
|
|
end_time=total_time, |
|
|
inputs=history, |
|
|
controls=melody, |
|
|
top_p=0.95, |
|
|
debug=False |
|
|
) |
|
|
|
|
|
|
|
|
output_events = ops.clip( |
|
|
ops.combine(accompaniment, melody), |
|
|
0, |
|
|
total_time, |
|
|
clip_duration=True |
|
|
) |
|
|
|
|
|
|
|
|
output_midi = "generated_accompaniment_huggingface.mid" |
|
|
mid_out = events_to_midi(output_events) |
|
|
mid_out.save(output_midi) |
|
|
|
|
|
return output_midi, None |
|
|
|
|
|
|
|
|
|
|
|
def process_fn(input_midi_path, model_choice, history_length): |
|
|
""" |
|
|
Returns (JSON, MIDI filepath) to satisfy HARP client's expectation that the 0th item is an object. |
|
|
""" |
|
|
output_midi, error_message = generate_accompaniment( |
|
|
input_midi_path, |
|
|
model_choice, |
|
|
float(history_length) |
|
|
) |
|
|
|
|
|
if error_message: |
|
|
|
|
|
return {"message": error_message}, None |
|
|
|
|
|
labels = LabelList() |
|
|
|
|
|
return asdict(labels), output_midi |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## 🎼 Anticipatory Music Transformer") |
|
|
|
|
|
|
|
|
input_midi = gr.File( |
|
|
file_types=[".mid", ".midi"], |
|
|
label="Input MIDI File", |
|
|
type="filepath", |
|
|
).harp_required(True) |
|
|
|
|
|
model_dropdown = gr.Dropdown( |
|
|
choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL], |
|
|
value=MEDIUM_MODEL, |
|
|
label="Select AMT Model (Faster vs. Higher Quality)" |
|
|
) |
|
|
|
|
|
history_slider = gr.Slider( |
|
|
minimum=1, maximum=10, step=1, value=5, |
|
|
label="Select History Length (seconds)" |
|
|
) |
|
|
|
|
|
|
|
|
output_labels = gr.JSON(label="Labels / Metadata") |
|
|
output_midi = gr.File( |
|
|
file_types=[".mid", ".midi"], |
|
|
label="Generated MIDI Output", |
|
|
type="filepath", |
|
|
) |
|
|
|
|
|
|
|
|
_ = build_endpoint( |
|
|
model_card=model_card, |
|
|
input_components=[ |
|
|
input_midi, |
|
|
model_dropdown, |
|
|
history_slider |
|
|
], |
|
|
output_components=[ |
|
|
output_labels, |
|
|
output_midi |
|
|
], |
|
|
process_fn=process_fn |
|
|
) |
|
|
|
|
|
|
|
|
demo.launch(share=True, show_error=True, debug=True) |