Spaces:
Running
Running
#!/usr/bin/env python | |
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Iterator, Union | |
import torch | |
class EpisodeAwareSampler: | |
def __init__( | |
self, | |
episode_data_index: dict, | |
episode_indices_to_use: Union[list, None] = None, | |
drop_n_first_frames: int = 0, | |
drop_n_last_frames: int = 0, | |
shuffle: bool = False, | |
): | |
"""Sampler that optionally incorporates episode boundary information. | |
Args: | |
episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode. | |
episode_indices_to_use: List of episode indices to use. If None, all episodes are used. | |
Assumes that episodes are indexed from 0 to N-1. | |
drop_n_first_frames: Number of frames to drop from the start of each episode. | |
drop_n_last_frames: Number of frames to drop from the end of each episode. | |
shuffle: Whether to shuffle the indices. | |
""" | |
indices = [] | |
for episode_idx, (start_index, end_index) in enumerate( | |
zip(episode_data_index["from"], episode_data_index["to"], strict=True) | |
): | |
if episode_indices_to_use is None or episode_idx in episode_indices_to_use: | |
indices.extend( | |
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames) | |
) | |
self.indices = indices | |
self.shuffle = shuffle | |
def __iter__(self) -> Iterator[int]: | |
if self.shuffle: | |
for i in torch.randperm(len(self.indices)): | |
yield self.indices[i] | |
else: | |
for i in self.indices: | |
yield i | |
def __len__(self) -> int: | |
return len(self.indices) | |