JustinLin610's picture
first commit
ee21b96
raw
history blame
No virus
2.87 kB
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <iostream>
#include "fstext/fstext-lib.h" // @manual
#include "util/common-utils.h" // @manual
/*
* This program is to modify a FST without self-loop by:
* for each incoming arc with non-eps input symbol, add a self-loop arc
* with that non-eps symbol as input and eps as output.
*
* This is to make sure the resultant FST can do deduplication for repeated
* symbols, which is very common in acoustic model
*
*/
namespace {
int32 AddSelfLoopsSimple(fst::StdVectorFst* fst) {
typedef fst::MutableArcIterator<fst::StdVectorFst> IterType;
int32 num_states_before = fst->NumStates();
fst::MakePrecedingInputSymbolsSame(false, fst);
int32 num_states_after = fst->NumStates();
KALDI_LOG << "There are " << num_states_before
<< " states in the original FST; "
<< " after MakePrecedingInputSymbolsSame, there are "
<< num_states_after << " states " << std::endl;
auto weight_one = fst::StdArc::Weight::One();
int32 num_arc_added = 0;
fst::StdArc self_loop_arc;
self_loop_arc.weight = weight_one;
int32 num_states = fst->NumStates();
std::vector<std::set<int32>> incoming_non_eps_label_per_state(num_states);
for (int32 state = 0; state < num_states; state++) {
for (IterType aiter(fst, state); !aiter.Done(); aiter.Next()) {
fst::StdArc arc(aiter.Value());
if (arc.ilabel != 0) {
incoming_non_eps_label_per_state[arc.nextstate].insert(arc.ilabel);
}
}
}
for (int32 state = 0; state < num_states; state++) {
if (!incoming_non_eps_label_per_state[state].empty()) {
auto& ilabel_set = incoming_non_eps_label_per_state[state];
for (auto it = ilabel_set.begin(); it != ilabel_set.end(); it++) {
self_loop_arc.ilabel = *it;
self_loop_arc.olabel = 0;
self_loop_arc.nextstate = state;
fst->AddArc(state, self_loop_arc);
num_arc_added++;
}
}
}
return num_arc_added;
}
void print_usage() {
std::cout << "add-self-loop-simple usage:\n"
"\tadd-self-loop-simple <in-fst> <out-fst> \n";
}
} // namespace
int main(int argc, char** argv) {
if (argc != 3) {
print_usage();
exit(1);
}
auto input = argv[1];
auto output = argv[2];
auto fst = fst::ReadFstKaldi(input);
auto num_states = fst->NumStates();
KALDI_LOG << "Loading FST from " << input << " with " << num_states
<< " states." << std::endl;
int32 num_arc_added = AddSelfLoopsSimple(fst);
KALDI_LOG << "Adding " << num_arc_added << " self-loop arcs " << std::endl;
fst::WriteFstKaldi(*fst, std::string(output));
KALDI_LOG << "Writing FST to " << output << std::endl;
delete fst;
}