Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Translate text or files using trained transformer model.""" | |
# Import libraries | |
from absl import logging | |
import numpy as np | |
import tensorflow as tf, tf_keras | |
from official.legacy.transformer.utils import tokenizer | |
_EXTRA_DECODE_LENGTH = 100 | |
_BEAM_SIZE = 4 | |
_ALPHA = 0.6 | |
def _get_sorted_inputs(filename): | |
"""Read and sort lines from the file sorted by decreasing length. | |
Args: | |
filename: String name of file to read inputs from. | |
Returns: | |
Sorted list of inputs, and dictionary mapping original index->sorted index | |
of each element. | |
""" | |
with tf.io.gfile.GFile(filename) as f: | |
records = f.read().split("\n") | |
inputs = [record.strip() for record in records] | |
if not inputs[-1]: | |
inputs.pop() | |
input_lens = [(i, len(line.split())) for i, line in enumerate(inputs)] | |
sorted_input_lens = sorted(input_lens, key=lambda x: x[1], reverse=True) | |
sorted_inputs = [None] * len(sorted_input_lens) | |
sorted_keys = [0] * len(sorted_input_lens) | |
for i, (index, _) in enumerate(sorted_input_lens): | |
sorted_inputs[i] = inputs[index] | |
sorted_keys[index] = i | |
return sorted_inputs, sorted_keys | |
def _encode_and_add_eos(line, subtokenizer): | |
"""Encode line with subtokenizer, and add EOS id to the end.""" | |
return subtokenizer.encode(line) + [tokenizer.EOS_ID] | |
def _trim_and_decode(ids, subtokenizer): | |
"""Trim EOS and PAD tokens from ids, and decode to return a string.""" | |
try: | |
index = list(ids).index(tokenizer.EOS_ID) | |
return subtokenizer.decode(ids[:index]) | |
except ValueError: # No EOS found in sequence | |
return subtokenizer.decode(ids) | |
def translate_file(model, | |
params, | |
subtokenizer, | |
input_file, | |
output_file=None, | |
print_all_translations=True, | |
distribution_strategy=None): | |
"""Translate lines in file, and save to output file if specified. | |
Args: | |
model: A Keras model, used to generate the translations. | |
params: A dictionary, containing the translation related parameters. | |
subtokenizer: A subtokenizer object, used for encoding and decoding source | |
and translated lines. | |
input_file: A file containing lines to translate. | |
output_file: A file that stores the generated translations. | |
print_all_translations: A bool. If true, all translations are printed to | |
stdout. | |
distribution_strategy: A distribution strategy, used to perform inference | |
directly with tf.function instead of Keras model.predict(). | |
Raises: | |
ValueError: if output file is invalid. | |
""" | |
batch_size = params["decode_batch_size"] | |
# Read and sort inputs by length. Keep dictionary (original index-->new index | |
# in sorted list) to write translations in the original order. | |
sorted_inputs, sorted_keys = _get_sorted_inputs(input_file) | |
total_samples = len(sorted_inputs) | |
num_decode_batches = (total_samples - 1) // batch_size + 1 | |
def input_generator(): | |
"""Yield encoded strings from sorted_inputs.""" | |
for i in range(num_decode_batches): | |
lines = [ | |
sorted_inputs[j + i * batch_size] | |
for j in range(batch_size) | |
if j + i * batch_size < total_samples | |
] | |
lines = [_encode_and_add_eos(l, subtokenizer) for l in lines] | |
if distribution_strategy: | |
for j in range(batch_size - len(lines)): | |
lines.append([tokenizer.EOS_ID]) | |
batch = tf_keras.preprocessing.sequence.pad_sequences( | |
lines, | |
maxlen=params["decode_max_length"], | |
dtype="int32", | |
padding="post") | |
logging.info("Decoding batch %d out of %d.", i, num_decode_batches) | |
yield batch | |
def predict_step(inputs): | |
"""Decoding step function for TPU runs.""" | |
def _step_fn(inputs): | |
"""Per replica step function.""" | |
tag = inputs[0] | |
val_inputs = inputs[1] | |
val_outputs, _ = model([val_inputs], training=False) | |
return tag, val_outputs | |
return distribution_strategy.run(_step_fn, args=(inputs,)) | |
translations = [] | |
if distribution_strategy: | |
num_replicas = distribution_strategy.num_replicas_in_sync | |
local_batch_size = params["decode_batch_size"] // num_replicas | |
for i, text in enumerate(input_generator()): | |
if distribution_strategy: | |
text = np.reshape(text, [num_replicas, local_batch_size, -1]) | |
# Add tag to the input of each replica with the reordering logic after | |
# outputs, to ensure the output order matches the input order. | |
text = tf.constant(text) | |
def text_as_per_replica(): | |
replica_context = tf.distribute.get_replica_context() | |
replica_id = replica_context.replica_id_in_sync_group | |
return replica_id, text[replica_id] # pylint: disable=cell-var-from-loop | |
text = distribution_strategy.run(text_as_per_replica) | |
outputs = distribution_strategy.experimental_local_results( | |
predict_step(text)) | |
val_outputs = [output for _, output in outputs] | |
val_outputs = np.reshape(val_outputs, [params["decode_batch_size"], -1]) | |
else: | |
val_outputs, _ = model.predict(text) | |
length = len(val_outputs) | |
for j in range(length): | |
if j + i * batch_size < total_samples: | |
translation = _trim_and_decode(val_outputs[j], subtokenizer) | |
translations.append(translation) | |
if print_all_translations: | |
logging.info("Translating:\n\tInput: %s\n\tOutput: %s", | |
sorted_inputs[j + i * batch_size], translation) | |
# Write translations in the order they appeared in the original file. | |
if output_file is not None: | |
if tf.io.gfile.isdir(output_file): | |
raise ValueError("File output is a directory, will not save outputs to " | |
"file.") | |
logging.info("Writing to file %s", output_file) | |
with tf.io.gfile.GFile(output_file, "w") as f: | |
for i in sorted_keys: | |
f.write("%s\n" % translations[i]) | |
def translate_from_text(model, subtokenizer, txt): | |
encoded_txt = _encode_and_add_eos(txt, subtokenizer) | |
result = model.predict(encoded_txt) | |
outputs = result["outputs"] | |
logging.info("Original: \"%s\"", txt) | |
translate_from_input(outputs, subtokenizer) | |
def translate_from_input(outputs, subtokenizer): | |
translation = _trim_and_decode(outputs, subtokenizer) | |
logging.info("Translation: \"%s\"", translation) | |