GravityLLM / tools /make_synthetic_dataset.py
lzanardos9's picture
Upload 20 files
b7720f0 verified
import argparse
import json
import random
from pathlib import Path
CLASSES = [
("lead_vocal", 0.95),
("kick", 0.25),
("bass", 0.35),
("pad", 0.10),
("synth_lead", 0.70),
("fx", 0.05),
]
STYLE_PRESETS = {
"club": {"layout": "iamf", "room_preset": "club_medium", "lufs": -14.0},
"cinematic": {"layout": "iamf", "room_preset": "cinema_large", "lufs": -16.0},
"live_stage": {"layout": "iamf", "room_preset": "stage_live", "lufs": -15.0},
}
def build_example(seed: int) -> dict:
rng = random.Random(seed)
style = rng.choice(list(STYLE_PRESETS.keys()))
bpm = rng.choice([90, 100, 110, 120, 128, 140])
energy = round(rng.uniform(0.35, 0.95), 2)
max_objects = rng.choice([8, 10, 12])
stems = []
for idx, (klass, leadness) in enumerate(CLASSES, start=1):
stems.append(
{
"id": f"{klass[:1]}{idx}",
"class": klass,
"lufs": round(rng.uniform(-25.0, -10.0), 1),
"transient": round(rng.uniform(0.05, 0.98), 2),
"band_energy": {
"low": round(rng.uniform(0.05, 0.9), 2),
"mid": round(rng.uniform(0.05, 0.9), 2),
"high": round(rng.uniform(0.05, 0.9), 2),
},
"leadness": leadness,
}
)
payload = {
"target_format": "iamf",
"max_objects": max_objects,
"style": style,
"section": rng.choice(["intro", "verse", "break", "drop"]),
"global": {"bpm": bpm, "energy": energy},
"stems": stems,
"rules": [
{"type": "anchor", "track_class": "lead_vocal", "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
{"type": "mono_low_end", "hz_below": 120},
{"type": "width_pref", "track_class": "pad", "min_width": 0.75},
],
}
preset = STYLE_PRESETS[style]
output = {
"version": "1.0",
"bed": {
"layout": preset["layout"],
"loudness_target_lufs": preset["lufs"],
"room_preset": preset["room_preset"],
},
"objects": [
{
"id": stems[0]["id"],
"class": "lead_vocal",
"az_deg": 0,
"el_deg": 10,
"dist_m": 1.6,
"width": 0.15,
"gain_db": 0.0,
"reverb_send": 0.18,
"early_reflections": 0.2,
"motion": [
{"t": 0.0, "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
{"t": 1.0, "az_deg": 0, "el_deg": 10, "dist_m": 1.6},
],
}
],
"constraints_applied": [
"anchor:lead_vocal@0/10/1.6",
"mono_low_end<120Hz",
"pad_width>=0.75",
],
}
prompt = (
"GravityLLM: Output ONLY valid JSON matching the Spatial9Scene schema.\n\n"
"INPUT:\n" + json.dumps(payload, indent=2)
)
return {"prompt": prompt, "completion": json.dumps(output, indent=2)}
def main() -> None:
parser = argparse.ArgumentParser(description="Generate a small synthetic GravityLLM dataset.")
parser.add_argument("--output", type=Path, default=Path("data/synthetic_train.jsonl"))
parser.add_argument("--count", type=int, default=25)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", encoding="utf-8") as f:
for i in range(args.count):
row = build_example(args.seed + i)
f.write(json.dumps(row, ensure_ascii=False) + "\n")
print(f"Wrote {args.count} examples to {args.output}")
if __name__ == "__main__":
main()