import os.path import sys base_dir = '..' sys.path.append(base_dir) from Trainer import Trainer from TranslatorTrainer import TranslatorTrainer from dataset import GridDataset, CharMap # whether to use a transformer with word or character level tokenization WORD_TOKENIZE = False # whether to filter out consecutive phonemes PHONEME_FILTER_PREV = False BEAM_SIZE = 0 # lipnet_weights = 'weights/phoneme-231201-0052/I198000-L00048-W00018-C00005.pt' # lipnet_weights = 'weights/phoneme-231201-2218/I119000-L00001-W00000-C00000.pt' lipnet_weights = 'saved-weights/phonemes-231207-2130/I283000-L00683-W01012-C00765.pt' if WORD_TOKENIZE: translator_weights = 'saved-weights/translate-231204-1652/I160-L00047-W00000.pt' else: translator_weights = 'saved-weights/translate-231204-2227/I860-L00000-W00000.pt' # translator_weights = 'weights/translate-231202-1509/I1560-L00000-W00000.pt' # translator_weights = 'weights/translate-231204-1709/I220-L00042-W00000.pt' lipnet_predictor = Trainer( write_logs=False, base_dir=base_dir, num_workers=0, char_map=CharMap.phonemes ) lipnet_predictor.load_weights(lipnet_weights) lipnet_predictor.load_datasets() dataset = lipnet_predictor.test_dataset phoneme_translator = TranslatorTrainer( write_logs=False, base_dir=base_dir, word_tokenize=WORD_TOKENIZE ) phoneme_translator.load_weights(os.path.join( base_dir, translator_weights )) """ new_phonemes = GridDataset.text_to_phonemes("Do you like fries") print("PRE_REV_TRANSLATE", [new_phonemes]) pred_text = phoneme_translator.translate(new_phonemes) print("AFT_REV_TRANSLATE", pred_text) phoneme_sentence = 'B-IH1-N B-L-UW1 AE1-T EH1-F TH-R-IY1 S-UW1-N' pred_text = phoneme_translator.translate(phoneme_sentence) print(f'PRED_TEXT: [{pred_text}]') """ total_samples = 1000 total_wer = 0 num_correct = 0 num_phonemes_correct = 0 for k in range(total_samples): sample = dataset.load_random_sample(char_map=all) tgt_phonemes = sample['phonemes'] tgt_text = sample['txt'] target_phonemes_sentence = dataset.ctc_arr2txt( tgt_phonemes, start=1, char_map=CharMap.phonemes, filter_previous=PHONEME_FILTER_PREV ) target_sentence = dataset.ctc_arr2txt( tgt_text, start=1, char_map=CharMap.letters, filter_previous=False ) pred_phonemes_sentence = lipnet_predictor.predict_sample(sample)[0] pred_text = phoneme_translator.translate( pred_phonemes_sentence, beam_size=BEAM_SIZE ) match_phonemes = pred_phonemes_sentence == target_phonemes_sentence wer = dataset.get_wer( [pred_text], [target_sentence], char_map=CharMap.letters )[0] total_wer += wer correct = False if pred_text == target_sentence: correct = True num_correct += 1 if pred_phonemes_sentence == target_phonemes_sentence: num_phonemes_correct += 1 print( f'PRED-PHONEMES [{k}]', [pred_phonemes_sentence, target_phonemes_sentence], [pred_text, target_sentence], correct, wer ) avg_wer = total_wer / total_samples print(f'{num_correct}/{total_samples} samples correct') print(f'{num_phonemes_correct}/{total_samples} phoneme samples correct') print(f'average WER: {avg_wer}')