DSTK / text2token /simple_infer.py
gooorillax's picture
first push of codes and models for g2p, t2u, tokenizer and detokenizer
cd8454d
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Yusen Sun,
# Xiao Chen)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import fileinput
import logging
import os
import sys
import time
import argparse
from collections import namedtuple
from tqdm import tqdm
from pathlib import Path
import numpy as np
import torch
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.dataclass.configs import FairseqConfig
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
from fairseq_cli.generate import get_symbols_to_strip_from_output
from fairseq.models import import_models
PHONE_SPLITTER = {"[SIL]", "[CM]", "[PD]", "[QN]", "[EX]"}
current_root = Path(__file__).absolute().parent
sys.path.append(str(current_root))
sys.path.append(str(current_root.parent / "thirdparty/G2P"))
from G2P_processors import MultilingualG2P
relative_path = Path(current_root.name)
namespace = str(relative_path / "models").replace("/" , ".")
import_models(str(current_root / "models"), namespace)
TOKENIZE_ON_NPU = os.environ.get("TOKENIZE_ON_NPU")
if TOKENIZE_ON_NPU is not None and TOKENIZE_ON_NPU == "1":
import torch_npu
from torch_npu.contrib import transfer_to_npu
logging.info("Applying Patches for NPU!!!")
console_format = logging.Formatter(
"[%(asctime)s][%(filename)s:%(levelname)s][%(process)d:%(threadName)s]%(message)s"
)
console_handler = logging.StreamHandler()
console_handler.setFormatter(console_format)
console_handler.setLevel(logging.INFO)
if len(logging.root.handlers) > 0:
for handler in logging.root.handlers:
logging.root.removeHandler(handler)
logging.root.addHandler(console_handler)
logging.root.setLevel(logging.INFO)
Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints")
Translation = namedtuple("Translation", "src_str hypos pos_scores alignments")
DEFAULT_T2U_ARGS = [
str(current_root) + "/data_bin",
"--path",
str(current_root) + "/ckpt/40ms.checkpoint15.pt",
"--batch-size",
"1",
"--buffer-size",
"2",
"--beam",
"5",
"--max-len-b",
"1024",
# "--input",
# "./sample.txt",
"--source-lang",
"ph",
"--target-lang",
"tgt.unit",
]
def dummy_encode_fn(x):
return x
class Text2TokenGenerator:
def __init__(self, args=None) -> None:
self._initialize(args)
def _initialize(self, args):
t2u_args = DEFAULT_T2U_ARGS
if args is not None and len(args) > 0:
t2u_args = t2u_args + args
parser = options.get_interactive_generation_parser()
t2u_fairseq_args = options.parse_args_and_arch(
parser=parser, input_args=t2u_args
)
cfg: FairseqConfig = convert_namespace_to_omegaconf(t2u_fairseq_args)
utils.import_user_module(cfg.common)
if cfg.interactive.buffer_size < 1:
cfg.interactive.buffer_size = 1
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
cfg.dataset.batch_size = 1
assert (
not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
not cfg.dataset.batch_size
or cfg.dataset.batch_size <= cfg.interactive.buffer_size
), "--batch-size cannot be larger than --buffer-size"
self.cfg = cfg
logging.info(self.cfg)
# Fix seed for stochastic decoding
if (
self.cfg.common.seed is not None
and not self.cfg.generation.no_seed_provided
):
np.random.seed(self.cfg.common.seed)
utils.set_torch_seed(self.cfg.common.seed)
self.use_cuda = torch.cuda.is_available() and not self.cfg.common.cpu
# Setup task, e.g., translation
self.task = tasks.setup_task(self.cfg.task)
# Load ensemble
overrides = ast.literal_eval(self.cfg.common_eval.model_overrides)
logging.info("loading model(s) from {}".format(self.cfg.common_eval.path))
self.models, _model_args = checkpoint_utils.load_model_ensemble(
utils.split_paths(self.cfg.common_eval.path),
arg_overrides=overrides,
task=self.task,
suffix=self.cfg.checkpoint.checkpoint_suffix,
strict=(self.cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=self.cfg.checkpoint.checkpoint_shard_count,
)
# Set dictionaries
self.src_dict = self.task.source_dictionary
self.tgt_dict = self.task.target_dictionary
# Optimize ensemble for generation
for model in self.models:
if model is None:
continue
if self.cfg.common.fp16:
model.half()
if (
self.use_cuda
and not self.cfg.distributed_training.pipeline_model_parallel
):
model.cuda()
model.prepare_for_inference_(cfg)
# Initialize generator
self.generator = self.task.build_generator(self.models, self.cfg.generation)
# Handle tokenization and BPE
self.tokenizer = self.task.build_tokenizer(cfg.tokenizer)
self.bpe = self.task.build_bpe(cfg.bpe)
self.align_dict = None
self.max_positions = utils.resolve_max_positions(
self.task.max_positions(), *[model.max_positions() for model in self.models]
)
# init G2P
self.language = "zh" # zh means the model treats all non-English as Chinese, en means the model treats all langauge as English.
self.mG2P = MultilingualG2P(
"wenet", remove_interjections=False, remove_erhua=False
) # 'baidu' or 'wenet'
def text2phone(self, text):
phones, norm_text = self.mG2P.text_normalization_and_g2p(
text, self.language, with_lang_prefix=True, normalize_punct=True
)
return " ".join(phones)
def buffered_read(self, input, buffer_size):
buffer = []
with fileinput.input(
files=[input], openhook=fileinput.hook_encoded("utf-8")
) as h:
for src_str in h:
phones = self.text2phone(src_str.strip())
buffer.append(phones)
if len(buffer) >= buffer_size:
yield buffer
buffer = []
if len(buffer) > 0:
yield buffer
def make_batches(self, lines, encode_fn):
def encode_fn_target(x):
return encode_fn(x)
if self.cfg.generation.constraints:
# Strip (tab-delimited) contraints, if present, from input lines,
# store them in batch_constraints
batch_constraints = [list() for _ in lines]
for i, line in enumerate(lines):
if "\t" in line:
lines[i], *batch_constraints[i] = line.split("\t")
# Convert each List[str] to List[Tensor]
for i, constraint_list in enumerate(batch_constraints):
batch_constraints[i] = [
self.task.target_dictionary.encode_line(
encode_fn_target(constraint),
append_eos=False,
add_if_not_exist=False,
)
for constraint in constraint_list
]
if self.cfg.generation.constraints:
constraints_tensor = pack_constraints(batch_constraints)
else:
constraints_tensor = None
tokens, lengths = self.task.get_interactive_tokens_and_lengths(lines, encode_fn)
itr = self.task.get_batch_iterator(
dataset=self.task.build_dataset_for_inference(
tokens, lengths, constraints=constraints_tensor
),
max_tokens=self.cfg.dataset.max_tokens,
max_sentences=self.cfg.dataset.batch_size,
max_positions=self.max_positions,
ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test,
).next_epoch_itr(shuffle=False)
for batch in itr:
ids = batch["id"]
src_tokens = batch["net_input"]["src_tokens"]
src_lengths = batch["net_input"]["src_lengths"]
constraints = batch.get("constraints", None)
yield Batch(
ids=ids,
src_tokens=src_tokens,
src_lengths=src_lengths,
constraints=constraints,
)
def generate_for_text_file_input(self, input):
start_time = time.time()
total_translate_time = 0
hypo_outputs = []
start_id = 0
for inputs in self.buffered_read(input, self.cfg.interactive.buffer_size):
results = []
for batch in self.make_batches(inputs, dummy_encode_fn):
bsz = batch.src_tokens.size(0)
src_tokens = batch.src_tokens
src_lengths = batch.src_lengths
constraints = batch.constraints
if self.use_cuda:
src_tokens = src_tokens.cuda()
src_lengths = src_lengths.cuda()
if constraints is not None:
constraints = constraints.cuda()
sample = {
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
},
}
translate_start_time = time.time()
translations = self.task.inference_step(
self.generator, self.models, sample, constraints=constraints
)
translate_time = time.time() - translate_start_time
total_translate_time += translate_time
list_constraints = [[] for _ in range(bsz)]
if self.cfg.generation.constraints:
list_constraints = [unpack_constraints(c) for c in constraints]
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad())
constraints = list_constraints[i]
results.append(
(
start_id + id,
src_tokens_i,
hypos,
{
"constraints": constraints,
"time": translate_time / len(translations),
},
)
)
# sort output to match input order
for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]):
output = {}
output["src_tokens"] = []
# src_str = ""
if self.src_dict is not None:
src_str = self.src_dict.string(
src_tokens, self.cfg.common_eval.post_process
)
output["src_tokens"] = src_str.split()
# Process top predictions
output["hypotheses"] = []
for hypo in hypos[: min(len(hypos), self.cfg.generation.nbest)]:
hypo_str = self.tgt_dict.string(
hypo["tokens"].int().cpu(),
self.cfg.common_eval.post_process,
extra_symbols_to_ignore=get_symbols_to_strip_from_output(
self.generator
),
)
output["hypotheses"].append(
{
"hypo_tokens": hypo_str.split(),
"alignment": hypo["alignment"],
}
)
hypo_outputs.append(output)
# update running id_ counter
start_id += len(inputs)
logging.info(
"Total time: {:.3f} seconds; translation time: {:.3f}".format(
time.time() - start_time, total_translate_time
)
)
return hypo_outputs
def split_phone_segments(self, phones, max_segment_len=0):
phone_segments = []
phone_splits = phones.split()
seps = []
for idx in range(len(phone_splits)):
ph = phone_splits[idx]
if ph in PHONE_SPLITTER:
seps.append(idx)
if len(seps) <= 0:
return [phones]
if seps[-1] < len(phone_splits) - 1:
seps.append(len(phone_splits) - 1)
segment_start = 0
segment_end = 0
for idx in range(len(seps)):
seglen = seps[idx] - segment_start + 1
if seglen >= max_segment_len or idx == len(seps) - 1:
segment_end = segment_start + seglen
phone_segments.append(" ".join(phone_splits[segment_start:segment_end]))
segment_start = segment_end
else:
continue
reproduce_phone = " ".join(phone_segments)
if phones != reproduce_phone:
logging.info(f"ERROR!!!!! segments shorter than phones")
exit()
return phone_segments
def generate_for_long_input_text(self, input_phones, max_segment_len=0):
total_translate_time = 0
input_segments = []
segment_lens = []
for input in input_phones:
segments = self.split_phone_segments(input, max_segment_len)
segment_lens.append(len(segments))
input_segments.extend(segments)
logging.info(
f"Spliting {len(input_phones)} inputs into {len(input_segments)} segments"
)
results = []
start_id = 0
for batch in self.make_batches(input_segments, dummy_encode_fn):
bsz = batch.src_tokens.size(0)
src_tokens = batch.src_tokens
src_lengths = batch.src_lengths
constraints = batch.constraints
if self.use_cuda:
src_tokens = src_tokens.cuda()
src_lengths = src_lengths.cuda()
if constraints is not None:
constraints = constraints.cuda()
sample = {
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
},
}
logging.info(f"processing batch: {bsz}")
translate_start_time = time.time()
translations = self.task.inference_step(
self.generator, self.models, sample, constraints=constraints
)
translate_time = time.time() - translate_start_time
total_translate_time += translate_time
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
results.append((start_id + id, hypos))
segment_results = []
sorted_results = sorted(results, key=lambda x: x[0])
start_pos = 0
for sl in segment_lens:
segment_results.append(sorted_results[start_pos : start_pos + sl])
start_pos += sl
assert len(input_phones) == len(segment_results)
hypo_tokens = []
for seg_res in segment_results:
token_res = []
for id_, hypos in seg_res:
# Process top predictions
hypo = hypos[0]
hypo_str = self.tgt_dict.string(
hypo["tokens"].int().cpu(),
self.cfg.common_eval.post_process,
extra_symbols_to_ignore=get_symbols_to_strip_from_output(
self.generator
),
)
token_res.extend(hypo_str.split())
hypo_tokens.append(token_res)
return hypo_tokens, total_translate_time
def generate_for_long_text_input_file(self, input, max_segment_len=0):
start_time = time.time()
total_translate_time = 0
hypo_outputs = []
for inputs in self.buffered_read(input, self.cfg.interactive.buffer_size):
logging.info(f"processing inputs: {len(inputs)}")
# for input_phones in tqdm(inputs):
hypo_tokens, translate_time = self.generate_for_long_input_text(
inputs, max_segment_len=max_segment_len
)
total_translate_time += translate_time
hypo_outputs.extend(hypo_tokens)
logging.info(
"Total time: {:.3f} seconds; translation time: {:.3f}".format(
time.time() - start_time, total_translate_time
)
)
return hypo_outputs
def infer(unk_args, output_file, max_seg_len):
output_fp = sys.stdout
if output_file is not None:
output_fp = open(output_file, "w")
t2u = Text2TokenGenerator(unk_args)
if max_seg_len <= 0:
speech_tokens_info = t2u.generate_for_text_file_input(t2u.cfg.interactive.input)
for infor in speech_tokens_info:
output_fp.write(" ".join(infor["hypotheses"][0]["hypo_tokens"]) + "\n")
else:
speech_tokens_info = t2u.generate_for_long_text_input_file(
t2u.cfg.interactive.input, max_segment_len=max_seg_len
)
for infor in speech_tokens_info:
output_fp.write(" ".join(infor) + "\n")
output_fp.flush()
output_fp.close()
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--output",
dest="output",
required=False,
default=None,
help="output file",
)
parser.add_argument(
"--max-seg-len",
dest="max_seg_len",
required=False,
default=0,
type=int,
help="max segment length",
)
args, unknown_args = parser.parse_known_args()
infer(unknown_args, args.output, args.max_seg_len)