Spaces:
Build error
Build error
# Copyright 2022 The MT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Simple debugging utility for printing out task contents.""" | |
import re | |
from absl import app | |
from absl import flags | |
import mt3.tasks # pylint: disable=unused-import | |
import seqio | |
import tensorflow as tf | |
FLAGS = flags.FLAGS | |
flags.DEFINE_string("task", None, "A registered Task.") | |
flags.DEFINE_string("task_cache_dir", None, "Directory to use for task cache.") | |
flags.DEFINE_integer("max_examples", 10, | |
"Maximum number of examples (-1 for no limit).") | |
flags.DEFINE_string("format_string", "targets = {targets}", | |
"Format for printing examples.") | |
flags.DEFINE_string("split", "train", | |
"Which split of the dataset, e.g. train or validation.") | |
flags.DEFINE_integer("sequence_length_inputs", 256, | |
"Sequence length for inputs.") | |
flags.DEFINE_integer("sequence_length_targets", 1024, | |
"Sequence length for targets.") | |
def main(_): | |
if FLAGS.task_cache_dir: | |
seqio.add_global_cache_dirs([FLAGS.task_cache_dir]) | |
task = seqio.get_mixture_or_task(FLAGS.task) | |
ds = task.get_dataset( | |
sequence_length={ | |
"inputs": FLAGS.sequence_length_inputs, | |
"targets": FLAGS.sequence_length_targets, | |
}, | |
split=FLAGS.split, | |
use_cached=bool(FLAGS.task_cache_dir), | |
shuffle=False) | |
keys = re.findall(r"{([\w+]+)}", FLAGS.format_string) | |
def _example_to_string(ex): | |
key_to_string = {} | |
for k in keys: | |
if k in ex: | |
v = ex[k].numpy().tolist() | |
key_to_string[k] = task.output_features[k].vocabulary.decode(v) | |
else: | |
key_to_string[k] = "" | |
return FLAGS.format_string.format(**key_to_string) | |
for ex in ds.take(FLAGS.max_examples): | |
for k, v in ex.items(): | |
print(f"{k}: {tf.shape(v)}") | |
print(_example_to_string(ex)) | |
print() | |
if __name__ == "__main__": | |
flags.mark_flags_as_required(["task"]) | |
app.run(main) | |