Laronix_Recording / local /data_preparation.py
KevinGeng's picture
push to HF
a1fe393
raw
history blame
2.44 kB
import os
import pdb
import shutil
import pandas as pd
from datasets import Dataset, load_dataset
audio_dir = "./data/Patient_sil_trim_16k_normed_5_snr_40/"
# split_files = {"train": "data/Patient_sil_trim_16k_normed_5_snr_40/train.csv",
# "test": "data/Patient_sil_trim_16k_normed_5_snr_40/test.csv",
# "dev": "data/Patient_sil_trim_16k_normed_5_snr_40/dev.csv"}
src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
pdb.set_trace()
def train_dev_test_split(
dataset: Dataset, dev_rate=0.1, test_rate=0.1, seed=1, metadata_output=False, root_dir=None
):
"""
input: dataset
dev_rate,
test_rate
seed
-------
Output:
dataset_dict{"train", "dev", "test"}
"""
train_dev_test = dataset.train_test_split(test_size=test_rate, seed=seed)
test = train_dev_test["test"]
train_dev = train_dev_test["train"]
if len(train_dev) <= int(len(dataset) * dev_rate):
train = Dataset.from_dict({"audio": [], "transcription": []})
dev = train_dev
else:
train_dev = train_dev.train_test_split(
test_size=int(len(dataset) * dev_rate), seed=seed
)
train = train_dev["train"]
dev = train_dev["test"]
train_size = len(train)
dev_size = len(dev)
test_size = len(test)
print(f"Train Size: {len(train)}")
print(f"Dev Size: {len(dev)}")
print(f"Test Size: {len(test)}")
import pdb
if metadata_output:
pdb.set_trace()
train_df = pd.DateFrame(train)
dev_df = pd.DataFrame(dev)
test_df = pd.DataFrame(test)
try:
os.path.exists(root_dir)
except:
raise FileNotFoundError
# Create directories for train, dev, and test data
import pdb
if not os.path.exists(f'{root_dir}/train'):
os.makedirs(f'{root_dir}/train')
if not os.path.exists(f'{root_dir}/dev'):
os.makedirs(f'{root_dir}/dev')
if not os.path.exists(f'{root_dir}/test'):
os.makedirs(f'{root_dir}/test')
pdb.set_trace()
train_df.to_csv(f'{root_dir}/train/metadata.csv', index=False)
dev_df.to_csv(f'{root_dir}/dev/metadata.csv', index=False)
test_df.to_csv(f'{root_dir}/test/metadata.csv', index=False)
return train, dev, test
train, dev, test = train_dev_test_split(src_dataset, dev_rate=0.1, test_rate=0.1, seed=1, metadata_output=True, root_dir=audio_dir)
pdb.set_trace()