Spaces:
Running
Running
| import logging | |
| import os | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import Union | |
| import pandas as pd | |
| import torch | |
| from torch.utils.data.dataset import Dataset | |
| log = logging.getLogger() | |
| class AudioCapsData(Dataset): | |
| def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]): | |
| df = pd.read_csv(csv_path).to_dict(orient='records') | |
| audio_files = sorted(os.listdir(audio_path)) | |
| audio_files = set( | |
| [Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')]) | |
| self.data = [] | |
| for row in df: | |
| self.data.append({ | |
| 'name': row['name'], | |
| 'caption': row['caption'], | |
| }) | |
| self.audio_path = Path(audio_path) | |
| self.csv_path = Path(csv_path) | |
| log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}') | |
| def __getitem__(self, idx: int) -> torch.Tensor: | |
| return self.data[idx] | |
| def __len__(self): | |
| return len(self.data) | |