|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import argparse | 
					
						
						|  | import librosa | 
					
						
						|  | import logging | 
					
						
						|  | import soundfile as sf | 
					
						
						|  | import sys | 
					
						
						|  | from pathlib import Path | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sub_modules = ["", "semantic_tokenizer/f40ms", "semantic_detokenizer"] | 
					
						
						|  | for sub in sub_modules: | 
					
						
						|  | sys.path.append(str((Path(__file__).parent / sub).absolute())) | 
					
						
						|  |  | 
					
						
						|  | from semantic_tokenizer.f40ms.simple_tokenizer_infer import SpeechTokenizer, TOKENIZER_CFG_NAME | 
					
						
						|  | from semantic_detokenizer.chunk_infer import SpeechDetokenizer | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ReconstructionPipeline: | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | detok_vocoder: str, | 
					
						
						|  | tokenizer_cfg_name: str = TOKENIZER_CFG_NAME, | 
					
						
						|  | tokenizer_cfg_path: str = str( | 
					
						
						|  | (Path(__file__).parent / "semantic_tokenizer/f40ms/config").absolute() | 
					
						
						|  | ), | 
					
						
						|  | tokenizer_ckpt: str = str( | 
					
						
						|  | ( | 
					
						
						|  | Path(__file__).parent / "semantic_tokenizer/f40ms/ckpt/model.pt" | 
					
						
						|  | ).absolute() | 
					
						
						|  | ), | 
					
						
						|  | detok_model_cfg: str = str( | 
					
						
						|  | (Path(__file__).parent / "semantic_detokenizer/ckpt/model.yaml").absolute() | 
					
						
						|  | ), | 
					
						
						|  | detok_ckpt: str = str( | 
					
						
						|  | (Path(__file__).parent / "semantic_detokenizer/ckpt/model.pt").absolute() | 
					
						
						|  | ), | 
					
						
						|  | detok_vocab: str = str( | 
					
						
						|  | ( | 
					
						
						|  | Path(__file__).parent / "semantic_detokenizer/ckpt/vocab_4096.txt" | 
					
						
						|  | ).absolute() | 
					
						
						|  | ), | 
					
						
						|  | ): | 
					
						
						|  | self.tokenizer_cfg_name = tokenizer_cfg_name | 
					
						
						|  | self.tokenizer = SpeechTokenizer( | 
					
						
						|  | ckpt_path=tokenizer_ckpt, | 
					
						
						|  | cfg_path=tokenizer_cfg_path, | 
					
						
						|  | cfg_name=self.tokenizer_cfg_name, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.device = "cuda:0" | 
					
						
						|  | self.detoker = SpeechDetokenizer( | 
					
						
						|  | vocoder_path=detok_vocoder, | 
					
						
						|  | model_cfg=detok_model_cfg, | 
					
						
						|  | ckpt_file=detok_ckpt, | 
					
						
						|  | vocab_file=detok_vocab, | 
					
						
						|  | device=self.device, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.token_chunk_len = 75 | 
					
						
						|  | self.chunk_cond_proportion = 0.3 | 
					
						
						|  | self.chunk_look_ahead = 10 | 
					
						
						|  | self.max_ref_duration = 4.5 | 
					
						
						|  | self.ref_audio_cut_from_head = False | 
					
						
						|  |  | 
					
						
						|  | def reconstruct(self, ref_wav, input_wav): | 
					
						
						|  | ref_wavs_list = [] | 
					
						
						|  | raw_ref_wav, sr = librosa.load(ref_wav, sr=16000) | 
					
						
						|  | ref_wavs_list.append(raw_ref_wav) | 
					
						
						|  |  | 
					
						
						|  | raw_input_wav, sr = librosa.load(input_wav, sr=16000) | 
					
						
						|  | ref_wavs_list.append(raw_input_wav) | 
					
						
						|  |  | 
					
						
						|  | token_list, token_info_list = self.tokenizer.extract( | 
					
						
						|  | ref_wavs_list | 
					
						
						|  | ) | 
					
						
						|  | ref_tokens = token_info_list[0]["reduced_unit_sequence"] | 
					
						
						|  | input_tokens = token_info_list[1]["reduced_unit_sequence"] | 
					
						
						|  | logging.info("tokens for ref wav: %s are [%s]" % (ref_wav, ref_tokens)) | 
					
						
						|  | logging.info("tokens for input wav: %s are [%s]" % (input_wav, input_tokens)) | 
					
						
						|  |  | 
					
						
						|  | generated_wave, target_sample_rate = self.detoker.chunk_generate( | 
					
						
						|  | ref_wav, | 
					
						
						|  | ref_tokens.split(), | 
					
						
						|  | input_tokens.split(), | 
					
						
						|  | self.token_chunk_len, | 
					
						
						|  | self.chunk_cond_proportion, | 
					
						
						|  | self.chunk_look_ahead, | 
					
						
						|  | self.max_ref_duration, | 
					
						
						|  | self.ref_audio_cut_from_head, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if generated_wave is None: | 
					
						
						|  | logging.info("generation FAILED") | 
					
						
						|  | return None, None | 
					
						
						|  | return generated_wave, target_sample_rate | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def main(args): | 
					
						
						|  |  | 
					
						
						|  | reconsturctor = ReconstructionPipeline( | 
					
						
						|  | detok_vocoder=args.detok_vocoder, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | generated_wave, target_sample_rate = reconsturctor.reconstruct(args.ref_wav, args.input_wav) | 
					
						
						|  | with open(args.output_wav, "wb") as f: | 
					
						
						|  | sf.write(f.name, generated_wave, target_sample_rate) | 
					
						
						|  | logging.info(f"write output to: {f.name}") | 
					
						
						|  |  | 
					
						
						|  | logging.info("Finished") | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | parser = argparse.ArgumentParser() | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--tokenizer-ckpt", | 
					
						
						|  | required=False, | 
					
						
						|  | help="path to ckpt", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--tokenizer-cfg-path", | 
					
						
						|  | required=False, | 
					
						
						|  | default="semantic_tokenizer/f40ms/config", | 
					
						
						|  | help="path to config", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--detok-ckpt", | 
					
						
						|  | required=False, | 
					
						
						|  | help="path to ckpt", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--detok-model-cfg", | 
					
						
						|  | required=False, | 
					
						
						|  | help="path to model_cfg", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--detok-vocab", | 
					
						
						|  | required=False, | 
					
						
						|  | help="path to vocab", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--detok-vocoder", | 
					
						
						|  | required=True, | 
					
						
						|  | help="path to vocoder", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--ref-wav", | 
					
						
						|  | required=True, | 
					
						
						|  | help="path to ref wav", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--output-wav", | 
					
						
						|  | required=True, | 
					
						
						|  | help="path to output reconstructed wav", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--input-wav", | 
					
						
						|  | required=True, | 
					
						
						|  | help="input wav to reconstruction", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | args = parser.parse_args() | 
					
						
						|  |  | 
					
						
						|  | main(args) | 
					
						
						|  |  |