|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
from simple_infer import Text2TokenGenerator, dummy_encode_fn |
|
|
from fairseq.dataclass.configs import FairseqConfig |
|
|
import fileinput |
|
|
from fairseq import utils, options |
|
|
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints |
|
|
from fairseq_cli.generate import get_symbols_to_strip_from_output |
|
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
|
|
from collections import namedtuple |
|
|
import time |
|
|
import logging |
|
|
import sys |
|
|
import os |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints") |
|
|
|
|
|
|
|
|
class T2USeedTTS(Text2TokenGenerator): |
|
|
def __init__(self, args): |
|
|
super().__init__(args) |
|
|
|
|
|
def buffered_read(self, input, buffer_size): |
|
|
buffer = [] |
|
|
with fileinput.input( |
|
|
files=[input], openhook=fileinput.hook_encoded("utf-8") |
|
|
) as h: |
|
|
for src_str in h: |
|
|
fields = src_str.strip().split("|") |
|
|
phones = self.text2phone(fields[-1]) |
|
|
buffer.append( |
|
|
[fields[0], fields[1], fields[2], fields[3], phones] |
|
|
) |
|
|
if len(buffer) >= buffer_size: |
|
|
yield buffer |
|
|
buffer = [] |
|
|
|
|
|
if len(buffer) > 0: |
|
|
yield buffer |
|
|
|
|
|
def generate_for_text_file_input(self, input): |
|
|
start_time = time.time() |
|
|
total_translate_time = 0 |
|
|
|
|
|
hypo_outputs = [] |
|
|
start_id = 0 |
|
|
for inputs in self.buffered_read(input, self.cfg.interactive.buffer_size): |
|
|
phone_lines = [x[-1] for x in inputs] |
|
|
results = [] |
|
|
for batch in self.make_batches(phone_lines, dummy_encode_fn): |
|
|
bsz = batch.src_tokens.size(0) |
|
|
src_tokens = batch.src_tokens |
|
|
src_lengths = batch.src_lengths |
|
|
constraints = batch.constraints |
|
|
if self.use_cuda: |
|
|
src_tokens = src_tokens.cuda() |
|
|
src_lengths = src_lengths.cuda() |
|
|
if constraints is not None: |
|
|
constraints = constraints.cuda() |
|
|
|
|
|
sample = { |
|
|
"net_input": { |
|
|
"src_tokens": src_tokens, |
|
|
"src_lengths": src_lengths, |
|
|
}, |
|
|
} |
|
|
|
|
|
logging.info(f"Processing batch of size: {bsz}") |
|
|
translate_start_time = time.time() |
|
|
translations = self.task.inference_step( |
|
|
self.generator, self.models, sample, constraints=constraints |
|
|
) |
|
|
translate_time = time.time() - translate_start_time |
|
|
total_translate_time += translate_time |
|
|
list_constraints = [[] for _ in range(bsz)] |
|
|
|
|
|
|
|
|
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): |
|
|
src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad()) |
|
|
constraints = list_constraints[i] |
|
|
results.append( |
|
|
( |
|
|
start_id + id, |
|
|
src_tokens_i, |
|
|
hypos, |
|
|
{ |
|
|
"constraints": constraints, |
|
|
"time": translate_time / len(translations), |
|
|
}, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): |
|
|
output = {} |
|
|
output["src_tokens"] = [] |
|
|
|
|
|
|
|
|
input_info = inputs[id_ % self.cfg.interactive.buffer_size] |
|
|
output["src_info"] = input_info |
|
|
|
|
|
|
|
|
if self.src_dict is not None: |
|
|
src_str = self.src_dict.string( |
|
|
src_tokens, self.cfg.common_eval.post_process |
|
|
) |
|
|
if src_str != input_info[-1]: |
|
|
logging.info(f"ERROR, input output mismatch!!") |
|
|
logging.info(f"{src_str}") |
|
|
logging.info(f"{ input_info[-1]}") |
|
|
output["src_tokens"] = src_str.split() |
|
|
|
|
|
|
|
|
output["hypotheses"] = [] |
|
|
for hypo in hypos[: min(len(hypos), self.cfg.generation.nbest)]: |
|
|
hypo_str = self.tgt_dict.string( |
|
|
hypo["tokens"].int().cpu(), |
|
|
self.cfg.common_eval.post_process, |
|
|
extra_symbols_to_ignore=get_symbols_to_strip_from_output( |
|
|
self.generator |
|
|
), |
|
|
) |
|
|
output["hypotheses"].append( |
|
|
{ |
|
|
"hypo_tokens": hypo_str.split(), |
|
|
"alignment": hypo["alignment"], |
|
|
} |
|
|
) |
|
|
|
|
|
hypo_outputs.append(output) |
|
|
logging.info(f"output records: {len(hypo_outputs)}") |
|
|
|
|
|
start_id += len(inputs) |
|
|
|
|
|
logging.info( |
|
|
"Total time: {:.3f} seconds; translation time: {:.3f}".format( |
|
|
time.time() - start_time, total_translate_time |
|
|
) |
|
|
) |
|
|
return hypo_outputs |
|
|
|
|
|
def generate_for_long_text_input_file(self, input, max_segment_len=0): |
|
|
start_time = time.time() |
|
|
total_translate_time = 0 |
|
|
|
|
|
hypo_outputs = [] |
|
|
for inputs in self.buffered_read(input, self.cfg.interactive.buffer_size): |
|
|
logging.info(f"processing inputs: {len(inputs)}") |
|
|
phones = [input_info[-1] for input_info in inputs] |
|
|
hypo_tokens, translate_time = self.generate_for_long_input_text( |
|
|
phones, max_segment_len=max_segment_len |
|
|
) |
|
|
total_translate_time += translate_time |
|
|
for tok, info in zip(hypo_tokens, inputs): |
|
|
hypo_outputs.append({"hypotheses": tok, "src_info": info}) |
|
|
|
|
|
logging.info( |
|
|
"Total time: {:.3f} seconds; translation time: {:.3f}".format( |
|
|
time.time() - start_time, total_translate_time |
|
|
) |
|
|
) |
|
|
return hypo_outputs |
|
|
|
|
|
|
|
|
def infer(unk_args, output_file, max_seg_len): |
|
|
output_fp = sys.stdout |
|
|
if output_file is not None: |
|
|
output_fp = open(output_file, "w") |
|
|
|
|
|
t2u = T2USeedTTS(unk_args) |
|
|
logging.info(f"Using max-seg-len = {max_seg_len}") |
|
|
if max_seg_len <= 0: |
|
|
speech_tokens_info = t2u.generate_for_text_file_input(t2u.cfg.interactive.input) |
|
|
for infor in speech_tokens_info: |
|
|
token_str = " ".join(infor["hypotheses"][0]["hypo_tokens"]) |
|
|
text = infor["src_info"][3] |
|
|
ref_wav = infor["src_info"][0] |
|
|
ref_token = infor["src_info"][1] |
|
|
test_id = infor["src_info"][2] |
|
|
test_line = f"{ref_wav}|{ref_token}|{test_id}.wav|{token_str}|{text}" |
|
|
output_fp.write(test_line + "\n") |
|
|
else: |
|
|
logging.info(f"Split long text into segments of length: {max_seg_len}") |
|
|
speech_tokens_info = t2u.generate_for_long_text_input_file( |
|
|
t2u.cfg.interactive.input, max_segment_len=max_seg_len |
|
|
) |
|
|
for infor in speech_tokens_info: |
|
|
token_str = " ".join(infor["hypotheses"]) |
|
|
text = infor["src_info"][3] |
|
|
ref_wav = infor["src_info"][0] |
|
|
ref_token = infor["src_info"][1] |
|
|
test_id = infor["src_info"][2] |
|
|
test_line = f"{ref_wav}|{ref_token}|{test_id}.wav|{token_str}|{text}" |
|
|
output_fp.write(test_line + "\n") |
|
|
|
|
|
|
|
|
|
|
|
output_fp.flush() |
|
|
output_fp.close() |
|
|
return |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--output", |
|
|
dest="output", |
|
|
required=False, |
|
|
default=None, |
|
|
help="output file", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-seg-len", |
|
|
dest="max_seg_len", |
|
|
required=False, |
|
|
default=0, |
|
|
type=int, |
|
|
help="max segment length", |
|
|
) |
|
|
args, unknown_args = parser.parse_known_args() |
|
|
infer(unknown_args, args.output, args.max_seg_len) |
|
|
|