DiffSpeech / mfa_usr /adapt.py
RayeRen's picture
init
d1b91e7
raw history blame
No virus
10.9 kB
import shutil
import os
import time
from montreal_forced_aligner import __version__
from montreal_forced_aligner.corpus.align_corpus import AlignableCorpus
from montreal_forced_aligner.dictionary import Dictionary, MultispeakerDictionary
from montreal_forced_aligner.aligner import TrainableAligner, PretrainedAligner
from montreal_forced_aligner.models import AcousticModel
from montreal_forced_aligner.config import TEMP_DIR, align_yaml_to_config, load_basic_align, load_command_configuration, \
train_yaml_to_config
from montreal_forced_aligner.utils import get_available_acoustic_languages, get_pretrained_acoustic_path, \
get_available_dict_languages, validate_dictionary_arg
from montreal_forced_aligner.helper import setup_logger, log_config
from montreal_forced_aligner.exceptions import ArgumentError
def load_adapt_config():
training_config, align_config = train_yaml_to_config('mfa_usr/adapt_config.yaml', require_mono=False)
training_config.training_configs[0].fmllr_iterations = list(
range(0, training_config.training_configs[0].num_iterations))
training_config.training_configs[0].realignment_iterations = list(range(0, training_config.training_configs[
0].num_iterations))
return training_config, align_config
class AcousticModel2(AcousticModel):
def adaptation_config(self):
train, align = load_adapt_config()
return train
def adapt_model(args, unknown_args=None):
command = 'align'
all_begin = time.time()
if not args.temp_directory:
temp_dir = TEMP_DIR
else:
temp_dir = os.path.expanduser(args.temp_directory)
corpus_name = os.path.basename(args.corpus_directory)
if corpus_name == '':
args.corpus_directory = os.path.dirname(args.corpus_directory)
corpus_name = os.path.basename(args.corpus_directory)
data_directory = os.path.join(temp_dir, corpus_name)
if args.config_path:
align_config = align_yaml_to_config(args.config_path)
else:
align_config = load_basic_align()
align_config.use_mp = not args.disable_mp
align_config.debug = args.debug
align_config.overwrite = args.overwrite
align_config.cleanup_textgrids = not args.disable_textgrid_cleanup
if unknown_args:
align_config.update_from_args(unknown_args)
conf_path = os.path.join(data_directory, 'config.yml')
if getattr(args, 'clean', False) and os.path.exists(data_directory):
print('Cleaning old directory!')
shutil.rmtree(data_directory, ignore_errors=True)
if getattr(args, 'verbose', False):
log_level = 'debug'
else:
log_level = 'info'
logger = setup_logger(command, data_directory, console_level=log_level)
logger.debug('ALIGN CONFIG:')
log_config(logger, align_config)
conf = load_command_configuration(conf_path, {'dirty': False,
'begin': all_begin,
'version': __version__,
'type': command,
'corpus_directory': args.corpus_directory,
'dictionary_path': args.dictionary_path,
'acoustic_model_path': args.acoustic_model_path})
if conf['dirty'] or conf['type'] != command \
or conf['corpus_directory'] != args.corpus_directory \
or conf['version'] != __version__ \
or conf['dictionary_path'] != args.dictionary_path:
logger.warning(
'WARNING: Using old temp directory, this might not be ideal for you, use the --clean flag to ensure no '
'weird behavior for previous versions of the temporary directory.')
if conf['dirty']:
logger.debug('Previous run ended in an error (maybe ctrl-c?)')
if conf['type'] != command:
logger.debug('Previous run was a different subcommand than {} (was {})'.format(command, conf['type']))
if conf['corpus_directory'] != args.corpus_directory:
logger.debug('Previous run used source directory '
'path {} (new run: {})'.format(conf['corpus_directory'], args.corpus_directory))
if conf['version'] != __version__:
logger.debug('Previous run was on {} version (new run: {})'.format(conf['version'], __version__))
if conf['dictionary_path'] != args.dictionary_path:
logger.debug('Previous run used dictionary path {} '
'(new run: {})'.format(conf['dictionary_path'], args.dictionary_path))
if conf['acoustic_model_path'] != args.acoustic_model_path:
logger.debug('Previous run used acoustic model path {} '
'(new run: {})'.format(conf['acoustic_model_path'], args.acoustic_model_path))
os.makedirs(data_directory, exist_ok=True)
model_directory = os.path.join(data_directory, 'acoustic_models')
os.makedirs(model_directory, exist_ok=True)
acoustic_model = AcousticModel2(args.acoustic_model_path, root_directory=model_directory)
print("| acoustic_model.meta", acoustic_model.meta)
acoustic_model.log_details(logger)
training_config = acoustic_model.adaptation_config()
training_config.training_configs[0].update({'beam': align_config.beam, 'retry_beam': align_config.retry_beam})
training_config.update_from_align(align_config)
logger.debug('ADAPT TRAINING CONFIG:')
log_config(logger, training_config)
audio_dir = None
if args.audio_directory:
audio_dir = args.audio_directory
try:
corpus = AlignableCorpus(args.corpus_directory, data_directory,
speaker_characters=args.speaker_characters,
num_jobs=args.num_jobs, sample_rate=align_config.feature_config.sample_frequency,
logger=logger, use_mp=align_config.use_mp, punctuation=align_config.punctuation,
clitic_markers=align_config.clitic_markers, audio_directory=audio_dir)
if corpus.issues_check:
logger.warning('Some issues parsing the corpus were detected. '
'Please run the validator to get more information.')
logger.info(corpus.speaker_utterance_info())
if args.dictionary_path.lower().endswith('.yaml'):
dictionary = MultispeakerDictionary(args.dictionary_path, data_directory, logger=logger,
punctuation=align_config.punctuation,
clitic_markers=align_config.clitic_markers,
compound_markers=align_config.compound_markers,
multilingual_ipa=acoustic_model.meta['multilingual_ipa'],
strip_diacritics=acoustic_model.meta.get('strip_diacritics', None),
digraphs=acoustic_model.meta.get('digraphs', None))
else:
dictionary = Dictionary(args.dictionary_path, data_directory, logger=logger,
punctuation=align_config.punctuation,
clitic_markers=align_config.clitic_markers,
compound_markers=align_config.compound_markers,
multilingual_ipa=acoustic_model.meta['multilingual_ipa'],
strip_diacritics=acoustic_model.meta.get('strip_diacritics', None),
digraphs=acoustic_model.meta.get('digraphs', None))
acoustic_model.validate(dictionary)
begin = time.time()
previous = PretrainedAligner(corpus, dictionary, acoustic_model, align_config,
temp_directory=data_directory,
debug=getattr(args, 'debug', False), logger=logger)
a = TrainableAligner(corpus, dictionary, training_config, align_config,
temp_directory=data_directory,
debug=getattr(args, 'debug', False), logger=logger, pretrained_aligner=previous)
logger.debug('Setup adapter in {} seconds'.format(time.time() - begin))
a.verbose = args.verbose
begin = time.time()
a.train()
logger.debug('Performed adaptation in {} seconds'.format(time.time() - begin))
begin = time.time()
a.save(args.output_model_path, root_directory=model_directory)
a.export_textgrids(args.output_directory)
logger.debug('Exported TextGrids in {} seconds'.format(time.time() - begin))
logger.info('All done!')
except Exception as _:
conf['dirty'] = True
raise
finally:
handlers = logger.handlers[:]
for handler in handlers:
handler.close()
logger.removeHandler(handler)
conf.save(conf_path)
def validate_args(args, downloaded_acoustic_models, download_dictionaries):
if not os.path.exists(args.corpus_directory):
raise ArgumentError('Could not find the corpus directory {}.'.format(args.corpus_directory))
if not os.path.isdir(args.corpus_directory):
raise ArgumentError('The specified corpus directory ({}) is not a directory.'.format(args.corpus_directory))
args.dictionary_path = validate_dictionary_arg(args.dictionary_path, download_dictionaries)
if args.acoustic_model_path.lower() in downloaded_acoustic_models:
args.acoustic_model_path = get_pretrained_acoustic_path(args.acoustic_model_path.lower())
elif args.acoustic_model_path.lower().endswith(AcousticModel.extension):
if not os.path.exists(args.acoustic_model_path):
raise ArgumentError('The specified model path does not exist: ' + args.acoustic_model_path)
else:
raise ArgumentError(
'The language \'{}\' is not currently included in the distribution, '
'please align via training or specify one of the following language names: {}.'.format(
args.acoustic_model_path.lower(), ', '.join(downloaded_acoustic_models)))
def run_adapt_model(args, unknown_args=None, downloaded_acoustic_models=None, download_dictionaries=None):
if downloaded_acoustic_models is None:
downloaded_acoustic_models = get_available_acoustic_languages()
if download_dictionaries is None:
download_dictionaries = get_available_dict_languages()
try:
args.speaker_characters = int(args.speaker_characters)
except ValueError:
pass
args.corpus_directory = args.corpus_directory.rstrip('/').rstrip('\\')
validate_args(args, downloaded_acoustic_models, download_dictionaries)
adapt_model(args, unknown_args)