Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import argparse | |
| import logging | |
| import multiprocessing | |
| from functools import partial | |
| from pathlib import Path | |
| import faiss | |
| from feature_retrieval import ( | |
| train_index, | |
| FaissIVFFlatTrainableFeatureIndexBuilder, | |
| OnConditionFeatureTransform, | |
| MinibatchKmeansFeatureTransform, | |
| DummyFeatureTransform, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def get_speaker_list(base_path: Path): | |
| speakers_path = base_path / "waves-16k" | |
| if not speakers_path.exists(): | |
| raise FileNotFoundError(f"path {speakers_path} does not exists") | |
| return [speaker_dir.name for speaker_dir in speakers_path.iterdir() if speaker_dir.is_dir()] | |
| def create_indexes_path(base_path: Path) -> Path: | |
| indexes_path = base_path / "indexes" | |
| logger.info("create indexes folder %s", indexes_path) | |
| indexes_path.mkdir(exist_ok=True) | |
| return indexes_path | |
| def create_index( | |
| feature_name: str, | |
| prefix: str, | |
| speaker: str, | |
| base_path: Path, | |
| indexes_path: Path, | |
| compress_features_after: int, | |
| n_clusters: int, | |
| n_parallel: int, | |
| train_batch_size: int = 8192, | |
| ) -> None: | |
| features_path = base_path / feature_name / speaker | |
| if not features_path.exists(): | |
| raise ValueError(f'features not found by path {features_path}') | |
| index_path = indexes_path / speaker | |
| index_path.mkdir(exist_ok=True) | |
| index_filename = f"{prefix}{feature_name}.index" | |
| index_filepath = index_path / index_filename | |
| logger.debug('index will be save to %s', index_filepath) | |
| builder = FaissIVFFlatTrainableFeatureIndexBuilder(train_batch_size, distance=faiss.METRIC_L2) | |
| transform = OnConditionFeatureTransform( | |
| condition=lambda matrix: matrix.shape[0] > compress_features_after, | |
| on_condition=MinibatchKmeansFeatureTransform(n_clusters, n_parallel), | |
| otherwise=DummyFeatureTransform() | |
| ) | |
| train_index(features_path, index_filepath, builder, transform) | |
| def main() -> None: | |
| arg_parser = argparse.ArgumentParser("crate faiss indexes for feature retrieval") | |
| arg_parser.add_argument("--debug", action="store_true") | |
| arg_parser.add_argument("--prefix", default='', help="add prefix to index filename") | |
| arg_parser.add_argument('--speakers', nargs="+", | |
| help="speaker names to create an index. By default all speakers are from data_svc") | |
| arg_parser.add_argument("--compress-features-after", type=int, default=200_000, | |
| help="If the number of features is greater than the value compress " | |
| "feature vectors using MiniBatchKMeans.") | |
| arg_parser.add_argument("--n-clusters", type=int, default=10_000, | |
| help="Number of centroids to which features will be compressed") | |
| arg_parser.add_argument("--n-parallel", type=int, default=multiprocessing.cpu_count()-1, | |
| help="Nuber of parallel job of MinibatchKmeans. Default is cpus-1") | |
| args = arg_parser.parse_args() | |
| if args.debug: | |
| logging.basicConfig(level=logging.DEBUG) | |
| else: | |
| logging.basicConfig(level=logging.INFO) | |
| base_path = Path(".").absolute() / "data_svc" | |
| if args.speakers: | |
| speakers = args.speakers | |
| else: | |
| speakers = get_speaker_list(base_path) | |
| logger.info("got %s speakers: %s", len(speakers), speakers) | |
| indexes_path = create_indexes_path(base_path) | |
| create_index_func = partial( | |
| create_index, | |
| prefix=args.prefix, | |
| base_path=base_path, | |
| indexes_path=indexes_path, | |
| compress_features_after=args.compress_features_after, | |
| n_clusters=args.n_clusters, | |
| n_parallel=args.n_parallel, | |
| ) | |
| for speaker in speakers: | |
| logger.info("create hubert index for speaker %s", speaker) | |
| create_index_func(feature_name="hubert", speaker=speaker) | |
| logger.info("create whisper index for speaker %s", speaker) | |
| create_index_func(feature_name="whisper", speaker=speaker) | |
| logger.info("done!") | |
| if __name__ == '__main__': | |
| main() | |