Spaces:
Runtime error
Runtime error
| import logging | |
| from typing import List, Optional | |
| import pandas as pd | |
| from src.datasets.base_dataset import SimpleAudioFakeDataset | |
| from src.datasets.deepfake_asvspoof_dataset import DeepFakeASVSpoofDataset | |
| from src.datasets.fakeavceleb_dataset import FakeAVCelebDataset | |
| from src.datasets.wavefake_dataset import WaveFakeDataset | |
| from src.datasets.asvspoof_dataset import ASVSpoof2019DatasetOriginal | |
| from src.datasets.MLAADv3_dataset import MLAADv3 | |
| from src.datasets.MAILABS_dataset import MAILABS | |
| from src.datasets.aihub_dataset import AIHUB | |
| from src.datasets.KoAAD_dataset import KoAAD | |
| LOGGER = logging.getLogger() | |
| class DetectionDataset(SimpleAudioFakeDataset): | |
| def __init__( | |
| self, | |
| asvspoof_path=None, | |
| wavefake_path=None, | |
| fakeavceleb_path=None, | |
| asvspoof2019_path=None, | |
| MLAADv3_path=None, | |
| MAILABS_path=None, | |
| AIHUB_path=None, | |
| KoAAD_path=None, | |
| subset: str = "val", | |
| transform=None, | |
| oversample: bool = True, | |
| undersample: bool = False, | |
| return_label: bool = True, | |
| reduced_number: Optional[int] = None, | |
| return_meta: bool = False, | |
| ): | |
| super().__init__( | |
| subset=subset, | |
| transform=transform, | |
| return_label=return_label, | |
| return_meta=return_meta, | |
| ) | |
| datasets = self._init_datasets( | |
| asvspoof_path=asvspoof_path, | |
| wavefake_path=wavefake_path, | |
| fakeavceleb_path=fakeavceleb_path, | |
| asvspoof2019_path=asvspoof2019_path, | |
| MLAADv3_path=MLAADv3_path, | |
| MAILABS_path=MAILABS_path, | |
| AIHUB_path=AIHUB_path, | |
| KoAAD_path=KoAAD_path, | |
| subset=subset, | |
| ) | |
| self.samples = pd.concat([ds.samples for ds in datasets], ignore_index=True) | |
| if oversample: | |
| self.oversample_dataset() | |
| elif undersample: | |
| self.undersample_dataset() | |
| if reduced_number: | |
| LOGGER.info(f"Using reduced number of samples - {reduced_number}!") | |
| self.samples = self.samples.sample( | |
| min(len(self.samples), reduced_number), | |
| random_state=42, | |
| ) | |
| def _init_datasets( | |
| self, | |
| subset: str, | |
| asvspoof_path: Optional[str], | |
| wavefake_path: Optional[str], | |
| fakeavceleb_path: Optional[str], | |
| asvspoof2019_path: Optional[str], | |
| MLAADv3_path=Optional[str], | |
| MAILABS_path=Optional[str], | |
| AIHUB_path=Optional[str], | |
| KoAAD_path=Optional[str], | |
| ) -> List[SimpleAudioFakeDataset]: | |
| datasets = [] | |
| if asvspoof_path is not None: | |
| asvspoof_dataset = DeepFakeASVSpoofDataset(asvspoof_path, subset=subset) | |
| datasets.append(asvspoof_dataset) | |
| if wavefake_path is not None: | |
| wavefake_dataset = WaveFakeDataset(wavefake_path, subset=subset) | |
| datasets.append(wavefake_dataset) | |
| if fakeavceleb_path is not None: | |
| fakeavceleb_dataset = FakeAVCelebDataset(fakeavceleb_path, subset=subset) | |
| datasets.append(fakeavceleb_dataset) | |
| if asvspoof2019_path is not None: | |
| la_dataset = ASVSpoof2019DatasetOriginal( | |
| asvspoof2019_path, fold_subset=subset | |
| ) | |
| datasets.append(la_dataset) | |
| if MLAADv3_path is not None: | |
| MLAADv3_dataset = MLAADv3(MLAADv3_path, subset=subset) | |
| datasets.append(MLAADv3_dataset) | |
| if MAILABS_path is not None: | |
| MAILABS_dataset = MAILABS(MAILABS_path, subset=subset) | |
| datasets.append(MAILABS_dataset) | |
| if AIHUB_path is not None: | |
| aihub_dataset = AIHUB(AIHUB_path, subset=subset) | |
| datasets.append(aihub_dataset) | |
| if KoAAD_path is not None: | |
| KoAAD_dataset = KoAAD(KoAAD_path, subset=subset) | |
| datasets.append(KoAAD_dataset) | |
| return datasets | |
| def oversample_dataset(self): | |
| samples = self.samples.groupby(by=["label"]) | |
| bona_length = len(samples.groups["bonafide"]) | |
| spoof_length = len(samples.groups["spoof"]) | |
| diff_length = spoof_length - bona_length | |
| if diff_length < 0: | |
| raise NotImplementedError | |
| if diff_length > 0: | |
| bonafide = samples.get_group("bonafide").sample(diff_length, replace=True) | |
| self.samples = pd.concat([self.samples, bonafide], ignore_index=True) | |
| def undersample_dataset(self): | |
| samples = self.samples.groupby(by=["label"]) | |
| bona_length = len(samples.groups["bonafide"]) | |
| spoof_length = len(samples.groups["spoof"]) | |
| if spoof_length < bona_length: | |
| raise NotImplementedError | |
| if spoof_length > bona_length: | |
| spoofs = samples.get_group("spoof").sample(bona_length, replace=True) | |
| self.samples = pd.concat( | |
| [samples.get_group("bonafide"), spoofs], ignore_index=True | |
| ) | |
| def get_bonafide_only(self): | |
| samples = self.samples.groupby(by=["label"]) | |
| self.samples = samples.get_group("bonafide") | |
| return self.samples | |
| def get_spoof_only(self): | |
| samples = self.samples.groupby(by=["label"]) | |
| self.samples = samples.get_group("spoof") | |
| return self.samples | |