Spaces:
Runtime error
Runtime error
OFA-OCR
/
fairseq
/examples
/speech_to_text
/simultaneous_translation
/agents
/fairseq_simul_st_agent.py
import math | |
import os | |
import json | |
import numpy as np | |
import torch | |
import torchaudio.compliance.kaldi as kaldi | |
import yaml | |
from fairseq import checkpoint_utils, tasks | |
from fairseq.file_io import PathManager | |
try: | |
from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS | |
from simuleval.agents import SpeechAgent | |
from simuleval.states import ListEntry, SpeechStates | |
except ImportError: | |
print("Please install simuleval 'pip install simuleval'") | |
SHIFT_SIZE = 10 | |
WINDOW_SIZE = 25 | |
SAMPLE_RATE = 16000 | |
FEATURE_DIM = 80 | |
BOW_PREFIX = "\u2581" | |
class OnlineFeatureExtractor: | |
""" | |
Extract speech feature on the fly. | |
""" | |
def __init__(self, args): | |
self.shift_size = args.shift_size | |
self.window_size = args.window_size | |
assert self.window_size >= self.shift_size | |
self.sample_rate = args.sample_rate | |
self.feature_dim = args.feature_dim | |
self.num_samples_per_shift = int(self.shift_size * self.sample_rate / 1000) | |
self.num_samples_per_window = int(self.window_size * self.sample_rate / 1000) | |
self.len_ms_to_samples = lambda x: x * self.sample_rate / 1000 | |
self.previous_residual_samples = [] | |
self.global_cmvn = args.global_cmvn | |
def clear_cache(self): | |
self.previous_residual_samples = [] | |
def __call__(self, new_samples): | |
samples = self.previous_residual_samples + new_samples | |
if len(samples) < self.num_samples_per_window: | |
self.previous_residual_samples = samples | |
return | |
# num_frames is the number of frames from the new segment | |
num_frames = math.floor( | |
(len(samples) - self.len_ms_to_samples(self.window_size - self.shift_size)) | |
/ self.num_samples_per_shift | |
) | |
# the number of frames used for feature extraction | |
# including some part of thte previous segment | |
effective_num_samples = int( | |
num_frames * self.len_ms_to_samples(self.shift_size) | |
+ self.len_ms_to_samples(self.window_size - self.shift_size) | |
) | |
input_samples = samples[:effective_num_samples] | |
self.previous_residual_samples = samples[ | |
num_frames * self.num_samples_per_shift: | |
] | |
torch.manual_seed(1) | |
output = kaldi.fbank( | |
torch.FloatTensor(input_samples).unsqueeze(0), | |
num_mel_bins=self.feature_dim, | |
frame_length=self.window_size, | |
frame_shift=self.shift_size, | |
).numpy() | |
output = self.transform(output) | |
return torch.from_numpy(output) | |
def transform(self, input): | |
if self.global_cmvn is None: | |
return input | |
mean = self.global_cmvn["mean"] | |
std = self.global_cmvn["std"] | |
x = np.subtract(input, mean) | |
x = np.divide(x, std) | |
return x | |
class TensorListEntry(ListEntry): | |
""" | |
Data structure to store a list of tensor. | |
""" | |
def append(self, value): | |
if len(self.value) == 0: | |
self.value = value | |
return | |
self.value = torch.cat([self.value] + [value], dim=0) | |
def info(self): | |
return { | |
"type": str(self.new_value_type), | |
"length": self.__len__(), | |
"value": "" if type(self.value) is list else self.value.size(), | |
} | |
class FairseqSimulSTAgent(SpeechAgent): | |
speech_segment_size = 40 # in ms, 4 pooling ratio * 10 ms step size | |
def __init__(self, args): | |
super().__init__(args) | |
self.eos = DEFAULT_EOS | |
self.gpu = getattr(args, "gpu", False) | |
self.args = args | |
self.load_model_vocab(args) | |
if getattr( | |
self.model.decoder.layers[0].encoder_attn, | |
'pre_decision_ratio', | |
None | |
) is not None: | |
self.speech_segment_size *= ( | |
self.model.decoder.layers[0].encoder_attn.pre_decision_ratio | |
) | |
args.global_cmvn = None | |
if args.config: | |
with open(os.path.join(args.data_bin, args.config), "r") as f: | |
config = yaml.load(f, Loader=yaml.BaseLoader) | |
if "global_cmvn" in config: | |
args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) | |
if args.global_stats: | |
with PathManager.open(args.global_stats, "r") as f: | |
global_cmvn = json.loads(f.read()) | |
self.global_cmvn = {"mean": global_cmvn["mean"], "std": global_cmvn["stddev"]} | |
self.feature_extractor = OnlineFeatureExtractor(args) | |
self.max_len = args.max_len | |
self.force_finish = args.force_finish | |
torch.set_grad_enabled(False) | |
def build_states(self, args, client, sentence_id): | |
# Initialize states here, for example add customized entry to states | |
# This function will be called at beginning of every new sentence | |
states = SpeechStates(args, client, sentence_id, self) | |
self.initialize_states(states) | |
return states | |
def to_device(self, tensor): | |
if self.gpu: | |
return tensor.cuda() | |
else: | |
return tensor.cpu() | |
def add_args(parser): | |
# fmt: off | |
parser.add_argument('--model-path', type=str, required=True, | |
help='path to your pretrained model.') | |
parser.add_argument("--data-bin", type=str, required=True, | |
help="Path of data binary") | |
parser.add_argument("--config", type=str, default=None, | |
help="Path to config yaml file") | |
parser.add_argument("--global-stats", type=str, default=None, | |
help="Path to json file containing cmvn stats") | |
parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece", | |
help="Subword splitter type for target text") | |
parser.add_argument("--tgt-splitter-path", type=str, default=None, | |
help="Subword splitter model path for target text") | |
parser.add_argument("--user-dir", type=str, default="examples/simultaneous_translation", | |
help="User directory for simultaneous translation") | |
parser.add_argument("--max-len", type=int, default=200, | |
help="Max length of translation") | |
parser.add_argument("--force-finish", default=False, action="store_true", | |
help="Force the model to finish the hypothsis if the source is not finished") | |
parser.add_argument("--shift-size", type=int, default=SHIFT_SIZE, | |
help="Shift size of feature extraction window.") | |
parser.add_argument("--window-size", type=int, default=WINDOW_SIZE, | |
help="Window size of feature extraction window.") | |
parser.add_argument("--sample-rate", type=int, default=SAMPLE_RATE, | |
help="Sample rate") | |
parser.add_argument("--feature-dim", type=int, default=FEATURE_DIM, | |
help="Acoustic feature dimension.") | |
# fmt: on | |
return parser | |
def load_model_vocab(self, args): | |
filename = args.model_path | |
if not os.path.exists(filename): | |
raise IOError("Model file not found: {}".format(filename)) | |
state = checkpoint_utils.load_checkpoint_to_cpu(filename) | |
task_args = state["cfg"]["task"] | |
task_args.data = args.data_bin | |
if args.config is not None: | |
task_args.config_yaml = args.config | |
task = tasks.setup_task(task_args) | |
# build model for ensemble | |
state["cfg"]["model"].load_pretrained_encoder_from = None | |
state["cfg"]["model"].load_pretrained_decoder_from = None | |
self.model = task.build_model(state["cfg"]["model"]) | |
self.model.load_state_dict(state["model"], strict=True) | |
self.model.eval() | |
self.model.share_memory() | |
if self.gpu: | |
self.model.cuda() | |
# Set dictionary | |
self.dict = {} | |
self.dict["tgt"] = task.target_dictionary | |
def initialize_states(self, states): | |
self.feature_extractor.clear_cache() | |
states.units.source = TensorListEntry() | |
states.units.target = ListEntry() | |
states.incremental_states = dict() | |
def segment_to_units(self, segment, states): | |
# Convert speech samples to features | |
features = self.feature_extractor(segment) | |
if features is not None: | |
return [features] | |
else: | |
return [] | |
def units_to_segment(self, units, states): | |
# Merge sub word to full word. | |
if self.model.decoder.dictionary.eos() == units[0]: | |
return DEFAULT_EOS | |
segment = [] | |
if None in units.value: | |
units.value.remove(None) | |
for index in units: | |
if index is None: | |
units.pop() | |
token = self.model.decoder.dictionary.string([index]) | |
if token.startswith(BOW_PREFIX): | |
if len(segment) == 0: | |
segment += [token.replace(BOW_PREFIX, "")] | |
else: | |
for j in range(len(segment)): | |
units.pop() | |
string_to_return = ["".join(segment)] | |
if self.model.decoder.dictionary.eos() == units[0]: | |
string_to_return += [DEFAULT_EOS] | |
return string_to_return | |
else: | |
segment += [token.replace(BOW_PREFIX, "")] | |
if ( | |
len(units) > 0 | |
and self.model.decoder.dictionary.eos() == units[-1] | |
or len(states.units.target) > self.max_len | |
): | |
tokens = [self.model.decoder.dictionary.string([unit]) for unit in units] | |
return ["".join(tokens).replace(BOW_PREFIX, "")] + [DEFAULT_EOS] | |
return None | |
def update_model_encoder(self, states): | |
if len(states.units.source) == 0: | |
return | |
src_indices = self.to_device( | |
states.units.source.value.unsqueeze(0) | |
) | |
src_lengths = self.to_device( | |
torch.LongTensor([states.units.source.value.size(0)]) | |
) | |
states.encoder_states = self.model.encoder(src_indices, src_lengths) | |
torch.cuda.empty_cache() | |
def update_states_read(self, states): | |
# Happens after a read action. | |
self.update_model_encoder(states) | |
def policy(self, states): | |
if not getattr(states, "encoder_states", None): | |
return READ_ACTION | |
tgt_indices = self.to_device( | |
torch.LongTensor( | |
[self.model.decoder.dictionary.eos()] | |
+ [x for x in states.units.target.value if x is not None] | |
).unsqueeze(0) | |
) | |
states.incremental_states["steps"] = { | |
"src": states.encoder_states["encoder_out"][0].size(0), | |
"tgt": 1 + len(states.units.target), | |
} | |
states.incremental_states["online"] = {"only": torch.tensor(not states.finish_read())} | |
x, outputs = self.model.decoder.forward( | |
prev_output_tokens=tgt_indices, | |
encoder_out=states.encoder_states, | |
incremental_state=states.incremental_states, | |
) | |
states.decoder_out = x | |
states.decoder_out_extra = outputs | |
torch.cuda.empty_cache() | |
if outputs.action == 0: | |
return READ_ACTION | |
else: | |
return WRITE_ACTION | |
def predict(self, states): | |
decoder_states = states.decoder_out | |
lprobs = self.model.get_normalized_probs( | |
[decoder_states[:, -1:]], log_probs=True | |
) | |
index = lprobs.argmax(dim=-1) | |
index = index[0, 0].item() | |
if ( | |
self.force_finish | |
and index == self.model.decoder.dictionary.eos() | |
and not states.finish_read() | |
): | |
# If we want to force finish the translation | |
# (don't stop before finish reading), return a None | |
# self.model.decoder.clear_cache(states.incremental_states) | |
index = None | |
return index | |