diff --git "a/lib/python3.11/site-packages/huggingface_hub/hf_api.py" "b/lib/python3.11/site-packages/huggingface_hub/hf_api.py" new file mode 100644--- /dev/null +++ "b/lib/python3.11/site-packages/huggingface_hub/hf_api.py" @@ -0,0 +1,8405 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import inspect +import json +import re +import struct +import warnings +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass, field +from datetime import datetime +from functools import wraps +from itertools import islice +from pathlib import Path +from typing import ( + Any, + BinaryIO, + Callable, + Dict, + Iterable, + Iterator, + List, + Literal, + Optional, + Tuple, + TypedDict, + TypeVar, + Union, + overload, +) +from urllib.parse import quote, urlencode + +import requests +from requests.exceptions import HTTPError +from tqdm.auto import tqdm as base_tqdm +from tqdm.contrib.concurrent import thread_map + +from ._commit_api import ( + CommitOperation, + CommitOperationAdd, + CommitOperationCopy, + CommitOperationDelete, + _fetch_lfs_files_to_copy, + _fetch_upload_modes, + _prepare_commit_payload, + _upload_lfs_files, + _warn_on_overwriting_operations, +) +from ._inference_endpoints import InferenceEndpoint, InferenceEndpointType +from ._multi_commits import ( + MULTI_COMMIT_PR_CLOSE_COMMENT_FAILURE_BAD_REQUEST_TEMPLATE, + MULTI_COMMIT_PR_CLOSE_COMMENT_FAILURE_NO_CHANGES_TEMPLATE, + MULTI_COMMIT_PR_CLOSING_COMMENT_TEMPLATE, + MULTI_COMMIT_PR_COMPLETION_COMMENT_TEMPLATE, + MultiCommitException, + MultiCommitStep, + MultiCommitStrategy, + multi_commit_create_pull_request, + multi_commit_generate_comment, + multi_commit_parse_pr_description, + plan_multi_commits, +) +from ._space_api import SpaceHardware, SpaceRuntime, SpaceStorage, SpaceVariable +from .community import ( + Discussion, + DiscussionComment, + DiscussionStatusChange, + DiscussionTitleChange, + DiscussionWithDetails, + deserialize_event, +) +from .constants import ( + DEFAULT_ETAG_TIMEOUT, + DEFAULT_REQUEST_TIMEOUT, + DEFAULT_REVISION, + DISCUSSION_STATUS, + DISCUSSION_TYPES, + ENDPOINT, + INFERENCE_ENDPOINTS_ENDPOINT, + REGEX_COMMIT_OID, + REPO_TYPE_MODEL, + REPO_TYPES, + REPO_TYPES_MAPPING, + REPO_TYPES_URL_PREFIXES, + SAFETENSORS_INDEX_FILE, + SAFETENSORS_MAX_HEADER_LENGTH, + SAFETENSORS_SINGLE_FILE, + SPACES_SDK_TYPES, + DiscussionStatusFilter, + DiscussionTypeFilter, +) +from .file_download import HfFileMetadata, get_hf_file_metadata, hf_hub_url +from .repocard_data import DatasetCardData, ModelCardData, SpaceCardData +from .utils import ( # noqa: F401 # imported for backward compatibility + IGNORE_GIT_FOLDER_PATTERNS, + BadRequestError, + EntryNotFoundError, + GatedRepoError, + HfFolder, + HfHubHTTPError, + LocalTokenNotFoundError, + NotASafetensorsRepoError, + RepositoryNotFoundError, + RevisionNotFoundError, + SafetensorsFileMetadata, + SafetensorsParsingError, + SafetensorsRepoMetadata, + TensorInfo, + build_hf_headers, + experimental, + filter_repo_objects, + get_session, + hf_raise_for_status, + logging, + paginate, + parse_datetime, + validate_hf_hub_args, +) +from .utils import tqdm as hf_tqdm +from .utils._deprecation import _deprecate_method +from .utils._typing import CallableT +from .utils.endpoint_helpers import ( + DatasetFilter, + ModelFilter, + _is_emission_within_treshold, +) + + +R = TypeVar("R") # Return type +CollectionItemType_T = Literal["model", "dataset", "space", "paper"] + +USERNAME_PLACEHOLDER = "hf_user" +_REGEX_DISCUSSION_URL = re.compile(r".*/discussions/(\d+)$") + +_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE = ( + "\nNote: Creating a commit assumes that the repo already exists on the" + " Huggingface Hub. Please use `create_repo` if it's not the case." +) + +logger = logging.get_logger(__name__) + + +def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> Tuple[Optional[str], Optional[str], str]: + """ + Returns the repo type and ID from a huggingface.co URL linking to a + repository + + Args: + hf_id (`str`): + An URL or ID of a repository on the HF hub. Accepted values are: + + - https://huggingface.co/// + - https://huggingface.co// + - hf://// + - hf:/// + - // + - / + - + hub_url (`str`, *optional*): + The URL of the HuggingFace Hub, defaults to https://huggingface.co + + Returns: + A tuple with three items: repo_type (`str` or `None`), namespace (`str` or + `None`) and repo_id (`str`). + + Raises: + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If URL cannot be parsed. + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `repo_type` is unknown. + """ + input_hf_id = hf_id + hub_url = re.sub(r"https?://", "", hub_url if hub_url is not None else ENDPOINT) + is_hf_url = hub_url in hf_id and "@" not in hf_id + + HFFS_PREFIX = "hf://" + if hf_id.startswith(HFFS_PREFIX): # Remove "hf://" prefix if exists + hf_id = hf_id[len(HFFS_PREFIX) :] + + url_segments = hf_id.split("/") + is_hf_id = len(url_segments) <= 3 + + namespace: Optional[str] + if is_hf_url: + namespace, repo_id = url_segments[-2:] + if namespace == hub_url: + namespace = None + if len(url_segments) > 2 and hub_url not in url_segments[-3]: + repo_type = url_segments[-3] + elif namespace in REPO_TYPES_MAPPING: + # Mean canonical dataset or model + repo_type = REPO_TYPES_MAPPING[namespace] + namespace = None + else: + repo_type = None + elif is_hf_id: + if len(url_segments) == 3: + # Passed // or // + repo_type, namespace, repo_id = url_segments[-3:] + elif len(url_segments) == 2: + if url_segments[0] in REPO_TYPES_MAPPING: + # Passed '' or 'datasets/' for a canonical model or dataset + repo_type = REPO_TYPES_MAPPING[url_segments[0]] + namespace = None + repo_id = hf_id.split("/")[-1] + else: + # Passed / or / + namespace, repo_id = hf_id.split("/")[-2:] + repo_type = None + else: + # Passed + repo_id = url_segments[0] + namespace, repo_type = None, None + else: + raise ValueError(f"Unable to retrieve user and repo ID from the passed HF ID: {hf_id}") + + # Check if repo type is known (mapping "spaces" => "space" + empty value => `None`) + if repo_type in REPO_TYPES_MAPPING: + repo_type = REPO_TYPES_MAPPING[repo_type] + if repo_type == "": + repo_type = None + if repo_type not in REPO_TYPES: + raise ValueError(f"Unknown `repo_type`: '{repo_type}' ('{input_hf_id}')") + + return repo_type, namespace, repo_id + + +class LastCommitInfo(TypedDict, total=False): + oid: str + title: str + date: datetime + + +class BlobLfsInfo(TypedDict, total=False): + size: int + sha256: str + pointer_size: int + + +class BlobSecurityInfo(TypedDict, total=False): + safe: bool + av_scan: Optional[Dict] + pickle_import_scan: Optional[Dict] + + +class TransformersInfo(TypedDict, total=False): + auto_model: str + custom_class: Optional[str] + # possible `pipeline_tag` values: https://github.com/huggingface/huggingface.js/blob/3ee32554b8620644a6287e786b2a83bf5caf559c/packages/tasks/src/pipelines.ts#L72 + pipeline_tag: Optional[str] + processor: Optional[str] + + +class SafeTensorsInfo(TypedDict, total=False): + parameters: List[Dict[str, int]] + total: int + + +@dataclass +class CommitInfo(str): + """Data structure containing information about a newly created commit. + + Returned by any method that creates a commit on the Hub: [`create_commit`], [`upload_file`], [`upload_folder`], + [`delete_file`], [`delete_folder`]. It inherits from `str` for backward compatibility but using methods specific + to `str` is deprecated. + + Attributes: + commit_url (`str`): + Url where to find the commit. + + commit_message (`str`): + The summary (first line) of the commit that has been created. + + commit_description (`str`): + Description of the commit that has been created. Can be empty. + + oid (`str`): + Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`. + + pr_url (`str`, *optional*): + Url to the PR that has been created, if any. Populated when `create_pr=True` + is passed. + + pr_revision (`str`, *optional*): + Revision of the PR that has been created, if any. Populated when + `create_pr=True` is passed. Example: `"refs/pr/1"`. + + pr_num (`int`, *optional*): + Number of the PR discussion that has been created, if any. Populated when + `create_pr=True` is passed. Can be passed as `discussion_num` in + [`get_discussion_details`]. Example: `1`. + + _url (`str`, *optional*): + Legacy url for `str` compatibility. Can be the url to the uploaded file on the Hub (if returned by + [`upload_file`]), to the uploaded folder on the Hub (if returned by [`upload_folder`]) or to the commit on + the Hub (if returned by [`create_commit`]). Defaults to `commit_url`. It is deprecated to use this + attribute. Please use `commit_url` instead. + """ + + commit_url: str + commit_message: str + commit_description: str + oid: str + pr_url: Optional[str] = None + + # Computed from `pr_url` in `__post_init__` + pr_revision: Optional[str] = field(init=False) + pr_num: Optional[str] = field(init=False) + + # legacy url for `str` compatibility (ex: url to uploaded file, url to uploaded folder, url to PR, etc.) + _url: str = field(repr=False, default=None) # type: ignore # defaults to `commit_url` + + def __new__(cls, *args, commit_url: str, _url: Optional[str] = None, **kwargs): + return str.__new__(cls, _url or commit_url) + + def __post_init__(self): + """Populate pr-related fields after initialization. + + See https://docs.python.org/3.10/library/dataclasses.html#post-init-processing. + """ + if self.pr_url is not None: + self.pr_revision = _parse_revision_from_pr_url(self.pr_url) + self.pr_num = int(self.pr_revision.split("/")[-1]) + else: + self.pr_revision = None + self.pr_num = None + + +@dataclass +class AccessRequest: + """Data structure containing information about a user access request. + + Attributes: + username (`str`): + Username of the user who requested access. + fullname (`str`): + Fullname of the user who requested access. + email (`str`): + Email of the user who requested access. + timestamp (`datetime`): + Timestamp of the request. + status (`Literal["pending", "accepted", "rejected"]`): + Status of the request. Can be one of `["pending", "accepted", "rejected"]`. + fields (`Dict[str, Any]`, *optional*): + Additional fields filled by the user in the gate form. + """ + + username: str + fullname: str + email: str + timestamp: datetime + status: Literal["pending", "accepted", "rejected"] + + # Additional fields filled by the user in the gate form + fields: Optional[Dict[str, Any]] = None + + +class RepoUrl(str): + """Subclass of `str` describing a repo URL on the Hub. + + `RepoUrl` is returned by `HfApi.create_repo`. It inherits from `str` for backward + compatibility. At initialization, the URL is parsed to populate properties: + - endpoint (`str`) + - namespace (`Optional[str]`) + - repo_name (`str`) + - repo_id (`str`) + - repo_type (`Literal["model", "dataset", "space"]`) + - url (`str`) + + Args: + url (`Any`): + String value of the repo url. + endpoint (`str`, *optional*): + Endpoint of the Hub. Defaults to . + + Example: + ```py + >>> RepoUrl('https://huggingface.co/gpt2') + RepoUrl('https://huggingface.co/gpt2', endpoint='https://huggingface.co', repo_type='model', repo_id='gpt2') + + >>> RepoUrl('https://hub-ci.huggingface.co/datasets/dummy_user/dummy_dataset', endpoint='https://hub-ci.huggingface.co') + RepoUrl('https://hub-ci.huggingface.co/datasets/dummy_user/dummy_dataset', endpoint='https://hub-ci.huggingface.co', repo_type='dataset', repo_id='dummy_user/dummy_dataset') + + >>> RepoUrl('hf://datasets/my-user/my-dataset') + RepoUrl('hf://datasets/my-user/my-dataset', endpoint='https://huggingface.co', repo_type='dataset', repo_id='user/dataset') + + >>> HfApi.create_repo("dummy_model") + RepoUrl('https://huggingface.co/Wauplin/dummy_model', endpoint='https://huggingface.co', repo_type='model', repo_id='Wauplin/dummy_model') + ``` + + Raises: + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If URL cannot be parsed. + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `repo_type` is unknown. + """ + + def __new__(cls, url: Any, endpoint: Optional[str] = None): + return super(RepoUrl, cls).__new__(cls, url) + + def __init__(self, url: Any, endpoint: Optional[str] = None) -> None: + super().__init__() + # Parse URL + self.endpoint = endpoint or ENDPOINT + repo_type, namespace, repo_name = repo_type_and_id_from_hf_id(self, hub_url=self.endpoint) + + # Populate fields + self.namespace = namespace + self.repo_name = repo_name + self.repo_id = repo_name if namespace is None else f"{namespace}/{repo_name}" + self.repo_type = repo_type or REPO_TYPE_MODEL + self.url = str(self) # just in case it's needed + + def __repr__(self) -> str: + return f"RepoUrl('{self}', endpoint='{self.endpoint}', repo_type='{self.repo_type}', repo_id='{self.repo_id}')" + + +@dataclass +class RepoSibling: + """ + Contains basic information about a repo file inside a repo on the Hub. + + + + All attributes of this class are optional except `rfilename`. This is because only the file names are returned when + listing repositories on the Hub (with [`list_models`], [`list_datasets`] or [`list_spaces`]). If you need more + information like file size, blob id or lfs details, you must request them specifically from one repo at a time + (using [`model_info`], [`dataset_info`] or [`space_info`]) as it adds more constraints on the backend server to + retrieve these. + + + + Attributes: + rfilename (str): + file name, relative to the repo root. + size (`int`, *optional*): + The file's size, in bytes. This attribute is defined when `files_metadata` argument of [`repo_info`] is set + to `True`. It's `None` otherwise. + blob_id (`str`, *optional*): + The file's git OID. This attribute is defined when `files_metadata` argument of [`repo_info`] is set to + `True`. It's `None` otherwise. + lfs (`BlobLfsInfo`, *optional*): + The file's LFS metadata. This attribute is defined when`files_metadata` argument of [`repo_info`] is set to + `True` and the file is stored with Git LFS. It's `None` otherwise. + """ + + rfilename: str + size: Optional[int] = None + blob_id: Optional[str] = None + lfs: Optional[BlobLfsInfo] = None + + +@dataclass +class RepoFile: + """ + Contains information about a file on the Hub. + + Attributes: + path (str): + file path relative to the repo root. + size (`int`): + The file's size, in bytes. + blob_id (`str`): + The file's git OID. + lfs (`BlobLfsInfo`): + The file's LFS metadata. + last_commit (`LastCommitInfo`, *optional*): + The file's last commit metadata. Only defined if [`list_files_info`], [`list_repo_tree`] and [`get_paths_info`] + are called with `expand=True`. + security (`BlobSecurityInfo`, *optional*): + The file's security scan metadata. Only defined if [`list_files_info`], [`list_repo_tree`] and [`get_paths_info`] + are called with `expand=True`. + """ + + path: str + size: int + blob_id: str + lfs: Optional[BlobLfsInfo] = None + last_commit: Optional[LastCommitInfo] = None + security: Optional[BlobSecurityInfo] = None + + def __init__(self, **kwargs): + self.path = kwargs.pop("path") + self.size = kwargs.pop("size") + self.blob_id = kwargs.pop("oid") + lfs = kwargs.pop("lfs", None) + if lfs is not None: + lfs = BlobLfsInfo(size=lfs["size"], sha256=lfs["oid"], pointer_size=lfs["pointerSize"]) + self.lfs = lfs + last_commit = kwargs.pop("lastCommit", None) or kwargs.pop("last_commit", None) + if last_commit is not None: + last_commit = LastCommitInfo( + oid=last_commit["id"], title=last_commit["title"], date=parse_datetime(last_commit["date"]) + ) + self.last_commit = last_commit + security = kwargs.pop("security", None) + if security is not None: + security = BlobSecurityInfo( + safe=security["safe"], av_scan=security["avScan"], pickle_import_scan=security["pickleImportScan"] + ) + self.security = security + + # backwards compatibility + self.rfilename = self.path + self.lastCommit = self.last_commit + + +@dataclass +class RepoFolder: + """ + Contains information about a folder on the Hub. + + Attributes: + path (str): + folder path relative to the repo root. + tree_id (`str`): + The folder's git OID. + last_commit (`LastCommitInfo`, *optional*): + The folder's last commit metadata. Only defined if [`list_repo_tree`] and [`get_paths_info`] + are called with `expand=True`. + """ + + path: str + tree_id: str + last_commit: Optional[LastCommitInfo] = None + + def __init__(self, **kwargs): + self.path = kwargs.pop("path") + self.tree_id = kwargs.pop("oid") + last_commit = kwargs.pop("lastCommit", None) or kwargs.pop("last_commit", None) + if last_commit is not None: + last_commit = LastCommitInfo( + oid=last_commit["id"], title=last_commit["title"], date=parse_datetime(last_commit["date"]) + ) + self.last_commit = last_commit + + +@dataclass +class ModelInfo: + """ + Contains information about a model on the Hub. + + + + Most attributes of this class are optional. This is because the data returned by the Hub depends on the query made. + In general, the more specific the query, the more information is returned. On the contrary, when listing models + using [`list_models`] only a subset of the attributes are returned. + + + + Attributes: + id (`str`): + ID of model. + author (`str`, *optional*): + Author of the model. + sha (`str`, *optional*): + Repo SHA at this particular revision. + created_at (`datetime`, *optional*): + Date of creation of the repo on the Hub. Note that the lowest value is `2022-03-02T23:29:04.000Z`, + corresponding to the date when we began to store creation dates. + last_modified (`datetime`, *optional*): + Date of last commit to the repo. + private (`bool`): + Is the repo private. + disabled (`bool`, *optional*): + Is the repo disabled. + gated (`Literal["auto", "manual", False]`, *optional*): + Is the repo gated. + If so, whether there is manual or automatic approval. + downloads (`int`): + Number of downloads of the model. + likes (`int`): + Number of likes of the model. + library_name (`str`, *optional*): + Library associated with the model. + tags (`List[str]`): + List of tags of the model. Compared to `card_data.tags`, contains extra tags computed by the Hub + (e.g. supported libraries, model's arXiv). + pipeline_tag (`str`, *optional*): + Pipeline tag associated with the model. + mask_token (`str`, *optional*): + Mask token used by the model. + widget_data (`Any`, *optional*): + Widget data associated with the model. + model_index (`Dict`, *optional*): + Model index for evaluation. + config (`Dict`, *optional*): + Model configuration. + transformers_info (`TransformersInfo`, *optional*): + Transformers-specific info (auto class, processor, etc.) associated with the model. + card_data (`ModelCardData`, *optional*): + Model Card Metadata as a [`huggingface_hub.repocard_data.ModelCardData`] object. + siblings (`List[RepoSibling]`): + List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the model. + spaces (`List[str]`, *optional*): + List of spaces using the model. + safetensors (`SafeTensorsInfo`, *optional*): + Model's safetensors information. + """ + + id: str + author: Optional[str] + sha: Optional[str] + created_at: Optional[datetime] + last_modified: Optional[datetime] + private: bool + gated: Optional[Literal["auto", "manual", False]] + disabled: Optional[bool] + downloads: int + likes: int + library_name: Optional[str] + tags: List[str] + pipeline_tag: Optional[str] + mask_token: Optional[str] + card_data: Optional[ModelCardData] + widget_data: Optional[Any] + model_index: Optional[Dict] + config: Optional[Dict] + transformers_info: Optional[TransformersInfo] + siblings: Optional[List[RepoSibling]] + spaces: Optional[List[str]] + safetensors: Optional[SafeTensorsInfo] + + def __init__(self, **kwargs): + self.id = kwargs.pop("id") + self.author = kwargs.pop("author", None) + self.sha = kwargs.pop("sha", None) + last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None) + self.last_modified = parse_datetime(last_modified) if last_modified else None + created_at = kwargs.pop("createdAt", None) or kwargs.pop("created_at", None) + self.created_at = parse_datetime(created_at) if created_at else None + self.private = kwargs.pop("private") + self.gated = kwargs.pop("gated", None) + self.disabled = kwargs.pop("disabled", None) + self.downloads = kwargs.pop("downloads") + self.likes = kwargs.pop("likes") + self.library_name = kwargs.pop("library_name", None) + self.tags = kwargs.pop("tags") + self.pipeline_tag = kwargs.pop("pipeline_tag", None) + self.mask_token = kwargs.pop("mask_token", None) + card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None) + self.card_data = ( + ModelCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data + ) + + self.widget_data = kwargs.pop("widget_data", None) + self.model_index = kwargs.pop("model-index", None) or kwargs.pop("model_index", None) + self.config = kwargs.pop("config", None) + transformers_info = kwargs.pop("transformersInfo", None) or kwargs.pop("transformers_info", None) + self.transformers_info = TransformersInfo(**transformers_info) if transformers_info else None + siblings = kwargs.pop("siblings", None) + self.siblings = ( + [ + RepoSibling( + rfilename=sibling["rfilename"], + size=sibling.get("size"), + blob_id=sibling.get("blobId"), + lfs=( + BlobLfsInfo( + size=sibling["lfs"]["size"], + sha256=sibling["lfs"]["sha256"], + pointer_size=sibling["lfs"]["pointerSize"], + ) + if sibling.get("lfs") + else None + ), + ) + for sibling in siblings + ] + if siblings + else None + ) + self.spaces = kwargs.pop("spaces", None) + safetensors = kwargs.pop("safetensors", None) + self.safetensors = SafeTensorsInfo(**safetensors) if safetensors else None + + # backwards compatibility + self.lastModified = self.last_modified + self.cardData = self.card_data + self.transformersInfo = self.transformers_info + self.__dict__.update(**kwargs) + + +@dataclass +class DatasetInfo: + """ + Contains information about a dataset on the Hub. + + + + Most attributes of this class are optional. This is because the data returned by the Hub depends on the query made. + In general, the more specific the query, the more information is returned. On the contrary, when listing datasets + using [`list_datasets`] only a subset of the attributes are returned. + + + + Attributes: + id (`str`): + ID of dataset. + author (`str`): + Author of the dataset. + sha (`str`): + Repo SHA at this particular revision. + created_at (`datetime`, *optional*): + Date of creation of the repo on the Hub. Note that the lowest value is `2022-03-02T23:29:04.000Z`, + corresponding to the date when we began to store creation dates. + last_modified (`datetime`, *optional*): + Date of last commit to the repo. + private (`bool`): + Is the repo private. + disabled (`bool`, *optional*): + Is the repo disabled. + gated (`Literal["auto", "manual", False]`, *optional*): + Is the repo gated. + If so, whether there is manual or automatic approval. + downloads (`int`): + Number of downloads of the dataset. + likes (`int`): + Number of likes of the dataset. + tags (`List[str]`): + List of tags of the dataset. + card_data (`DatasetCardData`, *optional*): + Model Card Metadata as a [`huggingface_hub.repocard_data.DatasetCardData`] object. + siblings (`List[RepoSibling]`): + List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the dataset. + """ + + id: str + author: Optional[str] + sha: Optional[str] + created_at: Optional[datetime] + last_modified: Optional[datetime] + private: bool + gated: Optional[Literal["auto", "manual", False]] + disabled: Optional[bool] + downloads: int + likes: int + paperswithcode_id: Optional[str] + tags: List[str] + card_data: Optional[DatasetCardData] + siblings: Optional[List[RepoSibling]] + + def __init__(self, **kwargs): + self.id = kwargs.pop("id") + self.author = kwargs.pop("author", None) + self.sha = kwargs.pop("sha", None) + created_at = kwargs.pop("createdAt", None) or kwargs.pop("created_at", None) + self.created_at = parse_datetime(created_at) if created_at else None + last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None) + self.last_modified = parse_datetime(last_modified) if last_modified else None + self.private = kwargs.pop("private") + self.gated = kwargs.pop("gated", None) + self.disabled = kwargs.pop("disabled", None) + self.downloads = kwargs.pop("downloads") + self.likes = kwargs.pop("likes") + self.paperswithcode_id = kwargs.pop("paperswithcode_id", None) + self.tags = kwargs.pop("tags") + card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None) + self.card_data = ( + DatasetCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data + ) + siblings = kwargs.pop("siblings", None) + self.siblings = ( + [ + RepoSibling( + rfilename=sibling["rfilename"], + size=sibling.get("size"), + blob_id=sibling.get("blobId"), + lfs=( + BlobLfsInfo( + size=sibling["lfs"]["size"], + sha256=sibling["lfs"]["sha256"], + pointer_size=sibling["lfs"]["pointerSize"], + ) + if sibling.get("lfs") + else None + ), + ) + for sibling in siblings + ] + if siblings + else None + ) + + # backwards compatibility + self.lastModified = self.last_modified + self.cardData = self.card_data + self.__dict__.update(**kwargs) + + +@dataclass +class SpaceInfo: + """ + Contains information about a Space on the Hub. + + + + Most attributes of this class are optional. This is because the data returned by the Hub depends on the query made. + In general, the more specific the query, the more information is returned. On the contrary, when listing spaces + using [`list_spaces`] only a subset of the attributes are returned. + + + + Attributes: + id (`str`): + ID of the Space. + author (`str`, *optional*): + Author of the Space. + sha (`str`, *optional*): + Repo SHA at this particular revision. + created_at (`datetime`, *optional*): + Date of creation of the repo on the Hub. Note that the lowest value is `2022-03-02T23:29:04.000Z`, + corresponding to the date when we began to store creation dates. + last_modified (`datetime`, *optional*): + Date of last commit to the repo. + private (`bool`): + Is the repo private. + gated (`Literal["auto", "manual", False]`, *optional*): + Is the repo gated. + If so, whether there is manual or automatic approval. + disabled (`bool`, *optional*): + Is the Space disabled. + host (`str`, *optional*): + Host URL of the Space. + subdomain (`str`, *optional*): + Subdomain of the Space. + likes (`int`): + Number of likes of the Space. + tags (`List[str]`): + List of tags of the Space. + siblings (`List[RepoSibling]`): + List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the Space. + card_data (`SpaceCardData`, *optional*): + Space Card Metadata as a [`huggingface_hub.repocard_data.SpaceCardData`] object. + runtime (`SpaceRuntime`, *optional*): + Space runtime information as a [`huggingface_hub.hf_api.SpaceRuntime`] object. + sdk (`str`, *optional*): + SDK used by the Space. + models (`List[str]`, *optional*): + List of models used by the Space. + datasets (`List[str]`, *optional*): + List of datasets used by the Space. + """ + + id: str + author: Optional[str] + sha: Optional[str] + created_at: Optional[datetime] + last_modified: Optional[datetime] + private: bool + gated: Optional[Literal["auto", "manual", False]] + disabled: Optional[bool] + host: Optional[str] + subdomain: Optional[str] + likes: int + sdk: Optional[str] + tags: List[str] + siblings: Optional[List[RepoSibling]] + card_data: Optional[SpaceCardData] + runtime: Optional[SpaceRuntime] + models: Optional[List[str]] + datasets: Optional[List[str]] + + def __init__(self, **kwargs): + self.id = kwargs.pop("id") + self.author = kwargs.pop("author", None) + self.sha = kwargs.pop("sha", None) + created_at = kwargs.pop("createdAt", None) or kwargs.pop("created_at", None) + self.created_at = parse_datetime(created_at) if created_at else None + last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None) + self.last_modified = parse_datetime(last_modified) if last_modified else None + self.private = kwargs.pop("private") + self.gated = kwargs.pop("gated", None) + self.disabled = kwargs.pop("disabled", None) + self.host = kwargs.pop("host", None) + self.subdomain = kwargs.pop("subdomain", None) + self.likes = kwargs.pop("likes") + self.sdk = kwargs.pop("sdk", None) + self.tags = kwargs.pop("tags") + card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None) + self.card_data = ( + SpaceCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data + ) + siblings = kwargs.pop("siblings", None) + self.siblings = ( + [ + RepoSibling( + rfilename=sibling["rfilename"], + size=sibling.get("size"), + blob_id=sibling.get("blobId"), + lfs=( + BlobLfsInfo( + size=sibling["lfs"]["size"], + sha256=sibling["lfs"]["sha256"], + pointer_size=sibling["lfs"]["pointerSize"], + ) + if sibling.get("lfs") + else None + ), + ) + for sibling in siblings + ] + if siblings + else None + ) + runtime = kwargs.pop("runtime", None) + self.runtime = SpaceRuntime(runtime) if runtime else None + self.models = kwargs.pop("models", None) + self.datasets = kwargs.pop("datasets", None) + + # backwards compatibility + self.lastModified = self.last_modified + self.cardData = self.card_data + self.__dict__.update(**kwargs) + + +@dataclass +class MetricInfo: + """ + Contains information about a metric on the Hub. + + Attributes: + id (`str`): + ID of the metric. E.g. `"accuracy"`. + space_id (`str`): + ID of the space associated with the metric. E.g. `"Accuracy"`. + description (`str`): + Description of the metric. + """ + + id: str + space_id: str + description: Optional[str] + + def __init__(self, **kwargs): + self.id = kwargs.pop("id") + self.space_id = kwargs.pop("spaceId") + self.description = kwargs.pop("description", None) + # backwards compatibility + self.spaceId = self.space_id + self.__dict__.update(**kwargs) + + +@dataclass +class CollectionItem: + """ + Contains information about an item of a Collection (model, dataset, Space or paper). + + Attributes: + item_object_id (`str`): + Unique ID of the item in the collection. + item_id (`str`): + ID of the underlying object on the Hub. Can be either a repo_id or a paper id + e.g. `"jbilcke-hf/ai-comic-factory"`, `"2307.09288"`. + item_type (`str`): + Type of the underlying object. Can be one of `"model"`, `"dataset"`, `"space"` or `"paper"`. + position (`int`): + Position of the item in the collection. + note (`str`, *optional*): + Note associated with the item, as plain text. + """ + + item_object_id: str # id in database + item_id: str # repo_id or paper id + item_type: str + position: int + note: Optional[str] = None + + def __init__( + self, _id: str, id: str, type: CollectionItemType_T, position: int, note: Optional[Dict] = None, **kwargs + ) -> None: + self.item_object_id: str = _id # id in database + self.item_id: str = id # repo_id or paper id + self.item_type: CollectionItemType_T = type + self.position: int = position + self.note: str = note["text"] if note is not None else None + + +@dataclass +class Collection: + """ + Contains information about a Collection on the Hub. + + Attributes: + slug (`str`): + Slug of the collection. E.g. `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + title (`str`): + Title of the collection. E.g. `"Recent models"`. + owner (`str`): + Owner of the collection. E.g. `"TheBloke"`. + items (`List[CollectionItem]`): + List of items in the collection. + last_updated (`datetime`): + Date of the last update of the collection. + position (`int`): + Position of the collection in the list of collections of the owner. + private (`bool`): + Whether the collection is private or not. + theme (`str`): + Theme of the collection. E.g. `"green"`. + upvotes (`int`): + Number of upvotes of the collection. + description (`str`, *optional*): + Description of the collection, as plain text. + url (`str`): + (property) URL of the collection on the Hub. + """ + + slug: str + title: str + owner: str + items: List[CollectionItem] + last_updated: datetime + position: int + private: bool + theme: str + upvotes: int + description: Optional[str] = None + + def __init__(self, **kwargs) -> None: + self.slug = kwargs.pop("slug") + self.title = kwargs.pop("title") + self.owner = kwargs.pop("owner") + self.items = [CollectionItem(**item) for item in kwargs.pop("items")] + self.last_updated = parse_datetime(kwargs.pop("lastUpdated")) + self.position = kwargs.pop("position") + self.private = kwargs.pop("private") + self.theme = kwargs.pop("theme") + self.upvotes = kwargs.pop("upvotes") + self.description = kwargs.pop("description", None) + endpoint = kwargs.pop("endpoint", None) + if endpoint is None: + endpoint = ENDPOINT + self._url = f"{endpoint}/collections/{self.slug}" + + @property + def url(self) -> str: + """Returns the URL of the collection on the Hub.""" + return self._url + + +@dataclass +class GitRefInfo: + """ + Contains information about a git reference for a repo on the Hub. + + Attributes: + name (`str`): + Name of the reference (e.g. tag name or branch name). + ref (`str`): + Full git ref on the Hub (e.g. `"refs/heads/main"` or `"refs/tags/v1.0"`). + target_commit (`str`): + OID of the target commit for the ref (e.g. `"e7da7f221d5bf496a48136c0cd264e630fe9fcc8"`) + """ + + name: str + ref: str + target_commit: str + + +@dataclass +class GitRefs: + """ + Contains information about all git references for a repo on the Hub. + + Object is returned by [`list_repo_refs`]. + + Attributes: + branches (`List[GitRefInfo]`): + A list of [`GitRefInfo`] containing information about branches on the repo. + converts (`List[GitRefInfo]`): + A list of [`GitRefInfo`] containing information about "convert" refs on the repo. + Converts are refs used (internally) to push preprocessed data in Dataset repos. + tags (`List[GitRefInfo]`): + A list of [`GitRefInfo`] containing information about tags on the repo. + pull_requests (`List[GitRefInfo]`, *optional*): + A list of [`GitRefInfo`] containing information about pull requests on the repo. + Only returned if `include_prs=True` is set. + """ + + branches: List[GitRefInfo] + converts: List[GitRefInfo] + tags: List[GitRefInfo] + pull_requests: Optional[List[GitRefInfo]] + + +@dataclass +class GitCommitInfo: + """ + Contains information about a git commit for a repo on the Hub. Check out [`list_repo_commits`] for more details. + + Attributes: + commit_id (`str`): + OID of the commit (e.g. `"e7da7f221d5bf496a48136c0cd264e630fe9fcc8"`) + authors (`List[str]`): + List of authors of the commit. + created_at (`datetime`): + Datetime when the commit was created. + title (`str`): + Title of the commit. This is a free-text value entered by the authors. + message (`str`): + Description of the commit. This is a free-text value entered by the authors. + formatted_title (`str`): + Title of the commit formatted as HTML. Only returned if `formatted=True` is set. + formatted_message (`str`): + Description of the commit formatted as HTML. Only returned if `formatted=True` is set. + """ + + commit_id: str + + authors: List[str] + created_at: datetime + title: str + message: str + + formatted_title: Optional[str] + formatted_message: Optional[str] + + +@dataclass +class UserLikes: + """ + Contains information about a user likes on the Hub. + + Attributes: + user (`str`): + Name of the user for which we fetched the likes. + total (`int`): + Total number of likes. + datasets (`List[str]`): + List of datasets liked by the user (as repo_ids). + models (`List[str]`): + List of models liked by the user (as repo_ids). + spaces (`List[str]`): + List of spaces liked by the user (as repo_ids). + """ + + # Metadata + user: str + total: int + + # User likes + datasets: List[str] + models: List[str] + spaces: List[str] + + +@dataclass +class User: + """ + Contains information about a user on the Hub. + + Attributes: + avatar_url (`str`): + URL of the user's avatar. + username (`str`): + Name of the user on the Hub (unique). + fullname (`str`): + User's full name. + """ + + # Metadata + avatar_url: str + username: str + fullname: str + + +def future_compatible(fn: CallableT) -> CallableT: + """Wrap a method of `HfApi` to handle `run_as_future=True`. + + A method flagged as "future_compatible" will be called in a thread if `run_as_future=True` and return a + `concurrent.futures.Future` instance. Otherwise, it will be called normally and return the result. + """ + sig = inspect.signature(fn) + args_params = list(sig.parameters)[1:] # remove "self" from list + + @wraps(fn) + def _inner(self, *args, **kwargs): + # Get `run_as_future` value if provided (default to False) + if "run_as_future" in kwargs: + run_as_future = kwargs["run_as_future"] + kwargs["run_as_future"] = False # avoid recursion error + else: + run_as_future = False + for param, value in zip(args_params, args): + if param == "run_as_future": + run_as_future = value + break + + # Call the function in a thread if `run_as_future=True` + if run_as_future: + return self.run_as_future(fn, self, *args, **kwargs) + + # Otherwise, call the function normally + return fn(self, *args, **kwargs) + + _inner.is_future_compatible = True # type: ignore + return _inner # type: ignore + + +class HfApi: + def __init__( + self, + endpoint: Optional[str] = None, + token: Optional[str] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, + ) -> None: + """Create a HF client to interact with the Hub via HTTP. + + The client is initialized with some high-level settings used in all requests + made to the Hub (HF endpoint, authentication, user agents...). Using the `HfApi` + client is preferred but not mandatory as all of its public methods are exposed + directly at the root of `huggingface_hub`. + + Args: + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if + not provided. + library_name (`str`, *optional*): + The name of the library that is making the HTTP request. Will be added to + the user-agent header. Example: `"transformers"`. + library_version (`str`, *optional*): + The version of the library that is making the HTTP request. Will be added + to the user-agent header. Example: `"4.24.0"`. + user_agent (`str`, `dict`, *optional*): + The user agent info in the form of a dictionary or a single string. It will + be completed with information about the installed packages. + """ + self.endpoint = endpoint if endpoint is not None else ENDPOINT + self.token = token + self.library_name = library_name + self.library_version = library_version + self.user_agent = user_agent + self._thread_pool: Optional[ThreadPoolExecutor] = None + + def run_as_future(self, fn: Callable[..., R], *args, **kwargs) -> Future[R]: + """ + Run a method in the background and return a Future instance. + + The main goal is to run methods without blocking the main thread (e.g. to push data during a training). + Background jobs are queued to preserve order but are not ran in parallel. If you need to speed-up your scripts + by parallelizing lots of call to the API, you must setup and use your own [ThreadPoolExecutor](https://docs.python.org/3/library/concurrent.futures.html#threadpoolexecutor). + + Note: Most-used methods like [`upload_file`], [`upload_folder`] and [`create_commit`] have a `run_as_future: bool` + argument to directly call them in the background. This is equivalent to calling `api.run_as_future(...)` on them + but less verbose. + + Args: + fn (`Callable`): + The method to run in the background. + *args, **kwargs: + Arguments with which the method will be called. + + Return: + `Future`: a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) instance to + get the result of the task. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> future = api.run_as_future(api.whoami) # instant + >>> future.done() + False + >>> future.result() # wait until complete and return result + (...) + >>> future.done() + True + ``` + """ + if self._thread_pool is None: + self._thread_pool = ThreadPoolExecutor(max_workers=1) + self._thread_pool + return self._thread_pool.submit(fn, *args, **kwargs) + + @validate_hf_hub_args + def whoami(self, token: Optional[str] = None) -> Dict: + """ + Call HF API to know "whoami". + + Args: + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if + not provided. + """ + r = get_session().get( + f"{self.endpoint}/api/whoami-v2", + headers=self._build_hf_headers( + # If `token` is provided and not `None`, it will be used by default. + # Otherwise, the token must be retrieved from cache or env variable. + token=(token or self.token or True), + ), + ) + try: + hf_raise_for_status(r) + except HTTPError as e: + raise HTTPError( + "Invalid user token. If you didn't pass a user token, make sure you " + "are properly logged in by executing `huggingface-cli login`, and " + "if you did pass a user token, double-check it's correct.", + request=e.request, + response=e.response, + ) from e + return r.json() + + def get_token_permission(self, token: Optional[str] = None) -> Literal["read", "write", None]: + """ + Check if a given `token` is valid and return its permissions. + + For more details about tokens, please refer to https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens. + + Args: + token (`str`, *optional*): + The token to check for validity. Defaults to the one saved locally. + + Returns: + `Literal["read", "write", None]`: Permission granted by the token ("read" or "write"). Returns `None` if no + token passed or token is invalid. + """ + try: + return self.whoami(token=token)["auth"]["accessToken"]["role"] + except (LocalTokenNotFoundError, HTTPError): + return None + + def get_model_tags(self) -> Dict: + """ + List all valid model tags as a nested namespace object + """ + path = f"{self.endpoint}/api/models-tags-by-type" + r = get_session().get(path) + hf_raise_for_status(r) + return r.json() + + def get_dataset_tags(self) -> Dict: + """ + List all valid dataset tags as a nested namespace object. + """ + path = f"{self.endpoint}/api/datasets-tags-by-type" + r = get_session().get(path) + hf_raise_for_status(r) + return r.json() + + @validate_hf_hub_args + def list_models( + self, + *, + filter: Union[ModelFilter, str, Iterable[str], None] = None, + author: Optional[str] = None, + search: Optional[str] = None, + emissions_thresholds: Optional[Tuple[float, float]] = None, + sort: Union[Literal["last_modified"], str, None] = None, + direction: Optional[Literal[-1]] = None, + limit: Optional[int] = None, + full: Optional[bool] = None, + cardData: bool = False, + fetch_config: bool = False, + token: Optional[Union[bool, str]] = None, + ) -> Iterable[ModelInfo]: + """ + List models hosted on the Huggingface Hub, given some filters. + + Args: + filter ([`ModelFilter`] or `str` or `Iterable`, *optional*): + A string or [`ModelFilter`] which can be used to identify models + on the Hub. + author (`str`, *optional*): + A string which identify the author (user or organization) of the + returned models + search (`str`, *optional*): + A string that will be contained in the returned model ids. + emissions_thresholds (`Tuple`, *optional*): + A tuple of two ints or floats representing a minimum and maximum + carbon footprint to filter the resulting models with in grams. + sort (`Literal["last_modified"]` or `str`, *optional*): + The key with which to sort the resulting models. Possible values + are the properties of the [`huggingface_hub.hf_api.ModelInfo`] class. + direction (`Literal[-1]` or `int`, *optional*): + Direction in which to sort. The value `-1` sorts by descending + order while all other values sort by ascending order. + limit (`int`, *optional*): + The limit on the number of models fetched. Leaving this option + to `None` fetches all models. + full (`bool`, *optional*): + Whether to fetch all model data, including the `last_modified`, + the `sha`, the files and the `tags`. This is set to `True` by + default when using a filter. + cardData (`bool`, *optional*): + Whether to grab the metadata for the model as well. Can contain + useful information such as carbon emissions, metrics, and + datasets trained on. + fetch_config (`bool`, *optional*): + Whether to fetch the model configs as well. This is not included + in `full` due to its size. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + If `None` or `True` and machine is logged in (through `huggingface-cli login` + or [`~huggingface_hub.login`]), token will be retrieved from the cache. + If `False`, token is not sent in the request header. + + Returns: + `Iterable[ModelInfo]`: an iterable of [`huggingface_hub.hf_api.ModelInfo`] objects. + + Example usage with the `filter` argument: + + ```python + >>> from huggingface_hub import HfApi + + >>> api = HfApi() + + >>> # List all models + >>> api.list_models() + + >>> # List only the text classification models + >>> api.list_models(filter="text-classification") + >>> # Using the `ModelFilter` + >>> filt = ModelFilter(task="text-classification") + + + >>> # List only models from the AllenNLP library + >>> api.list_models(filter="allennlp") + ``` + + Example usage with the `search` argument: + + ```python + >>> from huggingface_hub import HfApi + + >>> api = HfApi() + + >>> # List all models with "bert" in their name + >>> api.list_models(search="bert") + + >>> # List all models with "bert" in their name made by google + >>> api.list_models(search="bert", author="google") + ``` + """ + if emissions_thresholds is not None and cardData is None: + raise ValueError("`emissions_thresholds` were passed without setting `cardData=True`.") + + path = f"{self.endpoint}/api/models" + headers = self._build_hf_headers(token=token) + params = {} + if filter is not None: + if isinstance(filter, ModelFilter): + params = self._unpack_model_filter(filter) + else: + params.update({"filter": filter}) + params.update({"full": True}) + if author is not None: + params.update({"author": author}) + if search is not None: + params.update({"search": search}) + if sort is not None: + params.update({"sort": "lastModified" if sort == "last_modified" else sort}) + if direction is not None: + params.update({"direction": direction}) + if limit is not None: + params.update({"limit": limit}) + if full is not None: + if full: + params.update({"full": True}) + elif "full" in params: + del params["full"] + if fetch_config: + params.update({"config": True}) + if cardData: + params.update({"cardData": True}) + + # `items` is a generator + items = paginate(path, params=params, headers=headers) + if limit is not None: + items = islice(items, limit) # Do not iterate over all pages + for item in items: + if "siblings" not in item: + item["siblings"] = None + model_info = ModelInfo(**item) + if emissions_thresholds is None or _is_emission_within_treshold(model_info, *emissions_thresholds): + yield model_info + + def _unpack_model_filter(self, model_filter: ModelFilter): + """ + Unpacks a [`ModelFilter`] into something readable for `list_models` + """ + model_str = "" + + # Handling author + if model_filter.author is not None: + model_str = f"{model_filter.author}/" + + # Handling model_name + if model_filter.model_name is not None: + model_str += model_filter.model_name + + filter_list: List[str] = [] + + # Handling tasks + if model_filter.task is not None: + filter_list.extend([model_filter.task] if isinstance(model_filter.task, str) else model_filter.task) + + # Handling dataset + if model_filter.trained_dataset is not None: + if not isinstance(model_filter.trained_dataset, (list, tuple)): + model_filter.trained_dataset = [model_filter.trained_dataset] + for dataset in model_filter.trained_dataset: + if "dataset:" not in dataset: + dataset = f"dataset:{dataset}" + filter_list.append(dataset) + + # Handling library + if model_filter.library: + filter_list.extend( + [model_filter.library] if isinstance(model_filter.library, str) else model_filter.library + ) + + # Handling tags + if model_filter.tags: + filter_list.extend([model_filter.tags] if isinstance(model_filter.tags, str) else model_filter.tags) + + query_dict: Dict[str, Any] = {} + if model_str is not None: + query_dict["search"] = model_str + if isinstance(model_filter.language, list): + filter_list.extend(model_filter.language) + elif isinstance(model_filter.language, str): + filter_list.append(model_filter.language) + query_dict["filter"] = tuple(filter_list) + return query_dict + + @validate_hf_hub_args + def list_datasets( + self, + *, + filter: Union[DatasetFilter, str, Iterable[str], None] = None, + author: Optional[str] = None, + search: Optional[str] = None, + sort: Union[Literal["last_modified"], str, None] = None, + direction: Optional[Literal[-1]] = None, + limit: Optional[int] = None, + full: Optional[bool] = None, + token: Optional[str] = None, + ) -> Iterable[DatasetInfo]: + """ + List datasets hosted on the Huggingface Hub, given some filters. + + Args: + filter ([`DatasetFilter`] or `str` or `Iterable`, *optional*): + A string or [`DatasetFilter`] which can be used to identify + datasets on the hub. + author (`str`, *optional*): + A string which identify the author of the returned datasets. + search (`str`, *optional*): + A string that will be contained in the returned datasets. + sort (`Literal["last_modified"]` or `str`, *optional*): + The key with which to sort the resulting datasets. Possible + values are the properties of the [`huggingface_hub.hf_api.DatasetInfo`] class. + direction (`Literal[-1]` or `int`, *optional*): + Direction in which to sort. The value `-1` sorts by descending + order while all other values sort by ascending order. + limit (`int`, *optional*): + The limit on the number of datasets fetched. Leaving this option + to `None` fetches all datasets. + full (`bool`, *optional*): + Whether to fetch all dataset data, including the `last_modified`, + the `card_data` and the files. Can contain useful information such as the + PapersWithCode ID. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + If `None` or `True` and machine is logged in (through `huggingface-cli login` + or [`~huggingface_hub.login`]), token will be retrieved from the cache. + If `False`, token is not sent in the request header. + + Returns: + `Iterable[DatasetInfo]`: an iterable of [`huggingface_hub.hf_api.DatasetInfo`] objects. + + Example usage with the `filter` argument: + + ```python + >>> from huggingface_hub import HfApi + + >>> api = HfApi() + + >>> # List all datasets + >>> api.list_datasets() + + + >>> # List only the text classification datasets + >>> api.list_datasets(filter="task_categories:text-classification") + >>> # Using the `DatasetFilter` + >>> filt = DatasetFilter(task_categories="text-classification") + + + >>> # List only the datasets in russian for language modeling + >>> api.list_datasets( + ... filter=("language:ru", "task_ids:language-modeling") + ... ) + >>> # Using the `DatasetFilter` + >>> filt = DatasetFilter(language="ru", task_ids="language-modeling") + + >>> api.list_datasets(filter=filt) + ``` + + Example usage with the `search` argument: + + ```python + >>> from huggingface_hub import HfApi + + >>> api = HfApi() + + >>> # List all datasets with "text" in their name + >>> api.list_datasets(search="text") + + >>> # List all datasets with "text" in their name made by google + >>> api.list_datasets(search="text", author="google") + ``` + """ + path = f"{self.endpoint}/api/datasets" + headers = self._build_hf_headers(token=token) + params = {} + if filter is not None: + if isinstance(filter, DatasetFilter): + params = self._unpack_dataset_filter(filter) + else: + params.update({"filter": filter}) + if author is not None: + params.update({"author": author}) + if search is not None: + params.update({"search": search}) + if sort is not None: + params.update({"sort": "lastModified" if sort == "last_modified" else sort}) + if direction is not None: + params.update({"direction": direction}) + if limit is not None: + params.update({"limit": limit}) + if full: + params.update({"full": True}) + + items = paginate(path, params=params, headers=headers) + if limit is not None: + items = islice(items, limit) # Do not iterate over all pages + for item in items: + if "siblings" not in item: + item["siblings"] = None + yield DatasetInfo(**item) + + def _unpack_dataset_filter(self, dataset_filter: DatasetFilter): + """ + Unpacks a [`DatasetFilter`] into something readable for `list_datasets` + """ + dataset_str = "" + + # Handling author + if dataset_filter.author is not None: + dataset_str = f"{dataset_filter.author}/" + + # Handling dataset_name + if dataset_filter.dataset_name is not None: + dataset_str += dataset_filter.dataset_name + + filter_list = [] + data_attributes = [ + "benchmark", + "language_creators", + "language", + "multilinguality", + "size_categories", + "task_categories", + "task_ids", + ] + + for attr in data_attributes: + curr_attr = getattr(dataset_filter, attr) + if curr_attr is not None: + if not isinstance(curr_attr, (list, tuple)): + curr_attr = [curr_attr] + for data in curr_attr: + if f"{attr}:" not in data: + data = f"{attr}:{data}" + filter_list.append(data) + + query_dict: Dict[str, Any] = {} + if dataset_str is not None: + query_dict["search"] = dataset_str + query_dict["filter"] = tuple(filter_list) + return query_dict + + def list_metrics(self) -> List[MetricInfo]: + """ + Get the public list of all the metrics on huggingface.co + + Returns: + `List[MetricInfo]`: a list of [`MetricInfo`] objects which. + """ + path = f"{self.endpoint}/api/metrics" + r = get_session().get(path) + hf_raise_for_status(r) + d = r.json() + return [MetricInfo(**x) for x in d] + + @validate_hf_hub_args + def list_spaces( + self, + *, + filter: Union[str, Iterable[str], None] = None, + author: Optional[str] = None, + search: Optional[str] = None, + sort: Union[Literal["last_modified"], str, None] = None, + direction: Optional[Literal[-1]] = None, + limit: Optional[int] = None, + datasets: Union[str, Iterable[str], None] = None, + models: Union[str, Iterable[str], None] = None, + linked: bool = False, + full: Optional[bool] = None, + token: Optional[str] = None, + ) -> Iterable[SpaceInfo]: + """ + List spaces hosted on the Huggingface Hub, given some filters. + + Args: + filter (`str` or `Iterable`, *optional*): + A string tag or list of tags that can be used to identify Spaces on the Hub. + author (`str`, *optional*): + A string which identify the author of the returned Spaces. + search (`str`, *optional*): + A string that will be contained in the returned Spaces. + sort (`Literal["last_modified"]` or `str`, *optional*): + The key with which to sort the resulting Spaces. Possible + values are the properties of the [`huggingface_hub.hf_api.SpaceInfo`]` class. + direction (`Literal[-1]` or `int`, *optional*): + Direction in which to sort. The value `-1` sorts by descending + order while all other values sort by ascending order. + limit (`int`, *optional*): + The limit on the number of Spaces fetched. Leaving this option + to `None` fetches all Spaces. + datasets (`str` or `Iterable`, *optional*): + Whether to return Spaces that make use of a dataset. + The name of a specific dataset can be passed as a string. + models (`str` or `Iterable`, *optional*): + Whether to return Spaces that make use of a model. + The name of a specific model can be passed as a string. + linked (`bool`, *optional*): + Whether to return Spaces that make use of either a model or a dataset. + full (`bool`, *optional*): + Whether to fetch all Spaces data, including the `last_modified`, `siblings` + and `card_data` fields. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + If `None` or `True` and machine is logged in (through `huggingface-cli login` + or [`~huggingface_hub.login`]), token will be retrieved from the cache. + If `False`, token is not sent in the request header. + + Returns: + `Iterable[SpaceInfo]`: an iterable of [`huggingface_hub.hf_api.SpaceInfo`] objects. + """ + path = f"{self.endpoint}/api/spaces" + headers = self._build_hf_headers(token=token) + params: Dict[str, Any] = {} + if filter is not None: + params.update({"filter": filter}) + if author is not None: + params.update({"author": author}) + if search is not None: + params.update({"search": search}) + if sort is not None: + params.update({"sort": "lastModified" if sort == "last_modified" else sort}) + if direction is not None: + params.update({"direction": direction}) + if limit is not None: + params.update({"limit": limit}) + if full: + params.update({"full": True}) + if linked: + params.update({"linked": True}) + if datasets is not None: + params.update({"datasets": datasets}) + if models is not None: + params.update({"models": models}) + + items = paginate(path, params=params, headers=headers) + if limit is not None: + items = islice(items, limit) # Do not iterate over all pages + for item in items: + if "siblings" not in item: + item["siblings"] = None + yield SpaceInfo(**item) + + @validate_hf_hub_args + def like( + self, + repo_id: str, + *, + token: Optional[str] = None, + repo_type: Optional[str] = None, + ) -> None: + """ + Like a given repo on the Hub (e.g. set as favorite). + + See also [`unlike`] and [`list_liked_repos`]. + + Args: + repo_id (`str`): + The repository to like. Example: `"user/my-cool-model"`. + + token (`str`, *optional*): + Authentication token. Will default to the stored token. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if liking a dataset or space, `None` or + `"model"` if liking a model. Default is `None`. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + + Example: + ```python + >>> from huggingface_hub import like, list_liked_repos, unlike + >>> like("gpt2") + >>> "gpt2" in list_liked_repos().models + True + >>> unlike("gpt2") + >>> "gpt2" in list_liked_repos().models + False + ``` + """ + if repo_type is None: + repo_type = REPO_TYPE_MODEL + response = get_session().post( + url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/like", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + @validate_hf_hub_args + def unlike( + self, + repo_id: str, + *, + token: Optional[str] = None, + repo_type: Optional[str] = None, + ) -> None: + """ + Unlike a given repo on the Hub (e.g. remove from favorite list). + + See also [`like`] and [`list_liked_repos`]. + + Args: + repo_id (`str`): + The repository to unlike. Example: `"user/my-cool-model"`. + + token (`str`, *optional*): + Authentication token. Will default to the stored token. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if unliking a dataset or space, `None` or + `"model"` if unliking a model. Default is `None`. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + + Example: + ```python + >>> from huggingface_hub import like, list_liked_repos, unlike + >>> like("gpt2") + >>> "gpt2" in list_liked_repos().models + True + >>> unlike("gpt2") + >>> "gpt2" in list_liked_repos().models + False + ``` + """ + if repo_type is None: + repo_type = REPO_TYPE_MODEL + response = get_session().delete( + url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/like", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(response) + + @validate_hf_hub_args + def list_liked_repos( + self, + user: Optional[str] = None, + *, + token: Optional[str] = None, + ) -> UserLikes: + """ + List all public repos liked by a user on huggingface.co. + + This list is public so token is optional. If `user` is not passed, it defaults to + the logged in user. + + See also [`like`] and [`unlike`]. + + Args: + user (`str`, *optional*): + Name of the user for which you want to fetch the likes. + token (`str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + Used only if `user` is not passed to implicitly determine the current + user name. + + Returns: + [`UserLikes`]: object containing the user name and 3 lists of repo ids (1 for + models, 1 for datasets and 1 for Spaces). + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `user` is not passed and no token found (either from argument or from machine). + + Example: + ```python + >>> from huggingface_hub import list_liked_repos + + >>> likes = list_liked_repos("julien-c") + + >>> likes.user + "julien-c" + + >>> likes.models + ["osanseviero/streamlit_1.15", "Xhaheen/ChatGPT_HF", ...] + ``` + """ + # User is either provided explicitly or retrieved from current token. + if user is None: + me = self.whoami(token=token) + if me["type"] == "user": + user = me["name"] + else: + raise ValueError( + "Cannot list liked repos. You must provide a 'user' as input or be logged in as a user." + ) + + path = f"{self.endpoint}/api/users/{user}/likes" + headers = self._build_hf_headers(token=token) + + likes = list(paginate(path, params={}, headers=headers)) + # Looping over a list of items similar to: + # { + # 'createdAt': '2021-09-09T21:53:27.000Z', + # 'repo': { + # 'name': 'PaddlePaddle/PaddleOCR', + # 'type': 'space' + # } + # } + # Let's loop 3 times over the received list. Less efficient but more straightforward to read. + return UserLikes( + user=user, + total=len(likes), + models=[like["repo"]["name"] for like in likes if like["repo"]["type"] == "model"], + datasets=[like["repo"]["name"] for like in likes if like["repo"]["type"] == "dataset"], + spaces=[like["repo"]["name"] for like in likes if like["repo"]["type"] == "space"], + ) + + @validate_hf_hub_args + def list_repo_likers( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + token: Optional[str] = None, + ) -> List[User]: + """ + List all users who liked a given repo on the hugging Face Hub. + + See also [`like`] and [`list_liked_repos`]. + + Args: + repo_id (`str`): + The repository to retrieve . Example: `"user/my-cool-model"`. + + token (`str`, *optional*): + Authentication token. Will default to the stored token. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + + Returns: + `List[User]`: a list of [`User`] objects. + """ + + # Construct the API endpoint + if repo_type is None: + repo_type = REPO_TYPE_MODEL + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/likers" + headers = self._build_hf_headers(token=token) + + # Make the request + response = get_session().get(path, headers=headers) + hf_raise_for_status(response) + + # Parse the results into User objects + likers_data = response.json() + return [ + User( + username=user_data["user"], + fullname=user_data["fullname"], + avatar_url=user_data["avatarUrl"], + ) + for user_data in likers_data + ] + + @validate_hf_hub_args + def model_info( + self, + repo_id: str, + *, + revision: Optional[str] = None, + timeout: Optional[float] = None, + securityStatus: Optional[bool] = None, + files_metadata: bool = False, + token: Optional[Union[bool, str]] = None, + ) -> ModelInfo: + """ + Get info on one specific model on huggingface.co + + Model can be private if you pass an acceptable token or are logged in. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The revision of the model repository from which to get the + information. + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + securityStatus (`bool`, *optional*): + Whether to retrieve the security status from the model + repository as well. + files_metadata (`bool`, *optional*): + Whether or not to retrieve metadata for files in the repository + (size, LFS metadata, etc). Defaults to `False`. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + If `None` or `True` and machine is logged in (through `huggingface-cli login` + or [`~huggingface_hub.login`]), token will be retrieved from the cache. + If `False`, token is not sent in the request header. + + Returns: + [`huggingface_hub.hf_api.ModelInfo`]: The model repository information. + + + + Raises the following errors: + + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + - [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + + + """ + headers = self._build_hf_headers(token=token) + path = ( + f"{self.endpoint}/api/models/{repo_id}" + if revision is None + else (f"{self.endpoint}/api/models/{repo_id}/revision/{quote(revision, safe='')}") + ) + params = {} + if securityStatus: + params["securityStatus"] = True + if files_metadata: + params["blobs"] = True + r = get_session().get(path, headers=headers, timeout=timeout, params=params) + hf_raise_for_status(r) + data = r.json() + return ModelInfo(**data) + + @validate_hf_hub_args + def dataset_info( + self, + repo_id: str, + *, + revision: Optional[str] = None, + timeout: Optional[float] = None, + files_metadata: bool = False, + token: Optional[Union[bool, str]] = None, + ) -> DatasetInfo: + """ + Get info on one specific dataset on huggingface.co. + + Dataset can be private if you pass an acceptable token. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The revision of the dataset repository from which to get the + information. + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + files_metadata (`bool`, *optional*): + Whether or not to retrieve metadata for files in the repository + (size, LFS metadata, etc). Defaults to `False`. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + If `None` or `True` and machine is logged in (through `huggingface-cli login` + or [`~huggingface_hub.login`]), token will be retrieved from the cache. + If `False`, token is not sent in the request header. + + Returns: + [`hf_api.DatasetInfo`]: The dataset repository information. + + + + Raises the following errors: + + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + - [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + + + """ + headers = self._build_hf_headers(token=token) + path = ( + f"{self.endpoint}/api/datasets/{repo_id}" + if revision is None + else (f"{self.endpoint}/api/datasets/{repo_id}/revision/{quote(revision, safe='')}") + ) + params = {} + if files_metadata: + params["blobs"] = True + + r = get_session().get(path, headers=headers, timeout=timeout, params=params) + hf_raise_for_status(r) + data = r.json() + return DatasetInfo(**data) + + @validate_hf_hub_args + def space_info( + self, + repo_id: str, + *, + revision: Optional[str] = None, + timeout: Optional[float] = None, + files_metadata: bool = False, + token: Optional[Union[bool, str]] = None, + ) -> SpaceInfo: + """ + Get info on one specific Space on huggingface.co. + + Space can be private if you pass an acceptable token. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The revision of the space repository from which to get the + information. + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + files_metadata (`bool`, *optional*): + Whether or not to retrieve metadata for files in the repository + (size, LFS metadata, etc). Defaults to `False`. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + If `None` or `True` and machine is logged in (through `huggingface-cli login` + or [`~huggingface_hub.login`]), token will be retrieved from the cache. + If `False`, token is not sent in the request header. + + Returns: + [`~hf_api.SpaceInfo`]: The space repository information. + + + + Raises the following errors: + + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + - [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + + + """ + headers = self._build_hf_headers(token=token) + path = ( + f"{self.endpoint}/api/spaces/{repo_id}" + if revision is None + else (f"{self.endpoint}/api/spaces/{repo_id}/revision/{quote(revision, safe='')}") + ) + params = {} + if files_metadata: + params["blobs"] = True + + r = get_session().get(path, headers=headers, timeout=timeout, params=params) + hf_raise_for_status(r) + data = r.json() + return SpaceInfo(**data) + + @validate_hf_hub_args + def repo_info( + self, + repo_id: str, + *, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + timeout: Optional[float] = None, + files_metadata: bool = False, + token: Optional[Union[bool, str]] = None, + ) -> Union[ModelInfo, DatasetInfo, SpaceInfo]: + """ + Get the info object for a given repo of a given type. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The revision of the repository from which to get the + information. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if getting repository info from a dataset or a space, + `None` or `"model"` if getting repository info from a model. Default is `None`. + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + files_metadata (`bool`, *optional*): + Whether or not to retrieve metadata for files in the repository + (size, LFS metadata, etc). Defaults to `False`. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + If `None` or `True` and machine is logged in (through `huggingface-cli login` + or [`~huggingface_hub.login`]), token will be retrieved from the cache. + If `False`, token is not sent in the request header. + + Returns: + `Union[SpaceInfo, DatasetInfo, ModelInfo]`: The repository information, as a + [`huggingface_hub.hf_api.DatasetInfo`], [`huggingface_hub.hf_api.ModelInfo`] + or [`huggingface_hub.hf_api.SpaceInfo`] object. + + + + Raises the following errors: + + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + - [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + + + """ + if repo_type is None or repo_type == "model": + method = self.model_info + elif repo_type == "dataset": + method = self.dataset_info # type: ignore + elif repo_type == "space": + method = self.space_info # type: ignore + else: + raise ValueError("Unsupported repo type.") + return method( + repo_id, + revision=revision, + token=token, + timeout=timeout, + files_metadata=files_metadata, + ) + + @validate_hf_hub_args + def repo_exists( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + token: Optional[str] = None, + ) -> bool: + """ + Checks if a repository exists on the Hugging Face Hub. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if getting repository info from a dataset or a space, + `None` or `"model"` if getting repository info from a model. Default is `None`. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + If `None` or `True` and machine is logged in (through `huggingface-cli login` + or [`~huggingface_hub.login`]), token will be retrieved from the cache. + If `False`, token is not sent in the request header. + + Returns: + True if the repository exists, False otherwise. + + + + Examples: + ```py + >>> from huggingface_hub import repo_exists + >>> repo_exists("huggingface/transformers") + True + >>> repo_exists("huggingface/not-a-repo") + False + ``` + + + """ + try: + self.repo_info(repo_id=repo_id, repo_type=repo_type, token=token) + return True + except GatedRepoError: + return True # we don't have access but it exists + except RepositoryNotFoundError: + return False + + @validate_hf_hub_args + def file_exists( + self, + repo_id: str, + filename: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + token: Optional[str] = None, + ) -> bool: + """ + Checks if a file exists in a repository on the Hugging Face Hub. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + filename (`str`): + The name of the file to check, for example: + `"config.json"` + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if getting repository info from a dataset or a space, + `None` or `"model"` if getting repository info from a model. Default is `None`. + revision (`str`, *optional*): + The revision of the repository from which to get the information. Defaults to `"main"` branch. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + If `None` or `True` and machine is logged in (through `huggingface-cli login` + or [`~huggingface_hub.login`]), token will be retrieved from the cache. + If `False`, token is not sent in the request header. + + Returns: + True if the file exists, False otherwise. + + + + Examples: + ```py + >>> from huggingface_hub import file_exists + >>> file_exists("bigcode/starcoder", "config.json") + True + >>> file_exists("bigcode/starcoder", "not-a-file") + False + >>> file_exists("bigcode/not-a-repo", "config.json") + False + ``` + + + """ + url = hf_hub_url( + repo_id=repo_id, repo_type=repo_type, revision=revision, filename=filename, endpoint=self.endpoint + ) + try: + if token is None: + token = self.token + get_hf_file_metadata(url, token=token) + return True + except GatedRepoError: # raise specifically on gated repo + raise + except (RepositoryNotFoundError, EntryNotFoundError, RevisionNotFoundError): + return False + + @validate_hf_hub_args + @_deprecate_method(version="0.23", message="Use `list_repo_tree` and `get_paths_info` instead.") + def list_files_info( + self, + repo_id: str, + paths: Union[List[str], str, None] = None, + *, + expand: bool = False, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + token: Optional[Union[bool, str]] = None, + ) -> Iterable[RepoFile]: + """ + List files on a repo and get information about them. + + Takes as input a list of paths. Those paths can be either files or folders. Two server endpoints are called: + 1. POST "/paths-info" to get information about the provided paths. Called once. + 2. GET "/tree?recursive=True" to paginate over the input folders. Called only if a folder path is provided as + input. Will be called multiple times to follow pagination. + If no path is provided as input, step 1. is ignored and all files from the repo are listed. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + paths (`Union[List[str], str, None]`, *optional*): + The paths to get information about. Paths to files are directly resolved. Paths to folders are resolved + recursively which means that information is returned about all files in the folder and its subfolders. + If `None`, all files are returned (the default). If a path do not exist, it is ignored without raising + an exception. + expand (`bool`, *optional*, defaults to `False`): + Whether to fetch more information about the files (e.g. last commit and security scan results). This + operation is more expensive for the server so only 50 results are returned per page (instead of 1000). + As pagination is implemented in `huggingface_hub`, this is transparent for you except for the time it + takes to get the results. + revision (`str`, *optional*): + The revision of the repository from which to get the information. Defaults to `"main"` branch. + repo_type (`str`, *optional*): + The type of the repository from which to get the information (`"model"`, `"dataset"` or `"space"`. + Defaults to `"model"`. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). If `None` or `True` and + machine is logged in (through `huggingface-cli login` or [`~huggingface_hub.login`]), token will be + retrieved from the cache. If `False`, token is not sent in the request header. + + Returns: + `Iterable[RepoFile]`: + The information about the files, as an iterable of [`RepoFile`] objects. The order of the files is + not guaranteed. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo + does not exist. + [`~utils.RevisionNotFoundError`]: + If revision is not found (error 404) on the repo. + + Examples: + + Get information about files on a repo. + ```py + >>> from huggingface_hub import list_files_info + >>> files_info = list_files_info("lysandre/arxiv-nlp", ["README.md", "config.json"]) + >>> files_info + + >>> list(files_info) + [ + RepoFile(path='README.md', size=391, blob_id='43bd404b159de6fba7c2f4d3264347668d43af25', lfs=None, last_commit=None, security=None), + RepoFile(path='config.json', size=554, blob_id='2f9618c3a19b9a61add74f70bfb121335aeef666', lfs=None, last_commit=None, security=None) + ] + ``` + + Get even more information about files on a repo (last commit and security scan results) + ```py + >>> from huggingface_hub import list_files_info + >>> files_info = list_files_info("prompthero/openjourney-v4", expand=True) + >>> list(files_info) + [ + RepoFile( + path='safety_checker/pytorch_model.bin', + size=1216064769, + blob_id='c8835557a0d3af583cb06c7c154b7e54a069c41d', + lfs={ + 'size': 1216064769, + 'sha256': '16d28f2b37109f222cdc33620fdd262102ac32112be0352a7f77e9614b35a394', + 'pointer_size': 135 + }, + last_commit={ + 'oid': '47b62b20b20e06b9de610e840282b7e6c3d51190', + 'title': 'Upload diffusers weights (#48)', + 'date': datetime.datetime(2023, 3, 21, 10, 5, 27, tzinfo=datetime.timezone.utc) + }, + security={ + 'safe': True, + 'av_scan': { + 'virusFound': False, + 'virusNames': None + }, + 'pickle_import_scan': { + 'highestSafetyLevel': 'innocuous', + 'imports': [ + {'module': 'torch', 'name': 'FloatStorage', 'safety': 'innocuous'}, + {'module': 'collections', 'name': 'OrderedDict', 'safety': 'innocuous'}, + {'module': 'torch', 'name': 'LongStorage', 'safety': 'innocuous'}, + {'module': 'torch._utils', 'name': '_rebuild_tensor_v2', 'safety': 'innocuous'} + ] + } + } + ), + RepoFile( + path='scheduler/scheduler_config.json', + size=465, + blob_id='70d55e3e9679f41cbc66222831b38d5abce683dd', + lfs=None, + last_commit={ + 'oid': '47b62b20b20e06b9de610e840282b7e6c3d51190', + 'title': 'Upload diffusers weights (#48)', + 'date': datetime.datetime(2023, 3, 21, 10, 5, 27, tzinfo=datetime.timezone.utc)}, + security={ + 'safe': True, + 'av_scan': { + 'virusFound': False, + 'virusNames': None + }, + 'pickle_import_scan': None + } + ), + ... + ] + ``` + + List LFS files from the "vae/" folder in "stabilityai/stable-diffusion-2" repository. + + ```py + >>> from huggingface_hub import list_files_info + >>> [info.path for info in list_files_info("stabilityai/stable-diffusion-2", "vae") if info.lfs is not None] + ['vae/diffusion_pytorch_model.bin', 'vae/diffusion_pytorch_model.safetensors'] + ``` + + List all files on a repo. + ```py + >>> from huggingface_hub import list_files_info + >>> [info.path for info in list_files_info("glue", repo_type="dataset")] + ['.gitattributes', 'README.md', 'dataset_infos.json', 'glue.py'] + ``` + """ + repo_type = repo_type or REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else DEFAULT_REVISION + headers = self._build_hf_headers(token=token) + + folder_paths = [] + if paths is None: + # `paths` is not provided => list all files from the repo + folder_paths.append("") + elif paths == []: + # corner case: server would return a 400 error if `paths` is an empty list. Let's return early. + return + else: + # `paths` is provided => get info about those + response = get_session().post( + f"{self.endpoint}/api/{repo_type}s/{repo_id}/paths-info/{revision}", + data={ + "paths": paths if isinstance(paths, list) else [paths], + "expand": expand, + }, + headers=headers, + ) + hf_raise_for_status(response) + paths_info = response.json() + + # List top-level files first + for path_info in paths_info: + if path_info["type"] == "file": + yield RepoFile(**path_info) + else: + folder_paths.append(path_info["path"]) + + # List files in subdirectories + for path in folder_paths: + encoded_path = "/" + quote(path, safe="") if path else "" + tree_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tree/{revision}{encoded_path}" + for subpath_info in paginate(path=tree_url, headers=headers, params={"recursive": True, "expand": expand}): + if subpath_info["type"] == "file": + yield RepoFile(**subpath_info) + + @validate_hf_hub_args + def list_repo_files( + self, + repo_id: str, + *, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + token: Optional[Union[bool, str]] = None, + ) -> List[str]: + """ + Get the list of files in a given repo. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + revision (`str`, *optional*): + The revision of the model repository from which to get the information. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to + a model. Default is `None`. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). If `None` or `True` and + machine is logged in (through `huggingface-cli login` or [`~huggingface_hub.login`]), token will be + retrieved from the cache. If `False`, token is not sent in the request header. + + Returns: + `List[str]`: the list of files in a given repository. + """ + return [ + f.rfilename + for f in self.list_repo_tree( + repo_id=repo_id, recursive=True, revision=revision, repo_type=repo_type, token=token + ) + if isinstance(f, RepoFile) + ] + + @validate_hf_hub_args + def list_repo_tree( + self, + repo_id: str, + path_in_repo: Optional[str] = None, + *, + recursive: bool = False, + expand: bool = False, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + token: Optional[Union[bool, str]] = None, + ) -> Iterable[Union[RepoFile, RepoFolder]]: + """ + List a repo tree's files and folders and get information about them. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + path_in_repo (`str`, *optional*): + Relative path of the tree (folder) in the repo, for example: + `"checkpoints/1fec34a/results"`. Will default to the root tree (folder) of the repository. + recursive (`bool`, *optional*, defaults to `False`): + Whether to list tree's files and folders recursively. + expand (`bool`, *optional*, defaults to `False`): + Whether to fetch more information about the tree's files and folders (e.g. last commit and files' security scan results). This + operation is more expensive for the server so only 50 results are returned per page (instead of 1000). + As pagination is implemented in `huggingface_hub`, this is transparent for you except for the time it + takes to get the results. + revision (`str`, *optional*): + The revision of the repository from which to get the tree. Defaults to `"main"` branch. + repo_type (`str`, *optional*): + The type of the repository from which to get the tree (`"model"`, `"dataset"` or `"space"`. + Defaults to `"model"`. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). If `None` or `True` and + machine is logged in (through `huggingface-cli login` or [`~huggingface_hub.login`]), token will be + retrieved from the cache. If `False`, token is not sent in the request header. + + Returns: + `Iterable[Union[RepoFile, RepoFolder]]`: + The information about the tree's files and folders, as an iterable of [`RepoFile`] and [`RepoFolder`] objects. The order of the files and folders is + not guaranteed. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo + does not exist. + [`~utils.RevisionNotFoundError`]: + If revision is not found (error 404) on the repo. + [`~utils.EntryNotFoundError`]: + If the tree (folder) does not exist (error 404) on the repo. + + Examples: + + Get information about a repo's tree. + ```py + >>> from huggingface_hub import list_repo_tree + >>> repo_tree = list_repo_tree("lysandre/arxiv-nlp") + >>> repo_tree + + >>> list(repo_tree) + [ + RepoFile(path='.gitattributes', size=391, blob_id='ae8c63daedbd4206d7d40126955d4e6ab1c80f8f', lfs=None, last_commit=None, security=None), + RepoFile(path='README.md', size=391, blob_id='43bd404b159de6fba7c2f4d3264347668d43af25', lfs=None, last_commit=None, security=None), + RepoFile(path='config.json', size=554, blob_id='2f9618c3a19b9a61add74f70bfb121335aeef666', lfs=None, last_commit=None, security=None), + RepoFile( + path='flax_model.msgpack', size=497764107, blob_id='8095a62ccb4d806da7666fcda07467e2d150218e', + lfs={'size': 497764107, 'sha256': 'd88b0d6a6ff9c3f8151f9d3228f57092aaea997f09af009eefd7373a77b5abb9', 'pointer_size': 134}, last_commit=None, security=None + ), + RepoFile(path='merges.txt', size=456318, blob_id='226b0752cac7789c48f0cb3ec53eda48b7be36cc', lfs=None, last_commit=None, security=None), + RepoFile( + path='pytorch_model.bin', size=548123560, blob_id='64eaa9c526867e404b68f2c5d66fd78e27026523', + lfs={'size': 548123560, 'sha256': '9be78edb5b928eba33aa88f431551348f7466ba9f5ef3daf1d552398722a5436', 'pointer_size': 134}, last_commit=None, security=None + ), + RepoFile(path='vocab.json', size=898669, blob_id='b00361fece0387ca34b4b8b8539ed830d644dbeb', lfs=None, last_commit=None, security=None)] + ] + ``` + + Get even more information about a repo's tree (last commit and files' security scan results) + ```py + >>> from huggingface_hub import list_repo_tree + >>> repo_tree = list_repo_tree("prompthero/openjourney-v4", expand=True) + >>> list(repo_tree) + [ + RepoFolder( + path='feature_extractor', + tree_id='aa536c4ea18073388b5b0bc791057a7296a00398', + last_commit={ + 'oid': '47b62b20b20e06b9de610e840282b7e6c3d51190', + 'title': 'Upload diffusers weights (#48)', + 'date': datetime.datetime(2023, 3, 21, 9, 5, 27, tzinfo=datetime.timezone.utc) + } + ), + RepoFolder( + path='safety_checker', + tree_id='65aef9d787e5557373fdf714d6c34d4fcdd70440', + last_commit={ + 'oid': '47b62b20b20e06b9de610e840282b7e6c3d51190', + 'title': 'Upload diffusers weights (#48)', + 'date': datetime.datetime(2023, 3, 21, 9, 5, 27, tzinfo=datetime.timezone.utc) + } + ), + RepoFile( + path='model_index.json', + size=582, + blob_id='d3d7c1e8c3e78eeb1640b8e2041ee256e24c9ee1', + lfs=None, + last_commit={ + 'oid': 'b195ed2d503f3eb29637050a886d77bd81d35f0e', + 'title': 'Fix deprecation warning by changing `CLIPFeatureExtractor` to `CLIPImageProcessor`. (#54)', + 'date': datetime.datetime(2023, 5, 15, 21, 41, 59, tzinfo=datetime.timezone.utc) + }, + security={ + 'safe': True, + 'av_scan': {'virusFound': False, 'virusNames': None}, + 'pickle_import_scan': None + } + ) + ... + ] + ``` + """ + repo_type = repo_type or REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else DEFAULT_REVISION + headers = self._build_hf_headers(token=token) + + encoded_path_in_repo = "/" + quote(path_in_repo, safe="") if path_in_repo else "" + tree_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tree/{revision}{encoded_path_in_repo}" + for path_info in paginate(path=tree_url, headers=headers, params={"recursive": recursive, "expand": expand}): + yield (RepoFile(**path_info) if path_info["type"] == "file" else RepoFolder(**path_info)) + + @validate_hf_hub_args + def list_repo_refs( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + include_pull_requests: bool = False, + token: Optional[Union[bool, str]] = None, + ) -> GitRefs: + """ + Get the list of refs of a given repo (both tags and branches). + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if listing refs from a dataset or a Space, + `None` or `"model"` if listing from a model. Default is `None`. + include_pull_requests (`bool`, *optional*): + Whether to include refs from pull requests in the list. Defaults to `False`. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + If `None` or `True` and machine is logged in (through `huggingface-cli login` + or [`~huggingface_hub.login`]), token will be retrieved from the cache. + If `False`, token is not sent in the request header. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> api.list_repo_refs("gpt2") + GitRefs(branches=[GitRefInfo(name='main', ref='refs/heads/main', target_commit='e7da7f221d5bf496a48136c0cd264e630fe9fcc8')], converts=[], tags=[]) + + >>> api.list_repo_refs("bigcode/the-stack", repo_type='dataset') + GitRefs( + branches=[ + GitRefInfo(name='main', ref='refs/heads/main', target_commit='18edc1591d9ce72aa82f56c4431b3c969b210ae3'), + GitRefInfo(name='v1.1.a1', ref='refs/heads/v1.1.a1', target_commit='f9826b862d1567f3822d3d25649b0d6d22ace714') + ], + converts=[], + tags=[ + GitRefInfo(name='v1.0', ref='refs/tags/v1.0', target_commit='c37a8cd1e382064d8aced5e05543c5f7753834da') + ] + ) + ``` + + Returns: + [`GitRefs`]: object containing all information about branches and tags for a + repo on the Hub. + """ + repo_type = repo_type or REPO_TYPE_MODEL + response = get_session().get( + f"{self.endpoint}/api/{repo_type}s/{repo_id}/refs", + headers=self._build_hf_headers(token=token), + params={"include_prs": 1} if include_pull_requests else {}, + ) + hf_raise_for_status(response) + data = response.json() + + def _format_as_git_ref_info(item: Dict) -> GitRefInfo: + return GitRefInfo(name=item["name"], ref=item["ref"], target_commit=item["targetCommit"]) + + return GitRefs( + branches=[_format_as_git_ref_info(item) for item in data["branches"]], + converts=[_format_as_git_ref_info(item) for item in data["converts"]], + tags=[_format_as_git_ref_info(item) for item in data["tags"]], + pull_requests=[_format_as_git_ref_info(item) for item in data["pullRequests"]] + if include_pull_requests + else None, + ) + + @validate_hf_hub_args + def list_repo_commits( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + formatted: bool = False, + ) -> List[GitCommitInfo]: + """ + Get the list of commits of a given revision for a repo on the Hub. + + Commits are sorted by date (last commit first). + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if listing commits from a dataset or a Space, `None` or `"model"` if + listing from a model. Default is `None`. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + If `None` or `True` and machine is logged in (through `huggingface-cli login` + or [`~huggingface_hub.login`]), token will be retrieved from the cache. + If `False`, token is not sent in the request header. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + formatted (`bool`): + Whether to return the HTML-formatted title and description of the commits. Defaults to False. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + + # Commits are sorted by date (last commit first) + >>> initial_commit = api.list_repo_commits("gpt2")[-1] + + # Initial commit is always a system commit containing the `.gitattributes` file. + >>> initial_commit + GitCommitInfo( + commit_id='9b865efde13a30c13e0a33e536cf3e4a5a9d71d8', + authors=['system'], + created_at=datetime.datetime(2019, 2, 18, 10, 36, 15, tzinfo=datetime.timezone.utc), + title='initial commit', + message='', + formatted_title=None, + formatted_message=None + ) + + # Create an empty branch by deriving from initial commit + >>> api.create_branch("gpt2", "new_empty_branch", revision=initial_commit.commit_id) + ``` + + Returns: + List[[`GitCommitInfo`]]: list of objects containing information about the commits for a repo on the Hub. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo + does not exist. + [`~utils.RevisionNotFoundError`]: + If revision is not found (error 404) on the repo. + """ + repo_type = repo_type or REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else DEFAULT_REVISION + + # Paginate over results and return the list of commits. + return [ + GitCommitInfo( + commit_id=item["id"], + authors=[author["user"] for author in item["authors"]], + created_at=parse_datetime(item["date"]), + title=item["title"], + message=item["message"], + formatted_title=item.get("formatted", {}).get("title"), + formatted_message=item.get("formatted", {}).get("message"), + ) + for item in paginate( + f"{self.endpoint}/api/{repo_type}s/{repo_id}/commits/{revision}", + headers=self._build_hf_headers(token=token), + params={"expand[]": "formatted"} if formatted else {}, + ) + ] + + @validate_hf_hub_args + def get_paths_info( + self, + repo_id: str, + paths: Union[List[str], str], + *, + expand: bool = False, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + token: Optional[Union[bool, str]] = None, + ) -> List[Union[RepoFile, RepoFolder]]: + """ + Get information about a repo's paths. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + paths (`Union[List[str], str]`, *optional*): + The paths to get information about. If a path do not exist, it is ignored without raising + an exception. + expand (`bool`, *optional*, defaults to `False`): + Whether to fetch more information about the paths (e.g. last commit and files' security scan results). This + operation is more expensive for the server so only 50 results are returned per page (instead of 1000). + As pagination is implemented in `huggingface_hub`, this is transparent for you except for the time it + takes to get the results. + revision (`str`, *optional*): + The revision of the repository from which to get the information. Defaults to `"main"` branch. + repo_type (`str`, *optional*): + The type of the repository from which to get the information (`"model"`, `"dataset"` or `"space"`. + Defaults to `"model"`. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). If `None` or `True` and + machine is logged in (through `huggingface-cli login` or [`~huggingface_hub.login`]), token will be + retrieved from the cache. If `False`, token is not sent in the request header. + + Returns: + `List[Union[RepoFile, RepoFolder]]`: + The information about the paths, as a list of [`RepoFile`] and [`RepoFolder`] objects. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo + does not exist. + [`~utils.RevisionNotFoundError`]: + If revision is not found (error 404) on the repo. + + Example: + ```py + >>> from huggingface_hub import get_paths_info + >>> paths_info = get_paths_info("allenai/c4", ["README.md", "en"], repo_type="dataset") + >>> paths_info + [ + RepoFile(path='README.md', size=2379, blob_id='f84cb4c97182890fc1dbdeaf1a6a468fd27b4fff', lfs=None, last_commit=None, security=None), + RepoFolder(path='en', tree_id='dc943c4c40f53d02b31ced1defa7e5f438d5862e', last_commit=None) + ] + ``` + """ + repo_type = repo_type or REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else DEFAULT_REVISION + headers = self._build_hf_headers(token=token) + + response = get_session().post( + f"{self.endpoint}/api/{repo_type}s/{repo_id}/paths-info/{revision}", + data={ + "paths": paths if isinstance(paths, list) else [paths], + "expand": expand, + }, + headers=headers, + ) + hf_raise_for_status(response) + paths_info = response.json() + return [ + RepoFile(**path_info) if path_info["type"] == "file" else RepoFolder(**path_info) + for path_info in paths_info + ] + + @validate_hf_hub_args + def super_squash_history( + self, + repo_id: str, + *, + branch: Optional[str] = None, + commit_message: Optional[str] = None, + repo_type: Optional[str] = None, + token: Optional[str] = None, + ) -> None: + """Squash commit history on a branch for a repo on the Hub. + + Squashing the repo history is useful when you know you'll make hundreds of commits and you don't want to + clutter the history. Squashing commits can only be performed from the head of a branch. + + + + Once squashed, the commit history cannot be retrieved. This is a non-revertible operation. + + + + + + Once the history of a branch has been squashed, it is not possible to merge it back into another branch since + their history will have diverged. + + + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + branch (`str`, *optional*): + The branch to squash. Defaults to the head of the `"main"` branch. + commit_message (`str`, *optional*): + The commit message to use for the squashed commit. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if listing commits from a dataset or a Space, `None` or `"model"` if + listing from a model. Default is `None`. + token (`str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). If the machine is logged in + (through `huggingface-cli login` or [`~huggingface_hub.login`]), token can be automatically retrieved + from the cache. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo + does not exist. + [`~utils.RevisionNotFoundError`]: + If the branch to squash cannot be found. + [`~utils.BadRequestError`]: + If invalid reference for a branch. You cannot squash history on tags. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + + # Create repo + >>> repo_id = api.create_repo("test-squash").repo_id + + # Make a lot of commits. + >>> api.upload_file(repo_id=repo_id, path_in_repo="file.txt", path_or_fileobj=b"content") + >>> api.upload_file(repo_id=repo_id, path_in_repo="lfs.bin", path_or_fileobj=b"content") + >>> api.upload_file(repo_id=repo_id, path_in_repo="file.txt", path_or_fileobj=b"another_content") + + # Squash history + >>> api.super_squash_history(repo_id=repo_id) + ``` + """ + if repo_type is None: + repo_type = REPO_TYPE_MODEL + if repo_type not in REPO_TYPES: + raise ValueError("Invalid repo type") + if branch is None: + branch = DEFAULT_REVISION + + # Prepare request + url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/super-squash/{branch}" + headers = self._build_hf_headers(token=token, is_write_action=True) + commit_message = commit_message or f"Super-squash branch '{branch}' using huggingface_hub" + + # Super-squash + response = get_session().post(url=url, headers=headers, json={"message": commit_message}) + hf_raise_for_status(response) + + @validate_hf_hub_args + def create_repo( + self, + repo_id: str, + *, + token: Optional[str] = None, + private: bool = False, + repo_type: Optional[str] = None, + exist_ok: bool = False, + space_sdk: Optional[str] = None, + space_hardware: Optional[SpaceHardware] = None, + space_storage: Optional[SpaceStorage] = None, + space_sleep_time: Optional[int] = None, + space_secrets: Optional[List[Dict[str, str]]] = None, + space_variables: Optional[List[Dict[str, str]]] = None, + ) -> RepoUrl: + """Create an empty repo on the HuggingFace Hub. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token) + private (`bool`, *optional*, defaults to `False`): + Whether the model repo should be private. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + exist_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if repo already exists. + space_sdk (`str`, *optional*): + Choice of SDK to use if repo_type is "space". Can be "streamlit", "gradio", "docker", or "static". + space_hardware (`SpaceHardware` or `str`, *optional*): + Choice of Hardware if repo_type is "space". See [`SpaceHardware`] for a complete list. + space_storage (`SpaceStorage` or `str`, *optional*): + Choice of persistent storage tier. Example: `"small"`. See [`SpaceStorage`] for a complete list. + space_sleep_time (`int`, *optional*): + Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want + your Space to sleep (default behavior for upgraded hardware). For free hardware, you can't configure + the sleep time (value is fixed to 48 hours of inactivity). + See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. + space_secrets (`List[Dict[str, str]]`, *optional*): + A list of secret keys to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. + space_variables (`List[Dict[str, str]]`, *optional*): + A list of public environment variables to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables. + + Returns: + [`RepoUrl`]: URL to the newly created repo. Value is a subclass of `str` containing + attributes like `endpoint`, `repo_type` and `repo_id`. + """ + organization, name = repo_id.split("/") if "/" in repo_id else (None, repo_id) + + path = f"{self.endpoint}/api/repos/create" + + if repo_type not in REPO_TYPES: + raise ValueError("Invalid repo type") + + json: Dict[str, Any] = {"name": name, "organization": organization, "private": private} + if repo_type is not None: + json["type"] = repo_type + if repo_type == "space": + if space_sdk is None: + raise ValueError( + "No space_sdk provided. `create_repo` expects space_sdk to be one" + f" of {SPACES_SDK_TYPES} when repo_type is 'space'`" + ) + if space_sdk not in SPACES_SDK_TYPES: + raise ValueError(f"Invalid space_sdk. Please choose one of {SPACES_SDK_TYPES}.") + json["sdk"] = space_sdk + + if space_sdk is not None and repo_type != "space": + warnings.warn("Ignoring provided space_sdk because repo_type is not 'space'.") + + function_args = [ + "space_hardware", + "space_storage", + "space_sleep_time", + "space_secrets", + "space_variables", + ] + json_keys = ["hardware", "storageTier", "sleepTimeSeconds", "secrets", "variables"] + values = [space_hardware, space_storage, space_sleep_time, space_secrets, space_variables] + + if repo_type == "space": + json.update({k: v for k, v in zip(json_keys, values) if v is not None}) + else: + provided_space_args = [key for key, value in zip(function_args, values) if value is not None] + + if provided_space_args: + warnings.warn(f"Ignoring provided {', '.join(provided_space_args)} because repo_type is not 'space'.") + + if getattr(self, "_lfsmultipartthresh", None): + # Testing purposes only. + # See https://github.com/huggingface/huggingface_hub/pull/733/files#r820604472 + json["lfsmultipartthresh"] = self._lfsmultipartthresh # type: ignore + headers = self._build_hf_headers(token=token, is_write_action=True) + + while True: + r = get_session().post(path, headers=headers, json=json) + if r.status_code == 409 and "Cannot create repo: another conflicting operation is in progress" in r.text: + # Since https://github.com/huggingface/moon-landing/pull/7272 (private repo), it is not possible to + # concurrently create repos on the Hub for a same user. This is rarely an issue, except when running + # tests. To avoid any inconvenience, we retry to create the repo for this specific error. + # NOTE: This could have being fixed directly in the tests but adding it here should fixed CIs for all + # dependent libraries. + # NOTE: If a fix is implemented server-side, we should be able to remove this retry mechanism. + logger.debug("Create repo failed due to a concurrency issue. Retrying...") + continue + break + + try: + hf_raise_for_status(r) + except HTTPError as err: + if exist_ok and err.response.status_code == 409: + # Repo already exists and `exist_ok=True` + pass + elif exist_ok and err.response.status_code == 403: + # No write permission on the namespace but repo might already exist + try: + self.repo_info(repo_id=repo_id, repo_type=repo_type, token=token) + if repo_type is None or repo_type == REPO_TYPE_MODEL: + return RepoUrl(f"{self.endpoint}/{repo_id}") + return RepoUrl(f"{self.endpoint}/{repo_type}/{repo_id}") + except HfHubHTTPError: + raise + else: + raise + + d = r.json() + return RepoUrl(d["url"], endpoint=self.endpoint) + + @validate_hf_hub_args + def delete_repo( + self, + repo_id: str, + *, + token: Optional[str] = None, + repo_type: Optional[str] = None, + missing_ok: bool = False, + ) -> None: + """ + Delete a repo from the HuggingFace Hub. CAUTION: this is irreversible. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token) + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. + missing_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if repo does not exist. + + Raises: + - [`~utils.RepositoryNotFoundError`] + If the repository to delete from cannot be found and `missing_ok` is set to False (default). + """ + organization, name = repo_id.split("/") if "/" in repo_id else (None, repo_id) + + path = f"{self.endpoint}/api/repos/delete" + + if repo_type not in REPO_TYPES: + raise ValueError("Invalid repo type") + + json = {"name": name, "organization": organization} + if repo_type is not None: + json["type"] = repo_type + + headers = self._build_hf_headers(token=token, is_write_action=True) + r = get_session().delete(path, headers=headers, json=json) + try: + hf_raise_for_status(r) + except RepositoryNotFoundError: + if not missing_ok: + raise + + @validate_hf_hub_args + def update_repo_visibility( + self, + repo_id: str, + private: bool = False, + *, + token: Optional[str] = None, + organization: Optional[str] = None, + repo_type: Optional[str] = None, + name: Optional[str] = None, + ) -> Dict[str, bool]: + """Update the visibility setting of a repository. + + Args: + repo_id (`str`, *optional*): + A namespace (user or an organization) and a repo name separated + by a `/`. + private (`bool`, *optional*, defaults to `False`): + Whether the model repo should be private. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token) + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + + Returns: + The HTTP response in json. + + + + Raises the following errors: + + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + if repo_type not in REPO_TYPES: + raise ValueError("Invalid repo type") + + organization, name = repo_id.split("/") if "/" in repo_id else (None, repo_id) + + if organization is None: + namespace = self.whoami(token)["name"] + else: + namespace = organization + + if repo_type is None: + repo_type = REPO_TYPE_MODEL # default repo type + + r = get_session().put( + url=f"{self.endpoint}/api/{repo_type}s/{namespace}/{name}/settings", + headers=self._build_hf_headers(token=token, is_write_action=True), + json={"private": private}, + ) + hf_raise_for_status(r) + return r.json() + + def move_repo( + self, + from_id: str, + to_id: str, + *, + repo_type: Optional[str] = None, + token: Optional[str] = None, + ): + """ + Moving a repository from namespace1/repo_name1 to namespace2/repo_name2 + + Note there are certain limitations. For more information about moving + repositories, please see + https://hf.co/docs/hub/repositories-settings#renaming-or-transferring-a-repo. + + Args: + from_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. Original repository identifier. + to_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. Final repository identifier. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token) + + + + Raises the following errors: + + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + if len(from_id.split("/")) != 2: + raise ValueError(f"Invalid repo_id: {from_id}. It should have a namespace (:namespace:/:repo_name:)") + + if len(to_id.split("/")) != 2: + raise ValueError(f"Invalid repo_id: {to_id}. It should have a namespace (:namespace:/:repo_name:)") + + if repo_type is None: + repo_type = REPO_TYPE_MODEL # Hub won't accept `None`. + + json = {"fromRepo": from_id, "toRepo": to_id, "type": repo_type} + + path = f"{self.endpoint}/api/repos/move" + headers = self._build_hf_headers(token=token, is_write_action=True) + r = get_session().post(path, headers=headers, json=json) + try: + hf_raise_for_status(r) + except HfHubHTTPError as e: + e.append_to_message( + "\nFor additional documentation please see" + " https://hf.co/docs/hub/repositories-settings#renaming-or-transferring-a-repo." + ) + raise + + @overload + def create_commit( # type: ignore + self, + repo_id: str, + operations: Iterable[CommitOperation], + *, + commit_message: str, + commit_description: Optional[str] = None, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + num_threads: int = 5, + parent_commit: Optional[str] = None, + run_as_future: Literal[False] = ..., + ) -> CommitInfo: + ... + + @overload + def create_commit( + self, + repo_id: str, + operations: Iterable[CommitOperation], + *, + commit_message: str, + commit_description: Optional[str] = None, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + num_threads: int = 5, + parent_commit: Optional[str] = None, + run_as_future: Literal[True] = ..., + ) -> Future[CommitInfo]: + ... + + @validate_hf_hub_args + @future_compatible + def create_commit( + self, + repo_id: str, + operations: Iterable[CommitOperation], + *, + commit_message: str, + commit_description: Optional[str] = None, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + num_threads: int = 5, + parent_commit: Optional[str] = None, + run_as_future: bool = False, + ) -> Union[CommitInfo, Future[CommitInfo]]: + """ + Creates a commit in the given repo, deleting & uploading files as needed. + + + + The input list of `CommitOperation` will be mutated during the commit process. Do not reuse the same objects + for multiple commits. + + + + + + `create_commit` assumes that the repo already exists on the Hub. If you get a + Client error 404, please make sure you are authenticated and that `repo_id` and + `repo_type` are set correctly. If repo does not exist, create it first using + [`~hf_api.create_repo`]. + + + + + + `create_commit` is limited to 25k LFS files and a 1GB payload for regular files. + + + + Args: + repo_id (`str`): + The repository in which the commit will be created, for example: + `"username/custom_transformers"` + + operations (`Iterable` of [`~hf_api.CommitOperation`]): + An iterable of operations to include in the commit, either: + + - [`~hf_api.CommitOperationAdd`] to upload a file + - [`~hf_api.CommitOperationDelete`] to delete a file + - [`~hf_api.CommitOperationCopy`] to copy a file + + Operation objects will be mutated to include information relative to the upload. Do not reuse the + same objects for multiple commits. + + commit_message (`str`): + The summary (first line) of the commit that will be created. + + commit_description (`str`, *optional*): + The description of the commit that will be created + + token (`str`, *optional*): + Authentication token, obtained with `HfApi.login` method. Will + default to the stored token. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request with that commit. Defaults to `False`. + If `revision` is not set, PR is opened against the `"main"` branch. If + `revision` is set and is a branch, PR is opened against this branch. If + `revision` is set and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. + + num_threads (`int`, *optional*): + Number of concurrent threads for uploading files. Defaults to 5. + Setting it to 2 means at most 2 files will be uploaded concurrently. + + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. + Shorthands (7 first characters) are also supported. If specified and `create_pr` is `False`, + the commit will fail if `revision` does not point to `parent_commit`. If specified and `create_pr` + is `True`, the pull request will be created from `parent_commit`. Specifying `parent_commit` + ensures the repo has not changed before committing the changes, and can be especially useful + if the repo is updated / committed to concurrently. + run_as_future (`bool`, *optional*): + Whether or not to run this method in the background. Background jobs are run sequentially without + blocking the main thread. Passing `run_as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) + object. Defaults to `False`. + + Returns: + [`CommitInfo`] or `Future`: + Instance of [`CommitInfo`] containing information about the newly created commit (commit hash, commit + url, pr url, commit message,...). If `run_as_future=True` is passed, returns a Future object which will + contain the result when executed. + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If commit message is empty. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If parent commit is not a valid commit OID. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If the Hub API returns an HTTP 400 error (bad request) + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `create_pr` is `True` and revision is neither `None` nor `"main"`. + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + """ + if parent_commit is not None and not REGEX_COMMIT_OID.fullmatch(parent_commit): + raise ValueError( + f"`parent_commit` is not a valid commit OID. It must match the following regex: {REGEX_COMMIT_OID}" + ) + + if commit_message is None or len(commit_message) == 0: + raise ValueError("`commit_message` can't be empty, please pass a value.") + + commit_description = commit_description if commit_description is not None else "" + repo_type = repo_type if repo_type is not None else REPO_TYPE_MODEL + if repo_type not in REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + unquoted_revision = revision or DEFAULT_REVISION + revision = quote(unquoted_revision, safe="") + create_pr = create_pr if create_pr is not None else False + + operations = list(operations) + additions = [op for op in operations if isinstance(op, CommitOperationAdd)] + copies = [op for op in operations if isinstance(op, CommitOperationCopy)] + nb_additions = len(additions) + nb_copies = len(copies) + nb_deletions = len(operations) - nb_additions - nb_copies + + for addition in additions: + if addition._is_committed: + raise ValueError( + f"CommitOperationAdd {addition} has already being committed and cannot be reused. Please create a" + " new CommitOperationAdd object if you want to create a new commit." + ) + + logger.debug( + f"About to commit to the hub: {len(additions)} addition(s), {len(copies)} copie(s) and" + f" {nb_deletions} deletion(s)." + ) + + # If updating twice the same file or update then delete a file in a single commit + _warn_on_overwriting_operations(operations) + + self.preupload_lfs_files( + repo_id=repo_id, + additions=additions, + token=token, + repo_type=repo_type, + revision=unquoted_revision, # first-class methods take unquoted revision + create_pr=create_pr, + num_threads=num_threads, + free_memory=False, # do not remove `CommitOperationAdd.path_or_fileobj` on LFS files for "normal" users + ) + files_to_copy = _fetch_lfs_files_to_copy( + copies=copies, + repo_type=repo_type, + repo_id=repo_id, + token=token or self.token, + revision=revision, + endpoint=self.endpoint, + ) + commit_payload = _prepare_commit_payload( + operations=operations, + files_to_copy=files_to_copy, + commit_message=commit_message, + commit_description=commit_description, + parent_commit=parent_commit, + ) + commit_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/commit/{revision}" + + def _payload_as_ndjson() -> Iterable[bytes]: + for item in commit_payload: + yield json.dumps(item).encode() + yield b"\n" + + headers = { + # See https://github.com/huggingface/huggingface_hub/issues/1085#issuecomment-1265208073 + "Content-Type": "application/x-ndjson", + **self._build_hf_headers(token=token, is_write_action=True), + } + data = b"".join(_payload_as_ndjson()) + params = {"create_pr": "1"} if create_pr else None + + try: + commit_resp = get_session().post(url=commit_url, headers=headers, data=data, params=params) + hf_raise_for_status(commit_resp, endpoint_name="commit") + except RepositoryNotFoundError as e: + e.append_to_message(_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE) + raise + except EntryNotFoundError as e: + if nb_deletions > 0 and "A file with this name doesn't exist" in str(e): + e.append_to_message( + "\nMake sure to differentiate file and folder paths in delete" + " operations with a trailing '/' or using `is_folder=True/False`." + ) + raise + + # Mark additions as committed (cannot be reused in another commit) + for addition in additions: + addition._is_committed = True + + commit_data = commit_resp.json() + return CommitInfo( + commit_url=commit_data["commitUrl"], + commit_message=commit_message, + commit_description=commit_description, + oid=commit_data["commitOid"], + pr_url=commit_data["pullRequestUrl"] if create_pr else None, + ) + + @experimental + @validate_hf_hub_args + def create_commits_on_pr( + self, + *, + repo_id: str, + addition_commits: List[List[CommitOperationAdd]], + deletion_commits: List[List[CommitOperationDelete]], + commit_message: str, + commit_description: Optional[str] = None, + token: Optional[str] = None, + repo_type: Optional[str] = None, + merge_pr: bool = True, + num_threads: int = 5, # TODO: use to multithread uploads + verbose: bool = False, + ) -> str: + """Push changes to the Hub in multiple commits. + + Commits are pushed to a draft PR branch. If the upload fails or gets interrupted, it can be resumed. Progress + is tracked in the PR description. At the end of the process, the PR is set as open and the title is updated to + match the initial commit message. If `merge_pr=True` is passed, the PR is merged automatically. + + All deletion commits are pushed first, followed by the addition commits. The order of the commits is not + guaranteed as we might implement parallel commits in the future. Be sure that your are not updating several + times the same file. + + + + `create_commits_on_pr` is experimental. Its API and behavior is subject to change in the future without prior notice. + + + + + + `create_commits_on_pr` assumes that the repo already exists on the Hub. If you get a Client error 404, please + make sure you are authenticated and that `repo_id` and `repo_type` are set correctly. If repo does not exist, + create it first using [`~hf_api.create_repo`]. + + + + Args: + repo_id (`str`): + The repository in which the commits will be pushed. Example: `"username/my-cool-model"`. + + addition_commits (`List` of `List` of [`~hf_api.CommitOperationAdd`]): + A list containing lists of [`~hf_api.CommitOperationAdd`]. Each sublist will result in a commit on the + PR. + + deletion_commits + A list containing lists of [`~hf_api.CommitOperationDelete`]. Each sublist will result in a commit on + the PR. Deletion commits are pushed before addition commits. + + commit_message (`str`): + The summary (first line) of the commit that will be created. Will also be the title of the PR. + + commit_description (`str`, *optional*): + The description of the commit that will be created. The description will be added to the PR. + + token (`str`, *optional*): + Authentication token, obtained with `HfApi.login` method. Will default to the stored token. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to + a model. Default is `None`. + + merge_pr (`bool`): + If set to `True`, the Pull Request is merged at the end of the process. Defaults to `True`. + + num_threads (`int`, *optional*): + Number of concurrent threads for uploading files. Defaults to 5. + + verbose (`bool`): + If set to `True`, process will run on verbose mode i.e. print information about the ongoing tasks. + Defaults to `False`. + + Returns: + `str`: URL to the created PR. + + Example: + ```python + >>> from huggingface_hub import HfApi, plan_multi_commits + >>> addition_commits, deletion_commits = plan_multi_commits( + ... operations=[ + ... CommitOperationAdd(...), + ... CommitOperationAdd(...), + ... CommitOperationDelete(...), + ... CommitOperationDelete(...), + ... CommitOperationAdd(...), + ... ], + ... ) + >>> HfApi().create_commits_on_pr( + ... repo_id="my-cool-model", + ... addition_commits=addition_commits, + ... deletion_commits=deletion_commits, + ... (...) + ... verbose=True, + ... ) + ``` + + Raises: + [`MultiCommitException`]: + If an unexpected issue occur in the process: empty commits, unexpected commits in a PR, unexpected PR + description, etc. + """ + logger = logging.get_logger(__name__ + ".create_commits_on_pr") + if verbose: + logger.setLevel("INFO") + + # 1. Get strategy ID + logger.info( + f"Will create {len(deletion_commits)} deletion commit(s) and {len(addition_commits)} addition commit(s)," + f" totalling {sum(len(ops) for ops in addition_commits+deletion_commits)} atomic operations." + ) + strategy = MultiCommitStrategy( + addition_commits=[MultiCommitStep(operations=operations) for operations in addition_commits], # type: ignore + deletion_commits=[MultiCommitStep(operations=operations) for operations in deletion_commits], # type: ignore + ) + logger.info(f"Multi-commits strategy with ID {strategy.id}.") + + # 2. Get or create a PR with this strategy ID + for discussion in self.get_repo_discussions(repo_id=repo_id, repo_type=repo_type, token=token): + # search for a draft PR with strategy ID + if discussion.is_pull_request and discussion.status == "draft" and strategy.id in discussion.title: + pr = self.get_discussion_details( + repo_id=repo_id, discussion_num=discussion.num, repo_type=repo_type, token=token + ) + logger.info(f"PR already exists: {pr.url}. Will resume process where it stopped.") + break + else: + # did not find a PR matching the strategy ID + pr = multi_commit_create_pull_request( + self, + repo_id=repo_id, + commit_message=commit_message, + commit_description=commit_description, + strategy=strategy, + token=token, + repo_type=repo_type, + ) + logger.info(f"New PR created: {pr.url}") + + # 3. Parse PR description to check consistency with strategy (e.g. same commits are scheduled) + for event in pr.events: + if isinstance(event, DiscussionComment): + pr_comment = event + break + else: + raise MultiCommitException(f"PR #{pr.num} must have at least 1 comment") + + description_commits = multi_commit_parse_pr_description(pr_comment.content) + if len(description_commits) != len(strategy.all_steps): + raise MultiCommitException( + f"Corrupted multi-commit PR #{pr.num}: got {len(description_commits)} steps in" + f" description but {len(strategy.all_steps)} in strategy." + ) + for step_id in strategy.all_steps: + if step_id not in description_commits: + raise MultiCommitException( + f"Corrupted multi-commit PR #{pr.num}: expected step {step_id} but didn't find" + f" it (have {', '.join(description_commits)})." + ) + + # 4. Retrieve commit history (and check consistency) + commits_on_main_branch = { + commit.commit_id + for commit in self.list_repo_commits( + repo_id=repo_id, repo_type=repo_type, token=token, revision=DEFAULT_REVISION + ) + } + pr_commits = [ + commit + for commit in self.list_repo_commits( + repo_id=repo_id, repo_type=repo_type, token=token, revision=pr.git_reference + ) + if commit.commit_id not in commits_on_main_branch + ] + if len(pr_commits) > 0: + logger.info(f"Found {len(pr_commits)} existing commits on the PR.") + + # At this point `pr_commits` is a list of commits pushed to the PR. We expect all of these commits (if any) to have + # a step_id as title. We raise exception if an unexpected commit has been pushed. + if len(pr_commits) > len(strategy.all_steps): + raise MultiCommitException( + f"Corrupted multi-commit PR #{pr.num}: scheduled {len(strategy.all_steps)} steps but" + f" {len(pr_commits)} commits have already been pushed to the PR." + ) + + # Check which steps are already completed + remaining_additions = {step.id: step for step in strategy.addition_commits} + remaining_deletions = {step.id: step for step in strategy.deletion_commits} + for commit in pr_commits: + if commit.title in remaining_additions: + step = remaining_additions.pop(commit.title) + step.completed = True + elif commit.title in remaining_deletions: + step = remaining_deletions.pop(commit.title) + step.completed = True + + if len(remaining_deletions) > 0 and len(remaining_additions) < len(strategy.addition_commits): + raise MultiCommitException( + f"Corrupted multi-commit PR #{pr.num}: some addition commits have already been pushed to the PR but" + " deletion commits are not all completed yet." + ) + nb_remaining = len(remaining_deletions) + len(remaining_additions) + if len(pr_commits) > 0: + logger.info( + f"{nb_remaining} commits remaining ({len(remaining_deletions)} deletion commits and" + f" {len(remaining_additions)} addition commits)" + ) + + # 5. Push remaining commits to the PR + update description + # TODO: multi-thread this + for step in list(remaining_deletions.values()) + list(remaining_additions.values()): + # Push new commit + self.create_commit( + repo_id=repo_id, + repo_type=repo_type, + token=token, + commit_message=step.id, + revision=pr.git_reference, + num_threads=num_threads, + operations=step.operations, + create_pr=False, + ) + step.completed = True + nb_remaining -= 1 + logger.info(f" step {step.id} completed (still {nb_remaining} to go).") + + # Update PR description + self.edit_discussion_comment( + repo_id=repo_id, + repo_type=repo_type, + token=token, + discussion_num=pr.num, + comment_id=pr_comment.id, + new_content=multi_commit_generate_comment( + commit_message=commit_message, commit_description=commit_description, strategy=strategy + ), + ) + logger.info("All commits have been pushed.") + + # 6. Update PR (and merge) + self.rename_discussion( + repo_id=repo_id, + repo_type=repo_type, + token=token, + discussion_num=pr.num, + new_title=commit_message, + ) + self.change_discussion_status( + repo_id=repo_id, + repo_type=repo_type, + token=token, + discussion_num=pr.num, + new_status="open", + comment=MULTI_COMMIT_PR_COMPLETION_COMMENT_TEMPLATE, + ) + logger.info("PR is now open for reviews.") + + if merge_pr: # User don't want a PR => merge it + try: + self.merge_pull_request( + repo_id=repo_id, + repo_type=repo_type, + token=token, + discussion_num=pr.num, + comment=MULTI_COMMIT_PR_CLOSING_COMMENT_TEMPLATE, + ) + logger.info("PR has been automatically merged (`merge_pr=True` was passed).") + except BadRequestError as error: + if error.server_message is not None and "no associated changes" in error.server_message: + # PR cannot be merged as no changes are associated. We close the PR without merging with a comment to + # explain. + self.change_discussion_status( + repo_id=repo_id, + repo_type=repo_type, + token=token, + discussion_num=pr.num, + comment=MULTI_COMMIT_PR_CLOSE_COMMENT_FAILURE_NO_CHANGES_TEMPLATE, + new_status="closed", + ) + logger.warning("Couldn't merge the PR: no associated changes.") + else: + # PR cannot be merged for another reason (conflicting files for example). We comment the PR to explain + # and re-raise the exception. + self.comment_discussion( + repo_id=repo_id, + repo_type=repo_type, + token=token, + discussion_num=pr.num, + comment=MULTI_COMMIT_PR_CLOSE_COMMENT_FAILURE_BAD_REQUEST_TEMPLATE.format( + error_message=error.server_message + ), + ) + raise MultiCommitException( + f"Couldn't merge Pull Request in multi-commit: {error.server_message}" + ) from error + + return pr.url + + def preupload_lfs_files( + self, + repo_id: str, + additions: Iterable[CommitOperationAdd], + *, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + num_threads: int = 5, + free_memory: bool = True, + gitignore_content: Optional[str] = None, + ): + """Pre-upload LFS files to S3 in preparation on a future commit. + + This method is useful if you are generating the files to upload on-the-fly and you don't want to store them + in memory before uploading them all at once. + + + + This is a power-user method. You shouldn't need to call it directly to make a normal commit. + Use [`create_commit`] directly instead. + + + + + + Commit operations will be mutated during the process. In particular, the attached `path_or_fileobj` will be + removed after the upload to save memory (and replaced by an empty `bytes` object). Do not reuse the same + objects except to pass them to [`create_commit`]. If you don't want to remove the attached content from the + commit operation object, pass `free_memory=False`. + + + + Args: + repo_id (`str`): + The repository in which you will commit the files, for example: `"username/custom_transformers"`. + + operations (`Iterable` of [`CommitOperationAdd`]): + The list of files to upload. Warning: the objects in this list will be mutated to include information + relative to the upload. Do not reuse the same objects for multiple commits. + + token (`str`, *optional*): + Authentication token. Will default to the stored token. + + repo_type (`str`, *optional*): + The type of repository to upload to (e.g. `"model"` -default-, `"dataset"` or `"space"`). + + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + + create_pr (`boolean`, *optional*): + Whether or not you plan to create a Pull Request with that commit. Defaults to `False`. + + num_threads (`int`, *optional*): + Number of concurrent threads for uploading files. Defaults to 5. + Setting it to 2 means at most 2 files will be uploaded concurrently. + + gitignore_content (`str`, *optional*): + The content of the `.gitignore` file to know which files should be ignored. The order of priority + is to first check if `gitignore_content` is passed, then check if the `.gitignore` file is present + in the list of files to commit and finally default to the `.gitignore` file already hosted on the Hub + (if any). + + Example: + ```py + >>> from huggingface_hub import CommitOperationAdd, preupload_lfs_files, create_commit, create_repo + + >>> repo_id = create_repo("test_preupload").repo_id + + # Generate and preupload LFS files one by one + >>> operations = [] # List of all `CommitOperationAdd` objects that will be generated + >>> for i in range(5): + ... content = ... # generate binary content + ... addition = CommitOperationAdd(path_in_repo=f"shard_{i}_of_5.bin", path_or_fileobj=content) + ... preupload_lfs_files(repo_id, additions=[addition]) # upload + free memory + ... operations.append(addition) + + # Create commit + >>> create_commit(repo_id, operations=operations, commit_message="Commit all shards") + ``` + """ + repo_type = repo_type if repo_type is not None else REPO_TYPE_MODEL + if repo_type not in REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + revision = quote(revision, safe="") if revision is not None else DEFAULT_REVISION + create_pr = create_pr if create_pr is not None else False + + # Check if a `gitignore` file is being committed to the Hub. + additions = list(additions) + if gitignore_content is None: + for addition in additions: + if addition.path_in_repo == ".gitignore": + with addition.as_file() as f: + gitignore_content = f.read().decode() + break + + # Filter out already uploaded files + new_additions = [addition for addition in additions if not addition._is_uploaded] + + # Check which new files are LFS + try: + _fetch_upload_modes( + additions=new_additions, + repo_type=repo_type, + repo_id=repo_id, + token=token or self.token, + revision=revision, + endpoint=self.endpoint, + create_pr=create_pr or False, + gitignore_content=gitignore_content, + ) + except RepositoryNotFoundError as e: + e.append_to_message(_CREATE_COMMIT_NO_REPO_ERROR_MESSAGE) + raise + + # Filter out regular files + new_lfs_additions = [addition for addition in new_additions if addition._upload_mode == "lfs"] + + # Filter out files listed in .gitignore + new_lfs_additions_to_upload = [] + for addition in new_lfs_additions: + if addition._should_ignore: + logger.debug(f"Skipping upload for LFS file '{addition.path_in_repo}' (ignored by gitignore file).") + else: + new_lfs_additions_to_upload.append(addition) + if len(new_lfs_additions) != len(new_lfs_additions_to_upload): + logger.info( + f"Skipped upload for {len(new_lfs_additions) - len(new_lfs_additions_to_upload)} LFS file(s) " + "(ignored by gitignore file)." + ) + + # Upload new LFS files + _upload_lfs_files( + additions=new_lfs_additions_to_upload, + repo_type=repo_type, + repo_id=repo_id, + token=token or self.token, + endpoint=self.endpoint, + num_threads=num_threads, + # If `create_pr`, we don't want to check user permission on the revision as users with read permission + # should still be able to create PRs even if they don't have write permission on the target branch of the + # PR (i.e. `revision`). + revision=revision if not create_pr else None, + ) + for addition in new_lfs_additions_to_upload: + addition._is_uploaded = True + if free_memory: + addition.path_or_fileobj = b"" + + @overload + def upload_file( # type: ignore + self, + *, + path_or_fileobj: Union[str, Path, bytes, BinaryIO], + path_in_repo: str, + repo_id: str, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + run_as_future: Literal[False] = ..., + ) -> CommitInfo: + ... + + @overload + def upload_file( + self, + *, + path_or_fileobj: Union[str, Path, bytes, BinaryIO], + path_in_repo: str, + repo_id: str, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + run_as_future: Literal[True] = ..., + ) -> Future[CommitInfo]: + ... + + @validate_hf_hub_args + @future_compatible + def upload_file( + self, + *, + path_or_fileobj: Union[str, Path, bytes, BinaryIO], + path_in_repo: str, + repo_id: str, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + run_as_future: bool = False, + ) -> Union[CommitInfo, Future[CommitInfo]]: + """ + Upload a local file (up to 50 GB) to the given repo. The upload is done + through a HTTP post request, and doesn't require git or git-lfs to be + installed. + + Args: + path_or_fileobj (`str`, `Path`, `bytes`, or `IO`): + Path to a file on the local machine or binary data stream / + fileobj / buffer. + path_in_repo (`str`): + Relative filepath in the repo, for example: + `"checkpoints/1fec34a/weights.bin"` + repo_id (`str`): + The repository to which the file will be uploaded, for example: + `"username/custom_transformers"` + token (`str`, *optional*): + Authentication token, obtained with `HfApi.login` method. Will + default to the stored token. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit + commit_description (`str` *optional*) + The description of the generated commit + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request with that commit. Defaults to `False`. + If `revision` is not set, PR is opened against the `"main"` branch. If + `revision` is set and is a branch, PR is opened against this branch. If + `revision` is set and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + run_as_future (`bool`, *optional*): + Whether or not to run this method in the background. Background jobs are run sequentially without + blocking the main thread. Passing `run_as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) + object. Defaults to `False`. + + + Returns: + [`CommitInfo`] or `Future`: + Instance of [`CommitInfo`] containing information about the newly created commit (commit hash, commit + url, pr url, commit message,...). If `run_as_future=True` is passed, returns a Future object which will + contain the result when executed. + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + - [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + + + + + + `upload_file` assumes that the repo already exists on the Hub. If you get a + Client error 404, please make sure you are authenticated and that `repo_id` and + `repo_type` are set correctly. If repo does not exist, create it first using + [`~hf_api.create_repo`]. + + + + Example: + + ```python + >>> from huggingface_hub import upload_file + + >>> with open("./local/filepath", "rb") as fobj: + ... upload_file( + ... path_or_fileobj=fileobj, + ... path_in_repo="remote/file/path.h5", + ... repo_id="username/my-dataset", + ... repo_type="dataset", + ... token="my_token", + ... ) + "https://huggingface.co/datasets/username/my-dataset/blob/main/remote/file/path.h5" + + >>> upload_file( + ... path_or_fileobj=".\\\\local\\\\file\\\\path", + ... path_in_repo="remote/file/path.h5", + ... repo_id="username/my-model", + ... token="my_token", + ... ) + "https://huggingface.co/username/my-model/blob/main/remote/file/path.h5" + + >>> upload_file( + ... path_or_fileobj=".\\\\local\\\\file\\\\path", + ... path_in_repo="remote/file/path.h5", + ... repo_id="username/my-model", + ... token="my_token", + ... create_pr=True, + ... ) + "https://huggingface.co/username/my-model/blob/refs%2Fpr%2F1/remote/file/path.h5" + ``` + """ + if repo_type not in REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + + commit_message = ( + commit_message if commit_message is not None else f"Upload {path_in_repo} with huggingface_hub" + ) + operation = CommitOperationAdd( + path_or_fileobj=path_or_fileobj, + path_in_repo=path_in_repo, + ) + + commit_info = self.create_commit( + repo_id=repo_id, + repo_type=repo_type, + operations=[operation], + commit_message=commit_message, + commit_description=commit_description, + token=token, + revision=revision, + create_pr=create_pr, + parent_commit=parent_commit, + ) + + if commit_info.pr_url is not None: + revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="") + if repo_type in REPO_TYPES_URL_PREFIXES: + repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id + revision = revision if revision is not None else DEFAULT_REVISION + + return CommitInfo( + commit_url=commit_info.commit_url, + commit_message=commit_info.commit_message, + commit_description=commit_info.commit_description, + oid=commit_info.oid, + pr_url=commit_info.pr_url, + # Similar to `hf_hub_url` but it's "blob" instead of "resolve" + # TODO: remove this in v1.0 + _url=f"{self.endpoint}/{repo_id}/blob/{revision}/{path_in_repo}", + ) + + @overload + def upload_folder( # type: ignore + self, + *, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + delete_patterns: Optional[Union[List[str], str]] = None, + multi_commits: Literal[False] = ..., + multi_commits_verbose: bool = False, + run_as_future: Literal[False] = ..., + ) -> CommitInfo: + ... + + @overload + def upload_folder( # type: ignore + self, + *, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + delete_patterns: Optional[Union[List[str], str]] = None, + multi_commits: Literal[True] = ..., + multi_commits_verbose: bool = False, + run_as_future: Literal[False] = ..., + ) -> str: # Only the PR url in multi-commits mode + ... + + @overload + def upload_folder( # type: ignore + self, + *, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + delete_patterns: Optional[Union[List[str], str]] = None, + multi_commits: Literal[False] = ..., + multi_commits_verbose: bool = False, + run_as_future: Literal[True] = ..., + ) -> Future[CommitInfo]: + ... + + @overload + def upload_folder( + self, + *, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + delete_patterns: Optional[Union[List[str], str]] = None, + multi_commits: Literal[True] = ..., + multi_commits_verbose: bool = False, + run_as_future: Literal[True] = ..., + ) -> Future[str]: # Only the PR url in multi-commits mode + ... + + @validate_hf_hub_args + @future_compatible + def upload_folder( + self, + *, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + delete_patterns: Optional[Union[List[str], str]] = None, + multi_commits: bool = False, + multi_commits_verbose: bool = False, + run_as_future: bool = False, + ) -> Union[CommitInfo, str, Future[CommitInfo], Future[str]]: + """ + Upload a local folder to the given repo. The upload is done through a HTTP requests, and doesn't require git or + git-lfs to be installed. + + The structure of the folder will be preserved. Files with the same name already present in the repository will + be overwritten. Others will be left untouched. + + Use the `allow_patterns` and `ignore_patterns` arguments to specify which files to upload. These parameters + accept either a single pattern or a list of patterns. Patterns are Standard Wildcards (globbing patterns) as + documented [here](https://tldp.org/LDP/GNU-Linux-Tools-Summary/html/x11655.htm). If both `allow_patterns` and + `ignore_patterns` are provided, both constraints apply. By default, all files from the folder are uploaded. + + Use the `delete_patterns` argument to specify remote files you want to delete. Input type is the same as for + `allow_patterns` (see above). If `path_in_repo` is also provided, the patterns are matched against paths + relative to this folder. For example, `upload_folder(..., path_in_repo="experiment", delete_patterns="logs/*")` + will delete any remote file under `./experiment/logs/`. Note that the `.gitattributes` file will not be deleted + even if it matches the patterns. + + Any `.git/` folder present in any subdirectory will be ignored. However, please be aware that the `.gitignore` + file is not taken into account. + + Uses `HfApi.create_commit` under the hood. + + Args: + repo_id (`str`): + The repository to which the file will be uploaded, for example: + `"username/custom_transformers"` + folder_path (`str` or `Path`): + Path to the folder to upload on the local file system + path_in_repo (`str`, *optional*): + Relative path of the directory in the repo, for example: + `"checkpoints/1fec34a/results"`. Will default to the root folder of the repository. + token (`str`, *optional*): + Authentication token, obtained with `HfApi.login` method. Will + default to the stored token. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit. Defaults to: + `f"Upload {path_in_repo} with huggingface_hub"` + commit_description (`str` *optional*): + The description of the generated commit + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request with that commit. Defaults to `False`. If `revision` is not + set, PR is opened against the `"main"` branch. If `revision` is set and is a branch, PR is opened + against this branch. If `revision` is set and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. If both `multi_commits` and `create_pr` are True, + the PR created in the multi-commit process is kept opened. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are uploaded. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not uploaded. + delete_patterns (`List[str]` or `str`, *optional*): + If provided, remote files matching any of the patterns will be deleted from the repo while committing + new files. This is useful if you don't know which files have already been uploaded. + Note: to avoid discrepancies the `.gitattributes` file is not deleted even if it matches the pattern. + multi_commits (`bool`): + If True, changes are pushed to a PR using a multi-commit process. Defaults to `False`. + multi_commits_verbose (`bool`): + If True and `multi_commits` is used, more information will be displayed to the user. + run_as_future (`bool`, *optional*): + Whether or not to run this method in the background. Background jobs are run sequentially without + blocking the main thread. Passing `run_as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) + object. Defaults to `False`. + + Returns: + [`CommitInfo`] or `Future`: + Instance of [`CommitInfo`] containing information about the newly created commit (commit hash, commit + url, pr url, commit message,...). If `run_as_future=True` is passed, returns a Future object which will + contain the result when executed. + [`str`] or `Future`: + If `multi_commits=True`, returns the url of the PR created to push the changes. If `run_as_future=True` + is passed, returns a Future object which will contain the result when executed. + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + + + + + + `upload_folder` assumes that the repo already exists on the Hub. If you get a Client error 404, please make + sure you are authenticated and that `repo_id` and `repo_type` are set correctly. If repo does not exist, create + it first using [`~hf_api.create_repo`]. + + + + + + `multi_commits` is experimental. Its API and behavior is subject to change in the future without prior notice. + + + + Example: + + ```python + # Upload checkpoints folder except the log files + >>> upload_folder( + ... folder_path="local/checkpoints", + ... path_in_repo="remote/experiment/checkpoints", + ... repo_id="username/my-dataset", + ... repo_type="datasets", + ... token="my_token", + ... ignore_patterns="**/logs/*.txt", + ... ) + # "https://huggingface.co/datasets/username/my-dataset/tree/main/remote/experiment/checkpoints" + + # Upload checkpoints folder including logs while deleting existing logs from the repo + # Useful if you don't know exactly which log files have already being pushed + >>> upload_folder( + ... folder_path="local/checkpoints", + ... path_in_repo="remote/experiment/checkpoints", + ... repo_id="username/my-dataset", + ... repo_type="datasets", + ... token="my_token", + ... delete_patterns="**/logs/*.txt", + ... ) + "https://huggingface.co/datasets/username/my-dataset/tree/main/remote/experiment/checkpoints" + + # Upload checkpoints folder while creating a PR + >>> upload_folder( + ... folder_path="local/checkpoints", + ... path_in_repo="remote/experiment/checkpoints", + ... repo_id="username/my-dataset", + ... repo_type="datasets", + ... token="my_token", + ... create_pr=True, + ... ) + "https://huggingface.co/datasets/username/my-dataset/tree/refs%2Fpr%2F1/remote/experiment/checkpoints" + + ``` + """ + if repo_type not in REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + + if multi_commits: + if revision is not None and revision != DEFAULT_REVISION: + raise ValueError("Cannot use `multi_commit` to commit changes other than the main branch.") + + # By default, upload folder to the root directory in repo. + if path_in_repo is None: + path_in_repo = "" + + # Do not upload .git folder + if ignore_patterns is None: + ignore_patterns = [] + elif isinstance(ignore_patterns, str): + ignore_patterns = [ignore_patterns] + ignore_patterns += IGNORE_GIT_FOLDER_PATTERNS + + delete_operations = self._prepare_upload_folder_deletions( + repo_id=repo_id, + repo_type=repo_type, + revision=DEFAULT_REVISION if create_pr else revision, + token=token, + path_in_repo=path_in_repo, + delete_patterns=delete_patterns, + ) + add_operations = _prepare_upload_folder_additions( + folder_path, + path_in_repo, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + + # Optimize operations: if some files will be overwritten, we don't need to delete them first + if len(add_operations) > 0: + added_paths = set(op.path_in_repo for op in add_operations) + delete_operations = [ + delete_op for delete_op in delete_operations if delete_op.path_in_repo not in added_paths + ] + commit_operations = delete_operations + add_operations + + commit_message = commit_message or "Upload folder using huggingface_hub" + if multi_commits: + addition_commits, deletion_commits = plan_multi_commits(operations=commit_operations) + pr_url = self.create_commits_on_pr( + repo_id=repo_id, + repo_type=repo_type, + addition_commits=addition_commits, + deletion_commits=deletion_commits, + commit_message=commit_message, + commit_description=commit_description, + token=token, + merge_pr=not create_pr, + verbose=multi_commits_verbose, + ) + # Defining a CommitInfo object is not really relevant in this case + # Let's return early with pr_url only (as string). + return pr_url + + commit_info = self.create_commit( + repo_type=repo_type, + repo_id=repo_id, + operations=commit_operations, + commit_message=commit_message, + commit_description=commit_description, + token=token, + revision=revision, + create_pr=create_pr, + parent_commit=parent_commit, + ) + + # Create url to uploaded folder (for legacy return value) + if create_pr and commit_info.pr_url is not None: + revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="") + if repo_type in REPO_TYPES_URL_PREFIXES: + repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id + revision = revision if revision is not None else DEFAULT_REVISION + + return CommitInfo( + commit_url=commit_info.commit_url, + commit_message=commit_info.commit_message, + commit_description=commit_info.commit_description, + oid=commit_info.oid, + pr_url=commit_info.pr_url, + # Similar to `hf_hub_url` but it's "tree" instead of "resolve" + # TODO: remove this in v1.0 + _url=f"{self.endpoint}/{repo_id}/tree/{revision}/{path_in_repo}", + ) + + @validate_hf_hub_args + def delete_file( + self, + path_in_repo: str, + repo_id: str, + *, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + ) -> CommitInfo: + """ + Deletes a file in the given repo. + + Args: + path_in_repo (`str`): + Relative filepath in the repo, for example: + `"checkpoints/1fec34a/weights.bin"` + repo_id (`str`): + The repository from which the file will be deleted, for example: + `"username/custom_transformers"` + token (`str`, *optional*): + Authentication token, obtained with `HfApi.login` method. Will + default to the stored token. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if the file is in a dataset or + space, `None` or `"model"` if in a model. Default is `None`. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit. Defaults to + `f"Delete {path_in_repo} with huggingface_hub"`. + commit_description (`str` *optional*) + The description of the generated commit + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request with that commit. Defaults to `False`. + If `revision` is not set, PR is opened against the `"main"` branch. If + `revision` is set and is a branch, PR is opened against this branch. If + `revision` is set and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + - [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + - [`~utils.EntryNotFoundError`] + If the file to download cannot be found. + + + + """ + commit_message = ( + commit_message if commit_message is not None else f"Delete {path_in_repo} with huggingface_hub" + ) + + operations = [CommitOperationDelete(path_in_repo=path_in_repo)] + + return self.create_commit( + repo_id=repo_id, + repo_type=repo_type, + token=token, + operations=operations, + revision=revision, + commit_message=commit_message, + commit_description=commit_description, + create_pr=create_pr, + parent_commit=parent_commit, + ) + + @validate_hf_hub_args + def delete_folder( + self, + path_in_repo: str, + repo_id: str, + *, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + ) -> CommitInfo: + """ + Deletes a folder in the given repo. + + Simple wrapper around [`create_commit`] method. + + Args: + path_in_repo (`str`): + Relative folder path in the repo, for example: `"checkpoints/1fec34a"`. + repo_id (`str`): + The repository from which the folder will be deleted, for example: + `"username/custom_transformers"` + token (`str`, *optional*): + Authentication token, obtained with `HfApi.login` method. Will default + to the stored token. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if the folder is in a dataset or + space, `None` or `"model"` if in a model. Default is `None`. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit. Defaults to + `f"Delete folder {path_in_repo} with huggingface_hub"`. + commit_description (`str` *optional*) + The description of the generated commit. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request with that commit. Defaults to `False`. + If `revision` is not set, PR is opened against the `"main"` branch. If + `revision` is set and is a branch, PR is opened against this branch. If + `revision` is set and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + """ + return self.create_commit( + repo_id=repo_id, + repo_type=repo_type, + token=token, + operations=[CommitOperationDelete(path_in_repo=path_in_repo, is_folder=True)], + revision=revision, + commit_message=( + commit_message if commit_message is not None else f"Delete folder {path_in_repo} with huggingface_hub" + ), + commit_description=commit_description, + create_pr=create_pr, + parent_commit=parent_commit, + ) + + @validate_hf_hub_args + def get_hf_file_metadata( + self, + *, + url: str, + token: Union[bool, str, None] = None, + proxies: Optional[Dict] = None, + timeout: Optional[float] = DEFAULT_REQUEST_TIMEOUT, + ) -> HfFileMetadata: + """Fetch metadata of a file versioned on the Hub for a given url. + + Args: + url (`str`): + File url, for example returned by [`hf_hub_url`]. + token (`str` or `bool`, *optional*): + A token to be used for the download. + - If `True`, the token is read from the HuggingFace config + folder. + - If `False` or `None`, no token is provided. + - If a string, it's used as the authentication token. + proxies (`dict`, *optional*): + Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. + timeout (`float`, *optional*, defaults to 10): + How many seconds to wait for the server to send metadata before giving up. + + Returns: + A [`HfFileMetadata`] object containing metadata such as location, etag, size and commit_hash. + """ + if token is None: + # Cannot do `token = token or self.token` as token can be `False`. + token = self.token + + return get_hf_file_metadata( + url=url, + token=token, + proxies=proxies, + timeout=timeout, + library_name=self.library_name, + library_version=self.library_version, + user_agent=self.user_agent, + ) + + @validate_hf_hub_args + def hf_hub_download( + self, + repo_id: str, + filename: str, + *, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", + force_download: bool = False, + force_filename: Optional[str] = None, + proxies: Optional[Dict] = None, + etag_timeout: float = DEFAULT_ETAG_TIMEOUT, + resume_download: bool = False, + token: Optional[Union[str, bool]] = None, + local_files_only: bool = False, + legacy_cache_layout: bool = False, + ) -> str: + """Download a given file if it's not already present in the local cache. + + The new cache file layout looks like this: + - The cache directory contains one subfolder per repo_id (namespaced by repo type) + - inside each repo folder: + - refs is a list of the latest known revision => commit_hash pairs + - blobs contains the actual file blobs (identified by their git-sha or sha256, depending on + whether they're LFS files or not) + - snapshots contains one subfolder per commit, each "commit" contains the subset of the files + that have been resolved at that particular commit. Each filename is a symlink to the blob + at that particular commit. + + If `local_dir` is provided, the file structure from the repo will be replicated in this location. You can configure + how you want to move those files: + - If `local_dir_use_symlinks="auto"` (default), files are downloaded and stored in the cache directory as blob + files. Small files (<5MB) are duplicated in `local_dir` while a symlink is created for bigger files. The goal + is to be able to manually edit and save small files without corrupting the cache while saving disk space for + binary files. The 5MB threshold can be configured with the `HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD` + environment variable. + - If `local_dir_use_symlinks=True`, files are downloaded, stored in the cache directory and symlinked in `local_dir`. + This is optimal in term of disk usage but files must not be manually edited. + - If `local_dir_use_symlinks=False` and the blob files exist in the cache directory, they are duplicated in the + local dir. This means disk usage is not optimized. + - Finally, if `local_dir_use_symlinks=False` and the blob files do not exist in the cache directory, then the + files are downloaded and directly placed under `local_dir`. This means if you need to download them again later, + they will be re-downloaded entirely. + + ``` + [ 96] . + └── [ 160] models--julien-c--EsperBERTo-small + ├── [ 160] blobs + │ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + │ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e + │ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812 + ├── [ 96] refs + │ └── [ 40] main + └── [ 128] snapshots + ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f + │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 + │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + └── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48 + ├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e + └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + ``` + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + filename (`str`): + The name of the file in the repo. + subfolder (`str`, *optional*): + An optional value corresponding to a folder inside the model repo. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if downloading from a dataset or space, + `None` or `"model"` if downloading from a model. Default is `None`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_dir (`str` or `Path`, *optional*): + If provided, the downloaded file will be placed under this directory, either as a symlink (default) or + a regular file (see description for more details). + local_dir_use_symlinks (`"auto"` or `bool`, defaults to `"auto"`): + To be used with `local_dir`. If set to "auto", the cache directory will be used and the file will be either + duplicated or symlinked to the local directory depending on its size. It set to `True`, a symlink will be + created, no matter the file size. If set to `False`, the file will either be duplicated from cache (if + already exists) or downloaded from the Hub and not cached. See description for more details. + force_download (`bool`, *optional*, defaults to `False`): + Whether the file should be downloaded even if it already exists in + the local cache. + proxies (`dict`, *optional*): + Dictionary mapping protocol to the URL of the proxy passed to + `requests.request`. + etag_timeout (`float`, *optional*, defaults to `10`): + When fetching ETag, how many seconds to wait for the server to send + data before giving up which is passed to `requests.request`. + resume_download (`bool`, *optional*, defaults to `False`): + If `True`, resume a previously interrupted download. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + If `None` or `True` and machine is logged in (through `huggingface-cli login` + or [`~huggingface_hub.login`]), token will be retrieved from the cache. + If `False`, token is not sent in the request header. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + legacy_cache_layout (`bool`, *optional*, defaults to `False`): + If `True`, uses the legacy file cache layout i.e. just call [`hf_hub_url`] + then `cached_download`. This is deprecated as the new cache layout is + more powerful. + + Returns: + Local path (string) of file or if networking is off, last version of + file cached on disk. + + + + Raises the following errors: + + - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + if `token=True` and the token cannot be found. + - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) + if ETag cannot be determined. + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + - [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + - [`~utils.EntryNotFoundError`] + If the file to download cannot be found. + - [`~utils.LocalEntryNotFoundError`] + If network is disabled or unavailable and file is not found in cache. + + + """ + from .file_download import hf_hub_download + + if token is None: + # Cannot do `token = token or self.token` as token can be `False`. + token = self.token + + return hf_hub_download( + repo_id=repo_id, + filename=filename, + subfolder=subfolder, + repo_type=repo_type, + revision=revision, + endpoint=self.endpoint, + library_name=self.library_name, + library_version=self.library_version, + cache_dir=cache_dir, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + user_agent=self.user_agent, + force_download=force_download, + force_filename=force_filename, + proxies=proxies, + etag_timeout=etag_timeout, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + legacy_cache_layout=legacy_cache_layout, + ) + + @validate_hf_hub_args + def snapshot_download( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", + proxies: Optional[Dict] = None, + etag_timeout: float = DEFAULT_ETAG_TIMEOUT, + resume_download: bool = False, + force_download: bool = False, + token: Optional[Union[str, bool]] = None, + local_files_only: bool = False, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + max_workers: int = 8, + tqdm_class: Optional[base_tqdm] = None, + ) -> str: + """Download repo files. + + Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from + a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order + to keep their actual filename relative to that folder. You can also filter which files to download using + `allow_patterns` and `ignore_patterns`. + + If `local_dir` is provided, the file structure from the repo will be replicated in this location. You can configure + how you want to move those files: + - If `local_dir_use_symlinks="auto"` (default), files are downloaded and stored in the cache directory as blob + files. Small files (<5MB) are duplicated in `local_dir` while a symlink is created for bigger files. The goal + is to be able to manually edit and save small files without corrupting the cache while saving disk space for + binary files. The 5MB threshold can be configured with the `HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD` + environment variable. + - If `local_dir_use_symlinks=True`, files are downloaded, stored in the cache directory and symlinked in `local_dir`. + This is optimal in term of disk usage but files must not be manually edited. + - If `local_dir_use_symlinks=False` and the blob files exist in the cache directory, they are duplicated in the + local dir. This means disk usage is not optimized. + - Finally, if `local_dir_use_symlinks=False` and the blob files do not exist in the cache directory, then the + files are downloaded and directly placed under `local_dir`. This means if you need to download them again later, + they will be re-downloaded entirely. + + An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly + configured. It is also not possible to filter which files to download when cloning a repository using git. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if downloading from a dataset or space, + `None` or `"model"` if downloading from a model. Default is `None`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_dir (`str` or `Path`, *optional*): + If provided, the downloaded files will be placed under this directory, either as symlinks (default) or + regular files (see description for more details). + local_dir_use_symlinks (`"auto"` or `bool`, defaults to `"auto"`): + To be used with `local_dir`. If set to "auto", the cache directory will be used and the file will be either + duplicated or symlinked to the local directory depending on its size. It set to `True`, a symlink will be + created, no matter the file size. If set to `False`, the file will either be duplicated from cache (if + already exists) or downloaded from the Hub and not cached. See description for more details. + proxies (`dict`, *optional*): + Dictionary mapping protocol to the URL of the proxy passed to + `requests.request`. + etag_timeout (`float`, *optional*, defaults to `10`): + When fetching ETag, how many seconds to wait for the server to send + data before giving up which is passed to `requests.request`. + resume_download (`bool`, *optional*, defaults to `False): + If `True`, resume a previously interrupted download. + force_download (`bool`, *optional*, defaults to `False`): + Whether the file should be downloaded even if it already exists in the local cache. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + If `None` or `True` and machine is logged in (through `huggingface-cli login` + or [`~huggingface_hub.login`]), token will be retrieved from the cache. + If `False`, token is not sent in the request header. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are downloaded. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not downloaded. + max_workers (`int`, *optional*): + Number of concurrent threads to download files (1 thread = 1 file download). + Defaults to 8. + tqdm_class (`tqdm`, *optional*): + If provided, overwrites the default behavior for the progress bar. Passed + argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior. + Note that the `tqdm_class` is not passed to each individual download. + Defaults to the custom HF progress bar that can be disabled by setting + `HF_HUB_DISABLE_PROGRESS_BARS` environment variable. + + Returns: + Local folder path (string) of repo snapshot + + + + Raises the following errors: + + - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + if `token=True` and the token cannot be found. + - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if + ETag cannot be determined. + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + + + """ + from ._snapshot_download import snapshot_download + + if token is None: + # Cannot do `token = token or self.token` as token can be `False`. + token = self.token + + return snapshot_download( + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + endpoint=self.endpoint, + cache_dir=cache_dir, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + library_name=self.library_name, + library_version=self.library_version, + user_agent=self.user_agent, + proxies=proxies, + etag_timeout=etag_timeout, + resume_download=resume_download, + force_download=force_download, + token=token, + local_files_only=local_files_only, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + max_workers=max_workers, + tqdm_class=tqdm_class, + ) + + def get_safetensors_metadata( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + token: Optional[str] = None, + ) -> SafetensorsRepoMetadata: + """ + Parse metadata for a safetensors repo on the Hub. + + We first check if the repo has a single safetensors file or a sharded safetensors repo. If it's a single + safetensors file, we parse the metadata from this file. If it's a sharded safetensors repo, we parse the + metadata from the index file and then parse the metadata from each shard. + + To parse metadata from a single safetensors file, use [`get_safetensors_metadata`]. + + For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + filename (`str`): + The name of the file in the repo. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if the file is in a dataset or space, `None` or `"model"` if in a + model. Default is `None`. + revision (`str`, *optional*): + The git revision to fetch the file from. Can be a branch name, a tag, or a commit hash. Defaults to the + head of the `"main"` branch. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). If `None` or `True` and + machine is logged in (through `huggingface-cli login` or [`~huggingface_hub.login`]), token will be + retrieved from the cache. If `False`, token is not sent in the request header. + + Returns: + [`SafetensorsRepoMetadata`]: information related to safetensors repo. + + Raises: + - [`NotASafetensorsRepoError`]: if the repo is not a safetensors repo i.e. doesn't have either a + `model.safetensors` or a `model.safetensors.index.json` file. + - [`SafetensorsParsingError`]: if a safetensors file header couldn't be parsed correctly. + + Example: + ```py + # Parse repo with single weights file + >>> metadata = get_safetensors_metadata("bigscience/bloomz-560m") + >>> metadata + SafetensorsRepoMetadata( + metadata=None, + sharded=False, + weight_map={'h.0.input_layernorm.bias': 'model.safetensors', ...}, + files_metadata={'model.safetensors': SafetensorsFileMetadata(...)} + ) + >>> metadata.files_metadata["model.safetensors"].metadata + {'format': 'pt'} + + # Parse repo with sharded model + >>> metadata = get_safetensors_metadata("bigscience/bloom") + Parse safetensors files: 100%|██████████████████████████████████████████| 72/72 [00:12<00:00, 5.78it/s] + >>> metadata + SafetensorsRepoMetadata(metadata={'total_size': 352494542848}, sharded=True, weight_map={...}, files_metadata={...}) + >>> len(metadata.files_metadata) + 72 # All safetensors files have been fetched + + # Parse repo with sharded model + >>> get_safetensors_metadata("runwayml/stable-diffusion-v1-5") + NotASafetensorsRepoError: 'runwayml/stable-diffusion-v1-5' is not a safetensors repo. Couldn't find 'model.safetensors.index.json' or 'model.safetensors' files. + ``` + """ + if self.file_exists( # Single safetensors file => non-sharded model + repo_id=repo_id, filename=SAFETENSORS_SINGLE_FILE, repo_type=repo_type, revision=revision, token=token + ): + file_metadata = self.parse_safetensors_file_metadata( + repo_id=repo_id, filename=SAFETENSORS_SINGLE_FILE, repo_type=repo_type, revision=revision, token=token + ) + return SafetensorsRepoMetadata( + metadata=None, + sharded=False, + weight_map={tensor_name: SAFETENSORS_SINGLE_FILE for tensor_name in file_metadata.tensors.keys()}, + files_metadata={SAFETENSORS_SINGLE_FILE: file_metadata}, + ) + elif self.file_exists( # Multiple safetensors files => sharded with index + repo_id=repo_id, filename=SAFETENSORS_INDEX_FILE, repo_type=repo_type, revision=revision, token=token + ): + # Fetch index + index_file = self.hf_hub_download( + repo_id=repo_id, filename=SAFETENSORS_INDEX_FILE, repo_type=repo_type, revision=revision, token=token + ) + with open(index_file) as f: + index = json.load(f) + + weight_map = index.get("weight_map", {}) + + # Fetch metadata per shard + files_metadata = {} + + def _parse(filename: str) -> None: + files_metadata[filename] = self.parse_safetensors_file_metadata( + repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision, token=token + ) + + thread_map( + _parse, + set(weight_map.values()), + desc="Parse safetensors files", + tqdm_class=hf_tqdm, + ) + + return SafetensorsRepoMetadata( + metadata=index.get("metadata", None), + sharded=True, + weight_map=weight_map, + files_metadata=files_metadata, + ) + else: + # Not a safetensors repo + raise NotASafetensorsRepoError( + f"'{repo_id}' is not a safetensors repo. Couldn't find '{SAFETENSORS_INDEX_FILE}' or '{SAFETENSORS_SINGLE_FILE}' files." + ) + + def parse_safetensors_file_metadata( + self, + repo_id: str, + filename: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + token: Optional[str] = None, + ) -> SafetensorsFileMetadata: + """ + Parse metadata from a safetensors file on the Hub. + + To parse metadata from all safetensors files in a repo at once, use [`get_safetensors_metadata`]. + + For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + filename (`str`): + The name of the file in the repo. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if the file is in a dataset or space, `None` or `"model"` if in a + model. Default is `None`. + revision (`str`, *optional*): + The git revision to fetch the file from. Can be a branch name, a tag, or a commit hash. Defaults to the + head of the `"main"` branch. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). If `None` or `True` and + machine is logged in (through `huggingface-cli login` or [`~huggingface_hub.login`]), token will be + retrieved from the cache. If `False`, token is not sent in the request header. + + Returns: + [`SafetensorsFileMetadata`]: information related to a safetensors file. + + Raises: + - [`NotASafetensorsRepoError`]: if the repo is not a safetensors repo i.e. doesn't have either a + `model.safetensors` or a `model.safetensors.index.json` file. + - [`SafetensorsParsingError`]: if a safetensors file header couldn't be parsed correctly. + """ + url = hf_hub_url( + repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision, endpoint=self.endpoint + ) + _headers = self._build_hf_headers(token=token) + + # 1. Fetch first 100kb + # Empirically, 97% of safetensors files have a metadata size < 100kb (over the top 1000 models on the Hub). + # We assume fetching 100kb is faster than making 2 GET requests. Therefore we always fetch the first 100kb to + # avoid the 2nd GET in most cases. + # See https://github.com/huggingface/huggingface_hub/pull/1855#discussion_r1404286419. + response = get_session().get(url, headers={**_headers, "range": "bytes=0-100000"}) + hf_raise_for_status(response) + + # 2. Parse metadata size + metadata_size = struct.unpack(" SAFETENSORS_MAX_HEADER_LENGTH: + raise SafetensorsParsingError( + f"Failed to parse safetensors header for '{filename}' (repo '{repo_id}', revision " + f"'{revision or DEFAULT_REVISION}'): safetensors header is too big. Maximum supported size is " + f"{SAFETENSORS_MAX_HEADER_LENGTH} bytes (got {metadata_size})." + ) + + # 3.a. Get metadata from payload + if metadata_size <= 100000: + metadata_as_bytes = response.content[8 : 8 + metadata_size] + else: # 3.b. Request full metadata + response = get_session().get(url, headers={**_headers, "range": f"bytes=8-{metadata_size+7}"}) + hf_raise_for_status(response) + metadata_as_bytes = response.content + + # 4. Parse json header + try: + metadata_as_dict = json.loads(metadata_as_bytes.decode(errors="ignore")) + except json.JSONDecodeError as e: + raise SafetensorsParsingError( + f"Failed to parse safetensors header for '{filename}' (repo '{repo_id}', revision " + f"'{revision or DEFAULT_REVISION}'): header is not json-encoded string. Please make sure this is a " + "correctly formatted safetensors file." + ) from e + + try: + return SafetensorsFileMetadata( + metadata=metadata_as_dict.get("__metadata__", {}), + tensors={ + key: TensorInfo( + dtype=tensor["dtype"], + shape=tensor["shape"], + data_offsets=tuple(tensor["data_offsets"]), # type: ignore + ) + for key, tensor in metadata_as_dict.items() + if key != "__metadata__" + }, + ) + except (KeyError, IndexError) as e: + raise SafetensorsParsingError( + f"Failed to parse safetensors header for '{filename}' (repo '{repo_id}', revision " + f"'{revision or DEFAULT_REVISION}'): header format not recognized. Please make sure this is a correctly" + " formatted safetensors file." + ) from e + + @validate_hf_hub_args + def create_branch( + self, + repo_id: str, + *, + branch: str, + revision: Optional[str] = None, + token: Optional[str] = None, + repo_type: Optional[str] = None, + exist_ok: bool = False, + ) -> None: + """ + Create a new branch for a repo on the Hub, starting from the specified revision (defaults to `main`). + To find a revision suiting your needs, you can use [`list_repo_refs`] or [`list_repo_commits`]. + + Args: + repo_id (`str`): + The repository in which the branch will be created. + Example: `"user/my-cool-model"`. + + branch (`str`): + The name of the branch to create. + + revision (`str`, *optional*): + The git revision to create the branch from. It can be a branch name or + the OID/SHA of a commit, as a hexadecimal string. Defaults to the head + of the `"main"` branch. + + token (`str`, *optional*): + Authentication token. Will default to the stored token. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if creating a branch on a dataset or + space, `None` or `"model"` if tagging a model. Default is `None`. + + exist_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if branch already exists. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + [`~utils.BadRequestError`]: + If invalid reference for a branch. Ex: `refs/pr/5` or 'refs/foo/bar'. + [`~utils.HfHubHTTPError`]: + If the branch already exists on the repo (error 409) and `exist_ok` is + set to `False`. + """ + if repo_type is None: + repo_type = REPO_TYPE_MODEL + branch = quote(branch, safe="") + + # Prepare request + branch_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/branch/{branch}" + headers = self._build_hf_headers(token=token, is_write_action=True) + payload = {} + if revision is not None: + payload["startingPoint"] = revision + + # Create branch + response = get_session().post(url=branch_url, headers=headers, json=payload) + try: + hf_raise_for_status(response) + except HfHubHTTPError as e: + if not (e.response.status_code == 409 and exist_ok): + raise + + @validate_hf_hub_args + def delete_branch( + self, + repo_id: str, + *, + branch: str, + token: Optional[str] = None, + repo_type: Optional[str] = None, + ) -> None: + """ + Delete a branch from a repo on the Hub. + + Args: + repo_id (`str`): + The repository in which a branch will be deleted. + Example: `"user/my-cool-model"`. + + branch (`str`): + The name of the branch to delete. + + token (`str`, *optional*): + Authentication token. Will default to the stored token. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if creating a branch on a dataset or + space, `None` or `"model"` if tagging a model. Default is `None`. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + [`~utils.HfHubHTTPError`]: + If trying to delete a protected branch. Ex: `main` cannot be deleted. + [`~utils.HfHubHTTPError`]: + If trying to delete a branch that does not exist. + + """ + if repo_type is None: + repo_type = REPO_TYPE_MODEL + branch = quote(branch, safe="") + + # Prepare request + branch_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/branch/{branch}" + headers = self._build_hf_headers(token=token, is_write_action=True) + + # Delete branch + response = get_session().delete(url=branch_url, headers=headers) + hf_raise_for_status(response) + + @validate_hf_hub_args + def create_tag( + self, + repo_id: str, + *, + tag: str, + tag_message: Optional[str] = None, + revision: Optional[str] = None, + token: Optional[str] = None, + repo_type: Optional[str] = None, + exist_ok: bool = False, + ) -> None: + """ + Tag a given commit of a repo on the Hub. + + Args: + repo_id (`str`): + The repository in which a commit will be tagged. + Example: `"user/my-cool-model"`. + + tag (`str`): + The name of the tag to create. + + tag_message (`str`, *optional*): + The description of the tag to create. + + revision (`str`, *optional*): + The git revision to tag. It can be a branch name or the OID/SHA of a + commit, as a hexadecimal string. Shorthands (7 first characters) are + also supported. Defaults to the head of the `"main"` branch. + + token (`str`, *optional*): + Authentication token. Will default to the stored token. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if tagging a dataset or + space, `None` or `"model"` if tagging a model. Default is + `None`. + + exist_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if tag already exists. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + [`~utils.RevisionNotFoundError`]: + If revision is not found (error 404) on the repo. + [`~utils.HfHubHTTPError`]: + If the branch already exists on the repo (error 409) and `exist_ok` is + set to `False`. + """ + if repo_type is None: + repo_type = REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else DEFAULT_REVISION + + # Prepare request + tag_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tag/{revision}" + headers = self._build_hf_headers(token=token, is_write_action=True) + payload = {"tag": tag} + if tag_message is not None: + payload["message"] = tag_message + + # Tag + response = get_session().post(url=tag_url, headers=headers, json=payload) + try: + hf_raise_for_status(response) + except HfHubHTTPError as e: + if not (e.response.status_code == 409 and exist_ok): + raise + + @validate_hf_hub_args + def delete_tag( + self, + repo_id: str, + *, + tag: str, + token: Optional[str] = None, + repo_type: Optional[str] = None, + ) -> None: + """ + Delete a tag from a repo on the Hub. + + Args: + repo_id (`str`): + The repository in which a tag will be deleted. + Example: `"user/my-cool-model"`. + + tag (`str`): + The name of the tag to delete. + + token (`str`, *optional*): + Authentication token. Will default to the stored token. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if tagging a dataset or space, `None` or + `"model"` if tagging a model. Default is `None`. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + [`~utils.RevisionNotFoundError`]: + If tag is not found. + """ + if repo_type is None: + repo_type = REPO_TYPE_MODEL + tag = quote(tag, safe="") + + # Prepare request + tag_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tag/{tag}" + headers = self._build_hf_headers(token=token, is_write_action=True) + + # Un-tag + response = get_session().delete(url=tag_url, headers=headers) + hf_raise_for_status(response) + + @validate_hf_hub_args + def get_full_repo_name( + self, + model_id: str, + *, + organization: Optional[str] = None, + token: Optional[Union[bool, str]] = None, + ): + """ + Returns the repository name for a given model ID and optional + organization. + + Args: + model_id (`str`): + The name of the model. + organization (`str`, *optional*): + If passed, the repository name will be in the organization + namespace instead of the user namespace. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + If `None` or `True` and machine is logged in (through `huggingface-cli login` + or [`~huggingface_hub.login`]), token will be retrieved from the cache. + If `False`, token is not sent in the request header. + + Returns: + `str`: The repository name in the user's namespace + ({username}/{model_id}) if no organization is passed, and under the + organization namespace ({organization}/{model_id}) otherwise. + """ + if organization is None: + if "/" in model_id: + username = model_id.split("/")[0] + else: + username = self.whoami(token=token)["name"] # type: ignore + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + @validate_hf_hub_args + def get_repo_discussions( + self, + repo_id: str, + *, + author: Optional[str] = None, + discussion_type: Optional[DiscussionTypeFilter] = None, + discussion_status: Optional[DiscussionStatusFilter] = None, + repo_type: Optional[str] = None, + token: Optional[str] = None, + ) -> Iterator[Discussion]: + """ + Fetches Discussions and Pull Requests for the given repo. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + author (`str`, *optional*): + Pass a value to filter by discussion author. `None` means no filter. + Default is `None`. + discussion_type (`str`, *optional*): + Set to `"pull_request"` to fetch only pull requests, `"discussion"` + to fetch only discussions. Set to `"all"` or `None` to fetch both. + Default is `None`. + discussion_status (`str`, *optional*): + Set to `"open"` (respectively `"closed"`) to fetch only open + (respectively closed) discussions. Set to `"all"` or `None` + to fetch both. + Default is `None`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if fetching from a dataset or + space, `None` or `"model"` if fetching from a model. Default is + `None`. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token). + + Returns: + `Iterator[Discussion]`: An iterator of [`Discussion`] objects. + + Example: + Collecting all discussions of a repo in a list: + + ```python + >>> from huggingface_hub import get_repo_discussions + >>> discussions_list = list(get_repo_discussions(repo_id="bert-base-uncased")) + ``` + + Iterating over discussions of a repo: + + ```python + >>> from huggingface_hub import get_repo_discussions + >>> for discussion in get_repo_discussions(repo_id="bert-base-uncased"): + ... print(discussion.num, discussion.title) + ``` + """ + if repo_type not in REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if repo_type is None: + repo_type = REPO_TYPE_MODEL + if discussion_type is not None and discussion_type not in DISCUSSION_TYPES: + raise ValueError(f"Invalid discussion_type, must be one of {DISCUSSION_TYPES}") + if discussion_status is not None and discussion_status not in DISCUSSION_STATUS: + raise ValueError(f"Invalid discussion_status, must be one of {DISCUSSION_STATUS}") + + headers = self._build_hf_headers(token=token) + query_dict: Dict[str, str] = {} + if discussion_type is not None: + query_dict["type"] = discussion_type + if discussion_status is not None: + query_dict["status"] = discussion_status + if author is not None: + query_dict["author"] = author + + def _fetch_discussion_page(page_index: int): + query_string = urlencode({**query_dict, "page_index": page_index}) + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions?{query_string}" + resp = get_session().get(path, headers=headers) + hf_raise_for_status(resp) + paginated_discussions = resp.json() + total = paginated_discussions["count"] + start = paginated_discussions["start"] + discussions = paginated_discussions["discussions"] + has_next = (start + len(discussions)) < total + return discussions, has_next + + has_next, page_index = True, 0 + + while has_next: + discussions, has_next = _fetch_discussion_page(page_index=page_index) + for discussion in discussions: + yield Discussion( + title=discussion["title"], + num=discussion["num"], + author=discussion.get("author", {}).get("name", "deleted"), + created_at=parse_datetime(discussion["createdAt"]), + status=discussion["status"], + repo_id=discussion["repo"]["name"], + repo_type=discussion["repo"]["type"], + is_pull_request=discussion["isPullRequest"], + endpoint=self.endpoint, + ) + page_index = page_index + 1 + + @validate_hf_hub_args + def get_discussion_details( + self, + repo_id: str, + discussion_num: int, + *, + repo_type: Optional[str] = None, + token: Optional[str] = None, + ) -> DiscussionWithDetails: + """Fetches a Discussion's / Pull Request 's details from the Hub. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token) + + Returns: [`DiscussionWithDetails`] + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + if not isinstance(discussion_num, int) or discussion_num <= 0: + raise ValueError("Invalid discussion_num, must be a positive integer") + if repo_type not in REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if repo_type is None: + repo_type = REPO_TYPE_MODEL + + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions/{discussion_num}" + headers = self._build_hf_headers(token=token) + resp = get_session().get(path, params={"diff": "1"}, headers=headers) + hf_raise_for_status(resp) + + discussion_details = resp.json() + is_pull_request = discussion_details["isPullRequest"] + + target_branch = discussion_details["changes"]["base"] if is_pull_request else None + conflicting_files = discussion_details["filesWithConflicts"] if is_pull_request else None + merge_commit_oid = discussion_details["changes"].get("mergeCommitId", None) if is_pull_request else None + + return DiscussionWithDetails( + title=discussion_details["title"], + num=discussion_details["num"], + author=discussion_details.get("author", {}).get("name", "deleted"), + created_at=parse_datetime(discussion_details["createdAt"]), + status=discussion_details["status"], + repo_id=discussion_details["repo"]["name"], + repo_type=discussion_details["repo"]["type"], + is_pull_request=discussion_details["isPullRequest"], + events=[deserialize_event(evt) for evt in discussion_details["events"]], + conflicting_files=conflicting_files, + target_branch=target_branch, + merge_commit_oid=merge_commit_oid, + diff=discussion_details.get("diff"), + endpoint=self.endpoint, + ) + + @validate_hf_hub_args + def create_discussion( + self, + repo_id: str, + title: str, + *, + token: Optional[str] = None, + description: Optional[str] = None, + repo_type: Optional[str] = None, + pull_request: bool = False, + ) -> DiscussionWithDetails: + """Creates a Discussion or Pull Request. + + Pull Requests created programmatically will be in `"draft"` status. + + Creating a Pull Request with changes can also be done at once with [`HfApi.create_commit`]. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + title (`str`): + The title of the discussion. It can be up to 200 characters long, + and must be at least 3 characters long. Leading and trailing whitespaces + will be stripped. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token) + description (`str`, *optional*): + An optional description for the Pull Request. + Defaults to `"Discussion opened with the huggingface_hub Python library"` + pull_request (`bool`, *optional*): + Whether to create a Pull Request or discussion. If `True`, creates a Pull Request. + If `False`, creates a discussion. Defaults to `False`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + + Returns: [`DiscussionWithDetails`] + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + """ + if repo_type not in REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if repo_type is None: + repo_type = REPO_TYPE_MODEL + + if description is not None: + description = description.strip() + description = ( + description + if description + else ( + f"{'Pull Request' if pull_request else 'Discussion'} opened with the" + " [huggingface_hub Python" + " library](https://huggingface.co/docs/huggingface_hub)" + ) + ) + + headers = self._build_hf_headers(token=token, is_write_action=True) + resp = get_session().post( + f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions", + json={ + "title": title.strip(), + "description": description, + "pullRequest": pull_request, + }, + headers=headers, + ) + hf_raise_for_status(resp) + num = resp.json()["num"] + return self.get_discussion_details( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=num, + token=token, + ) + + @validate_hf_hub_args + def create_pull_request( + self, + repo_id: str, + title: str, + *, + token: Optional[str] = None, + description: Optional[str] = None, + repo_type: Optional[str] = None, + ) -> DiscussionWithDetails: + """Creates a Pull Request . Pull Requests created programmatically will be in `"draft"` status. + + Creating a Pull Request with changes can also be done at once with [`HfApi.create_commit`]; + + This is a wrapper around [`HfApi.create_discussion`]. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + title (`str`): + The title of the discussion. It can be up to 200 characters long, + and must be at least 3 characters long. Leading and trailing whitespaces + will be stripped. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token) + description (`str`, *optional*): + An optional description for the Pull Request. + Defaults to `"Discussion opened with the huggingface_hub Python library"` + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + + Returns: [`DiscussionWithDetails`] + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + """ + return self.create_discussion( + repo_id=repo_id, + title=title, + token=token, + description=description, + repo_type=repo_type, + pull_request=True, + ) + + def _post_discussion_changes( + self, + *, + repo_id: str, + discussion_num: int, + resource: str, + body: Optional[dict] = None, + token: Optional[str] = None, + repo_type: Optional[str] = None, + ) -> requests.Response: + """Internal utility to POST changes to a Discussion or Pull Request""" + if not isinstance(discussion_num, int) or discussion_num <= 0: + raise ValueError("Invalid discussion_num, must be a positive integer") + if repo_type not in REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if repo_type is None: + repo_type = REPO_TYPE_MODEL + repo_id = f"{repo_type}s/{repo_id}" + + path = f"{self.endpoint}/api/{repo_id}/discussions/{discussion_num}/{resource}" + + headers = self._build_hf_headers(token=token, is_write_action=True) + resp = requests.post(path, headers=headers, json=body) + hf_raise_for_status(resp) + return resp + + @validate_hf_hub_args + def comment_discussion( + self, + repo_id: str, + discussion_num: int, + comment: str, + *, + token: Optional[str] = None, + repo_type: Optional[str] = None, + ) -> DiscussionComment: + """Creates a new comment on the given Discussion. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + comment (`str`): + The content of the comment to create. Comments support markdown formatting. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token) + + Returns: + [`DiscussionComment`]: the newly created comment + + + Examples: + ```python + + >>> comment = \"\"\" + ... Hello @otheruser! + ... + ... # This is a title + ... + ... **This is bold**, *this is italic* and ~this is strikethrough~ + ... And [this](http://url) is a link + ... \"\"\" + + >>> HfApi().comment_discussion( + ... repo_id="username/repo_name", + ... discussion_num=34 + ... comment=comment + ... ) + # DiscussionComment(id='deadbeef0000000', type='comment', ...) + + ``` + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource="comment", + body={"comment": comment}, + ) + return deserialize_event(resp.json()["newMessage"]) # type: ignore + + @validate_hf_hub_args + def rename_discussion( + self, + repo_id: str, + discussion_num: int, + new_title: str, + *, + token: Optional[str] = None, + repo_type: Optional[str] = None, + ) -> DiscussionTitleChange: + """Renames a Discussion. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + new_title (`str`): + The new title for the discussion + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token) + + Returns: + [`DiscussionTitleChange`]: the title change event + + + Examples: + ```python + >>> new_title = "New title, fixing a typo" + >>> HfApi().rename_discussion( + ... repo_id="username/repo_name", + ... discussion_num=34 + ... new_title=new_title + ... ) + # DiscussionTitleChange(id='deadbeef0000000', type='title-change', ...) + + ``` + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource="title", + body={"title": new_title}, + ) + return deserialize_event(resp.json()["newTitle"]) # type: ignore + + @validate_hf_hub_args + def change_discussion_status( + self, + repo_id: str, + discussion_num: int, + new_status: Literal["open", "closed"], + *, + token: Optional[str] = None, + comment: Optional[str] = None, + repo_type: Optional[str] = None, + ) -> DiscussionStatusChange: + """Closes or re-opens a Discussion or Pull Request. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + new_status (`str`): + The new status for the discussion, either `"open"` or `"closed"`. + comment (`str`, *optional*): + An optional comment to post with the status change. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token) + + Returns: + [`DiscussionStatusChange`]: the status change event + + + Examples: + ```python + >>> new_title = "New title, fixing a typo" + >>> HfApi().rename_discussion( + ... repo_id="username/repo_name", + ... discussion_num=34 + ... new_title=new_title + ... ) + # DiscussionStatusChange(id='deadbeef0000000', type='status-change', ...) + + ``` + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + if new_status not in ["open", "closed"]: + raise ValueError("Invalid status, valid statuses are: 'open' and 'closed'") + body: Dict[str, str] = {"status": new_status} + if comment and comment.strip(): + body["comment"] = comment.strip() + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource="status", + body=body, + ) + return deserialize_event(resp.json()["newStatus"]) # type: ignore + + @validate_hf_hub_args + def merge_pull_request( + self, + repo_id: str, + discussion_num: int, + *, + token: Optional[str] = None, + comment: Optional[str] = None, + repo_type: Optional[str] = None, + ): + """Merges a Pull Request. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + comment (`str`, *optional*): + An optional comment to post with the status change. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token) + + Returns: + [`DiscussionStatusChange`]: the status change event + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource="merge", + body={"comment": comment.strip()} if comment and comment.strip() else None, + ) + + @validate_hf_hub_args + def edit_discussion_comment( + self, + repo_id: str, + discussion_num: int, + comment_id: str, + new_content: str, + *, + token: Optional[str] = None, + repo_type: Optional[str] = None, + ) -> DiscussionComment: + """Edits a comment on a Discussion / Pull Request. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + comment_id (`str`): + The ID of the comment to edit. + new_content (`str`): + The new content of the comment. Comments support markdown formatting. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token) + + Returns: + [`DiscussionComment`]: the edited comment + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource=f"comment/{comment_id.lower()}/edit", + body={"content": new_content}, + ) + return deserialize_event(resp.json()["updatedComment"]) # type: ignore + + @validate_hf_hub_args + def hide_discussion_comment( + self, + repo_id: str, + discussion_num: int, + comment_id: str, + *, + token: Optional[str] = None, + repo_type: Optional[str] = None, + ) -> DiscussionComment: + """Hides a comment on a Discussion / Pull Request. + + + Hidden comments' content cannot be retrieved anymore. Hiding a comment is irreversible. + + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + comment_id (`str`): + The ID of the comment to edit. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token) + + Returns: + [`DiscussionComment`]: the hidden comment + + + + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + warnings.warn( + "Hidden comments' content cannot be retrieved anymore. Hiding a comment is irreversible.", + UserWarning, + ) + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource=f"comment/{comment_id.lower()}/hide", + ) + return deserialize_event(resp.json()["updatedComment"]) # type: ignore + + @validate_hf_hub_args + def add_space_secret( + self, repo_id: str, key: str, value: str, *, description: Optional[str] = None, token: Optional[str] = None + ) -> None: + """Adds or updates a secret in a Space. + + Secrets allow to set secret keys or tokens to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + key (`str`): + Secret key. Example: `"GITHUB_API_KEY"` + value (`str`): + Secret value. Example: `"your_github_api_key"`. + description (`str`, *optional*): + Secret description. Example: `"Github API key to access the Github API"`. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + """ + payload = {"key": key, "value": value} + if description is not None: + payload["description"] = description + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/secrets", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(r) + + @validate_hf_hub_args + def delete_space_secret(self, repo_id: str, key: str, *, token: Optional[str] = None) -> None: + """Deletes a secret from a Space. + + Secrets allow to set secret keys or tokens to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + key (`str`): + Secret key. Example: `"GITHUB_API_KEY"`. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + """ + r = get_session().delete( + f"{self.endpoint}/api/spaces/{repo_id}/secrets", + headers=self._build_hf_headers(token=token), + json={"key": key}, + ) + hf_raise_for_status(r) + + @validate_hf_hub_args + def get_space_variables(self, repo_id: str, *, token: Optional[str] = None) -> Dict[str, SpaceVariable]: + """Gets all variables from a Space. + + Variables allow to set environment variables to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables + + Args: + repo_id (`str`): + ID of the repo to query. Example: `"bigcode/in-the-stack"`. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + """ + r = get_session().get( + f"{self.endpoint}/api/spaces/{repo_id}/variables", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(r) + return {k: SpaceVariable(k, v) for k, v in r.json().items()} + + @validate_hf_hub_args + def add_space_variable( + self, repo_id: str, key: str, value: str, *, description: Optional[str] = None, token: Optional[str] = None + ) -> Dict[str, SpaceVariable]: + """Adds or updates a variable in a Space. + + Variables allow to set environment variables to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + key (`str`): + Variable key. Example: `"MODEL_REPO_ID"` + value (`str`): + Variable value. Example: `"the_model_repo_id"`. + description (`str`): + Description of the variable. Example: `"Model Repo ID of the implemented model"`. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + """ + payload = {"key": key, "value": value} + if description is not None: + payload["description"] = description + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/variables", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(r) + return {k: SpaceVariable(k, v) for k, v in r.json().items()} + + @validate_hf_hub_args + def delete_space_variable( + self, repo_id: str, key: str, *, token: Optional[str] = None + ) -> Dict[str, SpaceVariable]: + """Deletes a variable from a Space. + + Variables allow to set environment variables to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + key (`str`): + Variable key. Example: `"MODEL_REPO_ID"` + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + """ + r = get_session().delete( + f"{self.endpoint}/api/spaces/{repo_id}/variables", + headers=self._build_hf_headers(token=token), + json={"key": key}, + ) + hf_raise_for_status(r) + return {k: SpaceVariable(k, v) for k, v in r.json().items()} + + @validate_hf_hub_args + def get_space_runtime(self, repo_id: str, *, token: Optional[str] = None) -> SpaceRuntime: + """Gets runtime information about a Space. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if + not provided. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + """ + r = get_session().get( + f"{self.endpoint}/api/spaces/{repo_id}/runtime", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def request_space_hardware( + self, + repo_id: str, + hardware: SpaceHardware, + *, + token: Optional[str] = None, + sleep_time: Optional[int] = None, + ) -> SpaceRuntime: + """Request new hardware for a Space. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + hardware (`str` or [`SpaceHardware`]): + Hardware on which to run the Space. Example: `"t4-medium"`. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + sleep_time (`int`, *optional*): + Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want + your Space to sleep (default behavior for upgraded hardware). For free hardware, you can't configure + the sleep time (value is fixed to 48 hours of inactivity). + See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + + + + It is also possible to request hardware directly when creating the Space repo! See [`create_repo`] for details. + + + """ + if sleep_time is not None and hardware == SpaceHardware.CPU_BASIC: + warnings.warn( + "If your Space runs on the default 'cpu-basic' hardware, it will go to sleep if inactive for more" + " than 48 hours. This value is not configurable. If you don't want your Space to deactivate or if" + " you want to set a custom sleep time, you need to upgrade to a paid Hardware.", + UserWarning, + ) + payload: Dict[str, Any] = {"flavor": hardware} + if sleep_time is not None: + payload["sleepTimeSeconds"] = sleep_time + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/hardware", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def set_space_sleep_time(self, repo_id: str, sleep_time: int, *, token: Optional[str] = None) -> SpaceRuntime: + """Set a custom sleep time for a Space running on upgraded hardware.. + + Your Space will go to sleep after X seconds of inactivity. You are not billed when your Space is in "sleep" + mode. If a new visitor lands on your Space, it will "wake it up". Only upgraded hardware can have a + configurable sleep time. To know more about the sleep stage, please refer to + https://huggingface.co/docs/hub/spaces-gpus#sleep-time. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + sleep_time (`int`, *optional*): + Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want + your Space to pause (default behavior for upgraded hardware). For free hardware, you can't configure + the sleep time (value is fixed to 48 hours of inactivity). + See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + + + + It is also possible to set a custom sleep time when requesting hardware with [`request_space_hardware`]. + + + """ + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/sleeptime", + headers=self._build_hf_headers(token=token), + json={"seconds": sleep_time}, + ) + hf_raise_for_status(r) + runtime = SpaceRuntime(r.json()) + + hardware = runtime.requested_hardware or runtime.hardware + if hardware == SpaceHardware.CPU_BASIC: + warnings.warn( + "If your Space runs on the default 'cpu-basic' hardware, it will go to sleep if inactive for more" + " than 48 hours. This value is not configurable. If you don't want your Space to deactivate or if" + " you want to set a custom sleep time, you need to upgrade to a paid Hardware.", + UserWarning, + ) + return runtime + + @validate_hf_hub_args + def pause_space(self, repo_id: str, *, token: Optional[str] = None) -> SpaceRuntime: + """Pause your Space. + + A paused Space stops executing until manually restarted by its owner. This is different from the sleeping + state in which free Spaces go after 48h of inactivity. Paused time is not billed to your account, no matter the + hardware you've selected. To restart your Space, use [`restart_space`] and go to your Space settings page. + + For more details, please visit [the docs](https://huggingface.co/docs/hub/spaces-gpus#pause). + + Args: + repo_id (`str`): + ID of the Space to pause. Example: `"Salesforce/BLIP2"`. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + + Returns: + [`SpaceRuntime`]: Runtime information about your Space including `stage=PAUSED` and requested hardware. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If your Space is not found (error 404). Most probably wrong repo_id or your space is private but you + are not authenticated. + [`~utils.HfHubHTTPError`]: + 403 Forbidden: only the owner of a Space can pause it. If you want to manage a Space that you don't + own, either ask the owner by opening a Discussion or duplicate the Space. + [`~utils.BadRequestError`]: + If your Space is a static Space. Static Spaces are always running and never billed. If you want to hide + a static Space, you can set it to private. + """ + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/pause", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def restart_space( + self, repo_id: str, *, token: Optional[str] = None, factory_reboot: bool = False + ) -> SpaceRuntime: + """Restart your Space. + + This is the only way to programmatically restart a Space if you've put it on Pause (see [`pause_space`]). You + must be the owner of the Space to restart it. If you are using an upgraded hardware, your account will be + billed as soon as the Space is restarted. You can trigger a restart no matter the current state of a Space. + + For more details, please visit [the docs](https://huggingface.co/docs/hub/spaces-gpus#pause). + + Args: + repo_id (`str`): + ID of the Space to restart. Example: `"Salesforce/BLIP2"`. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + factory_reboot (`bool`, *optional*): + If `True`, the Space will be rebuilt from scratch without caching any requirements. + + Returns: + [`SpaceRuntime`]: Runtime information about your Space. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If your Space is not found (error 404). Most probably wrong repo_id or your space is private but you + are not authenticated. + [`~utils.HfHubHTTPError`]: + 403 Forbidden: only the owner of a Space can restart it. If you want to restart a Space that you don't + own, either ask the owner by opening a Discussion or duplicate the Space. + [`~utils.BadRequestError`]: + If your Space is a static Space. Static Spaces are always running and never billed. If you want to hide + a static Space, you can set it to private. + """ + params = {} + if factory_reboot: + params["factory"] = "true" + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/restart", headers=self._build_hf_headers(token=token), params=params + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def duplicate_space( + self, + from_id: str, + to_id: Optional[str] = None, + *, + private: Optional[bool] = None, + token: Optional[str] = None, + exist_ok: bool = False, + hardware: Optional[SpaceHardware] = None, + storage: Optional[SpaceStorage] = None, + sleep_time: Optional[int] = None, + secrets: Optional[List[Dict[str, str]]] = None, + variables: Optional[List[Dict[str, str]]] = None, + ) -> RepoUrl: + """Duplicate a Space. + + Programmatically duplicate a Space. The new Space will be created in your account and will be in the same state + as the original Space (running or paused). You can duplicate a Space no matter the current state of a Space. + + Args: + from_id (`str`): + ID of the Space to duplicate. Example: `"pharma/CLIP-Interrogator"`. + to_id (`str`, *optional*): + ID of the new Space. Example: `"dog/CLIP-Interrogator"`. If not provided, the new Space will have the same + name as the original Space, but in your account. + private (`bool`, *optional*): + Whether the new Space should be private or not. Defaults to the same privacy as the original Space. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + exist_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if repo already exists. + hardware (`SpaceHardware` or `str`, *optional*): + Choice of Hardware. Example: `"t4-medium"`. See [`SpaceHardware`] for a complete list. + storage (`SpaceStorage` or `str`, *optional*): + Choice of persistent storage tier. Example: `"small"`. See [`SpaceStorage`] for a complete list. + sleep_time (`int`, *optional*): + Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want + your Space to sleep (default behavior for upgraded hardware). For free hardware, you can't configure + the sleep time (value is fixed to 48 hours of inactivity). + See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. + secrets (`List[Dict[str, str]]`, *optional*): + A list of secret keys to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. + variables (`List[Dict[str, str]]`, *optional*): + A list of public environment variables to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables. + + Returns: + [`RepoUrl`]: URL to the newly created repo. Value is a subclass of `str` containing + attributes like `endpoint`, `repo_type` and `repo_id`. + + Raises: + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`~utils.RepositoryNotFoundError`] + If one of `from_id` or `to_id` cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + Example: + ```python + >>> from huggingface_hub import duplicate_space + + # Duplicate a Space to your account + >>> duplicate_space("multimodalart/dreambooth-training") + RepoUrl('https://huggingface.co/spaces/nateraw/dreambooth-training',...) + + # Can set custom destination id and visibility flag. + >>> duplicate_space("multimodalart/dreambooth-training", to_id="my-dreambooth", private=True) + RepoUrl('https://huggingface.co/spaces/nateraw/my-dreambooth',...) + ``` + """ + # Parse to_id if provided + parsed_to_id = RepoUrl(to_id) if to_id is not None else None + + # Infer target repo_id + to_namespace = ( # set namespace manually or default to username + parsed_to_id.namespace + if parsed_to_id is not None and parsed_to_id.namespace is not None + else self.whoami(token)["name"] + ) + to_repo_name = parsed_to_id.repo_name if to_id is not None else RepoUrl(from_id).repo_name # type: ignore + + # repository must be a valid repo_id (namespace/repo_name). + payload: Dict[str, Any] = {"repository": f"{to_namespace}/{to_repo_name}"} + + keys = ["private", "hardware", "storageTier", "sleepTimeSeconds", "secrets", "variables"] + values = [private, hardware, storage, sleep_time, secrets, variables] + payload.update({k: v for k, v in zip(keys, values) if v is not None}) + + if sleep_time is not None and hardware == SpaceHardware.CPU_BASIC: + warnings.warn( + "If your Space runs on the default 'cpu-basic' hardware, it will go to sleep if inactive for more" + " than 48 hours. This value is not configurable. If you don't want your Space to deactivate or if" + " you want to set a custom sleep time, you need to upgrade to a paid Hardware.", + UserWarning, + ) + + r = get_session().post( + f"{self.endpoint}/api/spaces/{from_id}/duplicate", + headers=self._build_hf_headers(token=token, is_write_action=True), + json=payload, + ) + + try: + hf_raise_for_status(r) + except HTTPError as err: + if exist_ok and err.response.status_code == 409: + # Repo already exists and `exist_ok=True` + pass + else: + raise + + return RepoUrl(r.json()["url"], endpoint=self.endpoint) + + @validate_hf_hub_args + def request_space_storage( + self, + repo_id: str, + storage: SpaceStorage, + *, + token: Optional[str] = None, + ) -> SpaceRuntime: + """Request persistent storage for a Space. + + Args: + repo_id (`str`): + ID of the Space to update. Example: `"HuggingFaceH4/open_llm_leaderboard"`. + storage (`str` or [`SpaceStorage`]): + Storage tier. Either 'small', 'medium', or 'large'. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + + + + It is not possible to decrease persistent storage after its granted. To do so, you must delete it + via [`delete_space_storage`]. + + + """ + payload: Dict[str, SpaceStorage] = {"tier": storage} + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/storage", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def delete_space_storage( + self, + repo_id: str, + *, + token: Optional[str] = None, + ) -> SpaceRuntime: + """Delete persistent storage for a Space. + + Args: + repo_id (`str`): + ID of the Space to update. Example: `"HuggingFaceH4/open_llm_leaderboard"`. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + Raises: + [`BadRequestError`] + If space has no persistent storage. + + """ + r = get_session().delete( + f"{self.endpoint}/api/spaces/{repo_id}/storage", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + ####################### + # Inference Endpoints # + ####################### + + def list_inference_endpoints( + self, namespace: Optional[str] = None, *, token: Optional[str] = None + ) -> List[InferenceEndpoint]: + """Lists all inference endpoints for the given namespace. + + Args: + namespace (`str`, *optional*): + The namespace to list endpoints for. Defaults to the current user. Set to `"*"` to list all endpoints + from all namespaces (i.e. personal namespace and all orgs the user belongs to). + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token). + + Returns: + List[`InferenceEndpoint`]: A list of all inference endpoints for the given namespace. + + Example: + ```python + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> api.list_inference_endpoints() + [InferenceEndpoint(name='my-endpoint', ...), ...] + ``` + """ + # Special case: list all endpoints for all namespaces the user has access to + if namespace == "*": + user = self.whoami(token=token) + + # List personal endpoints first + endpoints: List[InferenceEndpoint] = list_inference_endpoints(namespace=self._get_namespace(token=token)) + + # Then list endpoints for all orgs the user belongs to and ignore 401 errors (no billing or no access) + for org in user.get("orgs", []): + try: + endpoints += list_inference_endpoints(namespace=org["name"], token=token) + except HfHubHTTPError as error: + if error.response.status_code == 401: # Either no billing or user don't have access) + logger.debug("Cannot list Inference Endpoints for org '%s': %s", org["name"], error) + pass + + return endpoints + + # Normal case: list endpoints for a specific namespace + namespace = namespace or self._get_namespace(token=token) + + response = get_session().get( + f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + return [ + InferenceEndpoint.from_raw(endpoint, namespace=namespace, token=token) + for endpoint in response.json()["items"] + ] + + def create_inference_endpoint( + self, + name: str, + *, + repository: str, + framework: str, + accelerator: str, + instance_size: str, + instance_type: str, + region: str, + vendor: str, + account_id: Optional[str] = None, + min_replica: int = 0, + max_replica: int = 1, + revision: Optional[str] = None, + task: Optional[str] = None, + custom_image: Optional[Dict] = None, + type: InferenceEndpointType = InferenceEndpointType.PROTECTED, + namespace: Optional[str] = None, + token: Optional[str] = None, + ) -> InferenceEndpoint: + """Create a new Inference Endpoint. + + Args: + name (`str`): + The unique name for the new Inference Endpoint. + repository (`str`): + The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`). + framework (`str`): + The machine learning framework used for the model (e.g. `"custom"`). + accelerator (`str`): + The hardware accelerator to be used for inference (e.g. `"cpu"`). + instance_size (`str`): + The size or type of the instance to be used for hosting the model (e.g. `"large"`). + instance_type (`str`): + The cloud instance type where the Inference Endpoint will be deployed (e.g. `"c6i"`). + region (`str`): + The cloud region in which the Inference Endpoint will be created (e.g. `"us-east-1"`). + vendor (`str`): + The cloud provider or vendor where the Inference Endpoint will be hosted (e.g. `"aws"`). + account_id (`str`, *optional*): + The account ID used to link a VPC to a private Inference Endpoint (if applicable). + min_replica (`int`, *optional*): + The minimum number of replicas (instances) to keep running for the Inference Endpoint. Defaults to 0. + max_replica (`int`, *optional*): + The maximum number of replicas (instances) to scale to for the Inference Endpoint. Defaults to 1. + revision (`str`, *optional*): + The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). + task (`str`, *optional*): + The task on which to deploy the model (e.g. `"text-classification"`). + custom_image (`Dict`, *optional*): + A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an + Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples). + type ([`InferenceEndpointType]`, *optional*): + The type of the Inference Endpoint, which can be `"protected"` (default), `"public"` or `"private"`. + namespace (`str`, *optional*): + The namespace where the Inference Endpoint will be created. Defaults to the current user's namespace. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token). + + Returns: + [`InferenceEndpoint`]: information about the updated Inference Endpoint. + + Example: + ```python + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> create_inference_endpoint( + ... "my-endpoint-name", + ... repository="gpt2", + ... framework="pytorch", + ... task="text-generation", + ... accelerator="cpu", + ... vendor="aws", + ... region="us-east-1", + ... type="protected", + ... instance_size="medium", + ... instance_type="c6i", + ... ) + >>> endpoint + InferenceEndpoint(name='my-endpoint-name', status="pending",...) + + # Run inference on the endpoint + >>> endpoint.client.text_generation(...) + "..." + ``` + + ```python + # Start an Inference Endpoint running Zephyr-7b-beta on TGI + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> create_inference_endpoint( + ... "aws-zephyr-7b-beta-0486", + ... repository="HuggingFaceH4/zephyr-7b-beta", + ... framework="pytorch", + ... task="text-generation", + ... accelerator="gpu", + ... vendor="aws", + ... region="us-east-1", + ... type="protected", + ... instance_size="medium", + ... instance_type="g5.2xlarge", + ... custom_image={ + ... "health_route": "/health", + ... "env": { + ... "MAX_BATCH_PREFILL_TOKENS": "2048", + ... "MAX_INPUT_LENGTH": "1024", + ... "MAX_TOTAL_TOKENS": "1512", + ... "MODEL_ID": "/repository" + ... }, + ... "url": "ghcr.io/huggingface/text-generation-inference:1.1.0", + ... }, + ... ) + + ``` + """ + namespace = namespace or self._get_namespace(token=token) + + image = {"custom": custom_image} if custom_image is not None else {"huggingface": {}} + payload: Dict = { + "accountId": account_id, + "compute": { + "accelerator": accelerator, + "instanceSize": instance_size, + "instanceType": instance_type, + "scaling": { + "maxReplica": max_replica, + "minReplica": min_replica, + }, + }, + "model": { + "framework": framework, + "repository": repository, + "revision": revision, + "task": task, + "image": image, + }, + "name": name, + "provider": { + "region": region, + "vendor": vendor, + }, + "type": type, + } + + response = get_session().post( + f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def get_inference_endpoint( + self, name: str, *, namespace: Optional[str] = None, token: Optional[str] = None + ) -> InferenceEndpoint: + """Get information about an Inference Endpoint. + + Args: + name (`str`): + The name of the Inference Endpoint to retrieve information about. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token). + + Returns: + [`InferenceEndpoint`]: information about the requested Inference Endpoint. + + Example: + ```python + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> endpoint = api.get_inference_endpoint("my-text-to-image") + >>> endpoint + InferenceEndpoint(name='my-text-to-image', ...) + + # Get status + >>> endpoint.status + 'running' + >>> endpoint.url + 'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud' + + # Run inference + >>> endpoint.client.text_to_image(...) + ``` + """ + namespace = namespace or self._get_namespace(token=token) + + response = get_session().get( + f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def update_inference_endpoint( + self, + name: str, + *, + # Compute update + accelerator: Optional[str] = None, + instance_size: Optional[str] = None, + instance_type: Optional[str] = None, + min_replica: Optional[int] = None, + max_replica: Optional[int] = None, + # Model update + repository: Optional[str] = None, + framework: Optional[str] = None, + revision: Optional[str] = None, + task: Optional[str] = None, + # Other + namespace: Optional[str] = None, + token: Optional[str] = None, + ) -> InferenceEndpoint: + """Update an Inference Endpoint. + + This method allows the update of either the compute configuration, the deployed model, or both. All arguments are + optional but at least one must be provided. + + For convenience, you can also update an Inference Endpoint using [`InferenceEndpoint.update`]. + + Args: + name (`str`): + The name of the Inference Endpoint to update. + + accelerator (`str`, *optional*): + The hardware accelerator to be used for inference (e.g. `"cpu"`). + instance_size (`str`, *optional*): + The size or type of the instance to be used for hosting the model (e.g. `"large"`). + instance_type (`str`, *optional*): + The cloud instance type where the Inference Endpoint will be deployed (e.g. `"c6i"`). + min_replica (`int`, *optional*): + The minimum number of replicas (instances) to keep running for the Inference Endpoint. + max_replica (`int`, *optional*): + The maximum number of replicas (instances) to scale to for the Inference Endpoint. + + repository (`str`, *optional*): + The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`). + framework (`str`, *optional*): + The machine learning framework used for the model (e.g. `"custom"`). + revision (`str`, *optional*): + The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). + task (`str`, *optional*): + The task on which to deploy the model (e.g. `"text-classification"`). + + namespace (`str`, *optional*): + The namespace where the Inference Endpoint will be updated. Defaults to the current user's namespace. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token). + + Returns: + [`InferenceEndpoint`]: information about the updated Inference Endpoint. + """ + namespace = namespace or self._get_namespace(token=token) + + payload: Dict = {} + if any(value is not None for value in (accelerator, instance_size, instance_type, min_replica, max_replica)): + payload["compute"] = { + "accelerator": accelerator, + "instanceSize": instance_size, + "instanceType": instance_type, + "scaling": { + "maxReplica": max_replica, + "minReplica": min_replica, + }, + } + if any(value is not None for value in (repository, framework, revision, task)): + payload["model"] = { + "framework": framework, + "repository": repository, + "revision": revision, + "task": task, + "image": {"huggingface": {}}, + } + + response = get_session().put( + f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def delete_inference_endpoint( + self, name: str, *, namespace: Optional[str] = None, token: Optional[str] = None + ) -> None: + """Delete an Inference Endpoint. + + This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable + to pause it with [`pause_inference_endpoint`] or scale it to zero with [`scale_to_zero_inference_endpoint`]. + + For convenience, you can also delete an Inference Endpoint using [`InferenceEndpoint.delete`]. + + Args: + name (`str`): + The name of the Inference Endpoint to delete. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token). + """ + namespace = namespace or self._get_namespace(token=token) + response = get_session().delete( + f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + def pause_inference_endpoint( + self, name: str, *, namespace: Optional[str] = None, token: Optional[str] = None + ) -> InferenceEndpoint: + """Pause an Inference Endpoint. + + A paused Inference Endpoint will not be charged. It can be resumed at any time using [`resume_inference_endpoint`]. + This is different than scaling the Inference Endpoint to zero with [`scale_to_zero_inference_endpoint`], which + would be automatically restarted when a request is made to it. + + For convenience, you can also pause an Inference Endpoint using [`pause_inference_endpoint`]. + + Args: + name (`str`): + The name of the Inference Endpoint to pause. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token). + + Returns: + [`InferenceEndpoint`]: information about the paused Inference Endpoint. + """ + namespace = namespace or self._get_namespace(token=token) + + response = get_session().post( + f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/pause", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def resume_inference_endpoint( + self, name: str, *, namespace: Optional[str] = None, token: Optional[str] = None + ) -> InferenceEndpoint: + """Resume an Inference Endpoint. + + For convenience, you can also resume an Inference Endpoint using [`InferenceEndpoint.resume`]. + + Args: + name (`str`): + The name of the Inference Endpoint to resume. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token). + + Returns: + [`InferenceEndpoint`]: information about the resumed Inference Endpoint. + """ + namespace = namespace or self._get_namespace(token=token) + + response = get_session().post( + f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/resume", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def scale_to_zero_inference_endpoint( + self, name: str, *, namespace: Optional[str] = None, token: Optional[str] = None + ) -> InferenceEndpoint: + """Scale Inference Endpoint to zero. + + An Inference Endpoint scaled to zero will not be charged. It will be resume on the next request to it, with a + cold start delay. This is different than pausing the Inference Endpoint with [`pause_inference_endpoint`], which + would require a manual resume with [`resume_inference_endpoint`]. + + For convenience, you can also scale an Inference Endpoint to zero using [`InferenceEndpoint.scale_to_zero`]. + + Args: + name (`str`): + The name of the Inference Endpoint to scale to zero. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + token (`str`, *optional*): + An authentication token (See https://huggingface.co/settings/token). + + Returns: + [`InferenceEndpoint`]: information about the scaled-to-zero Inference Endpoint. + """ + namespace = namespace or self._get_namespace(token=token) + + response = get_session().post( + f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/scale-to-zero", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def _get_namespace(self, token: Optional[str] = None) -> str: + """Get the default namespace for the current user.""" + me = self.whoami(token=token) + if me["type"] == "user": + return me["name"] + else: + raise ValueError( + "Cannot determine default namespace. You must provide a 'namespace' as input or be logged in as a" + " user." + ) + + ######################## + # Collection Endpoints # + ######################## + @validate_hf_hub_args + def list_collections( + self, + *, + owner: Union[List[str], str, None] = None, + item: Union[List[str], str, None] = None, + sort: Optional[Literal["lastModified", "trending", "upvotes"]] = None, + limit: Optional[int] = None, + token: Optional[Union[bool, str]] = None, + ) -> Iterable[Collection]: + """List collections on the Huggingface Hub, given some filters. + + + + When listing collections, the item list per collection is truncated to 4 items maximum. To retrieve all items + from a collection, you must use [`get_collection`]. + + + + Args: + owner (`List[str]` or `str`, *optional*): + Filter by owner's username. + item (`List[str]` or `str`, *optional*): + Filter collections containing a particular items. Example: `"models/teknium/OpenHermes-2.5-Mistral-7B"`, `"datasets/squad"` or `"papers/2311.12983"`. + sort (`Literal["lastModified", "trending", "upvotes"]`, *optional*): + Sort collections by last modified, trending or upvotes. + limit (`int`, *optional*): + Maximum number of collections to be returned. + token (`bool` or `str`, *optional*): + An authentication token (see https://huggingface.co/settings/token). + + Returns: + `Iterable[Collection]`: an iterable of [`Collection`] objects. + """ + # Construct the API endpoint + path = f"{self.endpoint}/api/collections" + headers = self._build_hf_headers(token=token) + params: Dict = {} + if owner is not None: + params.update({"owner": owner}) + if item is not None: + params.update({"item": item}) + if sort is not None: + params.update({"sort": sort}) + if limit is not None: + params.update({"limit": limit}) + + # Paginate over the results until limit is reached + items = paginate(path, headers=headers, params=params) + if limit is not None: + items = islice(items, limit) # Do not iterate over all pages + + # Parse as Collection and return + for position, collection_data in enumerate(items): + yield Collection(position=position, **collection_data) + + def get_collection(self, collection_slug: str, *, token: Optional[str] = None) -> Collection: + """Gets information about a Collection on the Hub. + + Args: + collection_slug (`str`): + Slug of the collection of the Hub. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + + Returns: [`Collection`] + + Example: + + ```py + >>> from huggingface_hub import get_collection + >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") + >>> collection.title + 'Recent models' + >>> len(collection.items) + 37 + >>> collection.items[0] + CollectionItem( + item_object_id='651446103cd773a050bf64c2', + item_id='TheBloke/U-Amethyst-20B-AWQ', + item_type='model', + position=88, + note=None + ) + ``` + """ + r = get_session().get( + f"{self.endpoint}/api/collections/{collection_slug}", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(r) + return Collection(**{**r.json(), "endpoint": self.endpoint}) + + def create_collection( + self, + title: str, + *, + namespace: Optional[str] = None, + description: Optional[str] = None, + private: bool = False, + exists_ok: bool = False, + token: Optional[str] = None, + ) -> Collection: + """Create a new Collection on the Hub. + + Args: + title (`str`): + Title of the collection to create. Example: `"Recent models"`. + namespace (`str`, *optional*): + Namespace of the collection to create (username or org). Will default to the owner name. + description (`str`, *optional*): + Description of the collection to create. + private (`bool`, *optional*): + Whether the collection should be private or not. Defaults to `False` (i.e. public collection). + exists_ok (`bool`, *optional*): + If `True`, do not raise an error if collection already exists. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + + Returns: [`Collection`] + + Example: + + ```py + >>> from huggingface_hub import create_collection + >>> collection = create_collection( + ... title="ICCV 2023", + ... description="Portfolio of models, papers and demos I presented at ICCV 2023", + ... ) + >>> collection.slug + "username/iccv-2023-64f9a55bb3115b4f513ec026" + ``` + """ + if namespace is None: + namespace = self.whoami(token)["name"] + + payload = { + "title": title, + "namespace": namespace, + "private": private, + } + if description is not None: + payload["description"] = description + + r = get_session().post( + f"{self.endpoint}/api/collections", headers=self._build_hf_headers(token=token), json=payload + ) + try: + hf_raise_for_status(r) + except HTTPError as err: + if exists_ok and err.response.status_code == 409: + # Collection already exists and `exists_ok=True` + slug = r.json()["slug"] + return self.get_collection(slug, token=token) + else: + raise + return Collection(**{**r.json(), "endpoint": self.endpoint}) + + def update_collection_metadata( + self, + collection_slug: str, + *, + title: Optional[str] = None, + description: Optional[str] = None, + position: Optional[int] = None, + private: Optional[bool] = None, + theme: Optional[str] = None, + token: Optional[str] = None, + ) -> Collection: + """Update metadata of a collection on the Hub. + + All arguments are optional. Only provided metadata will be updated. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + title (`str`): + Title of the collection to update. + description (`str`, *optional*): + Description of the collection to update. + position (`int`, *optional*): + New position of the collection in the list of collections of the user. + private (`bool`, *optional*): + Whether the collection should be private or not. + theme (`str`, *optional*): + Theme of the collection on the Hub. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + + Returns: [`Collection`] + + Example: + + ```py + >>> from huggingface_hub import update_collection_metadata + >>> collection = update_collection_metadata( + ... collection_slug="username/iccv-2023-64f9a55bb3115b4f513ec026", + ... title="ICCV Oct. 2023" + ... description="Portfolio of models, datasets, papers and demos I presented at ICCV Oct. 2023", + ... private=False, + ... theme="pink", + ... ) + >>> collection.slug + "username/iccv-oct-2023-64f9a55bb3115b4f513ec026" + # ^collection slug got updated but not the trailing ID + ``` + """ + payload = { + "position": position, + "private": private, + "theme": theme, + "title": title, + "description": description, + } + r = get_session().patch( + f"{self.endpoint}/api/collections/{collection_slug}", + headers=self._build_hf_headers(token=token), + # Only send not-none values to the API + json={key: value for key, value in payload.items() if value is not None}, + ) + hf_raise_for_status(r) + return Collection(**{**r.json()["data"], "endpoint": self.endpoint}) + + def delete_collection( + self, collection_slug: str, *, missing_ok: bool = False, token: Optional[str] = None + ) -> None: + """Delete a collection on the Hub. + + Args: + collection_slug (`str`): + Slug of the collection to delete. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + missing_ok (`bool`, *optional*): + If `True`, do not raise an error if collection doesn't exists. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + + Example: + + ```py + >>> from huggingface_hub import delete_collection + >>> collection = delete_collection("username/useless-collection-64f9a55bb3115b4f513ec026", missing_ok=True) + ``` + + + + This is a non-revertible action. A deleted collection cannot be restored. + + + """ + r = get_session().delete( + f"{self.endpoint}/api/collections/{collection_slug}", headers=self._build_hf_headers(token=token) + ) + try: + hf_raise_for_status(r) + except HTTPError as err: + if missing_ok and err.response.status_code == 404: + # Collection doesn't exists and `missing_ok=True` + return + else: + raise + + def add_collection_item( + self, + collection_slug: str, + item_id: str, + item_type: CollectionItemType_T, + *, + note: Optional[str] = None, + exists_ok: bool = False, + token: Optional[str] = None, + ) -> Collection: + """Add an item to a collection on the Hub. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + item_id (`str`): + ID of the item to add to the collection. It can be the ID of a repo on the Hub (e.g. `"facebook/bart-large-mnli"`) + or a paper id (e.g. `"2307.09288"`). + item_type (`str`): + Type of the item to add. Can be one of `"model"`, `"dataset"`, `"space"` or `"paper"`. + note (`str`, *optional*): + A note to attach to the item in the collection. The maximum size for a note is 500 characters. + exists_ok (`bool`, *optional*): + If `True`, do not raise an error if item already exists. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + + Returns: [`Collection`] + + Example: + + ```py + >>> from huggingface_hub import add_collection_item + >>> collection = add_collection_item( + ... collection_slug="davanstrien/climate-64f99dc2a5067f6b65531bab", + ... item_id="pierre-loic/climate-news-articles", + ... item_type="dataset" + ... ) + >>> collection.items[-1].item_id + "pierre-loic/climate-news-articles" + # ^item got added to the collection on last position + + # Add item with a note + >>> add_collection_item( + ... collection_slug="davanstrien/climate-64f99dc2a5067f6b65531bab", + ... item_id="datasets/climate_fever", + ... item_type="dataset" + ... note="This dataset adopts the FEVER methodology that consists of 1,535 real-world claims regarding climate-change collected on the internet." + ... ) + (...) + ``` + """ + payload: Dict[str, Any] = {"item": {"id": item_id, "type": item_type}} + if note is not None: + payload["note"] = note + r = get_session().post( + f"{self.endpoint}/api/collections/{collection_slug}/items", + headers=self._build_hf_headers(token=token), + json=payload, + ) + try: + hf_raise_for_status(r) + except HTTPError as err: + if exists_ok and err.response.status_code == 409: + # Item already exists and `exists_ok=True` + return self.get_collection(collection_slug, token=token) + else: + raise + return Collection(**{**r.json(), "endpoint": self.endpoint}) + + def update_collection_item( + self, + collection_slug: str, + item_object_id: str, + *, + note: Optional[str] = None, + position: Optional[int] = None, + token: Optional[str] = None, + ) -> None: + """Update an item in a collection. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + item_object_id (`str`): + ID of the item in the collection. This is not the id of the item on the Hub (repo_id or paper id). + It must be retrieved from a [`CollectionItem`] object. Example: `collection.items[0].item_object_id`. + note (`str`, *optional*): + A note to attach to the item in the collection. The maximum size for a note is 500 characters. + position (`int`, *optional*): + New position of the item in the collection. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + + Example: + + ```py + >>> from huggingface_hub import get_collection, update_collection_item + + # Get collection first + >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") + + # Update item based on its ID (add note + update position) + >>> update_collection_item( + ... collection_slug="TheBloke/recent-models-64f9a55bb3115b4f513ec026", + ... item_object_id=collection.items[-1].item_object_id, + ... note="Newly updated model!" + ... position=0, + ... ) + ``` + """ + payload = {"position": position, "note": note} + r = get_session().patch( + f"{self.endpoint}/api/collections/{collection_slug}/items/{item_object_id}", + headers=self._build_hf_headers(token=token), + # Only send not-none values to the API + json={key: value for key, value in payload.items() if value is not None}, + ) + hf_raise_for_status(r) + + def delete_collection_item( + self, + collection_slug: str, + item_object_id: str, + *, + missing_ok: bool = False, + token: Optional[str] = None, + ) -> None: + """Delete an item from a collection. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + item_object_id (`str`): + ID of the item in the collection. This is not the id of the item on the Hub (repo_id or paper id). + It must be retrieved from a [`CollectionItem`] object. Example: `collection.items[0]._id`. + missing_ok (`bool`, *optional*): + If `True`, do not raise an error if item doesn't exists. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + + Example: + + ```py + >>> from huggingface_hub import get_collection, delete_collection_item + + # Get collection first + >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") + + # Delete item based on its ID + >>> delete_collection_item( + ... collection_slug="TheBloke/recent-models-64f9a55bb3115b4f513ec026", + ... item_object_id=collection.items[-1].item_object_id, + ... ) + ``` + """ + r = get_session().delete( + f"{self.endpoint}/api/collections/{collection_slug}/items/{item_object_id}", + headers=self._build_hf_headers(token=token), + ) + try: + hf_raise_for_status(r) + except HTTPError as err: + if missing_ok and err.response.status_code == 404: + # Item already deleted and `missing_ok=True` + return + else: + raise + + ########################## + # Manage access requests # + ########################## + + @validate_hf_hub_args + def list_pending_access_requests( + self, repo_id: str, *, repo_type: Optional[str] = None, token: Optional[str] = None + ) -> List[AccessRequest]: + """ + Get pending access requests for a given gated repo. + + A pending request means the user has requested access to the repo but the request has not been processed yet. + If the approval mode is automatic, this list should be empty. Pending requests can be accepted or rejected + using [`accept_access_request`] and [`reject_access_request`]. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to get access requests for. + repo_type (`str`, *optional*): + The type of the repo to get access requests for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (`str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + + Returns: + `List[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, + `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will + be populated with user's answers. + + Raises: + `HTTPError`: + HTTP 400 if the repo is not gated. + `HTTPError`: + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + + Example: + ```py + >>> from huggingface_hub import list_pending_access_requests, accept_access_request + + # List pending requests + >>> requests = list_pending_access_requests("meta-llama/Llama-2-7b") + >>> len(requests) + 411 + >>> requests[0] + [ + AccessRequest( + username='clem', + fullname='Clem 🤗', + email='***', + timestamp=datetime.datetime(2023, 11, 23, 18, 4, 53, 828000, tzinfo=datetime.timezone.utc), + status='pending', + fields=None, + ), + ... + ] + + # Accept Clem's request + >>> accept_access_request("meta-llama/Llama-2-7b", "clem") + ``` + """ + return self._list_access_requests(repo_id, "pending", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def list_accepted_access_requests( + self, repo_id: str, *, repo_type: Optional[str] = None, token: Optional[str] = None + ) -> List[AccessRequest]: + """ + Get accepted access requests for a given gated repo. + + An accepted request means the user has requested access to the repo and the request has been accepted. The user + can download any file of the repo. If the approval mode is automatic, this list should contains by default all + requests. Accepted requests can be cancelled or rejected at any time using [`cancel_access_request`] and + [`reject_access_request`]. A cancelled request will go back to the pending list while a rejected request will + go to the rejected list. In both cases, the user will lose access to the repo. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to get access requests for. + repo_type (`str`, *optional*): + The type of the repo to get access requests for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (`str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + + Returns: + `List[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, + `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will + be populated with user's answers. + + Raises: + `HTTPError`: + HTTP 400 if the repo is not gated. + `HTTPError`: + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + + Example: + ```py + >>> from huggingface_hub import list_accepted_access_requests + + >>> requests = list_accepted_access_requests("meta-llama/Llama-2-7b") + >>> len(requests) + 411 + >>> requests[0] + [ + AccessRequest( + username='clem', + fullname='Clem 🤗', + email='***', + timestamp=datetime.datetime(2023, 11, 23, 18, 4, 53, 828000, tzinfo=datetime.timezone.utc), + status='accepted', + fields=None, + ), + ... + ] + ``` + """ + return self._list_access_requests(repo_id, "accepted", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def list_rejected_access_requests( + self, repo_id: str, *, repo_type: Optional[str] = None, token: Optional[str] = None + ) -> List[AccessRequest]: + """ + Get rejected access requests for a given gated repo. + + A rejected request means the user has requested access to the repo and the request has been explicitly rejected + by a repo owner (either you or another user from your organization). The user cannot download any file of the + repo. Rejected requests can be accepted or cancelled at any time using [`accept_access_request`] and + [`cancel_access_request`]. A cancelled request will go back to the pending list while an accepted request will + go to the accepted list. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to get access requests for. + repo_type (`str`, *optional*): + The type of the repo to get access requests for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (`str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + + Returns: + `List[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, + `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will + be populated with user's answers. + + Raises: + `HTTPError`: + HTTP 400 if the repo is not gated. + `HTTPError`: + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + + Example: + ```py + >>> from huggingface_hub import list_rejected_access_requests + + >>> requests = list_rejected_access_requests("meta-llama/Llama-2-7b") + >>> len(requests) + 411 + >>> requests[0] + [ + AccessRequest( + username='clem', + fullname='Clem 🤗', + email='***', + timestamp=datetime.datetime(2023, 11, 23, 18, 4, 53, 828000, tzinfo=datetime.timezone.utc), + status='rejected', + fields=None, + ), + ... + ] + ``` + """ + return self._list_access_requests(repo_id, "rejected", repo_type=repo_type, token=token) + + def _list_access_requests( + self, + repo_id: str, + status: Literal["accepted", "rejected", "pending"], + repo_type: Optional[str] = None, + token: Optional[str] = None, + ) -> List[AccessRequest]: + if repo_type not in REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if repo_type is None: + repo_type = REPO_TYPE_MODEL + + response = get_session().get( + f"{ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/{status}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + return [ + AccessRequest( + username=request["user"]["user"], + fullname=request["user"]["fullname"], + email=request["user"]["email"], + status=request["status"], + timestamp=parse_datetime(request["timestamp"]), + fields=request.get("fields"), # only if custom fields in form + ) + for request in response.json() + ] + + @validate_hf_hub_args + def cancel_access_request( + self, repo_id: str, user: str, *, repo_type: Optional[str] = None, token: Optional[str] = None + ) -> None: + """ + Cancel an access request from a user for a given gated repo. + + A cancelled request will go back to the pending list and the user will lose access to the repo. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to cancel access request for. + user (`str`): + The username of the user which access request should be cancelled. + repo_type (`str`, *optional*): + The type of the repo to cancel access request for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (`str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + + Raises: + `HTTPError`: + HTTP 400 if the repo is not gated. + `HTTPError`: + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + `HTTPError`: + HTTP 404 if the user does not exist on the Hub. + `HTTPError`: + HTTP 404 if the user access request cannot be found. + `HTTPError`: + HTTP 404 if the user access request is already in the pending list. + """ + self._handle_access_request(repo_id, user, "pending", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def accept_access_request( + self, repo_id: str, user: str, *, repo_type: Optional[str] = None, token: Optional[str] = None + ) -> None: + """ + Accept an access request from a user for a given gated repo. + + Once the request is accepted, the user will be able to download any file of the repo and access the community + tab. If the approval mode is automatic, you don't have to accept requests manually. An accepted request can be + cancelled or rejected at any time using [`cancel_access_request`] and [`reject_access_request`]. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to accept access request for. + user (`str`): + The username of the user which access request should be accepted. + repo_type (`str`, *optional*): + The type of the repo to accept access request for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (`str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + + Raises: + `HTTPError`: + HTTP 400 if the repo is not gated. + `HTTPError`: + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + `HTTPError`: + HTTP 404 if the user does not exist on the Hub. + `HTTPError`: + HTTP 404 if the user access request cannot be found. + `HTTPError`: + HTTP 404 if the user access request is already in the accepted list. + """ + self._handle_access_request(repo_id, user, "accepted", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def reject_access_request( + self, repo_id: str, user: str, *, repo_type: Optional[str] = None, token: Optional[str] = None + ) -> None: + """ + Reject an access request from a user for a given gated repo. + + A rejected request will go to the rejected list. The user cannot download any file of the repo. Rejected + requests can be accepted or cancelled at any time using [`accept_access_request`] and [`cancel_access_request`]. + A cancelled request will go back to the pending list while an accepted request will go to the accepted list. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to reject access request for. + user (`str`): + The username of the user which access request should be rejected. + repo_type (`str`, *optional*): + The type of the repo to reject access request for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (`str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + + Raises: + `HTTPError`: + HTTP 400 if the repo is not gated. + `HTTPError`: + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + `HTTPError`: + HTTP 404 if the user does not exist on the Hub. + `HTTPError`: + HTTP 404 if the user access request cannot be found. + `HTTPError`: + HTTP 404 if the user access request is already in the rejected list. + """ + self._handle_access_request(repo_id, user, "rejected", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def _handle_access_request( + self, + repo_id: str, + user: str, + status: Literal["accepted", "rejected", "pending"], + repo_type: Optional[str] = None, + token: Optional[str] = None, + ) -> None: + if repo_type not in REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if repo_type is None: + repo_type = REPO_TYPE_MODEL + + response = get_session().post( + f"{ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/handle", + headers=self._build_hf_headers(token=token), + json={"user": user, "status": status}, + ) + hf_raise_for_status(response) + + @validate_hf_hub_args + def grant_access( + self, repo_id: str, user: str, *, repo_type: Optional[str] = None, token: Optional[str] = None + ) -> None: + """ + Grant access to a user for a given gated repo. + + Granting access don't require for the user to send an access request by themselves. The user is automatically + added to the accepted list meaning they can download the files You can revoke the granted access at any time + using [`cancel_access_request`] or [`reject_access_request`]. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to grant access to. + user (`str`): + The username of the user to grant access. + repo_type (`str`, *optional*): + The type of the repo to grant access to. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (`str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + + Raises: + `HTTPError`: + HTTP 400 if the repo is not gated. + `HTTPError`: + HTTP 400 if the user already has access to the repo. + `HTTPError`: + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + `HTTPError`: + HTTP 404 if the user does not exist on the Hub. + """ + if repo_type not in REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if repo_type is None: + repo_type = REPO_TYPE_MODEL + + response = get_session().post( + f"{ENDPOINT}/api/models/{repo_id}/user-access-request/grant", + headers=self._build_hf_headers(token=token), + json={"user": user}, + ) + hf_raise_for_status(response) + return response.json() + + ############# + # Internals # + ############# + + def _build_hf_headers( + self, + token: Optional[Union[bool, str]] = None, + is_write_action: bool = False, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, + ) -> Dict[str, str]: + """ + Alias for [`build_hf_headers`] that uses the token from [`HfApi`] client + when `token` is not provided. + """ + if token is None: + # Cannot do `token = token or self.token` as token can be `False`. + token = self.token + return build_hf_headers( + token=token, + is_write_action=is_write_action, + library_name=library_name or self.library_name, + library_version=library_version or self.library_version, + user_agent=user_agent or self.user_agent, + ) + + def _prepare_upload_folder_deletions( + self, + repo_id: str, + repo_type: Optional[str], + revision: Optional[str], + token: Optional[str], + path_in_repo: str, + delete_patterns: Optional[Union[List[str], str]], + ) -> List[CommitOperationDelete]: + """Generate the list of Delete operations for a commit to delete files from a repo. + + List remote files and match them against the `delete_patterns` constraints. Returns a list of [`CommitOperationDelete`] + with the matching items. + + Note: `.gitattributes` file is essential to make a repo work properly on the Hub. This file will always be + kept even if it matches the `delete_patterns` constraints. + """ + if delete_patterns is None: + # If no delete patterns, no need to list and filter remote files + return [] + + # List remote files + filenames = self.list_repo_files(repo_id=repo_id, revision=revision, repo_type=repo_type, token=token) + + # Compute relative path in repo + if path_in_repo: + path_in_repo = path_in_repo.strip("/") + "/" # harmonize + relpath_to_abspath = { + file[len(path_in_repo) :]: file for file in filenames if file.startswith(path_in_repo) + } + else: + relpath_to_abspath = {file: file for file in filenames} + + # Apply filter on relative paths and return + return [ + CommitOperationDelete(path_in_repo=relpath_to_abspath[relpath], is_folder=False) + for relpath in filter_repo_objects(relpath_to_abspath.keys(), allow_patterns=delete_patterns) + if relpath_to_abspath[relpath] != ".gitattributes" + ] + + +def _prepare_upload_folder_additions( + folder_path: Union[str, Path], + path_in_repo: str, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, +) -> List[CommitOperationAdd]: + """Generate the list of Add operations for a commit to upload a folder. + + Files not matching the `allow_patterns` (allowlist) and `ignore_patterns` (denylist) + constraints are discarded. + """ + folder_path = Path(folder_path).expanduser().resolve() + if not folder_path.is_dir(): + raise ValueError(f"Provided path: '{folder_path}' is not a directory") + + # List files from folder + relpath_to_abspath = { + path.relative_to(folder_path).as_posix(): path + for path in sorted(folder_path.glob("**/*")) # sorted to be deterministic + if path.is_file() + } + + # Filter files and return + # Patterns are applied on the path relative to `folder_path`. `path_in_repo` is prefixed after the filtering. + prefix = f"{path_in_repo.strip('/')}/" if path_in_repo else "" + return [ + CommitOperationAdd( + path_or_fileobj=relpath_to_abspath[relpath], # absolute path on disk + path_in_repo=prefix + relpath, # "absolute" path in repo + ) + for relpath in filter_repo_objects( + relpath_to_abspath.keys(), allow_patterns=allow_patterns, ignore_patterns=ignore_patterns + ) + ] + + +def _parse_revision_from_pr_url(pr_url: str) -> str: + """Safely parse revision number from a PR url. + + Example: + ```py + >>> _parse_revision_from_pr_url("https://huggingface.co/bigscience/bloom/discussions/2") + "refs/pr/2" + ``` + """ + re_match = re.match(_REGEX_DISCUSSION_URL, pr_url) + if re_match is None: + raise RuntimeError(f"Unexpected response from the hub, expected a Pull Request URL but got: '{pr_url}'") + return f"refs/pr/{re_match[1]}" + + +api = HfApi() + +whoami = api.whoami +get_token_permission = api.get_token_permission + +list_models = api.list_models +model_info = api.model_info + +list_datasets = api.list_datasets +dataset_info = api.dataset_info + +list_spaces = api.list_spaces +space_info = api.space_info + +repo_exists = api.repo_exists +file_exists = api.file_exists +repo_info = api.repo_info +list_repo_files = api.list_repo_files +list_repo_refs = api.list_repo_refs +list_repo_commits = api.list_repo_commits +list_files_info = api.list_files_info +list_repo_tree = api.list_repo_tree +get_paths_info = api.get_paths_info + +list_metrics = api.list_metrics + +get_model_tags = api.get_model_tags +get_dataset_tags = api.get_dataset_tags + +create_commit = api.create_commit +create_repo = api.create_repo +delete_repo = api.delete_repo +update_repo_visibility = api.update_repo_visibility +super_squash_history = api.super_squash_history +move_repo = api.move_repo +upload_file = api.upload_file +upload_folder = api.upload_folder +delete_file = api.delete_file +delete_folder = api.delete_folder +create_commits_on_pr = api.create_commits_on_pr +preupload_lfs_files = api.preupload_lfs_files +create_branch = api.create_branch +delete_branch = api.delete_branch +create_tag = api.create_tag +delete_tag = api.delete_tag +get_full_repo_name = api.get_full_repo_name + +# Safetensors helpers +get_safetensors_metadata = api.get_safetensors_metadata +parse_safetensors_file_metadata = api.parse_safetensors_file_metadata + +# Background jobs +run_as_future = api.run_as_future + +# Activity API +list_liked_repos = api.list_liked_repos +list_repo_likers = api.list_repo_likers +like = api.like +unlike = api.unlike + +# Community API +get_discussion_details = api.get_discussion_details +get_repo_discussions = api.get_repo_discussions +create_discussion = api.create_discussion +create_pull_request = api.create_pull_request +change_discussion_status = api.change_discussion_status +comment_discussion = api.comment_discussion +edit_discussion_comment = api.edit_discussion_comment +rename_discussion = api.rename_discussion +merge_pull_request = api.merge_pull_request + +# Space API +add_space_secret = api.add_space_secret +delete_space_secret = api.delete_space_secret +get_space_variables = api.get_space_variables +add_space_variable = api.add_space_variable +delete_space_variable = api.delete_space_variable +get_space_runtime = api.get_space_runtime +request_space_hardware = api.request_space_hardware +set_space_sleep_time = api.set_space_sleep_time +pause_space = api.pause_space +restart_space = api.restart_space +duplicate_space = api.duplicate_space +request_space_storage = api.request_space_storage +delete_space_storage = api.delete_space_storage + +# Inference Endpoint API +list_inference_endpoints = api.list_inference_endpoints +create_inference_endpoint = api.create_inference_endpoint +get_inference_endpoint = api.get_inference_endpoint +update_inference_endpoint = api.update_inference_endpoint +delete_inference_endpoint = api.delete_inference_endpoint +pause_inference_endpoint = api.pause_inference_endpoint +resume_inference_endpoint = api.resume_inference_endpoint +scale_to_zero_inference_endpoint = api.scale_to_zero_inference_endpoint + +# Collections API +get_collection = api.get_collection +list_collections = api.list_collections +create_collection = api.create_collection +update_collection_metadata = api.update_collection_metadata +delete_collection = api.delete_collection +add_collection_item = api.add_collection_item +update_collection_item = api.update_collection_item +delete_collection_item = api.delete_collection_item +delete_collection_item = api.delete_collection_item + +# Access requests API +list_pending_access_requests = api.list_pending_access_requests +list_accepted_access_requests = api.list_accepted_access_requests +list_rejected_access_requests = api.list_rejected_access_requests +cancel_access_request = api.cancel_access_request +accept_access_request = api.accept_access_request +reject_access_request = api.reject_access_request +grant_access = api.grant_access