DSTK / text2token /infer_for_detok.py
gooorillax's picture
first push of codes and models for g2p, t2u, tokenizer and detokenizer
cd8454d
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Yusen Sun,
# Xiao Chen)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from simple_infer import Text2TokenGenerator, dummy_encode_fn
from fairseq.dataclass.configs import FairseqConfig
import fileinput
from fairseq import utils, options
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
from fairseq_cli.generate import get_symbols_to_strip_from_output
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from collections import namedtuple
import time
import logging
import sys
import os
from tqdm import tqdm
Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints")
class T2USeedTTS(Text2TokenGenerator):
def __init__(self, args):
super().__init__(args)
def buffered_read(self, input, buffer_size):
buffer = []
with fileinput.input(
files=[input], openhook=fileinput.hook_encoded("utf-8")
) as h:
for src_str in h:
fields = src_str.strip().split("|")
phones = self.text2phone(fields[-1])
buffer.append(
[fields[0], fields[1], fields[2], fields[3], phones]
) # (ref_wav, ref_wav_tokens, id, text, phones)
if len(buffer) >= buffer_size:
yield buffer
buffer = []
if len(buffer) > 0:
yield buffer
def generate_for_text_file_input(self, input):
start_time = time.time()
total_translate_time = 0
hypo_outputs = []
start_id = 0
for inputs in self.buffered_read(input, self.cfg.interactive.buffer_size):
phone_lines = [x[-1] for x in inputs]
results = []
for batch in self.make_batches(phone_lines, dummy_encode_fn):
bsz = batch.src_tokens.size(0)
src_tokens = batch.src_tokens
src_lengths = batch.src_lengths
constraints = batch.constraints
if self.use_cuda:
src_tokens = src_tokens.cuda()
src_lengths = src_lengths.cuda()
if constraints is not None:
constraints = constraints.cuda()
sample = {
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
},
}
logging.info(f"Processing batch of size: {bsz}")
translate_start_time = time.time()
translations = self.task.inference_step(
self.generator, self.models, sample, constraints=constraints
)
translate_time = time.time() - translate_start_time
total_translate_time += translate_time
list_constraints = [[] for _ in range(bsz)]
# if self.cfg.generation.constraints:
# list_constraints = [unpack_constraints(c) for c in constraints]
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad())
constraints = list_constraints[i]
results.append(
(
start_id + id,
src_tokens_i,
hypos,
{
"constraints": constraints,
"time": translate_time / len(translations),
},
)
)
# sort output to match input order
for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]):
output = {}
output["src_tokens"] = []
# src info
input_info = inputs[id_ % self.cfg.interactive.buffer_size]
output["src_info"] = input_info
# src_str = ""
if self.src_dict is not None:
src_str = self.src_dict.string(
src_tokens, self.cfg.common_eval.post_process
)
if src_str != input_info[-1]:
logging.info(f"ERROR, input output mismatch!!")
logging.info(f"{src_str}")
logging.info(f"{ input_info[-1]}")
output["src_tokens"] = src_str.split()
# Process top predictions
output["hypotheses"] = []
for hypo in hypos[: min(len(hypos), self.cfg.generation.nbest)]:
hypo_str = self.tgt_dict.string(
hypo["tokens"].int().cpu(),
self.cfg.common_eval.post_process,
extra_symbols_to_ignore=get_symbols_to_strip_from_output(
self.generator
),
)
output["hypotheses"].append(
{
"hypo_tokens": hypo_str.split(),
"alignment": hypo["alignment"],
}
)
hypo_outputs.append(output)
logging.info(f"output records: {len(hypo_outputs)}")
# update running id_ counter
start_id += len(inputs)
logging.info(
"Total time: {:.3f} seconds; translation time: {:.3f}".format(
time.time() - start_time, total_translate_time
)
)
return hypo_outputs
def generate_for_long_text_input_file(self, input, max_segment_len=0):
start_time = time.time()
total_translate_time = 0
hypo_outputs = []
for inputs in self.buffered_read(input, self.cfg.interactive.buffer_size):
logging.info(f"processing inputs: {len(inputs)}")
phones = [input_info[-1] for input_info in inputs]
hypo_tokens, translate_time = self.generate_for_long_input_text(
phones, max_segment_len=max_segment_len
)
total_translate_time += translate_time
for tok, info in zip(hypo_tokens, inputs):
hypo_outputs.append({"hypotheses": tok, "src_info": info})
logging.info(
"Total time: {:.3f} seconds; translation time: {:.3f}".format(
time.time() - start_time, total_translate_time
)
)
return hypo_outputs
def infer(unk_args, output_file, max_seg_len):
output_fp = sys.stdout
if output_file is not None:
output_fp = open(output_file, "w")
t2u = T2USeedTTS(unk_args)
logging.info(f"Using max-seg-len = {max_seg_len}")
if max_seg_len <= 0:
speech_tokens_info = t2u.generate_for_text_file_input(t2u.cfg.interactive.input)
for infor in speech_tokens_info:
token_str = " ".join(infor["hypotheses"][0]["hypo_tokens"])
text = infor["src_info"][3]
ref_wav = infor["src_info"][0]
ref_token = infor["src_info"][1]
test_id = infor["src_info"][2]
test_line = f"{ref_wav}|{ref_token}|{test_id}.wav|{token_str}|{text}"
output_fp.write(test_line + "\n")
else:
logging.info(f"Split long text into segments of length: {max_seg_len}")
speech_tokens_info = t2u.generate_for_long_text_input_file(
t2u.cfg.interactive.input, max_segment_len=max_seg_len
)
for infor in speech_tokens_info:
token_str = " ".join(infor["hypotheses"])
text = infor["src_info"][3]
ref_wav = infor["src_info"][0]
ref_token = infor["src_info"][1]
test_id = infor["src_info"][2]
test_line = f"{ref_wav}|{ref_token}|{test_id}.wav|{token_str}|{text}"
output_fp.write(test_line + "\n")
# speech_tokens_info = t2u.generate("只有当科技为本地社群创造价值的时候,才真正有意义。")
# output_fp.write(" ".join(speech_tokens_info["hypotheses"][0]["hypo_tokens"]) + "\n")
output_fp.flush()
output_fp.close()
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--output",
dest="output",
required=False,
default=None,
help="output file",
)
parser.add_argument(
"--max-seg-len",
dest="max_seg_len",
required=False,
default=0,
type=int,
help="max segment length",
)
args, unknown_args = parser.parse_known_args()
infer(unknown_args, args.output, args.max_seg_len)