|
import os |
|
import gradio as gr |
|
import json |
|
import logging |
|
import shutil |
|
import tempfile |
|
from pathlib import Path |
|
import numpy as np |
|
from utils import ( |
|
rename_files_remove_spaces, |
|
load_audio_files, |
|
get_stems, |
|
generate_section_variants, |
|
export_section_variants, |
|
edm_arrangement_tab, |
|
) |
|
|
|
from dotenv import load_dotenv |
|
from langsmith import traceable |
|
|
|
load_dotenv() |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
|
|
TEMP_DIR = "test" |
|
|
|
SELECTED_VARIANTS = {} |
|
|
|
ALL_VARIANTS = {} |
|
|
|
UPLOADED_STEMS = {} |
|
|
|
|
|
def process_uploaded_files(files, progress=gr.Progress()): |
|
"""Process uploaded files and return basic info""" |
|
global TEMP_DIR, UPLOADED_STEMS |
|
|
|
try: |
|
if not files: |
|
return "Error: No files uploaded", [] |
|
|
|
progress(0, desc="Starting process...") |
|
|
|
|
|
TEMP_DIR = tempfile.mkdtemp() |
|
try: |
|
|
|
progress(0.2, desc="Copying uploaded files...") |
|
for file in files: |
|
if file.name.lower().endswith(".wav"): |
|
shutil.copy2(file.name, TEMP_DIR) |
|
|
|
|
|
progress(0.5, desc="Renaming files...") |
|
rename_files_remove_spaces(TEMP_DIR) |
|
|
|
|
|
progress(0.8, desc="Loading audio files...") |
|
UPLOADED_STEMS = load_audio_files(TEMP_DIR) |
|
if not UPLOADED_STEMS: |
|
return "Error: No stems loaded", [] |
|
|
|
|
|
stem_names = get_stems(TEMP_DIR) |
|
|
|
progress(1.0, desc="Complete!") |
|
return f"Successfully loaded {len(stem_names)} stems", stem_names |
|
|
|
except Exception as e: |
|
if os.path.exists(TEMP_DIR): |
|
shutil.rmtree(TEMP_DIR) |
|
TEMP_DIR = None |
|
raise e |
|
|
|
except Exception as e: |
|
return f"Error occurred: {str(e)}", [] |
|
|
|
|
|
def generate_section_variants_handler( |
|
section_type, bpm_value, bars_value, p_value, progress=gr.Progress() |
|
): |
|
"""Handler function for generating section variants""" |
|
global TEMP_DIR, ALL_VARIANTS, UPLOADED_STEMS |
|
|
|
if not TEMP_DIR or not os.path.exists(TEMP_DIR): |
|
return ( |
|
"Error: No stems loaded. Please upload stems first.", |
|
None, |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
try: |
|
progress(0.1, desc=f"Getting arrangements for {section_type}...") |
|
|
|
|
|
variants = generate_section_variants( |
|
TEMP_DIR, |
|
UPLOADED_STEMS, |
|
section_type, |
|
bpm=int(bpm_value), |
|
bars=int(bars_value), |
|
p=float(p_value), |
|
progress=progress |
|
) |
|
|
|
logger.info(f"S1: VARIANTS of section {section_type} : {variants}") |
|
|
|
progress(0.4, desc="Variants generated") |
|
|
|
|
|
ALL_VARIANTS[section_type] = variants |
|
|
|
|
|
|
|
|
|
variant_output_dir = os.path.join(TEMP_DIR, section_type + "_variants") |
|
audio_paths = export_section_variants( |
|
variants, variant_output_dir, section_type |
|
) |
|
progress(0.6, desc="Exporting audio files...") |
|
|
|
logger.info(f"S2: AUDIO_PATHS of section {section_type} : {audio_paths}") |
|
|
|
|
|
variant1_audio = audio_paths.get("variant1") |
|
variant2_audio = audio_paths.get("variant2") |
|
variant3_audio = audio_paths.get("variant3") |
|
variant4_audio = audio_paths.get("variant4") |
|
|
|
|
|
descriptions = {key: data["description"] for key, data in variants.items()} |
|
descriptions_json = json.dumps(descriptions, indent=2) |
|
|
|
stem_list = {key: data["stems"] for key, data in variants.items()} |
|
|
|
logger.info(f"S3: DESCRIPTIONS of section {section_type} : {descriptions_json}") |
|
|
|
progress(1.0, desc="Complete!") |
|
|
|
return ( |
|
f"Generated {len(variants)} variants for {section_type}", |
|
variant1_audio, |
|
variant2_audio, |
|
variant3_audio, |
|
variant4_audio, |
|
descriptions_json, |
|
stem_list |
|
) |
|
|
|
except Exception as e: |
|
return f"Error generating variants: {str(e)}", None, None, None, None, None, None |
|
|
|
|
|
@traceable(run_type="chain", name="groq_call") |
|
def select_variant(section_type, variant_num, append): |
|
"""Select a variant for a specific section, with an option to append.""" |
|
global ALL_VARIANTS, SELECTED_VARIANTS |
|
|
|
try: |
|
if section_type not in ALL_VARIANTS: |
|
return f"No variants generated for {section_type} yet", "None" |
|
|
|
variant_key = f"variant{variant_num}" |
|
if variant_key not in ALL_VARIANTS[section_type]: |
|
return f"Variant {variant_num} not found for {section_type}", "None" |
|
|
|
|
|
variant_info = ALL_VARIANTS[section_type][variant_key] |
|
SELECTED_VARIANTS[section_type] = variant_info["config"] |
|
|
|
|
|
display_text = f"Selected: Variant {variant_num}" |
|
if "description" in variant_info: |
|
display_text += f" - {variant_info['description']}" |
|
|
|
|
|
logger.info(f"Selected variant for {section_type}: {variant_key}") |
|
logger.info(f"Variant description: {variant_info.get('description', 'No description')}") |
|
logger.info(f"Stems used: {variant_info.get('config', {}).get('stems', [])}") |
|
|
|
return f"Selected variant {variant_num} for {section_type}", display_text |
|
|
|
except Exception as e: |
|
logger.error(f"Error selecting variant: {str(e)}") |
|
return f"Error selecting variant: {str(e)}", "None" |
|
|
|
|
|
def generate_full_track( |
|
crossfade_ms, |
|
output_track_name, |
|
*section_flags, |
|
progress=gr.Progress(), |
|
): |
|
"""Generate the full track from selected variants""" |
|
global TEMP_DIR, SELECTED_VARIANTS, UPLOADED_STEMS, sections |
|
|
|
if not TEMP_DIR or not os.path.exists(TEMP_DIR): |
|
return "Error: No stems loaded", None, None, None |
|
|
|
try: |
|
progress(0.1, desc="Preparing to generate full track...") |
|
|
|
|
|
section_names = list(sections.keys()) |
|
|
|
|
|
sections_to_include = {} |
|
for section_name, include_flag in zip(section_names, section_flags): |
|
if include_flag and section_name in SELECTED_VARIANTS: |
|
sections_to_include[section_name] = SELECTED_VARIANTS[section_name] |
|
|
|
if not sections_to_include: |
|
return "Error: No sections selected or available", None, None, None |
|
|
|
progress(0.3, desc="Creating track structure...") |
|
|
|
|
|
final_track = None |
|
|
|
|
|
for section_name in section_names: |
|
if section_name not in sections_to_include: |
|
continue |
|
|
|
progress( |
|
0.4 + 0.1 * section_names.index(section_name) / len(section_names), |
|
desc=f"Processing {section_name}...", |
|
) |
|
|
|
|
|
variant_config = sections_to_include[section_name] |
|
|
|
|
|
stems_copy = {k: v for k, v in UPLOADED_STEMS.items()} |
|
|
|
|
|
from utils import create_section_from_json |
|
section_audio = create_section_from_json(variant_config, stems_copy) |
|
|
|
|
|
if final_track is None: |
|
final_track = section_audio |
|
else: |
|
final_track = final_track.append(section_audio, crossfade=crossfade_ms) |
|
|
|
progress(0.9, desc="Exporting final track...") |
|
|
|
|
|
full_track_path = os.path.join(TEMP_DIR, output_track_name) |
|
final_track.export(full_track_path, format="wav") |
|
|
|
|
|
sections_list = list(sections_to_include.keys()) |
|
track_duration = len(final_track) / 1000 |
|
|
|
track_summary = { |
|
"Sections included": sections_list, |
|
"Total sections": len(sections_list), |
|
"Duration": f"{int(track_duration // 60)}:{int(track_duration % 60):02d}", |
|
"Crossfade": f"{crossfade_ms} ms", |
|
"Section Details": { |
|
section: { |
|
"BPM": sections[section]["bpm"], |
|
"Bars": sections[section]["bars"], |
|
"Volume Automation": sections[section]["volume_automation"], |
|
"Curve": sections[section]["curve"] |
|
} |
|
for section in sections_list |
|
} |
|
} |
|
|
|
progress(1.0, desc="Complete!") |
|
|
|
return ( |
|
"Track generated successfully!", |
|
full_track_path, |
|
json.dumps(track_summary, indent=2), |
|
) |
|
|
|
except Exception as e: |
|
return f"Error generating track: {str(e)}", None, None |
|
|
|
|
|
def generate_full_loop_variants(bpm_value, bars_value, p_value, progress=gr.Progress()): |
|
"""Generate variants for the full loop section""" |
|
return generate_section_variants_handler( |
|
"full_loop", bpm_value, bars_value, p_value, progress |
|
) |
|
|
|
|
|
def create_section_ui(section_name, params): |
|
"""Helper function to create UI elements for a section.""" |
|
with gr.Accordion(f"Generate {section_name.capitalize()} Variants", open=False): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown(f"### {section_name.capitalize()} Parameters") |
|
gr.Markdown(f"**Volume Automation**: {params['volume_automation']}") |
|
gr.Markdown(f"**Curve Type**: {params['curve']}") |
|
|
|
bpm_slider = gr.Slider( |
|
label="BPM (Beats Per Minute)", |
|
minimum=60, |
|
maximum=180, |
|
value=params['bpm'], |
|
step=1, |
|
) |
|
bars_slider = gr.Slider( |
|
label="Number of Bars", |
|
minimum=4, |
|
maximum=64, |
|
value=params['bars'], |
|
step=4 |
|
) |
|
p_slider = gr.Slider( |
|
label="Variation Parameter (p)", |
|
minimum=0, |
|
maximum=1, |
|
value=params['p'], |
|
step=0.1, |
|
) |
|
|
|
generate_btn = gr.Button( |
|
f"Generate {section_name.capitalize()} Variants", variant="primary" |
|
) |
|
|
|
with gr.Column(scale=2): |
|
status = gr.Textbox(label="Status", interactive=False) |
|
descriptions = gr.JSON(label="Variant Descriptions") |
|
stem_list = gr.JSON(label="Stem List") |
|
selected_variant = gr.Text(label="Selected Variant", value="None", interactive=False) |
|
|
|
with gr.Row(): |
|
variant_audio_list = [] |
|
select_btn_list = [] |
|
for i in range(1, 5): |
|
with gr.Column(): |
|
gr.Markdown(f"### Variant {i}") |
|
variant_audio = gr.Audio(label=f"Variant {i}", interactive=True) |
|
variant_audio_list.append(variant_audio) |
|
|
|
select_btn = gr.Radio( |
|
choices=[f"Select Variant {i}"], |
|
label="", |
|
value=None, |
|
interactive=True |
|
) |
|
select_btn_list.append(select_btn) |
|
|
|
return { |
|
"bpm_slider": bpm_slider, |
|
"bars_slider": bars_slider, |
|
"p_slider": p_slider, |
|
"generate_btn": generate_btn, |
|
"status": status, |
|
"descriptions": descriptions, |
|
"variant_audio": variant_audio_list, |
|
"select_btn": select_btn_list, |
|
"selected_variant": selected_variant, |
|
"stem_list": stem_list |
|
} |
|
|
|
def setup_section_event_handlers(section_name, section_ui, selected_variants_display): |
|
"""Setup event handlers for a given section.""" |
|
section_type = gr.State(section_name) |
|
|
|
|
|
section_ui["generate_btn"].click( |
|
fn=generate_section_variants_handler, |
|
inputs=[ |
|
section_type, |
|
section_ui["bpm_slider"], |
|
section_ui["bars_slider"], |
|
section_ui["p_slider"], |
|
], |
|
outputs=[ |
|
section_ui["status"], |
|
*section_ui["variant_audio"], |
|
section_ui["descriptions"], |
|
section_ui["stem_list"] |
|
], |
|
) |
|
|
|
|
|
for i, select_btn in enumerate(section_ui["select_btn"], start=1): |
|
variant_num = gr.State(i) |
|
select_btn.change( |
|
fn=select_variant, |
|
inputs=[section_type, variant_num, gr.State(False)], |
|
outputs=[section_ui["status"], section_ui["selected_variant"]], |
|
) |
|
|
|
|
|
def update_selected_variants_display(): |
|
selected = {} |
|
for s in SELECTED_VARIANTS: |
|
if s in ALL_VARIANTS: |
|
for v in range(1, 5): |
|
variant_key = f"variant{v}" |
|
if variant_key in ALL_VARIANTS[s]: |
|
if s not in selected: |
|
selected[s] = {} |
|
selected[s]["Selected Variant"] = ALL_VARIANTS[s][variant_key].get("description", f"Variant {v}") |
|
return json.dumps(selected, indent=2) |
|
|
|
select_btn.change( |
|
fn=update_selected_variants_display, |
|
inputs=[], |
|
outputs=[selected_variants_display], |
|
) |
|
|
|
|
|
|
|
with open("final_arrangement.json", "r") as f: |
|
section_config_json = f.read() |
|
|
|
sections = json.loads(section_config_json) |
|
|
|
|
|
with gr.Blocks(title="Interactive Music Track Generator") as demo: |
|
gr.Markdown("# Interactive Music Track Generator") |
|
gr.Markdown( |
|
"Upload your WAV stems, generate variants for each section, and create a full track" |
|
) |
|
|
|
|
|
stem_list = gr.State([]) |
|
|
|
with gr.Tab("1. Upload Stems"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
gr.Markdown("### Upload Files") |
|
gr.Markdown("Drag and drop your WAV stem files here") |
|
file_input = gr.File( |
|
label="WAV Stems", file_count="multiple", file_types=[".wav"] |
|
) |
|
|
|
upload_btn = gr.Button("Upload and Process Files", variant="primary") |
|
|
|
with gr.Column(): |
|
upload_status = gr.Textbox(label="Upload Status", interactive=False) |
|
stem_display = gr.JSON(label="Available Stems") |
|
|
|
with gr.Tab("1.1. ⏳ Finalise Sections Arrangement (In Progress)"): |
|
gr.Markdown("### 🎶 Finalise Sections Arrangement") |
|
gr.Markdown( |
|
"🎛️ Use the diagram below to adjust the arrangement of your sections. Click on a section to edit its properties." |
|
) |
|
|
|
with gr.Row(): |
|
gr.Markdown("⏳ In Progress: Adjusting sections...") |
|
|
|
edm_arrangement_tab() |
|
|
|
|
|
with gr.Tab("2. Generate Section Variants"): |
|
gr.Markdown("### Generate and Select Variants for Each Section") |
|
gr.Markdown("Generate variants for each section and select which one to use in the final track") |
|
|
|
|
|
section_uis = {} |
|
|
|
|
|
selected_variants_display = gr.JSON( |
|
label="Selected Variants", |
|
value={"No variants selected yet": "Generate and select variants in sections below"} |
|
) |
|
|
|
for section_name, params in sections.items(): |
|
section_uis[section_name] = create_section_ui(section_name, params) |
|
setup_section_event_handlers(section_name, section_uis[section_name], selected_variants_display) |
|
|
|
with gr.Tab("3. Create Full Track"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### Track Settings") |
|
crossfade_ms = gr.Slider( |
|
label="Crossfade Duration (ms)", |
|
minimum=0, |
|
maximum=2000, |
|
value=500, |
|
step=100, |
|
) |
|
output_track_name = gr.Textbox( |
|
label="Output Filename", |
|
value="full_track_output.wav", |
|
placeholder="e.g., full_track_output.wav", |
|
) |
|
|
|
gr.Markdown("### Sections to Include") |
|
section_checkboxes = {} |
|
for section_name in sections.keys(): |
|
section_checkboxes[section_name] = gr.Checkbox( |
|
label=f"Include {section_name.capitalize()}", |
|
value=True |
|
) |
|
|
|
gr.Markdown("### Selected Variants Summary") |
|
selected_variants_display = gr.JSON( |
|
label="Selected Variants", |
|
value={"No variants selected yet": "Generate and select variants in Section 2"} |
|
) |
|
|
|
generate_track_btn = gr.Button( |
|
"Generate Full Track", variant="primary", scale=2 |
|
) |
|
|
|
with gr.Column(): |
|
track_status = gr.Textbox(label="Status", interactive=False) |
|
track_summary = gr.JSON(label="Track Summary") |
|
full_track_audio = gr.Audio(label="Generated Full Track") |
|
|
|
|
|
upload_btn.click( |
|
fn=process_uploaded_files, |
|
inputs=[file_input], |
|
outputs=[upload_status, stem_display], |
|
) |
|
|
|
|
|
generate_track_btn.click( |
|
fn=generate_full_track, |
|
inputs=[ |
|
crossfade_ms, |
|
output_track_name, |
|
*[section_checkboxes[section] for section in sections.keys()] |
|
], |
|
outputs=[track_status, full_track_audio, track_summary], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|