| import tensorflow as tf |
| import os |
| import torch |
| from torch.utils.data import IterableDataset |
|
|
| def _parse_function(example_proto): |
| |
| feature_description = { |
| 'steps/observation/rgb': tf.io.FixedLenSequenceFeature([], tf.string, allow_missing=True), |
| 'steps/observation/instruction': tf.io.FixedLenSequenceFeature([512], tf.int64, allow_missing=True), |
| 'steps/observation/effector_translation': tf.io.FixedLenSequenceFeature([2], tf.float32, allow_missing=True), |
| 'steps/action': tf.io.FixedLenSequenceFeature([2], tf.float32, allow_missing=True), |
| 'steps/reward': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True), |
| 'steps/is_first': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True), |
| 'steps/is_last': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True), |
| } |
| |
| parsed = tf.io.parse_single_example(example_proto, feature_description) |
| |
| |
| def decode_images(rgb_sequence): |
| return tf.map_fn(lambda x: tf.io.decode_jpeg(x), rgb_sequence, fn_output_signature=tf.uint8) |
|
|
| parsed['steps/observation/rgb'] = decode_images(parsed['steps/observation/rgb']) |
| |
| return parsed |
|
|
| class LanguageTableDataset(IterableDataset): |
| def __init__(self, data_dir, num_shards=None): |
| self.data_dir = data_dir |
| self.file_pattern = os.path.join(data_dir, "language_table-train.tfrecord*") |
| self.files = tf.io.gfile.glob(self.file_pattern) |
| if num_shards: |
| self.files = sorted(self.files)[:num_shards] |
| |
| def __iter__(self): |
| dataset = tf.data.TFRecordDataset(self.files) |
| dataset = dataset.map(_parse_function) |
| |
| for item in dataset: |
| yield { |
| 'obs': torch.from_numpy(item['steps/observation/rgb'].numpy()), |
| 'actions': torch.from_numpy(item['steps/action'].numpy()), |
| 'rewards': torch.from_numpy(item['steps/reward'].numpy()), |
| 'effector': torch.from_numpy(item['steps/observation/effector_translation'].numpy()), |
| } |
|
|