JacobLinCool's picture
Upload folder using huggingface_hub
975d5cb verified
from datasets import load_dataset, Audio
N_PROC = None
ds = load_dataset("JacobLinCool/taiko-1000-parsed")
ds = ds.remove_columns(["tja", "hard", "normal", "easy", "ura"])
def filter_out_broken(example):
try:
example["audio"]["array"]
return True
except:
return False
ds = ds.filter(filter_out_broken, num_proc=N_PROC, batch_size=32, writer_batch_size=32)
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
def build_beat_and_downbeat_labels(example):
"""
Extract beat and downbeat times from the chart segments.
- Downbeats: First beat of each measure (segment timestamp)
- Beats: All beats within each measure based on time signature
Returns lists of times in seconds.
"""
title = example["metadata"]["TITLE"]
segments = example["oni"]["segments"]
beats = []
downbeats = []
for i, segment in enumerate(segments):
seg_timestamp = segment["timestamp"]
measure_num = segment["measure_num"] # numerator (e.g., 4 in 4/4)
measure_den = segment["measure_den"] # denominator (e.g., 4 in 4/4)
notes = segment["notes"]
# Downbeat is the start of each measure
downbeats.append(seg_timestamp)
# Get BPM from the first note in segment, or fallback to next segment's first note
bpm = None
if notes:
bpm = notes[0]["bpm"]
else:
# Look ahead for BPM if current segment has no notes
for j in range(i + 1, len(segments)):
if segments[j]["notes"]:
bpm = segments[j]["notes"][0]["bpm"]
break
if bpm is None or bpm <= 0:
bpm = 120.0 # fallback default BPM
# Calculate beat duration: one beat = 60/BPM seconds (for quarter note)
# Adjust for time signature denominator (4 = quarter, 8 = eighth, etc.)
beat_duration = (60.0 / bpm) * (4.0 / measure_den)
# Calculate beat positions within this measure
for beat_idx in range(measure_num):
beat_time = seg_timestamp + beat_idx * beat_duration
beats.append(beat_time)
# Sort and deduplicate (in case of overlapping segments)
beats = sorted(set(beats))
downbeats = sorted(set(downbeats))
return {
"title": title,
"beats": beats,
"downbeats": downbeats,
}
ds = ds.map(
build_beat_and_downbeat_labels,
num_proc=N_PROC,
batch_size=32,
writer_batch_size=32,
remove_columns=["oni", "metadata"],
)
ds = ds.with_format("torch")
if __name__ == "__main__":
print(ds)
print(ds["train"].features)