Spaces:
Running
Running
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import abc | |
import logging | |
import os | |
from pathlib import Path | |
from typing import Type, TypeVar | |
import packaging | |
import safetensors | |
from huggingface_hub import hf_hub_download | |
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE | |
from huggingface_hub.errors import HfHubHTTPError | |
from safetensors.torch import load_model as load_model_as_safetensor | |
from safetensors.torch import save_model as save_model_as_safetensor | |
from torch import Tensor, nn | |
from lerobot.common.utils.hub import HubMixin | |
from lerobot.configs.policies import PreTrainedConfig | |
T = TypeVar("T", bound="PreTrainedPolicy") | |
DEFAULT_POLICY_CARD = """ | |
--- | |
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 | |
# Doc / guide: https://huggingface.co/docs/hub/model-cards | |
{{ card_data }} | |
--- | |
This policy has been pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot): | |
- Docs: {{ docs_url | default("[More Information Needed]", true) }} | |
""" | |
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): | |
""" | |
Base class for policy models. | |
""" | |
config_class: None | |
name: None | |
def __init__(self, config: PreTrainedConfig, *inputs, **kwargs): | |
super().__init__() | |
if not isinstance(config, PreTrainedConfig): | |
raise ValueError( | |
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class " | |
"`PreTrainedConfig`. To create a model from a pretrained model use " | |
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" | |
) | |
self.config = config | |
def __init_subclass__(cls, **kwargs): | |
super().__init_subclass__(**kwargs) | |
if not getattr(cls, "config_class", None): | |
raise TypeError(f"Class {cls.__name__} must define 'config_class'") | |
if not getattr(cls, "name", None): | |
raise TypeError(f"Class {cls.__name__} must define 'name'") | |
def _save_pretrained(self, save_directory: Path) -> None: | |
self.config._save_pretrained(save_directory) | |
model_to_save = self.module if hasattr(self, "module") else self | |
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)) | |
def from_pretrained( | |
cls: Type[T], | |
pretrained_name_or_path: str | Path, | |
*, | |
config: PreTrainedConfig | None = None, | |
force_download: bool = False, | |
resume_download: bool | None = None, | |
proxies: dict | None = None, | |
token: str | bool | None = None, | |
cache_dir: str | Path | None = None, | |
local_files_only: bool = False, | |
revision: str | None = None, | |
strict: bool = False, | |
**kwargs, | |
) -> T: | |
""" | |
The policy is set in evaluation mode by default using `policy.eval()` (dropout modules are | |
deactivated). To train it, you should first set it back in training mode with `policy.train()`. | |
""" | |
if config is None: | |
config = PreTrainedConfig.from_pretrained( | |
pretrained_name_or_path=pretrained_name_or_path, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
token=token, | |
cache_dir=cache_dir, | |
local_files_only=local_files_only, | |
revision=revision, | |
**kwargs, | |
) | |
model_id = str(pretrained_name_or_path) | |
instance = cls(config, **kwargs) | |
if os.path.isdir(model_id): | |
print("Loading weights from local directory") | |
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) | |
policy = cls._load_as_safetensor(instance, model_file, config.device, strict) | |
else: | |
try: | |
model_file = hf_hub_download( | |
repo_id=model_id, | |
filename=SAFETENSORS_SINGLE_FILE, | |
revision=revision, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
proxies=proxies, | |
resume_download=resume_download, | |
token=token, | |
local_files_only=local_files_only, | |
) | |
policy = cls._load_as_safetensor(instance, model_file, config.device, strict) | |
except HfHubHTTPError as e: | |
raise FileNotFoundError( | |
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}" | |
) from e | |
policy.to(config.device) | |
policy.eval() | |
return policy | |
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: | |
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): | |
load_model_as_safetensor(model, model_file, strict=strict) | |
if map_location != "cpu": | |
logging.warning( | |
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors." | |
" This means that the model is loaded on 'cpu' first and then copied to the device." | |
" This leads to a slower loading time." | |
" Please update safetensors to version 0.4.3 or above for improved performance." | |
) | |
model.to(map_location) | |
else: | |
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) | |
return model | |
# def generate_model_card(self, *args, **kwargs) -> ModelCard: | |
# card = ModelCard.from_template( | |
# card_data=self._hub_mixin_info.model_card_data, | |
# template_str=self._hub_mixin_info.model_card_template, | |
# repo_url=self._hub_mixin_info.repo_url, | |
# docs_url=self._hub_mixin_info.docs_url, | |
# **kwargs, | |
# ) | |
# return card | |
def get_optim_params(self) -> dict: | |
""" | |
Returns the policy-specific parameters dict to be passed on to the optimizer. | |
""" | |
raise NotImplementedError | |
def reset(self): | |
"""To be called whenever the environment is reset. | |
Does things like clearing caches. | |
""" | |
raise NotImplementedError | |
# TODO(aliberts, rcadene): split into 'forward' and 'compute_loss'? | |
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]: | |
"""_summary_ | |
Args: | |
batch (dict[str, Tensor]): _description_ | |
Returns: | |
tuple[Tensor, dict | None]: The loss and potentially other information. Apart from the loss which | |
is a Tensor, all other items should be logging-friendly, native Python types. | |
""" | |
raise NotImplementedError | |
def select_action(self, batch: dict[str, Tensor]) -> Tensor: | |
"""Return one action to run in the environment (potentially in batch mode). | |
When the model uses a history of observations, or outputs a sequence of actions, this method deals | |
with caching. | |
""" | |
raise NotImplementedError | |