ACT-Estimator / modeling_act_estimator.py
keishihara's picture
Upload folder using huggingface_hub
ecaafec verified
raw
history blame
526 Bytes
from torch import Tensor
from transformers import PreTrainedModel
from .configuration_act_estimator import ActEstimatorConfig
from .model import VideoActionEstimator
class ActEstimator(PreTrainedModel):
config_class = ActEstimatorConfig
def __init__(self, config: ActEstimatorConfig):
super().__init__(config)
self.model = VideoActionEstimator(**config.to_dict())
def forward(self, frames: Tensor, timestamps: Tensor = None) -> dict[str, Tensor]:
return self.model(frames, timestamps)