Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import spaces # Enables ZeroGPU on Hugging Face | |
| import gradio as gr | |
| import torch | |
| from dataclasses import asdict | |
| from transformers import AutoModelForCausalLM | |
| from anticipation.sample import generate | |
| from anticipation.convert import events_to_midi, midi_to_events | |
| from anticipation.tokenize import extract_instruments | |
| from anticipation import ops | |
| from mido import MidiFile # parse MIDI explicitly to avoid .time error | |
| from pyharp.core import ModelCard, build_endpoint | |
| from pyharp.labels import LabelList | |
| # Model Choices | |
| SMALL_MODEL = "stanford-crfm/music-small-800k" | |
| MEDIUM_MODEL = "stanford-crfm/music-medium-800k" | |
| LARGE_MODEL = "stanford-crfm/music-large-800k" | |
| # Model Card (new pyharp) | |
| model_card = ModelCard( | |
| name="Anticipatory Music Transformer", | |
| description=( | |
| "Generate musical accompaniment for your existing vamp using the Anticipatory Music Transformer. " | |
| "Input: a MIDI file with a short accompaniment (vamp) followed by a melody line. " | |
| "Output: a new MIDI file with extended accompaniment matching the melody continuation. " | |
| "Use the controls to choose model size and how much of the song is used as context." | |
| ), | |
| author="John Thickstun, David Hall, Chris Donahue, Percy Liang", | |
| tags=["midi", "generation", "accompaniment"] | |
| ) | |
| # Model cache | |
| _model_cache = {} | |
| def load_amt_model(model_choice: str): | |
| """Loads and caches the AMT model inside the worker process (same behavior as old app).""" | |
| if model_choice in _model_cache: | |
| return _model_cache[model_choice] | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if model_choice == LARGE_MODEL: | |
| print(f"Loading {LARGE_MODEL} (low_cpu_mem_usage, fp16 on CUDA if available)...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| LARGE_MODEL, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| low_cpu_mem_usage=True | |
| ).to(device) | |
| else: | |
| print(f"Loading {model_choice} ...") | |
| model = AutoModelForCausalLM.from_pretrained(model_choice).to(device) | |
| _model_cache[model_choice] = model | |
| return model | |
| def find_melody_program(mid, debug=False): | |
| track_stats = [] | |
| for i, track in enumerate(mid.tracks): | |
| pitches, times = [], [] | |
| current_time = 0 | |
| for msg in track: | |
| current_time += getattr(msg, "time", 0) | |
| if msg.type == "note_on" and msg.velocity > 0: | |
| pitches.append(msg.note) | |
| times.append(current_time) | |
| if pitches: | |
| mean_pitch = sum(pitches) / len(pitches) | |
| span = (max(times) - min(times)) or 1 | |
| density = len(pitches) / span | |
| polyphony = len(set(pitches)) / len(pitches) | |
| track_stats.append((i, mean_pitch, len(pitches), density, polyphony)) | |
| if not track_stats: | |
| if debug: | |
| print("No notes detected in any track.") | |
| return 0 | |
| melody_idx = sorted(track_stats, key=lambda x: (-x[1], -x[3]))[0][0] | |
| return melody_idx | |
| def get_program_number(mid, track_index): | |
| for msg in mid.tracks[track_index]: | |
| if msg.type == "program_change": | |
| return msg.program | |
| return None | |
| def auto_extract_melody(mid, debug=False): | |
| events = midi_to_events(mid) | |
| melody_track = find_melody_program(mid, debug=debug) | |
| melody_program = get_program_number(mid, melody_track) | |
| if debug: | |
| print(f"Melody Track: {melody_track} | Program: {melody_program}") | |
| if melody_program is not None: | |
| all_events_copy = events.copy() | |
| events, melody = extract_instruments(events, [melody_program]) | |
| for e in all_events_copy: | |
| if hasattr(e, "program") and e.program == melody_program: | |
| events.append(e) | |
| else: | |
| if debug: | |
| print("No program number found; using all events as melody.") | |
| melody = events | |
| return events, melody | |
| # Core generation | |
| def generate_accompaniment(midi_path: str, model_choice: str, history_length: float): | |
| """ | |
| Generates accompaniment for the entire MIDI input, conditioned on user-selected history length. | |
| FIX: parse MIDI with mido.MidiFile before midi_to_events to avoid 'str' .time error. | |
| """ | |
| model = load_amt_model(model_choice) | |
| # Parse MIDI correctly, then convert to events | |
| mid = MidiFile(midi_path) | |
| print(f"Loaded MIDI file: type {mid.type} ({'single track' if mid.type == 0 else 'multi-track'})") | |
| # Automatically detect and extract melody | |
| all_events, melody = auto_extract_melody(mid, debug=True) | |
| if len(melody) == 0: | |
| print("No melody detected; using all events") | |
| melody = all_events | |
| # History portion | |
| history = ops.clip(all_events, 0, history_length, clip_duration=False) | |
| start_time = ops.max_time(history, seconds=True) | |
| mid_time = mid.length or 0 | |
| ops_time = ops.max_time(all_events, seconds=True) | |
| total_time = round(max(mid_time, ops_time)) | |
| accompaniment = generate( | |
| model, | |
| start_time=history_length, | |
| end_time=total_time, | |
| inputs=history, | |
| controls=melody, | |
| top_p=0.95, | |
| debug=False | |
| ) | |
| # Combine accompaniment + melody and clip | |
| output_events = ops.clip( | |
| ops.combine(accompaniment, melody), | |
| 0, | |
| total_time, | |
| clip_duration=True | |
| ) | |
| # Save MIDI | |
| output_midi = "generated_accompaniment_huggingface.mid" | |
| mid_out = events_to_midi(output_events) | |
| mid_out.save(output_midi) | |
| return output_midi, None | |
| # HARP process fn (JSON FIRST) | |
| def process_fn(input_midi_path, model_choice, history_length): | |
| """ | |
| Returns (JSON, MIDI filepath) to satisfy HARP client's expectation that the 0th item is an object. | |
| """ | |
| output_midi, error_message = generate_accompaniment( | |
| input_midi_path, | |
| model_choice, | |
| float(history_length) | |
| ) | |
| if error_message: | |
| # JSON first, then no file | |
| return {"message": error_message}, None | |
| labels = LabelList() # add label entries if desired | |
| # JSON first, then MIDI filepath | |
| return asdict(labels), output_midi | |
| # Gradio + HARP UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🎼 Anticipatory Music Transformer") | |
| # Inputs | |
| input_midi = gr.File( | |
| file_types=[".mid", ".midi"], | |
| label="Input MIDI File", | |
| type="filepath", | |
| ).harp_required(True) | |
| model_dropdown = gr.Dropdown( | |
| choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL], | |
| value=MEDIUM_MODEL, | |
| label="Select AMT Model (Faster vs. Higher Quality)" | |
| ) | |
| history_slider = gr.Slider( | |
| minimum=1, maximum=10, step=1, value=5, | |
| label="Select History Length (seconds)" | |
| ) | |
| # Outputs (JSON FIRST) | |
| output_labels = gr.JSON(label="Labels / Metadata") | |
| output_midi = gr.File( | |
| file_types=[".mid", ".midi"], | |
| label="Generated MIDI Output", | |
| type="filepath", | |
| ) | |
| # Build HARP endpoint (new signature) | |
| _ = build_endpoint( | |
| model_card=model_card, | |
| input_components=[ | |
| input_midi, | |
| model_dropdown, | |
| history_slider | |
| ], | |
| output_components=[ | |
| output_labels, # JSON FIRST | |
| output_midi # MIDI SECOND | |
| ], | |
| process_fn=process_fn | |
| ) | |
| # Launch App | |
| demo.launch(share=True, show_error=True, debug=True) |