|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import ast |
|
|
import fileinput |
|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import argparse |
|
|
from collections import namedtuple |
|
|
from tqdm import tqdm |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from fairseq import checkpoint_utils, options, tasks, utils |
|
|
from fairseq.dataclass.configs import FairseqConfig |
|
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
|
|
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints |
|
|
from fairseq_cli.generate import get_symbols_to_strip_from_output |
|
|
from fairseq.models import import_models |
|
|
|
|
|
PHONE_SPLITTER = {"[SIL]", "[CM]", "[PD]", "[QN]", "[EX]"} |
|
|
|
|
|
current_root = Path(__file__).absolute().parent |
|
|
sys.path.append(str(current_root)) |
|
|
sys.path.append(str(current_root.parent / "thirdparty/G2P")) |
|
|
from G2P_processors import MultilingualG2P |
|
|
|
|
|
relative_path = Path(current_root.name) |
|
|
namespace = str(relative_path / "models").replace("/" , ".") |
|
|
import_models(str(current_root / "models"), namespace) |
|
|
|
|
|
TOKENIZE_ON_NPU = os.environ.get("TOKENIZE_ON_NPU") |
|
|
if TOKENIZE_ON_NPU is not None and TOKENIZE_ON_NPU == "1": |
|
|
import torch_npu |
|
|
from torch_npu.contrib import transfer_to_npu |
|
|
logging.info("Applying Patches for NPU!!!") |
|
|
|
|
|
|
|
|
console_format = logging.Formatter( |
|
|
"[%(asctime)s][%(filename)s:%(levelname)s][%(process)d:%(threadName)s]%(message)s" |
|
|
) |
|
|
console_handler = logging.StreamHandler() |
|
|
console_handler.setFormatter(console_format) |
|
|
console_handler.setLevel(logging.INFO) |
|
|
if len(logging.root.handlers) > 0: |
|
|
for handler in logging.root.handlers: |
|
|
logging.root.removeHandler(handler) |
|
|
logging.root.addHandler(console_handler) |
|
|
logging.root.setLevel(logging.INFO) |
|
|
|
|
|
|
|
|
Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints") |
|
|
Translation = namedtuple("Translation", "src_str hypos pos_scores alignments") |
|
|
DEFAULT_T2U_ARGS = [ |
|
|
str(current_root) + "/data_bin", |
|
|
"--path", |
|
|
str(current_root) + "/ckpt/40ms.checkpoint15.pt", |
|
|
"--batch-size", |
|
|
"1", |
|
|
"--buffer-size", |
|
|
"2", |
|
|
"--beam", |
|
|
"5", |
|
|
"--max-len-b", |
|
|
"1024", |
|
|
|
|
|
|
|
|
"--source-lang", |
|
|
"ph", |
|
|
"--target-lang", |
|
|
"tgt.unit", |
|
|
] |
|
|
|
|
|
|
|
|
def dummy_encode_fn(x): |
|
|
return x |
|
|
|
|
|
|
|
|
class Text2TokenGenerator: |
|
|
def __init__(self, args=None) -> None: |
|
|
self._initialize(args) |
|
|
|
|
|
def _initialize(self, args): |
|
|
t2u_args = DEFAULT_T2U_ARGS |
|
|
if args is not None and len(args) > 0: |
|
|
t2u_args = t2u_args + args |
|
|
parser = options.get_interactive_generation_parser() |
|
|
t2u_fairseq_args = options.parse_args_and_arch( |
|
|
parser=parser, input_args=t2u_args |
|
|
) |
|
|
cfg: FairseqConfig = convert_namespace_to_omegaconf(t2u_fairseq_args) |
|
|
|
|
|
utils.import_user_module(cfg.common) |
|
|
|
|
|
if cfg.interactive.buffer_size < 1: |
|
|
cfg.interactive.buffer_size = 1 |
|
|
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: |
|
|
cfg.dataset.batch_size = 1 |
|
|
|
|
|
assert ( |
|
|
not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam |
|
|
), "--sampling requires --nbest to be equal to --beam" |
|
|
assert ( |
|
|
not cfg.dataset.batch_size |
|
|
or cfg.dataset.batch_size <= cfg.interactive.buffer_size |
|
|
), "--batch-size cannot be larger than --buffer-size" |
|
|
|
|
|
self.cfg = cfg |
|
|
logging.info(self.cfg) |
|
|
|
|
|
|
|
|
if ( |
|
|
self.cfg.common.seed is not None |
|
|
and not self.cfg.generation.no_seed_provided |
|
|
): |
|
|
np.random.seed(self.cfg.common.seed) |
|
|
utils.set_torch_seed(self.cfg.common.seed) |
|
|
self.use_cuda = torch.cuda.is_available() and not self.cfg.common.cpu |
|
|
|
|
|
|
|
|
self.task = tasks.setup_task(self.cfg.task) |
|
|
|
|
|
|
|
|
overrides = ast.literal_eval(self.cfg.common_eval.model_overrides) |
|
|
logging.info("loading model(s) from {}".format(self.cfg.common_eval.path)) |
|
|
self.models, _model_args = checkpoint_utils.load_model_ensemble( |
|
|
utils.split_paths(self.cfg.common_eval.path), |
|
|
arg_overrides=overrides, |
|
|
task=self.task, |
|
|
suffix=self.cfg.checkpoint.checkpoint_suffix, |
|
|
strict=(self.cfg.checkpoint.checkpoint_shard_count == 1), |
|
|
num_shards=self.cfg.checkpoint.checkpoint_shard_count, |
|
|
) |
|
|
|
|
|
|
|
|
self.src_dict = self.task.source_dictionary |
|
|
self.tgt_dict = self.task.target_dictionary |
|
|
|
|
|
|
|
|
for model in self.models: |
|
|
if model is None: |
|
|
continue |
|
|
if self.cfg.common.fp16: |
|
|
model.half() |
|
|
if ( |
|
|
self.use_cuda |
|
|
and not self.cfg.distributed_training.pipeline_model_parallel |
|
|
): |
|
|
model.cuda() |
|
|
model.prepare_for_inference_(cfg) |
|
|
|
|
|
|
|
|
self.generator = self.task.build_generator(self.models, self.cfg.generation) |
|
|
|
|
|
|
|
|
self.tokenizer = self.task.build_tokenizer(cfg.tokenizer) |
|
|
self.bpe = self.task.build_bpe(cfg.bpe) |
|
|
|
|
|
self.align_dict = None |
|
|
|
|
|
self.max_positions = utils.resolve_max_positions( |
|
|
self.task.max_positions(), *[model.max_positions() for model in self.models] |
|
|
) |
|
|
|
|
|
|
|
|
self.language = "zh" |
|
|
self.mG2P = MultilingualG2P( |
|
|
"wenet", remove_interjections=False, remove_erhua=False |
|
|
) |
|
|
|
|
|
def text2phone(self, text): |
|
|
phones, norm_text = self.mG2P.text_normalization_and_g2p( |
|
|
text, self.language, with_lang_prefix=True, normalize_punct=True |
|
|
) |
|
|
return " ".join(phones) |
|
|
|
|
|
def buffered_read(self, input, buffer_size): |
|
|
buffer = [] |
|
|
with fileinput.input( |
|
|
files=[input], openhook=fileinput.hook_encoded("utf-8") |
|
|
) as h: |
|
|
for src_str in h: |
|
|
phones = self.text2phone(src_str.strip()) |
|
|
buffer.append(phones) |
|
|
if len(buffer) >= buffer_size: |
|
|
yield buffer |
|
|
buffer = [] |
|
|
|
|
|
if len(buffer) > 0: |
|
|
yield buffer |
|
|
|
|
|
def make_batches(self, lines, encode_fn): |
|
|
def encode_fn_target(x): |
|
|
return encode_fn(x) |
|
|
|
|
|
if self.cfg.generation.constraints: |
|
|
|
|
|
|
|
|
batch_constraints = [list() for _ in lines] |
|
|
for i, line in enumerate(lines): |
|
|
if "\t" in line: |
|
|
lines[i], *batch_constraints[i] = line.split("\t") |
|
|
|
|
|
|
|
|
for i, constraint_list in enumerate(batch_constraints): |
|
|
batch_constraints[i] = [ |
|
|
self.task.target_dictionary.encode_line( |
|
|
encode_fn_target(constraint), |
|
|
append_eos=False, |
|
|
add_if_not_exist=False, |
|
|
) |
|
|
for constraint in constraint_list |
|
|
] |
|
|
|
|
|
if self.cfg.generation.constraints: |
|
|
constraints_tensor = pack_constraints(batch_constraints) |
|
|
else: |
|
|
constraints_tensor = None |
|
|
|
|
|
tokens, lengths = self.task.get_interactive_tokens_and_lengths(lines, encode_fn) |
|
|
|
|
|
itr = self.task.get_batch_iterator( |
|
|
dataset=self.task.build_dataset_for_inference( |
|
|
tokens, lengths, constraints=constraints_tensor |
|
|
), |
|
|
max_tokens=self.cfg.dataset.max_tokens, |
|
|
max_sentences=self.cfg.dataset.batch_size, |
|
|
max_positions=self.max_positions, |
|
|
ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test, |
|
|
).next_epoch_itr(shuffle=False) |
|
|
for batch in itr: |
|
|
ids = batch["id"] |
|
|
src_tokens = batch["net_input"]["src_tokens"] |
|
|
src_lengths = batch["net_input"]["src_lengths"] |
|
|
constraints = batch.get("constraints", None) |
|
|
|
|
|
yield Batch( |
|
|
ids=ids, |
|
|
src_tokens=src_tokens, |
|
|
src_lengths=src_lengths, |
|
|
constraints=constraints, |
|
|
) |
|
|
|
|
|
def generate_for_text_file_input(self, input): |
|
|
start_time = time.time() |
|
|
total_translate_time = 0 |
|
|
|
|
|
hypo_outputs = [] |
|
|
start_id = 0 |
|
|
for inputs in self.buffered_read(input, self.cfg.interactive.buffer_size): |
|
|
results = [] |
|
|
for batch in self.make_batches(inputs, dummy_encode_fn): |
|
|
bsz = batch.src_tokens.size(0) |
|
|
src_tokens = batch.src_tokens |
|
|
src_lengths = batch.src_lengths |
|
|
constraints = batch.constraints |
|
|
if self.use_cuda: |
|
|
src_tokens = src_tokens.cuda() |
|
|
src_lengths = src_lengths.cuda() |
|
|
if constraints is not None: |
|
|
constraints = constraints.cuda() |
|
|
|
|
|
sample = { |
|
|
"net_input": { |
|
|
"src_tokens": src_tokens, |
|
|
"src_lengths": src_lengths, |
|
|
}, |
|
|
} |
|
|
translate_start_time = time.time() |
|
|
translations = self.task.inference_step( |
|
|
self.generator, self.models, sample, constraints=constraints |
|
|
) |
|
|
translate_time = time.time() - translate_start_time |
|
|
total_translate_time += translate_time |
|
|
list_constraints = [[] for _ in range(bsz)] |
|
|
if self.cfg.generation.constraints: |
|
|
list_constraints = [unpack_constraints(c) for c in constraints] |
|
|
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): |
|
|
src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad()) |
|
|
constraints = list_constraints[i] |
|
|
results.append( |
|
|
( |
|
|
start_id + id, |
|
|
src_tokens_i, |
|
|
hypos, |
|
|
{ |
|
|
"constraints": constraints, |
|
|
"time": translate_time / len(translations), |
|
|
}, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): |
|
|
output = {} |
|
|
output["src_tokens"] = [] |
|
|
|
|
|
if self.src_dict is not None: |
|
|
src_str = self.src_dict.string( |
|
|
src_tokens, self.cfg.common_eval.post_process |
|
|
) |
|
|
output["src_tokens"] = src_str.split() |
|
|
|
|
|
|
|
|
output["hypotheses"] = [] |
|
|
for hypo in hypos[: min(len(hypos), self.cfg.generation.nbest)]: |
|
|
hypo_str = self.tgt_dict.string( |
|
|
hypo["tokens"].int().cpu(), |
|
|
self.cfg.common_eval.post_process, |
|
|
extra_symbols_to_ignore=get_symbols_to_strip_from_output( |
|
|
self.generator |
|
|
), |
|
|
) |
|
|
output["hypotheses"].append( |
|
|
{ |
|
|
"hypo_tokens": hypo_str.split(), |
|
|
"alignment": hypo["alignment"], |
|
|
} |
|
|
) |
|
|
|
|
|
hypo_outputs.append(output) |
|
|
|
|
|
start_id += len(inputs) |
|
|
|
|
|
logging.info( |
|
|
"Total time: {:.3f} seconds; translation time: {:.3f}".format( |
|
|
time.time() - start_time, total_translate_time |
|
|
) |
|
|
) |
|
|
return hypo_outputs |
|
|
|
|
|
def split_phone_segments(self, phones, max_segment_len=0): |
|
|
phone_segments = [] |
|
|
phone_splits = phones.split() |
|
|
seps = [] |
|
|
for idx in range(len(phone_splits)): |
|
|
ph = phone_splits[idx] |
|
|
if ph in PHONE_SPLITTER: |
|
|
seps.append(idx) |
|
|
|
|
|
if len(seps) <= 0: |
|
|
return [phones] |
|
|
|
|
|
if seps[-1] < len(phone_splits) - 1: |
|
|
seps.append(len(phone_splits) - 1) |
|
|
segment_start = 0 |
|
|
segment_end = 0 |
|
|
for idx in range(len(seps)): |
|
|
seglen = seps[idx] - segment_start + 1 |
|
|
if seglen >= max_segment_len or idx == len(seps) - 1: |
|
|
segment_end = segment_start + seglen |
|
|
phone_segments.append(" ".join(phone_splits[segment_start:segment_end])) |
|
|
segment_start = segment_end |
|
|
else: |
|
|
continue |
|
|
|
|
|
reproduce_phone = " ".join(phone_segments) |
|
|
|
|
|
if phones != reproduce_phone: |
|
|
logging.info(f"ERROR!!!!! segments shorter than phones") |
|
|
exit() |
|
|
|
|
|
return phone_segments |
|
|
|
|
|
def generate_for_long_input_text(self, input_phones, max_segment_len=0): |
|
|
total_translate_time = 0 |
|
|
input_segments = [] |
|
|
segment_lens = [] |
|
|
for input in input_phones: |
|
|
segments = self.split_phone_segments(input, max_segment_len) |
|
|
segment_lens.append(len(segments)) |
|
|
input_segments.extend(segments) |
|
|
|
|
|
logging.info( |
|
|
f"Spliting {len(input_phones)} inputs into {len(input_segments)} segments" |
|
|
) |
|
|
results = [] |
|
|
start_id = 0 |
|
|
for batch in self.make_batches(input_segments, dummy_encode_fn): |
|
|
bsz = batch.src_tokens.size(0) |
|
|
src_tokens = batch.src_tokens |
|
|
src_lengths = batch.src_lengths |
|
|
constraints = batch.constraints |
|
|
if self.use_cuda: |
|
|
src_tokens = src_tokens.cuda() |
|
|
src_lengths = src_lengths.cuda() |
|
|
if constraints is not None: |
|
|
constraints = constraints.cuda() |
|
|
|
|
|
sample = { |
|
|
"net_input": { |
|
|
"src_tokens": src_tokens, |
|
|
"src_lengths": src_lengths, |
|
|
}, |
|
|
} |
|
|
|
|
|
logging.info(f"processing batch: {bsz}") |
|
|
translate_start_time = time.time() |
|
|
translations = self.task.inference_step( |
|
|
self.generator, self.models, sample, constraints=constraints |
|
|
) |
|
|
translate_time = time.time() - translate_start_time |
|
|
total_translate_time += translate_time |
|
|
|
|
|
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): |
|
|
results.append((start_id + id, hypos)) |
|
|
|
|
|
segment_results = [] |
|
|
sorted_results = sorted(results, key=lambda x: x[0]) |
|
|
start_pos = 0 |
|
|
for sl in segment_lens: |
|
|
segment_results.append(sorted_results[start_pos : start_pos + sl]) |
|
|
start_pos += sl |
|
|
|
|
|
assert len(input_phones) == len(segment_results) |
|
|
|
|
|
hypo_tokens = [] |
|
|
for seg_res in segment_results: |
|
|
token_res = [] |
|
|
for id_, hypos in seg_res: |
|
|
|
|
|
hypo = hypos[0] |
|
|
hypo_str = self.tgt_dict.string( |
|
|
hypo["tokens"].int().cpu(), |
|
|
self.cfg.common_eval.post_process, |
|
|
extra_symbols_to_ignore=get_symbols_to_strip_from_output( |
|
|
self.generator |
|
|
), |
|
|
) |
|
|
token_res.extend(hypo_str.split()) |
|
|
hypo_tokens.append(token_res) |
|
|
|
|
|
return hypo_tokens, total_translate_time |
|
|
|
|
|
def generate_for_long_text_input_file(self, input, max_segment_len=0): |
|
|
start_time = time.time() |
|
|
total_translate_time = 0 |
|
|
|
|
|
hypo_outputs = [] |
|
|
for inputs in self.buffered_read(input, self.cfg.interactive.buffer_size): |
|
|
logging.info(f"processing inputs: {len(inputs)}") |
|
|
|
|
|
hypo_tokens, translate_time = self.generate_for_long_input_text( |
|
|
inputs, max_segment_len=max_segment_len |
|
|
) |
|
|
total_translate_time += translate_time |
|
|
hypo_outputs.extend(hypo_tokens) |
|
|
|
|
|
logging.info( |
|
|
"Total time: {:.3f} seconds; translation time: {:.3f}".format( |
|
|
time.time() - start_time, total_translate_time |
|
|
) |
|
|
) |
|
|
return hypo_outputs |
|
|
|
|
|
|
|
|
def infer(unk_args, output_file, max_seg_len): |
|
|
output_fp = sys.stdout |
|
|
if output_file is not None: |
|
|
output_fp = open(output_file, "w") |
|
|
|
|
|
t2u = Text2TokenGenerator(unk_args) |
|
|
if max_seg_len <= 0: |
|
|
speech_tokens_info = t2u.generate_for_text_file_input(t2u.cfg.interactive.input) |
|
|
for infor in speech_tokens_info: |
|
|
output_fp.write(" ".join(infor["hypotheses"][0]["hypo_tokens"]) + "\n") |
|
|
else: |
|
|
speech_tokens_info = t2u.generate_for_long_text_input_file( |
|
|
t2u.cfg.interactive.input, max_segment_len=max_seg_len |
|
|
) |
|
|
for infor in speech_tokens_info: |
|
|
output_fp.write(" ".join(infor) + "\n") |
|
|
|
|
|
output_fp.flush() |
|
|
output_fp.close() |
|
|
return |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--output", |
|
|
dest="output", |
|
|
required=False, |
|
|
default=None, |
|
|
help="output file", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-seg-len", |
|
|
dest="max_seg_len", |
|
|
required=False, |
|
|
default=0, |
|
|
type=int, |
|
|
help="max segment length", |
|
|
) |
|
|
args, unknown_args = parser.parse_known_args() |
|
|
infer(unknown_args, args.output, args.max_seg_len) |
|
|
|