72cce7d5913aad64af86c29560a31d664b4ae723604106669ab386e118a0be60
Browse files- lib/python3.11/site-packages/huggingface_hub/_tensorboard_logger.py +168 -0
- lib/python3.11/site-packages/huggingface_hub/_webhooks_payload.py +115 -0
- lib/python3.11/site-packages/huggingface_hub/_webhooks_server.py +379 -0
- lib/python3.11/site-packages/huggingface_hub/commands/__init__.py +27 -0
- lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/__init__.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/_cli_utils.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/delete_cache.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/download.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/env.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/huggingface_cli.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/lfs.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/scan_cache.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/upload.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/user.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/commands/_cli_utils.py +63 -0
- lib/python3.11/site-packages/huggingface_hub/commands/delete_cache.py +427 -0
- lib/python3.11/site-packages/huggingface_hub/commands/download.py +214 -0
- lib/python3.11/site-packages/huggingface_hub/commands/env.py +35 -0
- lib/python3.11/site-packages/huggingface_hub/commands/huggingface_cli.py +53 -0
- lib/python3.11/site-packages/huggingface_hub/commands/lfs.py +199 -0
- lib/python3.11/site-packages/huggingface_hub/commands/scan_cache.py +138 -0
- lib/python3.11/site-packages/huggingface_hub/commands/upload.py +297 -0
- lib/python3.11/site-packages/huggingface_hub/commands/user.py +188 -0
- lib/python3.11/site-packages/huggingface_hub/community.py +354 -0
- lib/python3.11/site-packages/huggingface_hub/constants.py +213 -0
- lib/python3.11/site-packages/huggingface_hub/fastai_utils.py +425 -0
- lib/python3.11/site-packages/huggingface_hub/file_download.py +1727 -0
- lib/python3.11/site-packages/huggingface_hub/hf_api.py +0 -0
- lib/python3.11/site-packages/huggingface_hub/hf_file_system.py +670 -0
- lib/python3.11/site-packages/huggingface_hub/hub_mixin.py +368 -0
- lib/python3.11/site-packages/huggingface_hub/inference/__init__.py +0 -0
- lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/__init__.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_client.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_common.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_text_generation.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_types.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/inference/_client.py +1990 -0
- lib/python3.11/site-packages/huggingface_hub/inference/_common.py +327 -0
- lib/python3.11/site-packages/huggingface_hub/inference/_generated/__init__.py +0 -0
- lib/python3.11/site-packages/huggingface_hub/inference/_generated/__pycache__/__init__.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/inference/_generated/__pycache__/_async_client.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/inference/_generated/_async_client.py +2020 -0
- lib/python3.11/site-packages/huggingface_hub/inference/_text_generation.py +546 -0
- lib/python3.11/site-packages/huggingface_hub/inference/_types.py +183 -0
- lib/python3.11/site-packages/huggingface_hub/inference_api.py +217 -0
- lib/python3.11/site-packages/huggingface_hub/keras_mixin.py +480 -0
- lib/python3.11/site-packages/huggingface_hub/lfs.py +522 -0
- lib/python3.11/site-packages/huggingface_hub/repocard.py +818 -0
- lib/python3.11/site-packages/huggingface_hub/repocard_data.py +711 -0
- lib/python3.11/site-packages/huggingface_hub/repository.py +1476 -0
lib/python3.11/site-packages/huggingface_hub/_tensorboard_logger.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Contains a logger to push training logs to the Hub, using Tensorboard."""
|
15 |
+
from pathlib import Path
|
16 |
+
from typing import TYPE_CHECKING, List, Optional, Union
|
17 |
+
|
18 |
+
from huggingface_hub._commit_scheduler import CommitScheduler
|
19 |
+
|
20 |
+
from .utils import experimental, is_tensorboard_available
|
21 |
+
|
22 |
+
|
23 |
+
if is_tensorboard_available():
|
24 |
+
from tensorboardX import SummaryWriter
|
25 |
+
|
26 |
+
# TODO: clarify: should we import from torch.utils.tensorboard ?
|
27 |
+
|
28 |
+
else:
|
29 |
+
SummaryWriter = object # Dummy class to avoid failing at import. Will raise on instance creation.
|
30 |
+
|
31 |
+
if TYPE_CHECKING:
|
32 |
+
from tensorboardX import SummaryWriter
|
33 |
+
|
34 |
+
|
35 |
+
class HFSummaryWriter(SummaryWriter):
|
36 |
+
"""
|
37 |
+
Wrapper around the tensorboard's `SummaryWriter` to push training logs to the Hub.
|
38 |
+
|
39 |
+
Data is logged locally and then pushed to the Hub asynchronously. Pushing data to the Hub is done in a separate
|
40 |
+
thread to avoid blocking the training script. In particular, if the upload fails for any reason (e.g. a connection
|
41 |
+
issue), the main script will not be interrupted. Data is automatically pushed to the Hub every `commit_every`
|
42 |
+
minutes (default to every 5 minutes).
|
43 |
+
|
44 |
+
<Tip warning={true}>
|
45 |
+
|
46 |
+
`HFSummaryWriter` is experimental. Its API is subject to change in the future without prior notice.
|
47 |
+
|
48 |
+
</Tip>
|
49 |
+
|
50 |
+
Args:
|
51 |
+
repo_id (`str`):
|
52 |
+
The id of the repo to which the logs will be pushed.
|
53 |
+
logdir (`str`, *optional*):
|
54 |
+
The directory where the logs will be written. If not specified, a local directory will be created by the
|
55 |
+
underlying `SummaryWriter` object.
|
56 |
+
commit_every (`int` or `float`, *optional*):
|
57 |
+
The frequency (in minutes) at which the logs will be pushed to the Hub. Defaults to 5 minutes.
|
58 |
+
squash_history (`bool`, *optional*):
|
59 |
+
Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
|
60 |
+
useful to avoid degraded performances on the repo when it grows too large.
|
61 |
+
repo_type (`str`, *optional*):
|
62 |
+
The type of the repo to which the logs will be pushed. Defaults to "model".
|
63 |
+
repo_revision (`str`, *optional*):
|
64 |
+
The revision of the repo to which the logs will be pushed. Defaults to "main".
|
65 |
+
repo_private (`bool`, *optional*):
|
66 |
+
Whether to create a private repo or not. Defaults to False. This argument is ignored if the repo already
|
67 |
+
exists.
|
68 |
+
path_in_repo (`str`, *optional*):
|
69 |
+
The path to the folder in the repo where the logs will be pushed. Defaults to "tensorboard/".
|
70 |
+
repo_allow_patterns (`List[str]` or `str`, *optional*):
|
71 |
+
A list of patterns to include in the upload. Defaults to `"*.tfevents.*"`. Check out the
|
72 |
+
[upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details.
|
73 |
+
repo_ignore_patterns (`List[str]` or `str`, *optional*):
|
74 |
+
A list of patterns to exclude in the upload. Check out the
|
75 |
+
[upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details.
|
76 |
+
token (`str`, *optional*):
|
77 |
+
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more
|
78 |
+
details
|
79 |
+
kwargs:
|
80 |
+
Additional keyword arguments passed to `SummaryWriter`.
|
81 |
+
|
82 |
+
Examples:
|
83 |
+
```py
|
84 |
+
>>> from huggingface_hub import HFSummaryWriter
|
85 |
+
|
86 |
+
# Logs are automatically pushed every 15 minutes
|
87 |
+
>>> logger = HFSummaryWriter(repo_id="test_hf_logger", commit_every=15)
|
88 |
+
>>> logger.add_scalar("a", 1)
|
89 |
+
>>> logger.add_scalar("b", 2)
|
90 |
+
...
|
91 |
+
|
92 |
+
# You can also trigger a push manually
|
93 |
+
>>> logger.scheduler.trigger()
|
94 |
+
```
|
95 |
+
|
96 |
+
```py
|
97 |
+
>>> from huggingface_hub import HFSummaryWriter
|
98 |
+
|
99 |
+
# Logs are automatically pushed every 5 minutes (default) + when exiting the context manager
|
100 |
+
>>> with HFSummaryWriter(repo_id="test_hf_logger") as logger:
|
101 |
+
... logger.add_scalar("a", 1)
|
102 |
+
... logger.add_scalar("b", 2)
|
103 |
+
```
|
104 |
+
"""
|
105 |
+
|
106 |
+
@experimental
|
107 |
+
def __new__(cls, *args, **kwargs) -> "HFSummaryWriter":
|
108 |
+
if not is_tensorboard_available():
|
109 |
+
raise ImportError(
|
110 |
+
"You must have `tensorboard` installed to use `HFSummaryWriter`. Please run `pip install --upgrade"
|
111 |
+
" tensorboardX` first."
|
112 |
+
)
|
113 |
+
return super().__new__(cls)
|
114 |
+
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
repo_id: str,
|
118 |
+
*,
|
119 |
+
logdir: Optional[str] = None,
|
120 |
+
commit_every: Union[int, float] = 5,
|
121 |
+
squash_history: bool = False,
|
122 |
+
repo_type: Optional[str] = None,
|
123 |
+
repo_revision: Optional[str] = None,
|
124 |
+
repo_private: bool = False,
|
125 |
+
path_in_repo: Optional[str] = "tensorboard",
|
126 |
+
repo_allow_patterns: Optional[Union[List[str], str]] = "*.tfevents.*",
|
127 |
+
repo_ignore_patterns: Optional[Union[List[str], str]] = None,
|
128 |
+
token: Optional[str] = None,
|
129 |
+
**kwargs,
|
130 |
+
):
|
131 |
+
# Initialize SummaryWriter
|
132 |
+
super().__init__(logdir=logdir, **kwargs)
|
133 |
+
|
134 |
+
# Check logdir has been correctly initialized and fail early otherwise. In practice, SummaryWriter takes care of it.
|
135 |
+
if not isinstance(self.logdir, str):
|
136 |
+
raise ValueError(f"`self.logdir` must be a string. Got '{self.logdir}' of type {type(self.logdir)}.")
|
137 |
+
|
138 |
+
# Append logdir name to `path_in_repo`
|
139 |
+
if path_in_repo is None or path_in_repo == "":
|
140 |
+
path_in_repo = Path(self.logdir).name
|
141 |
+
else:
|
142 |
+
path_in_repo = path_in_repo.strip("/") + "/" + Path(self.logdir).name
|
143 |
+
|
144 |
+
# Initialize scheduler
|
145 |
+
self.scheduler = CommitScheduler(
|
146 |
+
folder_path=self.logdir,
|
147 |
+
path_in_repo=path_in_repo,
|
148 |
+
repo_id=repo_id,
|
149 |
+
repo_type=repo_type,
|
150 |
+
revision=repo_revision,
|
151 |
+
private=repo_private,
|
152 |
+
token=token,
|
153 |
+
allow_patterns=repo_allow_patterns,
|
154 |
+
ignore_patterns=repo_ignore_patterns,
|
155 |
+
every=commit_every,
|
156 |
+
squash_history=squash_history,
|
157 |
+
)
|
158 |
+
|
159 |
+
# Exposing some high-level info at root level
|
160 |
+
self.repo_id = self.scheduler.repo_id
|
161 |
+
self.repo_type = self.scheduler.repo_type
|
162 |
+
self.repo_revision = self.scheduler.revision
|
163 |
+
|
164 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
165 |
+
"""Push to hub in a non-blocking way when exiting the logger's context manager."""
|
166 |
+
super().__exit__(exc_type, exc_val, exc_tb)
|
167 |
+
future = self.scheduler.trigger()
|
168 |
+
future.result()
|
lib/python3.11/site-packages/huggingface_hub/_webhooks_payload.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023-present, the HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Contains data structures to parse the webhooks payload."""
|
16 |
+
from typing import List, Literal, Optional
|
17 |
+
|
18 |
+
from pydantic import BaseModel
|
19 |
+
|
20 |
+
|
21 |
+
# This is an adaptation of the ReportV3 interface implemented in moon-landing. V0, V1 and V2 have been ignored as they
|
22 |
+
# are not in used anymore. To keep in sync when format is updated in
|
23 |
+
# https://github.com/huggingface/moon-landing/blob/main/server/lib/HFWebhooks.ts (internal link).
|
24 |
+
|
25 |
+
|
26 |
+
WebhookEvent_T = Literal[
|
27 |
+
"create",
|
28 |
+
"delete",
|
29 |
+
"move",
|
30 |
+
"update",
|
31 |
+
]
|
32 |
+
RepoChangeEvent_T = Literal[
|
33 |
+
"add",
|
34 |
+
"move",
|
35 |
+
"remove",
|
36 |
+
"update",
|
37 |
+
]
|
38 |
+
RepoType_T = Literal[
|
39 |
+
"dataset",
|
40 |
+
"model",
|
41 |
+
"space",
|
42 |
+
]
|
43 |
+
DiscussionStatus_T = Literal[
|
44 |
+
"closed",
|
45 |
+
"draft",
|
46 |
+
"open",
|
47 |
+
"merged",
|
48 |
+
]
|
49 |
+
SupportedWebhookVersion = Literal[3]
|
50 |
+
|
51 |
+
|
52 |
+
class ObjectId(BaseModel):
|
53 |
+
id: str
|
54 |
+
|
55 |
+
|
56 |
+
class WebhookPayloadUrl(BaseModel):
|
57 |
+
web: str
|
58 |
+
api: Optional[str] = None
|
59 |
+
|
60 |
+
|
61 |
+
class WebhookPayloadMovedTo(BaseModel):
|
62 |
+
name: str
|
63 |
+
owner: ObjectId
|
64 |
+
|
65 |
+
|
66 |
+
class WebhookPayloadWebhook(ObjectId):
|
67 |
+
version: SupportedWebhookVersion
|
68 |
+
|
69 |
+
|
70 |
+
class WebhookPayloadEvent(BaseModel):
|
71 |
+
action: WebhookEvent_T
|
72 |
+
scope: str
|
73 |
+
|
74 |
+
|
75 |
+
class WebhookPayloadDiscussionChanges(BaseModel):
|
76 |
+
base: str
|
77 |
+
mergeCommitId: Optional[str] = None
|
78 |
+
|
79 |
+
|
80 |
+
class WebhookPayloadComment(ObjectId):
|
81 |
+
author: ObjectId
|
82 |
+
hidden: bool
|
83 |
+
content: Optional[str] = None
|
84 |
+
url: WebhookPayloadUrl
|
85 |
+
|
86 |
+
|
87 |
+
class WebhookPayloadDiscussion(ObjectId):
|
88 |
+
num: int
|
89 |
+
author: ObjectId
|
90 |
+
url: WebhookPayloadUrl
|
91 |
+
title: str
|
92 |
+
isPullRequest: bool
|
93 |
+
status: DiscussionStatus_T
|
94 |
+
changes: Optional[WebhookPayloadDiscussionChanges] = None
|
95 |
+
pinned: Optional[bool] = None
|
96 |
+
|
97 |
+
|
98 |
+
class WebhookPayloadRepo(ObjectId):
|
99 |
+
owner: ObjectId
|
100 |
+
head_sha: Optional[str] = None
|
101 |
+
name: str
|
102 |
+
private: bool
|
103 |
+
subdomain: Optional[str] = None
|
104 |
+
tags: Optional[List[str]] = None
|
105 |
+
type: Literal["dataset", "model", "space"]
|
106 |
+
url: WebhookPayloadUrl
|
107 |
+
|
108 |
+
|
109 |
+
class WebhookPayload(BaseModel):
|
110 |
+
event: WebhookPayloadEvent
|
111 |
+
repo: WebhookPayloadRepo
|
112 |
+
discussion: Optional[WebhookPayloadDiscussion] = None
|
113 |
+
comment: Optional[WebhookPayloadComment] = None
|
114 |
+
webhook: WebhookPayloadWebhook
|
115 |
+
movedTo: Optional[WebhookPayloadMovedTo] = None
|
lib/python3.11/site-packages/huggingface_hub/_webhooks_server.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023-present, the HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Contains `WebhooksServer` and `webhook_endpoint` to create a webhook server easily."""
|
16 |
+
import atexit
|
17 |
+
import inspect
|
18 |
+
import os
|
19 |
+
from functools import wraps
|
20 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
21 |
+
|
22 |
+
from .utils import experimental, is_gradio_available
|
23 |
+
from .utils._deprecation import _deprecate_method
|
24 |
+
|
25 |
+
|
26 |
+
if TYPE_CHECKING:
|
27 |
+
import gradio as gr
|
28 |
+
|
29 |
+
|
30 |
+
from fastapi import FastAPI, Request
|
31 |
+
from fastapi.responses import JSONResponse
|
32 |
+
|
33 |
+
|
34 |
+
_global_app: Optional["WebhooksServer"] = None
|
35 |
+
_is_local = os.getenv("SYSTEM") != "spaces"
|
36 |
+
|
37 |
+
|
38 |
+
@experimental
|
39 |
+
class WebhooksServer:
|
40 |
+
"""
|
41 |
+
The [`WebhooksServer`] class lets you create an instance of a Gradio app that can receive Huggingface webhooks.
|
42 |
+
These webhooks can be registered using the [`~WebhooksServer.add_webhook`] decorator. Webhook endpoints are added to
|
43 |
+
the app as a POST endpoint to the FastAPI router. Once all the webhooks are registered, the `run` method has to be
|
44 |
+
called to start the app.
|
45 |
+
|
46 |
+
It is recommended to accept [`WebhookPayload`] as the first argument of the webhook function. It is a Pydantic
|
47 |
+
model that contains all the information about the webhook event. The data will be parsed automatically for you.
|
48 |
+
|
49 |
+
Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your
|
50 |
+
WebhooksServer and deploy it on a Space.
|
51 |
+
|
52 |
+
<Tip warning={true}>
|
53 |
+
|
54 |
+
`WebhooksServer` is experimental. Its API is subject to change in the future.
|
55 |
+
|
56 |
+
</Tip>
|
57 |
+
|
58 |
+
<Tip warning={true}>
|
59 |
+
|
60 |
+
You must have `gradio` installed to use `WebhooksServer` (`pip install --upgrade gradio`).
|
61 |
+
|
62 |
+
</Tip>
|
63 |
+
|
64 |
+
Args:
|
65 |
+
ui (`gradio.Blocks`, optional):
|
66 |
+
A Gradio UI instance to be used as the Space landing page. If `None`, a UI displaying instructions
|
67 |
+
about the configured webhooks is created.
|
68 |
+
webhook_secret (`str`, optional):
|
69 |
+
A secret key to verify incoming webhook requests. You can set this value to any secret you want as long as
|
70 |
+
you also configure it in your [webhooks settings panel](https://huggingface.co/settings/webhooks). You
|
71 |
+
can also set this value as the `WEBHOOK_SECRET` environment variable. If no secret is provided, the
|
72 |
+
webhook endpoints are opened without any security.
|
73 |
+
|
74 |
+
Example:
|
75 |
+
|
76 |
+
```python
|
77 |
+
import gradio as gr
|
78 |
+
from huggingface_hub import WebhooksServer, WebhookPayload
|
79 |
+
|
80 |
+
with gr.Blocks() as ui:
|
81 |
+
...
|
82 |
+
|
83 |
+
app = WebhooksServer(ui=ui, webhook_secret="my_secret_key")
|
84 |
+
|
85 |
+
@app.add_webhook("/say_hello")
|
86 |
+
async def hello(payload: WebhookPayload):
|
87 |
+
return {"message": "hello"}
|
88 |
+
|
89 |
+
app.run()
|
90 |
+
```
|
91 |
+
"""
|
92 |
+
|
93 |
+
def __new__(cls, *args, **kwargs) -> "WebhooksServer":
|
94 |
+
if not is_gradio_available():
|
95 |
+
raise ImportError(
|
96 |
+
"You must have `gradio` installed to use `WebhooksServer`. Please run `pip install --upgrade gradio`"
|
97 |
+
" first."
|
98 |
+
)
|
99 |
+
return super().__new__(cls)
|
100 |
+
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
ui: Optional["gr.Blocks"] = None,
|
104 |
+
webhook_secret: Optional[str] = None,
|
105 |
+
) -> None:
|
106 |
+
self._ui = ui
|
107 |
+
|
108 |
+
self.webhook_secret = webhook_secret or os.getenv("WEBHOOK_SECRET")
|
109 |
+
self.registered_webhooks: Dict[str, Callable] = {}
|
110 |
+
_warn_on_empty_secret(self.webhook_secret)
|
111 |
+
|
112 |
+
def add_webhook(self, path: Optional[str] = None) -> Callable:
|
113 |
+
"""
|
114 |
+
Decorator to add a webhook to the [`WebhooksServer`] server.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
path (`str`, optional):
|
118 |
+
The URL path to register the webhook function. If not provided, the function name will be used as the
|
119 |
+
path. In any case, all webhooks are registered under `/webhooks`.
|
120 |
+
|
121 |
+
Raises:
|
122 |
+
ValueError: If the provided path is already registered as a webhook.
|
123 |
+
|
124 |
+
Example:
|
125 |
+
```python
|
126 |
+
from huggingface_hub import WebhooksServer, WebhookPayload
|
127 |
+
|
128 |
+
app = WebhooksServer()
|
129 |
+
|
130 |
+
@app.add_webhook
|
131 |
+
async def trigger_training(payload: WebhookPayload):
|
132 |
+
if payload.repo.type == "dataset" and payload.event.action == "update":
|
133 |
+
# Trigger a training job if a dataset is updated
|
134 |
+
...
|
135 |
+
|
136 |
+
app.run()
|
137 |
+
```
|
138 |
+
"""
|
139 |
+
# Usage: directly as decorator. Example: `@app.add_webhook`
|
140 |
+
if callable(path):
|
141 |
+
# If path is a function, it means it was used as a decorator without arguments
|
142 |
+
return self.add_webhook()(path)
|
143 |
+
|
144 |
+
# Usage: provide a path. Example: `@app.add_webhook(...)`
|
145 |
+
@wraps(FastAPI.post)
|
146 |
+
def _inner_post(*args, **kwargs):
|
147 |
+
func = args[0]
|
148 |
+
abs_path = f"/webhooks/{(path or func.__name__).strip('/')}"
|
149 |
+
if abs_path in self.registered_webhooks:
|
150 |
+
raise ValueError(f"Webhook {abs_path} already exists.")
|
151 |
+
self.registered_webhooks[abs_path] = func
|
152 |
+
|
153 |
+
return _inner_post
|
154 |
+
|
155 |
+
def launch(self, prevent_thread_lock: bool = False, **launch_kwargs: Any) -> None:
|
156 |
+
"""Launch the Gradio app and register webhooks to the underlying FastAPI server.
|
157 |
+
|
158 |
+
Input parameters are forwarded to Gradio when launching the app.
|
159 |
+
"""
|
160 |
+
ui = self._ui or self._get_default_ui()
|
161 |
+
|
162 |
+
# Start Gradio App
|
163 |
+
# - as non-blocking so that webhooks can be added afterwards
|
164 |
+
# - as shared if launch locally (to debug webhooks)
|
165 |
+
launch_kwargs.setdefault("share", _is_local)
|
166 |
+
self.fastapi_app, _, _ = ui.launch(prevent_thread_lock=True, **launch_kwargs)
|
167 |
+
|
168 |
+
# Register webhooks to FastAPI app
|
169 |
+
for path, func in self.registered_webhooks.items():
|
170 |
+
# Add secret check if required
|
171 |
+
if self.webhook_secret is not None:
|
172 |
+
func = _wrap_webhook_to_check_secret(func, webhook_secret=self.webhook_secret)
|
173 |
+
|
174 |
+
# Add route to FastAPI app
|
175 |
+
self.fastapi_app.post(path)(func)
|
176 |
+
|
177 |
+
# Print instructions and block main thread
|
178 |
+
url = (ui.share_url or ui.local_url).strip("/")
|
179 |
+
message = "\nWebhooks are correctly setup and ready to use:"
|
180 |
+
message += "\n" + "\n".join(f" - POST {url}{webhook}" for webhook in self.registered_webhooks)
|
181 |
+
message += "\nGo to https://huggingface.co/settings/webhooks to setup your webhooks."
|
182 |
+
print(message)
|
183 |
+
|
184 |
+
if not prevent_thread_lock:
|
185 |
+
ui.block_thread()
|
186 |
+
|
187 |
+
@_deprecate_method(version="0.23", message="Use `WebhooksServer.launch` instead.")
|
188 |
+
def run(self) -> None:
|
189 |
+
return self.launch()
|
190 |
+
|
191 |
+
def _get_default_ui(self) -> "gr.Blocks":
|
192 |
+
"""Default UI if not provided (lists webhooks and provides basic instructions)."""
|
193 |
+
import gradio as gr
|
194 |
+
|
195 |
+
with gr.Blocks() as ui:
|
196 |
+
gr.Markdown("# This is an app to process 🤗 Webhooks")
|
197 |
+
gr.Markdown(
|
198 |
+
"Webhooks are a foundation for MLOps-related features. They allow you to listen for new changes on"
|
199 |
+
" specific repos or to all repos belonging to particular set of users/organizations (not just your"
|
200 |
+
" repos, but any repo). Check out this [guide](https://huggingface.co/docs/hub/webhooks) to get to"
|
201 |
+
" know more about webhooks on the Huggingface Hub."
|
202 |
+
)
|
203 |
+
gr.Markdown(
|
204 |
+
f"{len(self.registered_webhooks)} webhook(s) are registered:"
|
205 |
+
+ "\n\n"
|
206 |
+
+ "\n ".join(
|
207 |
+
f"- [{webhook_path}]({_get_webhook_doc_url(webhook.__name__, webhook_path)})"
|
208 |
+
for webhook_path, webhook in self.registered_webhooks.items()
|
209 |
+
)
|
210 |
+
)
|
211 |
+
gr.Markdown(
|
212 |
+
"Go to https://huggingface.co/settings/webhooks to setup your webhooks."
|
213 |
+
+ "\nYou app is running locally. Please look at the logs to check the full URL you need to set."
|
214 |
+
if _is_local
|
215 |
+
else (
|
216 |
+
"\nThis app is running on a Space. You can find the corresponding URL in the options menu"
|
217 |
+
" (top-right) > 'Embed the Space'. The URL looks like 'https://{username}-{repo_name}.hf.space'."
|
218 |
+
)
|
219 |
+
)
|
220 |
+
return ui
|
221 |
+
|
222 |
+
|
223 |
+
@experimental
|
224 |
+
def webhook_endpoint(path: Optional[str] = None) -> Callable:
|
225 |
+
"""Decorator to start a [`WebhooksServer`] and register the decorated function as a webhook endpoint.
|
226 |
+
|
227 |
+
This is a helper to get started quickly. If you need more flexibility (custom landing page or webhook secret),
|
228 |
+
you can use [`WebhooksServer`] directly. You can register multiple webhook endpoints (to the same server) by using
|
229 |
+
this decorator multiple times.
|
230 |
+
|
231 |
+
Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your
|
232 |
+
server and deploy it on a Space.
|
233 |
+
|
234 |
+
<Tip warning={true}>
|
235 |
+
|
236 |
+
`webhook_endpoint` is experimental. Its API is subject to change in the future.
|
237 |
+
|
238 |
+
</Tip>
|
239 |
+
|
240 |
+
<Tip warning={true}>
|
241 |
+
|
242 |
+
You must have `gradio` installed to use `webhook_endpoint` (`pip install --upgrade gradio`).
|
243 |
+
|
244 |
+
</Tip>
|
245 |
+
|
246 |
+
Args:
|
247 |
+
path (`str`, optional):
|
248 |
+
The URL path to register the webhook function. If not provided, the function name will be used as the path.
|
249 |
+
In any case, all webhooks are registered under `/webhooks`.
|
250 |
+
|
251 |
+
Examples:
|
252 |
+
The default usage is to register a function as a webhook endpoint. The function name will be used as the path.
|
253 |
+
The server will be started automatically at exit (i.e. at the end of the script).
|
254 |
+
|
255 |
+
```python
|
256 |
+
from huggingface_hub import webhook_endpoint, WebhookPayload
|
257 |
+
|
258 |
+
@webhook_endpoint
|
259 |
+
async def trigger_training(payload: WebhookPayload):
|
260 |
+
if payload.repo.type == "dataset" and payload.event.action == "update":
|
261 |
+
# Trigger a training job if a dataset is updated
|
262 |
+
...
|
263 |
+
|
264 |
+
# Server is automatically started at the end of the script.
|
265 |
+
```
|
266 |
+
|
267 |
+
Advanced usage: register a function as a webhook endpoint and start the server manually. This is useful if you
|
268 |
+
are running it in a notebook.
|
269 |
+
|
270 |
+
```python
|
271 |
+
from huggingface_hub import webhook_endpoint, WebhookPayload
|
272 |
+
|
273 |
+
@webhook_endpoint
|
274 |
+
async def trigger_training(payload: WebhookPayload):
|
275 |
+
if payload.repo.type == "dataset" and payload.event.action == "update":
|
276 |
+
# Trigger a training job if a dataset is updated
|
277 |
+
...
|
278 |
+
|
279 |
+
# Start the server manually
|
280 |
+
trigger_training.run()
|
281 |
+
```
|
282 |
+
"""
|
283 |
+
if callable(path):
|
284 |
+
# If path is a function, it means it was used as a decorator without arguments
|
285 |
+
return webhook_endpoint()(path)
|
286 |
+
|
287 |
+
@wraps(WebhooksServer.add_webhook)
|
288 |
+
def _inner(func: Callable) -> Callable:
|
289 |
+
app = _get_global_app()
|
290 |
+
app.add_webhook(path)(func)
|
291 |
+
if len(app.registered_webhooks) == 1:
|
292 |
+
# Register `app.run` to run at exit (only once)
|
293 |
+
atexit.register(app.run)
|
294 |
+
|
295 |
+
@wraps(app.run)
|
296 |
+
def _run_now():
|
297 |
+
# Run the app directly (without waiting atexit)
|
298 |
+
atexit.unregister(app.run)
|
299 |
+
app.run()
|
300 |
+
|
301 |
+
func.run = _run_now # type: ignore
|
302 |
+
return func
|
303 |
+
|
304 |
+
return _inner
|
305 |
+
|
306 |
+
|
307 |
+
def _get_global_app() -> WebhooksServer:
|
308 |
+
global _global_app
|
309 |
+
if _global_app is None:
|
310 |
+
_global_app = WebhooksServer()
|
311 |
+
return _global_app
|
312 |
+
|
313 |
+
|
314 |
+
def _warn_on_empty_secret(webhook_secret: Optional[str]) -> None:
|
315 |
+
if webhook_secret is None:
|
316 |
+
print("Webhook secret is not defined. This means your webhook endpoints will be open to everyone.")
|
317 |
+
print(
|
318 |
+
"To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: "
|
319 |
+
"\n\t`app = WebhooksServer(webhook_secret='my_secret', ...)`"
|
320 |
+
)
|
321 |
+
print(
|
322 |
+
"For more details about webhook secrets, please refer to"
|
323 |
+
" https://huggingface.co/docs/hub/webhooks#webhook-secret."
|
324 |
+
)
|
325 |
+
else:
|
326 |
+
print("Webhook secret is correctly defined.")
|
327 |
+
|
328 |
+
|
329 |
+
def _get_webhook_doc_url(webhook_name: str, webhook_path: str) -> str:
|
330 |
+
"""Returns the anchor to a given webhook in the docs (experimental)"""
|
331 |
+
return "/docs#/default/" + webhook_name + webhook_path.replace("/", "_") + "_post"
|
332 |
+
|
333 |
+
|
334 |
+
def _wrap_webhook_to_check_secret(func: Callable, webhook_secret: str) -> Callable:
|
335 |
+
"""Wraps a webhook function to check the webhook secret before calling the function.
|
336 |
+
|
337 |
+
This is a hacky way to add the `request` parameter to the function signature. Since FastAPI based itself on route
|
338 |
+
parameters to inject the values to the function, we need to hack the function signature to retrieve the `Request`
|
339 |
+
object (and hence the headers). A far cleaner solution would be to use a middleware. However, since
|
340 |
+
`fastapi==0.90.1`, a middleware cannot be added once the app has started. And since the FastAPI app is started by
|
341 |
+
Gradio internals (and not by us), we cannot add a middleware.
|
342 |
+
|
343 |
+
This method is called only when a secret has been defined by the user. If a request is sent without the
|
344 |
+
"x-webhook-secret", the function will return a 401 error (unauthorized). If the header is sent but is incorrect,
|
345 |
+
the function will return a 403 error (forbidden).
|
346 |
+
|
347 |
+
Inspired by https://stackoverflow.com/a/33112180.
|
348 |
+
"""
|
349 |
+
initial_sig = inspect.signature(func)
|
350 |
+
|
351 |
+
@wraps(func)
|
352 |
+
async def _protected_func(request: Request, **kwargs):
|
353 |
+
request_secret = request.headers.get("x-webhook-secret")
|
354 |
+
if request_secret is None:
|
355 |
+
return JSONResponse({"error": "x-webhook-secret header not set."}, status_code=401)
|
356 |
+
if request_secret != webhook_secret:
|
357 |
+
return JSONResponse({"error": "Invalid webhook secret."}, status_code=403)
|
358 |
+
|
359 |
+
# Inject `request` in kwargs if required
|
360 |
+
if "request" in initial_sig.parameters:
|
361 |
+
kwargs["request"] = request
|
362 |
+
|
363 |
+
# Handle both sync and async routes
|
364 |
+
if inspect.iscoroutinefunction(func):
|
365 |
+
return await func(**kwargs)
|
366 |
+
else:
|
367 |
+
return func(**kwargs)
|
368 |
+
|
369 |
+
# Update signature to include request
|
370 |
+
if "request" not in initial_sig.parameters:
|
371 |
+
_protected_func.__signature__ = initial_sig.replace( # type: ignore
|
372 |
+
parameters=(
|
373 |
+
inspect.Parameter(name="request", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request),
|
374 |
+
)
|
375 |
+
+ tuple(initial_sig.parameters.values())
|
376 |
+
)
|
377 |
+
|
378 |
+
# Return protected route
|
379 |
+
return _protected_func
|
lib/python3.11/site-packages/huggingface_hub/commands/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from abc import ABC, abstractmethod
|
16 |
+
from argparse import _SubParsersAction
|
17 |
+
|
18 |
+
|
19 |
+
class BaseHuggingfaceCLICommand(ABC):
|
20 |
+
@staticmethod
|
21 |
+
@abstractmethod
|
22 |
+
def register_subcommand(parser: _SubParsersAction):
|
23 |
+
raise NotImplementedError()
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def run(self):
|
27 |
+
raise NotImplementedError()
|
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (1.16 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/_cli_utils.cpython-311.pyc
ADDED
Binary file (3.65 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/delete_cache.cpython-311.pyc
ADDED
Binary file (20 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/download.cpython-311.pyc
ADDED
Binary file (9.56 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/env.cpython-311.pyc
ADDED
Binary file (1.67 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/huggingface_cli.cpython-311.pyc
ADDED
Binary file (2.24 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/lfs.cpython-311.pyc
ADDED
Binary file (9.61 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/scan_cache.cpython-311.pyc
ADDED
Binary file (6.85 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/upload.cpython-311.pyc
ADDED
Binary file (13.8 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/user.cpython-311.pyc
ADDED
Binary file (11.5 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/commands/_cli_utils.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Contains a utility for good-looking prints."""
|
15 |
+
import os
|
16 |
+
from typing import List, Union
|
17 |
+
|
18 |
+
|
19 |
+
class ANSI:
|
20 |
+
"""
|
21 |
+
Helper for en.wikipedia.org/wiki/ANSI_escape_code
|
22 |
+
"""
|
23 |
+
|
24 |
+
_bold = "\u001b[1m"
|
25 |
+
_gray = "\u001b[90m"
|
26 |
+
_red = "\u001b[31m"
|
27 |
+
_reset = "\u001b[0m"
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def bold(cls, s: str) -> str:
|
31 |
+
return cls._format(s, cls._bold)
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def gray(cls, s: str) -> str:
|
35 |
+
return cls._format(s, cls._gray)
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def red(cls, s: str) -> str:
|
39 |
+
return cls._format(s, cls._bold + cls._red)
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def _format(cls, s: str, code: str) -> str:
|
43 |
+
if os.environ.get("NO_COLOR"):
|
44 |
+
# See https://no-color.org/
|
45 |
+
return s
|
46 |
+
return f"{code}{s}{cls._reset}"
|
47 |
+
|
48 |
+
|
49 |
+
def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
|
50 |
+
"""
|
51 |
+
Inspired by:
|
52 |
+
|
53 |
+
- stackoverflow.com/a/8356620/593036
|
54 |
+
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
|
55 |
+
"""
|
56 |
+
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
|
57 |
+
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
|
58 |
+
lines = []
|
59 |
+
lines.append(row_format.format(*headers))
|
60 |
+
lines.append(row_format.format(*["-" * w for w in col_widths]))
|
61 |
+
for row in rows:
|
62 |
+
lines.append(row_format.format(*row))
|
63 |
+
return "\n".join(lines)
|
lib/python3.11/site-packages/huggingface_hub/commands/delete_cache.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022-present, the HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Contains command to delete some revisions from the HF cache directory.
|
16 |
+
|
17 |
+
Usage:
|
18 |
+
huggingface-cli delete-cache
|
19 |
+
huggingface-cli delete-cache --disable-tui
|
20 |
+
huggingface-cli delete-cache --dir ~/.cache/huggingface/hub
|
21 |
+
|
22 |
+
NOTE:
|
23 |
+
This command is based on `InquirerPy` to build the multiselect menu in the terminal.
|
24 |
+
This dependency has to be installed with `pip install huggingface_hub[cli]`. Since
|
25 |
+
we want to avoid as much as possible cross-platform issues, I chose a library that
|
26 |
+
is built on top of `python-prompt-toolkit` which seems to be a reference in terminal
|
27 |
+
GUI (actively maintained on both Unix and Windows, 7.9k stars).
|
28 |
+
|
29 |
+
For the moment, the TUI feature is in beta.
|
30 |
+
|
31 |
+
See:
|
32 |
+
- https://github.com/kazhala/InquirerPy
|
33 |
+
- https://inquirerpy.readthedocs.io/en/latest/
|
34 |
+
- https://github.com/prompt-toolkit/python-prompt-toolkit
|
35 |
+
|
36 |
+
Other solutions could have been:
|
37 |
+
- `simple_term_menu`: would be good as well for our use case but some issues suggest
|
38 |
+
that Windows is less supported.
|
39 |
+
See: https://github.com/IngoMeyer441/simple-term-menu
|
40 |
+
- `PyInquirer`: very similar to `InquirerPy` but older and not maintained anymore.
|
41 |
+
In particular, no support of Python3.10.
|
42 |
+
See: https://github.com/CITGuru/PyInquirer
|
43 |
+
- `pick` (or `pickpack`): easy to use and flexible but built on top of Python's
|
44 |
+
standard library `curses` that is specific to Unix (not implemented on Windows).
|
45 |
+
See https://github.com/wong2/pick and https://github.com/anafvana/pickpack.
|
46 |
+
- `inquirer`: lot of traction (700 stars) but explicitly states "experimental
|
47 |
+
support of Windows". Not built on top of `python-prompt-toolkit`.
|
48 |
+
See https://github.com/magmax/python-inquirer
|
49 |
+
|
50 |
+
TODO: add support for `huggingface-cli delete-cache aaaaaa bbbbbb cccccc (...)` ?
|
51 |
+
TODO: add "--keep-last" arg to delete revisions that are not on `main` ref
|
52 |
+
TODO: add "--filter" arg to filter repositories by name ?
|
53 |
+
TODO: add "--sort" arg to sort by size ?
|
54 |
+
TODO: add "--limit" arg to limit to X repos ?
|
55 |
+
TODO: add "-y" arg for immediate deletion ?
|
56 |
+
See discussions in https://github.com/huggingface/huggingface_hub/issues/1025.
|
57 |
+
"""
|
58 |
+
import os
|
59 |
+
from argparse import Namespace, _SubParsersAction
|
60 |
+
from functools import wraps
|
61 |
+
from tempfile import mkstemp
|
62 |
+
from typing import Any, Callable, Iterable, List, Optional, Union
|
63 |
+
|
64 |
+
from ..utils import CachedRepoInfo, CachedRevisionInfo, HFCacheInfo, scan_cache_dir
|
65 |
+
from . import BaseHuggingfaceCLICommand
|
66 |
+
from ._cli_utils import ANSI
|
67 |
+
|
68 |
+
|
69 |
+
try:
|
70 |
+
from InquirerPy import inquirer
|
71 |
+
from InquirerPy.base.control import Choice
|
72 |
+
from InquirerPy.separator import Separator
|
73 |
+
|
74 |
+
_inquirer_py_available = True
|
75 |
+
except ImportError:
|
76 |
+
_inquirer_py_available = False
|
77 |
+
|
78 |
+
|
79 |
+
def require_inquirer_py(fn: Callable) -> Callable:
|
80 |
+
"""Decorator to flag methods that require `InquirerPy`."""
|
81 |
+
|
82 |
+
# TODO: refactor this + imports in a unified pattern across codebase
|
83 |
+
@wraps(fn)
|
84 |
+
def _inner(*args, **kwargs):
|
85 |
+
if not _inquirer_py_available:
|
86 |
+
raise ImportError(
|
87 |
+
"The `delete-cache` command requires extra dependencies to work with"
|
88 |
+
" the TUI.\nPlease run `pip install huggingface_hub[cli]` to install"
|
89 |
+
" them.\nOtherwise, disable TUI using the `--disable-tui` flag."
|
90 |
+
)
|
91 |
+
|
92 |
+
return fn(*args, **kwargs)
|
93 |
+
|
94 |
+
return _inner
|
95 |
+
|
96 |
+
|
97 |
+
# Possibility for the user to cancel deletion
|
98 |
+
_CANCEL_DELETION_STR = "CANCEL_DELETION"
|
99 |
+
|
100 |
+
|
101 |
+
class DeleteCacheCommand(BaseHuggingfaceCLICommand):
|
102 |
+
@staticmethod
|
103 |
+
def register_subcommand(parser: _SubParsersAction):
|
104 |
+
delete_cache_parser = parser.add_parser("delete-cache", help="Delete revisions from the cache directory.")
|
105 |
+
|
106 |
+
delete_cache_parser.add_argument(
|
107 |
+
"--dir",
|
108 |
+
type=str,
|
109 |
+
default=None,
|
110 |
+
help="cache directory (optional). Default to the default HuggingFace cache.",
|
111 |
+
)
|
112 |
+
|
113 |
+
delete_cache_parser.add_argument(
|
114 |
+
"--disable-tui",
|
115 |
+
action="store_true",
|
116 |
+
help=(
|
117 |
+
"Disable Terminal User Interface (TUI) mode. Useful if your"
|
118 |
+
" platform/terminal doesn't support the multiselect menu."
|
119 |
+
),
|
120 |
+
)
|
121 |
+
|
122 |
+
delete_cache_parser.set_defaults(func=DeleteCacheCommand)
|
123 |
+
|
124 |
+
def __init__(self, args: Namespace) -> None:
|
125 |
+
self.cache_dir: Optional[str] = args.dir
|
126 |
+
self.disable_tui: bool = args.disable_tui
|
127 |
+
|
128 |
+
def run(self):
|
129 |
+
"""Run `delete-cache` command with or without TUI."""
|
130 |
+
# Scan cache directory
|
131 |
+
hf_cache_info = scan_cache_dir(self.cache_dir)
|
132 |
+
|
133 |
+
# Manual review from the user
|
134 |
+
if self.disable_tui:
|
135 |
+
selected_hashes = _manual_review_no_tui(hf_cache_info, preselected=[])
|
136 |
+
else:
|
137 |
+
selected_hashes = _manual_review_tui(hf_cache_info, preselected=[])
|
138 |
+
|
139 |
+
# If deletion is not cancelled
|
140 |
+
if len(selected_hashes) > 0 and _CANCEL_DELETION_STR not in selected_hashes:
|
141 |
+
confirm_message = _get_expectations_str(hf_cache_info, selected_hashes) + " Confirm deletion ?"
|
142 |
+
|
143 |
+
# Confirm deletion
|
144 |
+
if self.disable_tui:
|
145 |
+
confirmed = _ask_for_confirmation_no_tui(confirm_message)
|
146 |
+
else:
|
147 |
+
confirmed = _ask_for_confirmation_tui(confirm_message)
|
148 |
+
|
149 |
+
# Deletion is confirmed
|
150 |
+
if confirmed:
|
151 |
+
strategy = hf_cache_info.delete_revisions(*selected_hashes)
|
152 |
+
print("Start deletion.")
|
153 |
+
strategy.execute()
|
154 |
+
print(
|
155 |
+
f"Done. Deleted {len(strategy.repos)} repo(s) and"
|
156 |
+
f" {len(strategy.snapshots)} revision(s) for a total of"
|
157 |
+
f" {strategy.expected_freed_size_str}."
|
158 |
+
)
|
159 |
+
return
|
160 |
+
|
161 |
+
# Deletion is cancelled
|
162 |
+
print("Deletion is cancelled. Do nothing.")
|
163 |
+
|
164 |
+
|
165 |
+
@require_inquirer_py
|
166 |
+
def _manual_review_tui(hf_cache_info: HFCacheInfo, preselected: List[str]) -> List[str]:
|
167 |
+
"""Ask the user for a manual review of the revisions to delete.
|
168 |
+
|
169 |
+
Displays a multi-select menu in the terminal (TUI).
|
170 |
+
"""
|
171 |
+
# Define multiselect list
|
172 |
+
choices = _get_tui_choices_from_scan(repos=hf_cache_info.repos, preselected=preselected)
|
173 |
+
checkbox = inquirer.checkbox(
|
174 |
+
message="Select revisions to delete:",
|
175 |
+
choices=choices, # List of revisions with some pre-selection
|
176 |
+
cycle=False, # No loop between top and bottom
|
177 |
+
height=100, # Large list if possible
|
178 |
+
# We use the instruction to display to the user the expected effect of the
|
179 |
+
# deletion.
|
180 |
+
instruction=_get_expectations_str(
|
181 |
+
hf_cache_info,
|
182 |
+
selected_hashes=[c.value for c in choices if isinstance(c, Choice) and c.enabled],
|
183 |
+
),
|
184 |
+
# We use the long instruction to should keybindings instructions to the user
|
185 |
+
long_instruction="Press <space> to select, <enter> to validate and <ctrl+c> to quit without modification.",
|
186 |
+
# Message that is displayed once the user validates its selection.
|
187 |
+
transformer=lambda result: f"{len(result)} revision(s) selected.",
|
188 |
+
)
|
189 |
+
|
190 |
+
# Add a callback to update the information line when a revision is
|
191 |
+
# selected/unselected
|
192 |
+
def _update_expectations(_) -> None:
|
193 |
+
# Hacky way to dynamically set an instruction message to the checkbox when
|
194 |
+
# a revision hash is selected/unselected.
|
195 |
+
checkbox._instruction = _get_expectations_str(
|
196 |
+
hf_cache_info,
|
197 |
+
selected_hashes=[choice["value"] for choice in checkbox.content_control.choices if choice["enabled"]],
|
198 |
+
)
|
199 |
+
|
200 |
+
checkbox.kb_func_lookup["toggle"].append({"func": _update_expectations})
|
201 |
+
|
202 |
+
# Finally display the form to the user.
|
203 |
+
try:
|
204 |
+
return checkbox.execute()
|
205 |
+
except KeyboardInterrupt:
|
206 |
+
return [] # Quit without deletion
|
207 |
+
|
208 |
+
|
209 |
+
@require_inquirer_py
|
210 |
+
def _ask_for_confirmation_tui(message: str, default: bool = True) -> bool:
|
211 |
+
"""Ask for confirmation using Inquirer."""
|
212 |
+
return inquirer.confirm(message, default=default).execute()
|
213 |
+
|
214 |
+
|
215 |
+
def _get_tui_choices_from_scan(repos: Iterable[CachedRepoInfo], preselected: List[str]) -> List:
|
216 |
+
"""Build a list of choices from the scanned repos.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
repos (*Iterable[`CachedRepoInfo`]*):
|
220 |
+
List of scanned repos on which we want to delete revisions.
|
221 |
+
preselected (*List[`str`]*):
|
222 |
+
List of revision hashes that will be preselected.
|
223 |
+
|
224 |
+
Return:
|
225 |
+
The list of choices to pass to `inquirer.checkbox`.
|
226 |
+
"""
|
227 |
+
choices: List[Union[Choice, Separator]] = []
|
228 |
+
|
229 |
+
# First choice is to cancel the deletion. If selected, nothing will be deleted,
|
230 |
+
# no matter the other selected items.
|
231 |
+
choices.append(
|
232 |
+
Choice(
|
233 |
+
_CANCEL_DELETION_STR,
|
234 |
+
name="None of the following (if selected, nothing will be deleted).",
|
235 |
+
enabled=False,
|
236 |
+
)
|
237 |
+
)
|
238 |
+
|
239 |
+
# Display a separator per repo and a Choice for each revisions of the repo
|
240 |
+
for repo in sorted(repos, key=_repo_sorting_order):
|
241 |
+
# Repo as separator
|
242 |
+
choices.append(
|
243 |
+
Separator(
|
244 |
+
f"\n{repo.repo_type.capitalize()} {repo.repo_id} ({repo.size_on_disk_str},"
|
245 |
+
f" used {repo.last_accessed_str})"
|
246 |
+
)
|
247 |
+
)
|
248 |
+
for revision in sorted(repo.revisions, key=_revision_sorting_order):
|
249 |
+
# Revision as choice
|
250 |
+
choices.append(
|
251 |
+
Choice(
|
252 |
+
revision.commit_hash,
|
253 |
+
name=(
|
254 |
+
f"{revision.commit_hash[:8]}:"
|
255 |
+
f" {', '.join(sorted(revision.refs)) or '(detached)'} #"
|
256 |
+
f" modified {revision.last_modified_str}"
|
257 |
+
),
|
258 |
+
enabled=revision.commit_hash in preselected,
|
259 |
+
)
|
260 |
+
)
|
261 |
+
|
262 |
+
# Return choices
|
263 |
+
return choices
|
264 |
+
|
265 |
+
|
266 |
+
def _manual_review_no_tui(hf_cache_info: HFCacheInfo, preselected: List[str]) -> List[str]:
|
267 |
+
"""Ask the user for a manual review of the revisions to delete.
|
268 |
+
|
269 |
+
Used when TUI is disabled. Manual review happens in a separate tmp file that the
|
270 |
+
user can manually edit.
|
271 |
+
"""
|
272 |
+
# 1. Generate temporary file with delete commands.
|
273 |
+
fd, tmp_path = mkstemp(suffix=".txt") # suffix to make it easier to find by editors
|
274 |
+
os.close(fd)
|
275 |
+
|
276 |
+
lines = []
|
277 |
+
for repo in sorted(hf_cache_info.repos, key=_repo_sorting_order):
|
278 |
+
lines.append(
|
279 |
+
f"\n# {repo.repo_type.capitalize()} {repo.repo_id} ({repo.size_on_disk_str},"
|
280 |
+
f" used {repo.last_accessed_str})"
|
281 |
+
)
|
282 |
+
for revision in sorted(repo.revisions, key=_revision_sorting_order):
|
283 |
+
lines.append(
|
284 |
+
# Deselect by prepending a '#'
|
285 |
+
f"{'' if revision.commit_hash in preselected else '#'} "
|
286 |
+
f" {revision.commit_hash} # Refs:"
|
287 |
+
# Print `refs` as comment on same line
|
288 |
+
f" {', '.join(sorted(revision.refs)) or '(detached)'} # modified"
|
289 |
+
# Print `last_modified` as comment on same line
|
290 |
+
f" {revision.last_modified_str}"
|
291 |
+
)
|
292 |
+
|
293 |
+
with open(tmp_path, "w") as f:
|
294 |
+
f.write(_MANUAL_REVIEW_NO_TUI_INSTRUCTIONS)
|
295 |
+
f.write("\n".join(lines))
|
296 |
+
|
297 |
+
# 2. Prompt instructions to user.
|
298 |
+
instructions = f"""
|
299 |
+
TUI is disabled. In order to select which revisions you want to delete, please edit
|
300 |
+
the following file using the text editor of your choice. Instructions for manual
|
301 |
+
editing are located at the beginning of the file. Edit the file, save it and confirm
|
302 |
+
to continue.
|
303 |
+
File to edit: {ANSI.bold(tmp_path)}
|
304 |
+
"""
|
305 |
+
print("\n".join(line.strip() for line in instructions.strip().split("\n")))
|
306 |
+
|
307 |
+
# 3. Wait for user confirmation.
|
308 |
+
while True:
|
309 |
+
selected_hashes = _read_manual_review_tmp_file(tmp_path)
|
310 |
+
if _ask_for_confirmation_no_tui(
|
311 |
+
_get_expectations_str(hf_cache_info, selected_hashes) + " Continue ?",
|
312 |
+
default=False,
|
313 |
+
):
|
314 |
+
break
|
315 |
+
|
316 |
+
# 4. Return selected_hashes
|
317 |
+
os.remove(tmp_path)
|
318 |
+
return selected_hashes
|
319 |
+
|
320 |
+
|
321 |
+
def _ask_for_confirmation_no_tui(message: str, default: bool = True) -> bool:
|
322 |
+
"""Ask for confirmation using pure-python."""
|
323 |
+
YES = ("y", "yes", "1")
|
324 |
+
NO = ("n", "no", "0")
|
325 |
+
DEFAULT = ""
|
326 |
+
ALL = YES + NO + (DEFAULT,)
|
327 |
+
full_message = message + (" (Y/n) " if default else " (y/N) ")
|
328 |
+
while True:
|
329 |
+
answer = input(full_message).lower()
|
330 |
+
if answer == DEFAULT:
|
331 |
+
return default
|
332 |
+
if answer in YES:
|
333 |
+
return True
|
334 |
+
if answer in NO:
|
335 |
+
return False
|
336 |
+
print(f"Invalid input. Must be one of {ALL}")
|
337 |
+
|
338 |
+
|
339 |
+
def _get_expectations_str(hf_cache_info: HFCacheInfo, selected_hashes: List[str]) -> str:
|
340 |
+
"""Format a string to display to the user how much space would be saved.
|
341 |
+
|
342 |
+
Example:
|
343 |
+
```
|
344 |
+
>>> _get_expectations_str(hf_cache_info, selected_hashes)
|
345 |
+
'7 revisions selected counting for 4.3G.'
|
346 |
+
```
|
347 |
+
"""
|
348 |
+
if _CANCEL_DELETION_STR in selected_hashes:
|
349 |
+
return "Nothing will be deleted."
|
350 |
+
strategy = hf_cache_info.delete_revisions(*selected_hashes)
|
351 |
+
return f"{len(selected_hashes)} revisions selected counting for {strategy.expected_freed_size_str}."
|
352 |
+
|
353 |
+
|
354 |
+
def _read_manual_review_tmp_file(tmp_path: str) -> List[str]:
|
355 |
+
"""Read the manually reviewed instruction file and return a list of revision hash.
|
356 |
+
|
357 |
+
Example:
|
358 |
+
```txt
|
359 |
+
# This is the tmp file content
|
360 |
+
###
|
361 |
+
|
362 |
+
# Commented out line
|
363 |
+
123456789 # revision hash
|
364 |
+
|
365 |
+
# Something else
|
366 |
+
# a_newer_hash # 2 days ago
|
367 |
+
an_older_hash # 3 days ago
|
368 |
+
```
|
369 |
+
|
370 |
+
```py
|
371 |
+
>>> _read_manual_review_tmp_file(tmp_path)
|
372 |
+
['123456789', 'an_older_hash']
|
373 |
+
```
|
374 |
+
"""
|
375 |
+
with open(tmp_path) as f:
|
376 |
+
content = f.read()
|
377 |
+
|
378 |
+
# Split lines
|
379 |
+
lines = [line.strip() for line in content.split("\n")]
|
380 |
+
|
381 |
+
# Filter commented lines
|
382 |
+
selected_lines = [line for line in lines if not line.startswith("#")]
|
383 |
+
|
384 |
+
# Select only before comment
|
385 |
+
selected_hashes = [line.split("#")[0].strip() for line in selected_lines]
|
386 |
+
|
387 |
+
# Return revision hashes
|
388 |
+
return [hash for hash in selected_hashes if len(hash) > 0]
|
389 |
+
|
390 |
+
|
391 |
+
_MANUAL_REVIEW_NO_TUI_INSTRUCTIONS = f"""
|
392 |
+
# INSTRUCTIONS
|
393 |
+
# ------------
|
394 |
+
# This is a temporary file created by running `huggingface-cli delete-cache` with the
|
395 |
+
# `--disable-tui` option. It contains a set of revisions that can be deleted from your
|
396 |
+
# local cache directory.
|
397 |
+
#
|
398 |
+
# Please manually review the revisions you want to delete:
|
399 |
+
# - Revision hashes can be commented out with '#'.
|
400 |
+
# - Only non-commented revisions in this file will be deleted.
|
401 |
+
# - Revision hashes that are removed from this file are ignored as well.
|
402 |
+
# - If `{_CANCEL_DELETION_STR}` line is uncommented, the all cache deletion is cancelled and
|
403 |
+
# no changes will be applied.
|
404 |
+
#
|
405 |
+
# Once you've manually reviewed this file, please confirm deletion in the terminal. This
|
406 |
+
# file will be automatically removed once done.
|
407 |
+
# ------------
|
408 |
+
|
409 |
+
# KILL SWITCH
|
410 |
+
# ------------
|
411 |
+
# Un-comment following line to completely cancel the deletion process
|
412 |
+
# {_CANCEL_DELETION_STR}
|
413 |
+
# ------------
|
414 |
+
|
415 |
+
# REVISIONS
|
416 |
+
# ------------
|
417 |
+
""".strip()
|
418 |
+
|
419 |
+
|
420 |
+
def _repo_sorting_order(repo: CachedRepoInfo) -> Any:
|
421 |
+
# First split by Dataset/Model, then sort by last accessed (oldest first)
|
422 |
+
return (repo.repo_type, repo.last_accessed)
|
423 |
+
|
424 |
+
|
425 |
+
def _revision_sorting_order(revision: CachedRevisionInfo) -> Any:
|
426 |
+
# Sort by last modified (oldest first)
|
427 |
+
return revision.last_modified
|
lib/python3.11/site-packages/huggingface_hub/commands/download.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023-present, the HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Contains command to download files from the Hub with the CLI.
|
16 |
+
|
17 |
+
Usage:
|
18 |
+
huggingface-cli download --help
|
19 |
+
|
20 |
+
# Download file
|
21 |
+
huggingface-cli download gpt2 config.json
|
22 |
+
|
23 |
+
# Download entire repo
|
24 |
+
huggingface-cli download fffiloni/zeroscope --repo-type=space --revision=refs/pr/78
|
25 |
+
|
26 |
+
# Download repo with filters
|
27 |
+
huggingface-cli download gpt2 --include="*.safetensors"
|
28 |
+
|
29 |
+
# Download with token
|
30 |
+
huggingface-cli download Wauplin/private-model --token=hf_***
|
31 |
+
|
32 |
+
# Download quietly (no progress bar, no warnings, only the returned path)
|
33 |
+
huggingface-cli download gpt2 config.json --quiet
|
34 |
+
|
35 |
+
# Download to local dir
|
36 |
+
huggingface-cli download gpt2 --local-dir=./models/gpt2
|
37 |
+
"""
|
38 |
+
import warnings
|
39 |
+
from argparse import Namespace, _SubParsersAction
|
40 |
+
from typing import List, Literal, Optional, Union
|
41 |
+
|
42 |
+
from huggingface_hub import logging
|
43 |
+
from huggingface_hub._snapshot_download import snapshot_download
|
44 |
+
from huggingface_hub.commands import BaseHuggingfaceCLICommand
|
45 |
+
from huggingface_hub.constants import HF_HUB_ENABLE_HF_TRANSFER
|
46 |
+
from huggingface_hub.file_download import hf_hub_download
|
47 |
+
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
|
48 |
+
|
49 |
+
|
50 |
+
logger = logging.get_logger(__name__)
|
51 |
+
|
52 |
+
|
53 |
+
class DownloadCommand(BaseHuggingfaceCLICommand):
|
54 |
+
@staticmethod
|
55 |
+
def register_subcommand(parser: _SubParsersAction):
|
56 |
+
download_parser = parser.add_parser("download", help="Download files from the Hub")
|
57 |
+
download_parser.add_argument(
|
58 |
+
"repo_id", type=str, help="ID of the repo to download from (e.g. `username/repo-name`)."
|
59 |
+
)
|
60 |
+
download_parser.add_argument(
|
61 |
+
"filenames", type=str, nargs="*", help="Files to download (e.g. `config.json`, `data/metadata.jsonl`)."
|
62 |
+
)
|
63 |
+
download_parser.add_argument(
|
64 |
+
"--repo-type",
|
65 |
+
choices=["model", "dataset", "space"],
|
66 |
+
default="model",
|
67 |
+
help="Type of repo to download from (e.g. `dataset`).",
|
68 |
+
)
|
69 |
+
download_parser.add_argument(
|
70 |
+
"--revision",
|
71 |
+
type=str,
|
72 |
+
help="An optional Git revision id which can be a branch name, a tag, or a commit hash.",
|
73 |
+
)
|
74 |
+
download_parser.add_argument(
|
75 |
+
"--include", nargs="*", type=str, help="Glob patterns to match files to download."
|
76 |
+
)
|
77 |
+
download_parser.add_argument(
|
78 |
+
"--exclude", nargs="*", type=str, help="Glob patterns to exclude from files to download."
|
79 |
+
)
|
80 |
+
download_parser.add_argument(
|
81 |
+
"--cache-dir", type=str, help="Path to the directory where to save the downloaded files."
|
82 |
+
)
|
83 |
+
download_parser.add_argument(
|
84 |
+
"--local-dir",
|
85 |
+
type=str,
|
86 |
+
help=(
|
87 |
+
"If set, the downloaded file will be placed under this directory either as a symlink (default) or a"
|
88 |
+
" regular file. Check out"
|
89 |
+
" https://huggingface.co/docs/huggingface_hub/guides/download#download-files-to-local-folder for more"
|
90 |
+
" details."
|
91 |
+
),
|
92 |
+
)
|
93 |
+
download_parser.add_argument(
|
94 |
+
"--local-dir-use-symlinks",
|
95 |
+
choices=["auto", "True", "False"],
|
96 |
+
default="auto",
|
97 |
+
help=(
|
98 |
+
"To be used with `local_dir`. If set to 'auto', the cache directory will be used and the file will be"
|
99 |
+
" either duplicated or symlinked to the local directory depending on its size. It set to `True`, a"
|
100 |
+
" symlink will be created, no matter the file size. If set to `False`, the file will either be"
|
101 |
+
" duplicated from cache (if already exists) or downloaded from the Hub and not cached."
|
102 |
+
),
|
103 |
+
)
|
104 |
+
download_parser.add_argument(
|
105 |
+
"--force-download",
|
106 |
+
action="store_true",
|
107 |
+
help="If True, the files will be downloaded even if they are already cached.",
|
108 |
+
)
|
109 |
+
download_parser.add_argument(
|
110 |
+
"--resume-download", action="store_true", help="If True, resume a previously interrupted download."
|
111 |
+
)
|
112 |
+
download_parser.add_argument(
|
113 |
+
"--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens"
|
114 |
+
)
|
115 |
+
download_parser.add_argument(
|
116 |
+
"--quiet",
|
117 |
+
action="store_true",
|
118 |
+
help="If True, progress bars are disabled and only the path to the download files is printed.",
|
119 |
+
)
|
120 |
+
download_parser.set_defaults(func=DownloadCommand)
|
121 |
+
|
122 |
+
def __init__(self, args: Namespace) -> None:
|
123 |
+
self.token = args.token
|
124 |
+
self.repo_id: str = args.repo_id
|
125 |
+
self.filenames: List[str] = args.filenames
|
126 |
+
self.repo_type: str = args.repo_type
|
127 |
+
self.revision: Optional[str] = args.revision
|
128 |
+
self.include: Optional[List[str]] = args.include
|
129 |
+
self.exclude: Optional[List[str]] = args.exclude
|
130 |
+
self.cache_dir: Optional[str] = args.cache_dir
|
131 |
+
self.local_dir: Optional[str] = args.local_dir
|
132 |
+
self.force_download: bool = args.force_download
|
133 |
+
self.resume_download: bool = args.resume_download
|
134 |
+
self.quiet: bool = args.quiet
|
135 |
+
|
136 |
+
# Raise if local_dir_use_symlinks is invalid
|
137 |
+
self.local_dir_use_symlinks: Union[Literal["auto"], bool]
|
138 |
+
use_symlinks_lowercase = args.local_dir_use_symlinks.lower()
|
139 |
+
if use_symlinks_lowercase == "true":
|
140 |
+
self.local_dir_use_symlinks = True
|
141 |
+
elif use_symlinks_lowercase == "false":
|
142 |
+
self.local_dir_use_symlinks = False
|
143 |
+
elif use_symlinks_lowercase == "auto":
|
144 |
+
self.local_dir_use_symlinks = "auto"
|
145 |
+
else:
|
146 |
+
raise ValueError(
|
147 |
+
f"'{args.local_dir_use_symlinks}' is not a valid value for `local_dir_use_symlinks`. It must be either"
|
148 |
+
" 'auto', 'True' or 'False'."
|
149 |
+
)
|
150 |
+
|
151 |
+
def run(self) -> None:
|
152 |
+
if self.quiet:
|
153 |
+
disable_progress_bars()
|
154 |
+
with warnings.catch_warnings():
|
155 |
+
warnings.simplefilter("ignore")
|
156 |
+
print(self._download()) # Print path to downloaded files
|
157 |
+
enable_progress_bars()
|
158 |
+
else:
|
159 |
+
logging.set_verbosity_info()
|
160 |
+
print(self._download()) # Print path to downloaded files
|
161 |
+
logging.set_verbosity_warning()
|
162 |
+
|
163 |
+
def _download(self) -> str:
|
164 |
+
# Warn user if patterns are ignored
|
165 |
+
if len(self.filenames) > 0:
|
166 |
+
if self.include is not None and len(self.include) > 0:
|
167 |
+
warnings.warn("Ignoring `--include` since filenames have being explicitly set.")
|
168 |
+
if self.exclude is not None and len(self.exclude) > 0:
|
169 |
+
warnings.warn("Ignoring `--exclude` since filenames have being explicitly set.")
|
170 |
+
|
171 |
+
if not HF_HUB_ENABLE_HF_TRANSFER:
|
172 |
+
logger.info(
|
173 |
+
"Consider using `hf_transfer` for faster downloads. This solution comes with some limitations. See"
|
174 |
+
" https://huggingface.co/docs/huggingface_hub/hf_transfer for more details."
|
175 |
+
)
|
176 |
+
|
177 |
+
# Single file to download: use `hf_hub_download`
|
178 |
+
if len(self.filenames) == 1:
|
179 |
+
return hf_hub_download(
|
180 |
+
repo_id=self.repo_id,
|
181 |
+
repo_type=self.repo_type,
|
182 |
+
revision=self.revision,
|
183 |
+
filename=self.filenames[0],
|
184 |
+
cache_dir=self.cache_dir,
|
185 |
+
resume_download=self.resume_download,
|
186 |
+
force_download=self.force_download,
|
187 |
+
token=self.token,
|
188 |
+
local_dir=self.local_dir,
|
189 |
+
local_dir_use_symlinks=self.local_dir_use_symlinks,
|
190 |
+
library_name="huggingface-cli",
|
191 |
+
)
|
192 |
+
|
193 |
+
# Otherwise: use `snapshot_download` to ensure all files comes from same revision
|
194 |
+
elif len(self.filenames) == 0:
|
195 |
+
allow_patterns = self.include
|
196 |
+
ignore_patterns = self.exclude
|
197 |
+
else:
|
198 |
+
allow_patterns = self.filenames
|
199 |
+
ignore_patterns = None
|
200 |
+
|
201 |
+
return snapshot_download(
|
202 |
+
repo_id=self.repo_id,
|
203 |
+
repo_type=self.repo_type,
|
204 |
+
revision=self.revision,
|
205 |
+
allow_patterns=allow_patterns,
|
206 |
+
ignore_patterns=ignore_patterns,
|
207 |
+
resume_download=self.resume_download,
|
208 |
+
force_download=self.force_download,
|
209 |
+
cache_dir=self.cache_dir,
|
210 |
+
token=self.token,
|
211 |
+
local_dir=self.local_dir,
|
212 |
+
local_dir_use_symlinks=self.local_dir_use_symlinks,
|
213 |
+
library_name="huggingface-cli",
|
214 |
+
)
|
lib/python3.11/site-packages/huggingface_hub/commands/env.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Contains command to print information about the environment.
|
15 |
+
|
16 |
+
Usage:
|
17 |
+
huggingface-cli env
|
18 |
+
"""
|
19 |
+
from argparse import _SubParsersAction
|
20 |
+
|
21 |
+
from ..utils import dump_environment_info
|
22 |
+
from . import BaseHuggingfaceCLICommand
|
23 |
+
|
24 |
+
|
25 |
+
class EnvironmentCommand(BaseHuggingfaceCLICommand):
|
26 |
+
def __init__(self, args):
|
27 |
+
self.args = args
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def register_subcommand(parser: _SubParsersAction):
|
31 |
+
env_parser = parser.add_parser("env", help="Print information about the environment.")
|
32 |
+
env_parser.set_defaults(func=EnvironmentCommand)
|
33 |
+
|
34 |
+
def run(self) -> None:
|
35 |
+
dump_environment_info()
|
lib/python3.11/site-packages/huggingface_hub/commands/huggingface_cli.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
|
18 |
+
from huggingface_hub.commands.delete_cache import DeleteCacheCommand
|
19 |
+
from huggingface_hub.commands.download import DownloadCommand
|
20 |
+
from huggingface_hub.commands.env import EnvironmentCommand
|
21 |
+
from huggingface_hub.commands.lfs import LfsCommands
|
22 |
+
from huggingface_hub.commands.scan_cache import ScanCacheCommand
|
23 |
+
from huggingface_hub.commands.upload import UploadCommand
|
24 |
+
from huggingface_hub.commands.user import UserCommands
|
25 |
+
|
26 |
+
|
27 |
+
def main():
|
28 |
+
parser = ArgumentParser("huggingface-cli", usage="huggingface-cli <command> [<args>]")
|
29 |
+
commands_parser = parser.add_subparsers(help="huggingface-cli command helpers")
|
30 |
+
|
31 |
+
# Register commands
|
32 |
+
EnvironmentCommand.register_subcommand(commands_parser)
|
33 |
+
UserCommands.register_subcommand(commands_parser)
|
34 |
+
UploadCommand.register_subcommand(commands_parser)
|
35 |
+
DownloadCommand.register_subcommand(commands_parser)
|
36 |
+
LfsCommands.register_subcommand(commands_parser)
|
37 |
+
ScanCacheCommand.register_subcommand(commands_parser)
|
38 |
+
DeleteCacheCommand.register_subcommand(commands_parser)
|
39 |
+
|
40 |
+
# Let's go
|
41 |
+
args = parser.parse_args()
|
42 |
+
|
43 |
+
if not hasattr(args, "func"):
|
44 |
+
parser.print_help()
|
45 |
+
exit(1)
|
46 |
+
|
47 |
+
# Run
|
48 |
+
service = args.func(args)
|
49 |
+
service.run()
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
main()
|
lib/python3.11/site-packages/huggingface_hub/commands/lfs.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of a custom transfer agent for the transfer type "multipart" for
|
3 |
+
git-lfs.
|
4 |
+
|
5 |
+
Inspired by:
|
6 |
+
github.com/cbartz/git-lfs-swift-transfer-agent/blob/master/git_lfs_swift_transfer.py
|
7 |
+
|
8 |
+
Spec is: github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md
|
9 |
+
|
10 |
+
|
11 |
+
To launch debugger while developing:
|
12 |
+
|
13 |
+
``` [lfs "customtransfer.multipart"]
|
14 |
+
path = /path/to/huggingface_hub/.env/bin/python args = -m debugpy --listen 5678
|
15 |
+
--wait-for-client
|
16 |
+
/path/to/huggingface_hub/src/huggingface_hub/commands/huggingface_cli.py
|
17 |
+
lfs-multipart-upload ```"""
|
18 |
+
|
19 |
+
import json
|
20 |
+
import os
|
21 |
+
import subprocess
|
22 |
+
import sys
|
23 |
+
from argparse import _SubParsersAction
|
24 |
+
from typing import Dict, List, Optional
|
25 |
+
|
26 |
+
from huggingface_hub.commands import BaseHuggingfaceCLICommand
|
27 |
+
from huggingface_hub.lfs import LFS_MULTIPART_UPLOAD_COMMAND, SliceFileObj
|
28 |
+
|
29 |
+
from ..utils import get_session, hf_raise_for_status, logging
|
30 |
+
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
class LfsCommands(BaseHuggingfaceCLICommand):
|
36 |
+
"""
|
37 |
+
Implementation of a custom transfer agent for the transfer type "multipart"
|
38 |
+
for git-lfs. This lets users upload large files >5GB 🔥. Spec for LFS custom
|
39 |
+
transfer agent is:
|
40 |
+
https://github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md
|
41 |
+
|
42 |
+
This introduces two commands to the CLI:
|
43 |
+
|
44 |
+
1. $ huggingface-cli lfs-enable-largefiles
|
45 |
+
|
46 |
+
This should be executed once for each model repo that contains a model file
|
47 |
+
>5GB. It's documented in the error message you get if you just try to git
|
48 |
+
push a 5GB file without having enabled it before.
|
49 |
+
|
50 |
+
2. $ huggingface-cli lfs-multipart-upload
|
51 |
+
|
52 |
+
This command is called by lfs directly and is not meant to be called by the
|
53 |
+
user.
|
54 |
+
"""
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def register_subcommand(parser: _SubParsersAction):
|
58 |
+
enable_parser = parser.add_parser(
|
59 |
+
"lfs-enable-largefiles", help="Configure your repository to enable upload of files > 5GB."
|
60 |
+
)
|
61 |
+
enable_parser.add_argument("path", type=str, help="Local path to repository you want to configure.")
|
62 |
+
enable_parser.set_defaults(func=lambda args: LfsEnableCommand(args))
|
63 |
+
|
64 |
+
# Command will get called by git-lfs, do not call it directly.
|
65 |
+
upload_parser = parser.add_parser(LFS_MULTIPART_UPLOAD_COMMAND, add_help=False)
|
66 |
+
upload_parser.set_defaults(func=lambda args: LfsUploadCommand(args))
|
67 |
+
|
68 |
+
|
69 |
+
class LfsEnableCommand:
|
70 |
+
def __init__(self, args):
|
71 |
+
self.args = args
|
72 |
+
|
73 |
+
def run(self):
|
74 |
+
local_path = os.path.abspath(self.args.path)
|
75 |
+
if not os.path.isdir(local_path):
|
76 |
+
print("This does not look like a valid git repo.")
|
77 |
+
exit(1)
|
78 |
+
subprocess.run(
|
79 |
+
"git config lfs.customtransfer.multipart.path huggingface-cli".split(),
|
80 |
+
check=True,
|
81 |
+
cwd=local_path,
|
82 |
+
)
|
83 |
+
subprocess.run(
|
84 |
+
f"git config lfs.customtransfer.multipart.args {LFS_MULTIPART_UPLOAD_COMMAND}".split(),
|
85 |
+
check=True,
|
86 |
+
cwd=local_path,
|
87 |
+
)
|
88 |
+
print("Local repo set up for largefiles")
|
89 |
+
|
90 |
+
|
91 |
+
def write_msg(msg: Dict):
|
92 |
+
"""Write out the message in Line delimited JSON."""
|
93 |
+
msg_str = json.dumps(msg) + "\n"
|
94 |
+
sys.stdout.write(msg_str)
|
95 |
+
sys.stdout.flush()
|
96 |
+
|
97 |
+
|
98 |
+
def read_msg() -> Optional[Dict]:
|
99 |
+
"""Read Line delimited JSON from stdin."""
|
100 |
+
msg = json.loads(sys.stdin.readline().strip())
|
101 |
+
|
102 |
+
if "terminate" in (msg.get("type"), msg.get("event")):
|
103 |
+
# terminate message received
|
104 |
+
return None
|
105 |
+
|
106 |
+
if msg.get("event") not in ("download", "upload"):
|
107 |
+
logger.critical("Received unexpected message")
|
108 |
+
sys.exit(1)
|
109 |
+
|
110 |
+
return msg
|
111 |
+
|
112 |
+
|
113 |
+
class LfsUploadCommand:
|
114 |
+
def __init__(self, args) -> None:
|
115 |
+
self.args = args
|
116 |
+
|
117 |
+
def run(self) -> None:
|
118 |
+
# Immediately after invoking a custom transfer process, git-lfs
|
119 |
+
# sends initiation data to the process over stdin.
|
120 |
+
# This tells the process useful information about the configuration.
|
121 |
+
init_msg = json.loads(sys.stdin.readline().strip())
|
122 |
+
if not (init_msg.get("event") == "init" and init_msg.get("operation") == "upload"):
|
123 |
+
write_msg({"error": {"code": 32, "message": "Wrong lfs init operation"}})
|
124 |
+
sys.exit(1)
|
125 |
+
|
126 |
+
# The transfer process should use the information it needs from the
|
127 |
+
# initiation structure, and also perform any one-off setup tasks it
|
128 |
+
# needs to do. It should then respond on stdout with a simple empty
|
129 |
+
# confirmation structure, as follows:
|
130 |
+
write_msg({})
|
131 |
+
|
132 |
+
# After the initiation exchange, git-lfs will send any number of
|
133 |
+
# transfer requests to the stdin of the transfer process, in a serial sequence.
|
134 |
+
while True:
|
135 |
+
msg = read_msg()
|
136 |
+
if msg is None:
|
137 |
+
# When all transfers have been processed, git-lfs will send
|
138 |
+
# a terminate event to the stdin of the transfer process.
|
139 |
+
# On receiving this message the transfer process should
|
140 |
+
# clean up and terminate. No response is expected.
|
141 |
+
sys.exit(0)
|
142 |
+
|
143 |
+
oid = msg["oid"]
|
144 |
+
filepath = msg["path"]
|
145 |
+
completion_url = msg["action"]["href"]
|
146 |
+
header = msg["action"]["header"]
|
147 |
+
chunk_size = int(header.pop("chunk_size"))
|
148 |
+
presigned_urls: List[str] = list(header.values())
|
149 |
+
|
150 |
+
# Send a "started" progress event to allow other workers to start.
|
151 |
+
# Otherwise they're delayed until first "progress" event is reported,
|
152 |
+
# i.e. after the first 5GB by default (!)
|
153 |
+
write_msg(
|
154 |
+
{
|
155 |
+
"event": "progress",
|
156 |
+
"oid": oid,
|
157 |
+
"bytesSoFar": 1,
|
158 |
+
"bytesSinceLast": 0,
|
159 |
+
}
|
160 |
+
)
|
161 |
+
|
162 |
+
parts = []
|
163 |
+
with open(filepath, "rb") as file:
|
164 |
+
for i, presigned_url in enumerate(presigned_urls):
|
165 |
+
with SliceFileObj(
|
166 |
+
file,
|
167 |
+
seek_from=i * chunk_size,
|
168 |
+
read_limit=chunk_size,
|
169 |
+
) as data:
|
170 |
+
r = get_session().put(presigned_url, data=data)
|
171 |
+
hf_raise_for_status(r)
|
172 |
+
parts.append(
|
173 |
+
{
|
174 |
+
"etag": r.headers.get("etag"),
|
175 |
+
"partNumber": i + 1,
|
176 |
+
}
|
177 |
+
)
|
178 |
+
# In order to support progress reporting while data is uploading / downloading,
|
179 |
+
# the transfer process should post messages to stdout
|
180 |
+
write_msg(
|
181 |
+
{
|
182 |
+
"event": "progress",
|
183 |
+
"oid": oid,
|
184 |
+
"bytesSoFar": (i + 1) * chunk_size,
|
185 |
+
"bytesSinceLast": chunk_size,
|
186 |
+
}
|
187 |
+
)
|
188 |
+
# Not precise but that's ok.
|
189 |
+
|
190 |
+
r = get_session().post(
|
191 |
+
completion_url,
|
192 |
+
json={
|
193 |
+
"oid": oid,
|
194 |
+
"parts": parts,
|
195 |
+
},
|
196 |
+
)
|
197 |
+
hf_raise_for_status(r)
|
198 |
+
|
199 |
+
write_msg({"event": "complete", "oid": oid})
|
lib/python3.11/site-packages/huggingface_hub/commands/scan_cache.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022-present, the HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Contains command to scan the HF cache directory.
|
16 |
+
|
17 |
+
Usage:
|
18 |
+
huggingface-cli scan-cache
|
19 |
+
huggingface-cli scan-cache -v
|
20 |
+
huggingface-cli scan-cache -vvv
|
21 |
+
huggingface-cli scan-cache --dir ~/.cache/huggingface/hub
|
22 |
+
"""
|
23 |
+
import time
|
24 |
+
from argparse import Namespace, _SubParsersAction
|
25 |
+
from typing import Optional
|
26 |
+
|
27 |
+
from ..utils import CacheNotFound, HFCacheInfo, scan_cache_dir
|
28 |
+
from . import BaseHuggingfaceCLICommand
|
29 |
+
from ._cli_utils import ANSI, tabulate
|
30 |
+
|
31 |
+
|
32 |
+
class ScanCacheCommand(BaseHuggingfaceCLICommand):
|
33 |
+
@staticmethod
|
34 |
+
def register_subcommand(parser: _SubParsersAction):
|
35 |
+
scan_cache_parser = parser.add_parser("scan-cache", help="Scan cache directory.")
|
36 |
+
|
37 |
+
scan_cache_parser.add_argument(
|
38 |
+
"--dir",
|
39 |
+
type=str,
|
40 |
+
default=None,
|
41 |
+
help="cache directory to scan (optional). Default to the default HuggingFace cache.",
|
42 |
+
)
|
43 |
+
scan_cache_parser.add_argument(
|
44 |
+
"-v",
|
45 |
+
"--verbose",
|
46 |
+
action="count",
|
47 |
+
default=0,
|
48 |
+
help="show a more verbose output",
|
49 |
+
)
|
50 |
+
scan_cache_parser.set_defaults(func=ScanCacheCommand)
|
51 |
+
|
52 |
+
def __init__(self, args: Namespace) -> None:
|
53 |
+
self.verbosity: int = args.verbose
|
54 |
+
self.cache_dir: Optional[str] = args.dir
|
55 |
+
|
56 |
+
def run(self):
|
57 |
+
try:
|
58 |
+
t0 = time.time()
|
59 |
+
hf_cache_info = scan_cache_dir(self.cache_dir)
|
60 |
+
t1 = time.time()
|
61 |
+
except CacheNotFound as exc:
|
62 |
+
cache_dir = exc.cache_dir
|
63 |
+
print(f"Cache directory not found: {cache_dir}")
|
64 |
+
return
|
65 |
+
|
66 |
+
self._print_hf_cache_info_as_table(hf_cache_info)
|
67 |
+
|
68 |
+
print(
|
69 |
+
f"\nDone in {round(t1-t0,1)}s. Scanned {len(hf_cache_info.repos)} repo(s)"
|
70 |
+
f" for a total of {ANSI.red(hf_cache_info.size_on_disk_str)}."
|
71 |
+
)
|
72 |
+
if len(hf_cache_info.warnings) > 0:
|
73 |
+
message = f"Got {len(hf_cache_info.warnings)} warning(s) while scanning."
|
74 |
+
if self.verbosity >= 3:
|
75 |
+
print(ANSI.gray(message))
|
76 |
+
for warning in hf_cache_info.warnings:
|
77 |
+
print(ANSI.gray(warning))
|
78 |
+
else:
|
79 |
+
print(ANSI.gray(message + " Use -vvv to print details."))
|
80 |
+
|
81 |
+
def _print_hf_cache_info_as_table(self, hf_cache_info: HFCacheInfo) -> None:
|
82 |
+
if self.verbosity == 0:
|
83 |
+
print(
|
84 |
+
tabulate(
|
85 |
+
rows=[
|
86 |
+
[
|
87 |
+
repo.repo_id,
|
88 |
+
repo.repo_type,
|
89 |
+
"{:>12}".format(repo.size_on_disk_str),
|
90 |
+
repo.nb_files,
|
91 |
+
repo.last_accessed_str,
|
92 |
+
repo.last_modified_str,
|
93 |
+
", ".join(sorted(repo.refs)),
|
94 |
+
str(repo.repo_path),
|
95 |
+
]
|
96 |
+
for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path)
|
97 |
+
],
|
98 |
+
headers=[
|
99 |
+
"REPO ID",
|
100 |
+
"REPO TYPE",
|
101 |
+
"SIZE ON DISK",
|
102 |
+
"NB FILES",
|
103 |
+
"LAST_ACCESSED",
|
104 |
+
"LAST_MODIFIED",
|
105 |
+
"REFS",
|
106 |
+
"LOCAL PATH",
|
107 |
+
],
|
108 |
+
)
|
109 |
+
)
|
110 |
+
else:
|
111 |
+
print(
|
112 |
+
tabulate(
|
113 |
+
rows=[
|
114 |
+
[
|
115 |
+
repo.repo_id,
|
116 |
+
repo.repo_type,
|
117 |
+
revision.commit_hash,
|
118 |
+
"{:>12}".format(revision.size_on_disk_str),
|
119 |
+
revision.nb_files,
|
120 |
+
revision.last_modified_str,
|
121 |
+
", ".join(sorted(revision.refs)),
|
122 |
+
str(revision.snapshot_path),
|
123 |
+
]
|
124 |
+
for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path)
|
125 |
+
for revision in sorted(repo.revisions, key=lambda revision: revision.commit_hash)
|
126 |
+
],
|
127 |
+
headers=[
|
128 |
+
"REPO ID",
|
129 |
+
"REPO TYPE",
|
130 |
+
"REVISION",
|
131 |
+
"SIZE ON DISK",
|
132 |
+
"NB FILES",
|
133 |
+
"LAST_MODIFIED",
|
134 |
+
"REFS",
|
135 |
+
"LOCAL PATH",
|
136 |
+
],
|
137 |
+
)
|
138 |
+
)
|
lib/python3.11/site-packages/huggingface_hub/commands/upload.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023-present, the HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Contains command to upload a repo or file with the CLI.
|
16 |
+
|
17 |
+
Usage:
|
18 |
+
# Upload file (implicit)
|
19 |
+
huggingface-cli upload my-cool-model ./my-cool-model.safetensors
|
20 |
+
|
21 |
+
# Upload file (explicit)
|
22 |
+
huggingface-cli upload my-cool-model ./my-cool-model.safetensors model.safetensors
|
23 |
+
|
24 |
+
# Upload directory (implicit). If `my-cool-model/` is a directory it will be uploaded, otherwise an exception is raised.
|
25 |
+
huggingface-cli upload my-cool-model
|
26 |
+
|
27 |
+
# Upload directory (explicit)
|
28 |
+
huggingface-cli upload my-cool-model ./models/my-cool-model .
|
29 |
+
|
30 |
+
# Upload filtered directory (example: tensorboard logs except for the last run)
|
31 |
+
huggingface-cli upload my-cool-model ./model/training /logs --include "*.tfevents.*" --exclude "*20230905*"
|
32 |
+
|
33 |
+
# Upload private dataset
|
34 |
+
huggingface-cli upload Wauplin/my-cool-dataset ./data . --repo-type=dataset --private
|
35 |
+
|
36 |
+
# Upload with token
|
37 |
+
huggingface-cli upload Wauplin/my-cool-model --token=hf_****
|
38 |
+
|
39 |
+
# Sync local Space with Hub (upload new files, delete removed files)
|
40 |
+
huggingface-cli upload Wauplin/space-example --repo-type=space --exclude="/logs/*" --delete="*" --commit-message="Sync local Space with Hub"
|
41 |
+
|
42 |
+
# Schedule commits every 30 minutes
|
43 |
+
huggingface-cli upload Wauplin/my-cool-model --every=30
|
44 |
+
"""
|
45 |
+
import os
|
46 |
+
import time
|
47 |
+
import warnings
|
48 |
+
from argparse import Namespace, _SubParsersAction
|
49 |
+
from typing import List, Optional
|
50 |
+
|
51 |
+
from huggingface_hub import logging
|
52 |
+
from huggingface_hub._commit_scheduler import CommitScheduler
|
53 |
+
from huggingface_hub.commands import BaseHuggingfaceCLICommand
|
54 |
+
from huggingface_hub.constants import HF_HUB_ENABLE_HF_TRANSFER
|
55 |
+
from huggingface_hub.hf_api import HfApi
|
56 |
+
from huggingface_hub.utils import RevisionNotFoundError, disable_progress_bars, enable_progress_bars
|
57 |
+
|
58 |
+
|
59 |
+
logger = logging.get_logger(__name__)
|
60 |
+
|
61 |
+
|
62 |
+
class UploadCommand(BaseHuggingfaceCLICommand):
|
63 |
+
@staticmethod
|
64 |
+
def register_subcommand(parser: _SubParsersAction):
|
65 |
+
upload_parser = parser.add_parser("upload", help="Upload a file or a folder to a repo on the Hub")
|
66 |
+
upload_parser.add_argument(
|
67 |
+
"repo_id", type=str, help="The ID of the repo to upload to (e.g. `username/repo-name`)."
|
68 |
+
)
|
69 |
+
upload_parser.add_argument(
|
70 |
+
"local_path", nargs="?", help="Local path to the file or folder to upload. Defaults to current directory."
|
71 |
+
)
|
72 |
+
upload_parser.add_argument(
|
73 |
+
"path_in_repo",
|
74 |
+
nargs="?",
|
75 |
+
help="Path of the file or folder in the repo. Defaults to the relative path of the file or folder.",
|
76 |
+
)
|
77 |
+
upload_parser.add_argument(
|
78 |
+
"--repo-type",
|
79 |
+
choices=["model", "dataset", "space"],
|
80 |
+
default="model",
|
81 |
+
help="Type of the repo to upload to (e.g. `dataset`).",
|
82 |
+
)
|
83 |
+
upload_parser.add_argument(
|
84 |
+
"--revision",
|
85 |
+
type=str,
|
86 |
+
help=(
|
87 |
+
"An optional Git revision to push to. It can be a branch name or a PR reference. If revision does not"
|
88 |
+
" exist and `--create-pr` is not set, a branch will be automatically created."
|
89 |
+
),
|
90 |
+
)
|
91 |
+
upload_parser.add_argument(
|
92 |
+
"--private",
|
93 |
+
action="store_true",
|
94 |
+
help=(
|
95 |
+
"Whether to create a private repo if repo doesn't exist on the Hub. Ignored if the repo already"
|
96 |
+
" exists."
|
97 |
+
),
|
98 |
+
)
|
99 |
+
upload_parser.add_argument("--include", nargs="*", type=str, help="Glob patterns to match files to upload.")
|
100 |
+
upload_parser.add_argument(
|
101 |
+
"--exclude", nargs="*", type=str, help="Glob patterns to exclude from files to upload."
|
102 |
+
)
|
103 |
+
upload_parser.add_argument(
|
104 |
+
"--delete",
|
105 |
+
nargs="*",
|
106 |
+
type=str,
|
107 |
+
help="Glob patterns for file to be deleted from the repo while committing.",
|
108 |
+
)
|
109 |
+
upload_parser.add_argument(
|
110 |
+
"--commit-message", type=str, help="The summary / title / first line of the generated commit."
|
111 |
+
)
|
112 |
+
upload_parser.add_argument("--commit-description", type=str, help="The description of the generated commit.")
|
113 |
+
upload_parser.add_argument(
|
114 |
+
"--create-pr", action="store_true", help="Whether to upload content as a new Pull Request."
|
115 |
+
)
|
116 |
+
upload_parser.add_argument(
|
117 |
+
"--every",
|
118 |
+
type=float,
|
119 |
+
help="If set, a background job is scheduled to create commits every `every` minutes.",
|
120 |
+
)
|
121 |
+
upload_parser.add_argument(
|
122 |
+
"--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens"
|
123 |
+
)
|
124 |
+
upload_parser.add_argument(
|
125 |
+
"--quiet",
|
126 |
+
action="store_true",
|
127 |
+
help="If True, progress bars are disabled and only the path to the uploaded files is printed.",
|
128 |
+
)
|
129 |
+
upload_parser.set_defaults(func=UploadCommand)
|
130 |
+
|
131 |
+
def __init__(self, args: Namespace) -> None:
|
132 |
+
self.repo_id: str = args.repo_id
|
133 |
+
self.repo_type: Optional[str] = args.repo_type
|
134 |
+
self.revision: Optional[str] = args.revision
|
135 |
+
self.private: bool = args.private
|
136 |
+
|
137 |
+
self.include: Optional[List[str]] = args.include
|
138 |
+
self.exclude: Optional[List[str]] = args.exclude
|
139 |
+
self.delete: Optional[List[str]] = args.delete
|
140 |
+
|
141 |
+
self.commit_message: Optional[str] = args.commit_message
|
142 |
+
self.commit_description: Optional[str] = args.commit_description
|
143 |
+
self.create_pr: bool = args.create_pr
|
144 |
+
self.api: HfApi = HfApi(token=args.token, library_name="huggingface-cli")
|
145 |
+
self.quiet: bool = args.quiet # disable warnings and progress bars
|
146 |
+
|
147 |
+
# Check `--every` is valid
|
148 |
+
if args.every is not None and args.every <= 0:
|
149 |
+
raise ValueError(f"`every` must be a positive value (got '{args.every}')")
|
150 |
+
self.every: Optional[float] = args.every
|
151 |
+
|
152 |
+
# Resolve `local_path` and `path_in_repo`
|
153 |
+
repo_name: str = args.repo_id.split("/")[-1] # e.g. "Wauplin/my-cool-model" => "my-cool-model"
|
154 |
+
self.local_path: str
|
155 |
+
self.path_in_repo: str
|
156 |
+
if args.local_path is None and os.path.isfile(repo_name):
|
157 |
+
# Implicit case 1: user provided only a repo_id which happen to be a local file as well => upload it with same name
|
158 |
+
self.local_path = repo_name
|
159 |
+
self.path_in_repo = repo_name
|
160 |
+
elif args.local_path is None and os.path.isdir(repo_name):
|
161 |
+
# Implicit case 2: user provided only a repo_id which happen to be a local folder as well => upload it at root
|
162 |
+
self.local_path = repo_name
|
163 |
+
self.path_in_repo = "."
|
164 |
+
elif args.local_path is None:
|
165 |
+
# Implicit case 3: user provided only a repo_id that does not match a local file or folder
|
166 |
+
# => the user must explicitly provide a local_path => raise exception
|
167 |
+
raise ValueError(f"'{repo_name}' is not a local file or folder. Please set `local_path` explicitly.")
|
168 |
+
elif args.path_in_repo is None and os.path.isfile(args.local_path):
|
169 |
+
# Explicit local path to file, no path in repo => upload it at root with same name
|
170 |
+
self.local_path = args.local_path
|
171 |
+
self.path_in_repo = os.path.basename(args.local_path)
|
172 |
+
elif args.path_in_repo is None:
|
173 |
+
# Explicit local path to folder, no path in repo => upload at root
|
174 |
+
self.local_path = args.local_path
|
175 |
+
self.path_in_repo = "."
|
176 |
+
else:
|
177 |
+
# Finally, if both paths are explicit
|
178 |
+
self.local_path = args.local_path
|
179 |
+
self.path_in_repo = args.path_in_repo
|
180 |
+
|
181 |
+
def run(self) -> None:
|
182 |
+
if self.quiet:
|
183 |
+
disable_progress_bars()
|
184 |
+
with warnings.catch_warnings():
|
185 |
+
warnings.simplefilter("ignore")
|
186 |
+
print(self._upload())
|
187 |
+
enable_progress_bars()
|
188 |
+
else:
|
189 |
+
logging.set_verbosity_info()
|
190 |
+
print(self._upload())
|
191 |
+
logging.set_verbosity_warning()
|
192 |
+
|
193 |
+
def _upload(self) -> str:
|
194 |
+
if os.path.isfile(self.local_path):
|
195 |
+
if self.include is not None and len(self.include) > 0:
|
196 |
+
warnings.warn("Ignoring `--include` since a single file is uploaded.")
|
197 |
+
if self.exclude is not None and len(self.exclude) > 0:
|
198 |
+
warnings.warn("Ignoring `--exclude` since a single file is uploaded.")
|
199 |
+
if self.delete is not None and len(self.delete) > 0:
|
200 |
+
warnings.warn("Ignoring `--delete` since a single file is uploaded.")
|
201 |
+
|
202 |
+
if not HF_HUB_ENABLE_HF_TRANSFER:
|
203 |
+
logger.info(
|
204 |
+
"Consider using `hf_transfer` for faster uploads. This solution comes with some limitations. See"
|
205 |
+
" https://huggingface.co/docs/huggingface_hub/hf_transfer for more details."
|
206 |
+
)
|
207 |
+
|
208 |
+
# Schedule commits if `every` is set
|
209 |
+
if self.every is not None:
|
210 |
+
if os.path.isfile(self.local_path):
|
211 |
+
# If file => watch entire folder + use allow_patterns
|
212 |
+
folder_path = os.path.dirname(self.local_path)
|
213 |
+
path_in_repo = (
|
214 |
+
self.path_in_repo[: -len(self.local_path)] # remove filename from path_in_repo
|
215 |
+
if self.path_in_repo.endswith(self.local_path)
|
216 |
+
else self.path_in_repo
|
217 |
+
)
|
218 |
+
allow_patterns = [self.local_path]
|
219 |
+
ignore_patterns = []
|
220 |
+
else:
|
221 |
+
folder_path = self.local_path
|
222 |
+
path_in_repo = self.path_in_repo
|
223 |
+
allow_patterns = self.include or []
|
224 |
+
ignore_patterns = self.exclude or []
|
225 |
+
if self.delete is not None and len(self.delete) > 0:
|
226 |
+
warnings.warn("Ignoring `--delete` when uploading with scheduled commits.")
|
227 |
+
|
228 |
+
scheduler = CommitScheduler(
|
229 |
+
folder_path=folder_path,
|
230 |
+
repo_id=self.repo_id,
|
231 |
+
repo_type=self.repo_type,
|
232 |
+
revision=self.revision,
|
233 |
+
allow_patterns=allow_patterns,
|
234 |
+
ignore_patterns=ignore_patterns,
|
235 |
+
path_in_repo=path_in_repo,
|
236 |
+
private=self.private,
|
237 |
+
every=self.every,
|
238 |
+
hf_api=self.api,
|
239 |
+
)
|
240 |
+
print(f"Scheduling commits every {self.every} minutes to {scheduler.repo_id}.")
|
241 |
+
try: # Block main thread until KeyboardInterrupt
|
242 |
+
while True:
|
243 |
+
time.sleep(100)
|
244 |
+
except KeyboardInterrupt:
|
245 |
+
scheduler.stop()
|
246 |
+
return "Stopped scheduled commits."
|
247 |
+
|
248 |
+
# Otherwise, create repo and proceed with the upload
|
249 |
+
if not os.path.isfile(self.local_path) and not os.path.isdir(self.local_path):
|
250 |
+
raise FileNotFoundError(f"No such file or directory: '{self.local_path}'.")
|
251 |
+
repo_id = self.api.create_repo(
|
252 |
+
repo_id=self.repo_id,
|
253 |
+
repo_type=self.repo_type,
|
254 |
+
exist_ok=True,
|
255 |
+
private=self.private,
|
256 |
+
space_sdk="gradio" if self.repo_type == "space" else None,
|
257 |
+
# ^ We don't want it to fail when uploading to a Space => let's set Gradio by default.
|
258 |
+
# ^ I'd rather not add CLI args to set it explicitly as we already have `huggingface-cli repo create` for that.
|
259 |
+
).repo_id
|
260 |
+
|
261 |
+
# Check if branch already exists and if not, create it
|
262 |
+
if self.revision is not None and not self.create_pr:
|
263 |
+
try:
|
264 |
+
self.api.repo_info(repo_id=repo_id, repo_type=self.repo_type, revision=self.revision)
|
265 |
+
except RevisionNotFoundError:
|
266 |
+
logger.info(f"Branch '{self.revision}' not found. Creating it...")
|
267 |
+
self.api.create_branch(repo_id=repo_id, repo_type=self.repo_type, branch=self.revision, exist_ok=True)
|
268 |
+
# ^ `exist_ok=True` to avoid race concurrency issues
|
269 |
+
|
270 |
+
# File-based upload
|
271 |
+
if os.path.isfile(self.local_path):
|
272 |
+
return self.api.upload_file(
|
273 |
+
path_or_fileobj=self.local_path,
|
274 |
+
path_in_repo=self.path_in_repo,
|
275 |
+
repo_id=repo_id,
|
276 |
+
repo_type=self.repo_type,
|
277 |
+
revision=self.revision,
|
278 |
+
commit_message=self.commit_message,
|
279 |
+
commit_description=self.commit_description,
|
280 |
+
create_pr=self.create_pr,
|
281 |
+
)
|
282 |
+
|
283 |
+
# Folder-based upload
|
284 |
+
else:
|
285 |
+
return self.api.upload_folder(
|
286 |
+
folder_path=self.local_path,
|
287 |
+
path_in_repo=self.path_in_repo,
|
288 |
+
repo_id=repo_id,
|
289 |
+
repo_type=self.repo_type,
|
290 |
+
revision=self.revision,
|
291 |
+
commit_message=self.commit_message,
|
292 |
+
commit_description=self.commit_description,
|
293 |
+
create_pr=self.create_pr,
|
294 |
+
allow_patterns=self.include,
|
295 |
+
ignore_patterns=self.exclude,
|
296 |
+
delete_patterns=self.delete,
|
297 |
+
)
|
lib/python3.11/site-packages/huggingface_hub/commands/user.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import subprocess
|
15 |
+
from argparse import _SubParsersAction
|
16 |
+
|
17 |
+
from requests.exceptions import HTTPError
|
18 |
+
|
19 |
+
from huggingface_hub.commands import BaseHuggingfaceCLICommand
|
20 |
+
from huggingface_hub.constants import (
|
21 |
+
ENDPOINT,
|
22 |
+
REPO_TYPES,
|
23 |
+
REPO_TYPES_URL_PREFIXES,
|
24 |
+
SPACES_SDK_TYPES,
|
25 |
+
)
|
26 |
+
from huggingface_hub.hf_api import HfApi
|
27 |
+
|
28 |
+
from .._login import ( # noqa: F401 # for backward compatibility # noqa: F401 # for backward compatibility
|
29 |
+
NOTEBOOK_LOGIN_PASSWORD_HTML,
|
30 |
+
NOTEBOOK_LOGIN_TOKEN_HTML_END,
|
31 |
+
NOTEBOOK_LOGIN_TOKEN_HTML_START,
|
32 |
+
login,
|
33 |
+
logout,
|
34 |
+
notebook_login,
|
35 |
+
)
|
36 |
+
from ..utils import get_token
|
37 |
+
from ._cli_utils import ANSI
|
38 |
+
|
39 |
+
|
40 |
+
class UserCommands(BaseHuggingfaceCLICommand):
|
41 |
+
@staticmethod
|
42 |
+
def register_subcommand(parser: _SubParsersAction):
|
43 |
+
login_parser = parser.add_parser("login", help="Log in using a token from huggingface.co/settings/tokens")
|
44 |
+
login_parser.add_argument(
|
45 |
+
"--token",
|
46 |
+
type=str,
|
47 |
+
help="Token generated from https://huggingface.co/settings/tokens",
|
48 |
+
)
|
49 |
+
login_parser.add_argument(
|
50 |
+
"--add-to-git-credential",
|
51 |
+
action="store_true",
|
52 |
+
help="Optional: Save token to git credential helper.",
|
53 |
+
)
|
54 |
+
login_parser.set_defaults(func=lambda args: LoginCommand(args))
|
55 |
+
whoami_parser = parser.add_parser("whoami", help="Find out which huggingface.co account you are logged in as.")
|
56 |
+
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
|
57 |
+
logout_parser = parser.add_parser("logout", help="Log out")
|
58 |
+
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
|
59 |
+
|
60 |
+
# new system: git-based repo system
|
61 |
+
repo_parser = parser.add_parser("repo", help="{create} Commands to interact with your huggingface.co repos.")
|
62 |
+
repo_subparsers = repo_parser.add_subparsers(help="huggingface.co repos related commands")
|
63 |
+
repo_create_parser = repo_subparsers.add_parser("create", help="Create a new repo on huggingface.co")
|
64 |
+
repo_create_parser.add_argument(
|
65 |
+
"name",
|
66 |
+
type=str,
|
67 |
+
help="Name for your repo. Will be namespaced under your username to build the repo id.",
|
68 |
+
)
|
69 |
+
repo_create_parser.add_argument(
|
70 |
+
"--type",
|
71 |
+
type=str,
|
72 |
+
help='Optional: repo_type: set to "dataset" or "space" if creating a dataset or space, default is model.',
|
73 |
+
)
|
74 |
+
repo_create_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
75 |
+
repo_create_parser.add_argument(
|
76 |
+
"--space_sdk",
|
77 |
+
type=str,
|
78 |
+
help='Optional: Hugging Face Spaces SDK type. Required when --type is set to "space".',
|
79 |
+
choices=SPACES_SDK_TYPES,
|
80 |
+
)
|
81 |
+
repo_create_parser.add_argument(
|
82 |
+
"-y",
|
83 |
+
"--yes",
|
84 |
+
action="store_true",
|
85 |
+
help="Optional: answer Yes to the prompt",
|
86 |
+
)
|
87 |
+
repo_create_parser.set_defaults(func=lambda args: RepoCreateCommand(args))
|
88 |
+
|
89 |
+
|
90 |
+
class BaseUserCommand:
|
91 |
+
def __init__(self, args):
|
92 |
+
self.args = args
|
93 |
+
self._api = HfApi()
|
94 |
+
|
95 |
+
|
96 |
+
class LoginCommand(BaseUserCommand):
|
97 |
+
def run(self):
|
98 |
+
login(token=self.args.token, add_to_git_credential=self.args.add_to_git_credential)
|
99 |
+
|
100 |
+
|
101 |
+
class LogoutCommand(BaseUserCommand):
|
102 |
+
def run(self):
|
103 |
+
logout()
|
104 |
+
|
105 |
+
|
106 |
+
class WhoamiCommand(BaseUserCommand):
|
107 |
+
def run(self):
|
108 |
+
token = get_token()
|
109 |
+
if token is None:
|
110 |
+
print("Not logged in")
|
111 |
+
exit()
|
112 |
+
try:
|
113 |
+
info = self._api.whoami(token)
|
114 |
+
print(info["name"])
|
115 |
+
orgs = [org["name"] for org in info["orgs"]]
|
116 |
+
if orgs:
|
117 |
+
print(ANSI.bold("orgs: "), ",".join(orgs))
|
118 |
+
|
119 |
+
if ENDPOINT != "https://huggingface.co":
|
120 |
+
print(f"Authenticated through private endpoint: {ENDPOINT}")
|
121 |
+
except HTTPError as e:
|
122 |
+
print(e)
|
123 |
+
print(ANSI.red(e.response.text))
|
124 |
+
exit(1)
|
125 |
+
|
126 |
+
|
127 |
+
class RepoCreateCommand(BaseUserCommand):
|
128 |
+
def run(self):
|
129 |
+
token = get_token()
|
130 |
+
if token is None:
|
131 |
+
print("Not logged in")
|
132 |
+
exit(1)
|
133 |
+
try:
|
134 |
+
stdout = subprocess.check_output(["git", "--version"]).decode("utf-8")
|
135 |
+
print(ANSI.gray(stdout.strip()))
|
136 |
+
except FileNotFoundError:
|
137 |
+
print("Looks like you do not have git installed, please install.")
|
138 |
+
|
139 |
+
try:
|
140 |
+
stdout = subprocess.check_output(["git-lfs", "--version"]).decode("utf-8")
|
141 |
+
print(ANSI.gray(stdout.strip()))
|
142 |
+
except FileNotFoundError:
|
143 |
+
print(
|
144 |
+
ANSI.red(
|
145 |
+
"Looks like you do not have git-lfs installed, please install."
|
146 |
+
" You can install from https://git-lfs.github.com/."
|
147 |
+
" Then run `git lfs install` (you only have to do this once)."
|
148 |
+
)
|
149 |
+
)
|
150 |
+
print("")
|
151 |
+
|
152 |
+
user = self._api.whoami(token)["name"]
|
153 |
+
namespace = self.args.organization if self.args.organization is not None else user
|
154 |
+
|
155 |
+
repo_id = f"{namespace}/{self.args.name}"
|
156 |
+
|
157 |
+
if self.args.type not in REPO_TYPES:
|
158 |
+
print("Invalid repo --type")
|
159 |
+
exit(1)
|
160 |
+
|
161 |
+
if self.args.type in REPO_TYPES_URL_PREFIXES:
|
162 |
+
prefixed_repo_id = REPO_TYPES_URL_PREFIXES[self.args.type] + repo_id
|
163 |
+
else:
|
164 |
+
prefixed_repo_id = repo_id
|
165 |
+
|
166 |
+
print(f"You are about to create {ANSI.bold(prefixed_repo_id)}")
|
167 |
+
|
168 |
+
if not self.args.yes:
|
169 |
+
choice = input("Proceed? [Y/n] ").lower()
|
170 |
+
if not (choice == "" or choice == "y" or choice == "yes"):
|
171 |
+
print("Abort")
|
172 |
+
exit()
|
173 |
+
try:
|
174 |
+
url = self._api.create_repo(
|
175 |
+
repo_id=repo_id,
|
176 |
+
token=token,
|
177 |
+
repo_type=self.args.type,
|
178 |
+
space_sdk=self.args.space_sdk,
|
179 |
+
)
|
180 |
+
except HTTPError as e:
|
181 |
+
print(e)
|
182 |
+
print(ANSI.red(e.response.text))
|
183 |
+
exit(1)
|
184 |
+
print("\nYour repo now lives at:")
|
185 |
+
print(f" {ANSI.bold(url)}")
|
186 |
+
print("\nYou can clone it locally with the command below, and commit/push as usual.")
|
187 |
+
print(f"\n git clone {url}")
|
188 |
+
print("")
|
lib/python3.11/site-packages/huggingface_hub/community.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Data structures to interact with Discussions and Pull Requests on the Hub.
|
3 |
+
|
4 |
+
See [the Discussions and Pull Requests guide](https://huggingface.co/docs/hub/repositories-pull-requests-discussions)
|
5 |
+
for more information on Pull Requests, Discussions, and the community tab.
|
6 |
+
"""
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from datetime import datetime
|
9 |
+
from typing import List, Literal, Optional, Union
|
10 |
+
|
11 |
+
from .constants import REPO_TYPE_MODEL
|
12 |
+
from .utils import parse_datetime
|
13 |
+
|
14 |
+
|
15 |
+
DiscussionStatus = Literal["open", "closed", "merged", "draft"]
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class Discussion:
|
20 |
+
"""
|
21 |
+
A Discussion or Pull Request on the Hub.
|
22 |
+
|
23 |
+
This dataclass is not intended to be instantiated directly.
|
24 |
+
|
25 |
+
Attributes:
|
26 |
+
title (`str`):
|
27 |
+
The title of the Discussion / Pull Request
|
28 |
+
status (`str`):
|
29 |
+
The status of the Discussion / Pull Request.
|
30 |
+
It must be one of:
|
31 |
+
* `"open"`
|
32 |
+
* `"closed"`
|
33 |
+
* `"merged"` (only for Pull Requests )
|
34 |
+
* `"draft"` (only for Pull Requests )
|
35 |
+
num (`int`):
|
36 |
+
The number of the Discussion / Pull Request.
|
37 |
+
repo_id (`str`):
|
38 |
+
The id (`"{namespace}/{repo_name}"`) of the repo on which
|
39 |
+
the Discussion / Pull Request was open.
|
40 |
+
repo_type (`str`):
|
41 |
+
The type of the repo on which the Discussion / Pull Request was open.
|
42 |
+
Possible values are: `"model"`, `"dataset"`, `"space"`.
|
43 |
+
author (`str`):
|
44 |
+
The username of the Discussion / Pull Request author.
|
45 |
+
Can be `"deleted"` if the user has been deleted since.
|
46 |
+
is_pull_request (`bool`):
|
47 |
+
Whether or not this is a Pull Request.
|
48 |
+
created_at (`datetime`):
|
49 |
+
The `datetime` of creation of the Discussion / Pull Request.
|
50 |
+
endpoint (`str`):
|
51 |
+
Endpoint of the Hub. Default is https://huggingface.co.
|
52 |
+
git_reference (`str`, *optional*):
|
53 |
+
(property) Git reference to which changes can be pushed if this is a Pull Request, `None` otherwise.
|
54 |
+
url (`str`):
|
55 |
+
(property) URL of the discussion on the Hub.
|
56 |
+
"""
|
57 |
+
|
58 |
+
title: str
|
59 |
+
status: DiscussionStatus
|
60 |
+
num: int
|
61 |
+
repo_id: str
|
62 |
+
repo_type: str
|
63 |
+
author: str
|
64 |
+
is_pull_request: bool
|
65 |
+
created_at: datetime
|
66 |
+
endpoint: str
|
67 |
+
|
68 |
+
@property
|
69 |
+
def git_reference(self) -> Optional[str]:
|
70 |
+
"""
|
71 |
+
If this is a Pull Request , returns the git reference to which changes can be pushed.
|
72 |
+
Returns `None` otherwise.
|
73 |
+
"""
|
74 |
+
if self.is_pull_request:
|
75 |
+
return f"refs/pr/{self.num}"
|
76 |
+
return None
|
77 |
+
|
78 |
+
@property
|
79 |
+
def url(self) -> str:
|
80 |
+
"""Returns the URL of the discussion on the Hub."""
|
81 |
+
if self.repo_type is None or self.repo_type == REPO_TYPE_MODEL:
|
82 |
+
return f"{self.endpoint}/{self.repo_id}/discussions/{self.num}"
|
83 |
+
return f"{self.endpoint}/{self.repo_type}s/{self.repo_id}/discussions/{self.num}"
|
84 |
+
|
85 |
+
|
86 |
+
@dataclass
|
87 |
+
class DiscussionWithDetails(Discussion):
|
88 |
+
"""
|
89 |
+
Subclass of [`Discussion`].
|
90 |
+
|
91 |
+
Attributes:
|
92 |
+
title (`str`):
|
93 |
+
The title of the Discussion / Pull Request
|
94 |
+
status (`str`):
|
95 |
+
The status of the Discussion / Pull Request.
|
96 |
+
It can be one of:
|
97 |
+
* `"open"`
|
98 |
+
* `"closed"`
|
99 |
+
* `"merged"` (only for Pull Requests )
|
100 |
+
* `"draft"` (only for Pull Requests )
|
101 |
+
num (`int`):
|
102 |
+
The number of the Discussion / Pull Request.
|
103 |
+
repo_id (`str`):
|
104 |
+
The id (`"{namespace}/{repo_name}"`) of the repo on which
|
105 |
+
the Discussion / Pull Request was open.
|
106 |
+
repo_type (`str`):
|
107 |
+
The type of the repo on which the Discussion / Pull Request was open.
|
108 |
+
Possible values are: `"model"`, `"dataset"`, `"space"`.
|
109 |
+
author (`str`):
|
110 |
+
The username of the Discussion / Pull Request author.
|
111 |
+
Can be `"deleted"` if the user has been deleted since.
|
112 |
+
is_pull_request (`bool`):
|
113 |
+
Whether or not this is a Pull Request.
|
114 |
+
created_at (`datetime`):
|
115 |
+
The `datetime` of creation of the Discussion / Pull Request.
|
116 |
+
events (`list` of [`DiscussionEvent`])
|
117 |
+
The list of [`DiscussionEvents`] in this Discussion or Pull Request.
|
118 |
+
conflicting_files (`Union[List[str], bool, None]`, *optional*):
|
119 |
+
A list of conflicting files if this is a Pull Request.
|
120 |
+
`None` if `self.is_pull_request` is `False`.
|
121 |
+
`True` if there are conflicting files but the list can't be retrieved.
|
122 |
+
target_branch (`str`, *optional*):
|
123 |
+
The branch into which changes are to be merged if this is a
|
124 |
+
Pull Request . `None` if `self.is_pull_request` is `False`.
|
125 |
+
merge_commit_oid (`str`, *optional*):
|
126 |
+
If this is a merged Pull Request , this is set to the OID / SHA of
|
127 |
+
the merge commit, `None` otherwise.
|
128 |
+
diff (`str`, *optional*):
|
129 |
+
The git diff if this is a Pull Request , `None` otherwise.
|
130 |
+
endpoint (`str`):
|
131 |
+
Endpoint of the Hub. Default is https://huggingface.co.
|
132 |
+
git_reference (`str`, *optional*):
|
133 |
+
(property) Git reference to which changes can be pushed if this is a Pull Request, `None` otherwise.
|
134 |
+
url (`str`):
|
135 |
+
(property) URL of the discussion on the Hub.
|
136 |
+
"""
|
137 |
+
|
138 |
+
events: List["DiscussionEvent"]
|
139 |
+
conflicting_files: Union[List[str], bool, None]
|
140 |
+
target_branch: Optional[str]
|
141 |
+
merge_commit_oid: Optional[str]
|
142 |
+
diff: Optional[str]
|
143 |
+
|
144 |
+
|
145 |
+
@dataclass
|
146 |
+
class DiscussionEvent:
|
147 |
+
"""
|
148 |
+
An event in a Discussion or Pull Request.
|
149 |
+
|
150 |
+
Use concrete classes:
|
151 |
+
* [`DiscussionComment`]
|
152 |
+
* [`DiscussionStatusChange`]
|
153 |
+
* [`DiscussionCommit`]
|
154 |
+
* [`DiscussionTitleChange`]
|
155 |
+
|
156 |
+
Attributes:
|
157 |
+
id (`str`):
|
158 |
+
The ID of the event. An hexadecimal string.
|
159 |
+
type (`str`):
|
160 |
+
The type of the event.
|
161 |
+
created_at (`datetime`):
|
162 |
+
A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
|
163 |
+
object holding the creation timestamp for the event.
|
164 |
+
author (`str`):
|
165 |
+
The username of the Discussion / Pull Request author.
|
166 |
+
Can be `"deleted"` if the user has been deleted since.
|
167 |
+
"""
|
168 |
+
|
169 |
+
id: str
|
170 |
+
type: str
|
171 |
+
created_at: datetime
|
172 |
+
author: str
|
173 |
+
|
174 |
+
_event: dict
|
175 |
+
"""Stores the original event data, in case we need to access it later."""
|
176 |
+
|
177 |
+
|
178 |
+
@dataclass
|
179 |
+
class DiscussionComment(DiscussionEvent):
|
180 |
+
"""A comment in a Discussion / Pull Request.
|
181 |
+
|
182 |
+
Subclass of [`DiscussionEvent`].
|
183 |
+
|
184 |
+
|
185 |
+
Attributes:
|
186 |
+
id (`str`):
|
187 |
+
The ID of the event. An hexadecimal string.
|
188 |
+
type (`str`):
|
189 |
+
The type of the event.
|
190 |
+
created_at (`datetime`):
|
191 |
+
A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
|
192 |
+
object holding the creation timestamp for the event.
|
193 |
+
author (`str`):
|
194 |
+
The username of the Discussion / Pull Request author.
|
195 |
+
Can be `"deleted"` if the user has been deleted since.
|
196 |
+
content (`str`):
|
197 |
+
The raw markdown content of the comment. Mentions, links and images are not rendered.
|
198 |
+
edited (`bool`):
|
199 |
+
Whether or not this comment has been edited.
|
200 |
+
hidden (`bool`):
|
201 |
+
Whether or not this comment has been hidden.
|
202 |
+
"""
|
203 |
+
|
204 |
+
content: str
|
205 |
+
edited: bool
|
206 |
+
hidden: bool
|
207 |
+
|
208 |
+
@property
|
209 |
+
def rendered(self) -> str:
|
210 |
+
"""The rendered comment, as a HTML string"""
|
211 |
+
return self._event["data"]["latest"]["html"]
|
212 |
+
|
213 |
+
@property
|
214 |
+
def last_edited_at(self) -> datetime:
|
215 |
+
"""The last edit time, as a `datetime` object."""
|
216 |
+
return parse_datetime(self._event["data"]["latest"]["updatedAt"])
|
217 |
+
|
218 |
+
@property
|
219 |
+
def last_edited_by(self) -> str:
|
220 |
+
"""The last edit time, as a `datetime` object."""
|
221 |
+
return self._event["data"]["latest"].get("author", {}).get("name", "deleted")
|
222 |
+
|
223 |
+
@property
|
224 |
+
def edit_history(self) -> List[dict]:
|
225 |
+
"""The edit history of the comment"""
|
226 |
+
return self._event["data"]["history"]
|
227 |
+
|
228 |
+
@property
|
229 |
+
def number_of_edits(self) -> int:
|
230 |
+
return len(self.edit_history)
|
231 |
+
|
232 |
+
|
233 |
+
@dataclass
|
234 |
+
class DiscussionStatusChange(DiscussionEvent):
|
235 |
+
"""A change of status in a Discussion / Pull Request.
|
236 |
+
|
237 |
+
Subclass of [`DiscussionEvent`].
|
238 |
+
|
239 |
+
Attributes:
|
240 |
+
id (`str`):
|
241 |
+
The ID of the event. An hexadecimal string.
|
242 |
+
type (`str`):
|
243 |
+
The type of the event.
|
244 |
+
created_at (`datetime`):
|
245 |
+
A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
|
246 |
+
object holding the creation timestamp for the event.
|
247 |
+
author (`str`):
|
248 |
+
The username of the Discussion / Pull Request author.
|
249 |
+
Can be `"deleted"` if the user has been deleted since.
|
250 |
+
new_status (`str`):
|
251 |
+
The status of the Discussion / Pull Request after the change.
|
252 |
+
It can be one of:
|
253 |
+
* `"open"`
|
254 |
+
* `"closed"`
|
255 |
+
* `"merged"` (only for Pull Requests )
|
256 |
+
"""
|
257 |
+
|
258 |
+
new_status: str
|
259 |
+
|
260 |
+
|
261 |
+
@dataclass
|
262 |
+
class DiscussionCommit(DiscussionEvent):
|
263 |
+
"""A commit in a Pull Request.
|
264 |
+
|
265 |
+
Subclass of [`DiscussionEvent`].
|
266 |
+
|
267 |
+
Attributes:
|
268 |
+
id (`str`):
|
269 |
+
The ID of the event. An hexadecimal string.
|
270 |
+
type (`str`):
|
271 |
+
The type of the event.
|
272 |
+
created_at (`datetime`):
|
273 |
+
A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
|
274 |
+
object holding the creation timestamp for the event.
|
275 |
+
author (`str`):
|
276 |
+
The username of the Discussion / Pull Request author.
|
277 |
+
Can be `"deleted"` if the user has been deleted since.
|
278 |
+
summary (`str`):
|
279 |
+
The summary of the commit.
|
280 |
+
oid (`str`):
|
281 |
+
The OID / SHA of the commit, as a hexadecimal string.
|
282 |
+
"""
|
283 |
+
|
284 |
+
summary: str
|
285 |
+
oid: str
|
286 |
+
|
287 |
+
|
288 |
+
@dataclass
|
289 |
+
class DiscussionTitleChange(DiscussionEvent):
|
290 |
+
"""A rename event in a Discussion / Pull Request.
|
291 |
+
|
292 |
+
Subclass of [`DiscussionEvent`].
|
293 |
+
|
294 |
+
Attributes:
|
295 |
+
id (`str`):
|
296 |
+
The ID of the event. An hexadecimal string.
|
297 |
+
type (`str`):
|
298 |
+
The type of the event.
|
299 |
+
created_at (`datetime`):
|
300 |
+
A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
|
301 |
+
object holding the creation timestamp for the event.
|
302 |
+
author (`str`):
|
303 |
+
The username of the Discussion / Pull Request author.
|
304 |
+
Can be `"deleted"` if the user has been deleted since.
|
305 |
+
old_title (`str`):
|
306 |
+
The previous title for the Discussion / Pull Request.
|
307 |
+
new_title (`str`):
|
308 |
+
The new title.
|
309 |
+
"""
|
310 |
+
|
311 |
+
old_title: str
|
312 |
+
new_title: str
|
313 |
+
|
314 |
+
|
315 |
+
def deserialize_event(event: dict) -> DiscussionEvent:
|
316 |
+
"""Instantiates a [`DiscussionEvent`] from a dict"""
|
317 |
+
event_id: str = event["id"]
|
318 |
+
event_type: str = event["type"]
|
319 |
+
created_at = parse_datetime(event["createdAt"])
|
320 |
+
|
321 |
+
common_args = dict(
|
322 |
+
id=event_id,
|
323 |
+
type=event_type,
|
324 |
+
created_at=created_at,
|
325 |
+
author=event.get("author", {}).get("name", "deleted"),
|
326 |
+
_event=event,
|
327 |
+
)
|
328 |
+
|
329 |
+
if event_type == "comment":
|
330 |
+
return DiscussionComment(
|
331 |
+
**common_args,
|
332 |
+
edited=event["data"]["edited"],
|
333 |
+
hidden=event["data"]["hidden"],
|
334 |
+
content=event["data"]["latest"]["raw"],
|
335 |
+
)
|
336 |
+
if event_type == "status-change":
|
337 |
+
return DiscussionStatusChange(
|
338 |
+
**common_args,
|
339 |
+
new_status=event["data"]["status"],
|
340 |
+
)
|
341 |
+
if event_type == "commit":
|
342 |
+
return DiscussionCommit(
|
343 |
+
**common_args,
|
344 |
+
summary=event["data"]["subject"],
|
345 |
+
oid=event["data"]["oid"],
|
346 |
+
)
|
347 |
+
if event_type == "title-change":
|
348 |
+
return DiscussionTitleChange(
|
349 |
+
**common_args,
|
350 |
+
old_title=event["data"]["from"],
|
351 |
+
new_title=event["data"]["to"],
|
352 |
+
)
|
353 |
+
|
354 |
+
return DiscussionEvent(**common_args)
|
lib/python3.11/site-packages/huggingface_hub/constants.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import typing
|
4 |
+
from typing import Literal, Optional, Tuple
|
5 |
+
|
6 |
+
|
7 |
+
# Possible values for env variables
|
8 |
+
|
9 |
+
|
10 |
+
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
11 |
+
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
12 |
+
|
13 |
+
|
14 |
+
def _is_true(value: Optional[str]) -> bool:
|
15 |
+
if value is None:
|
16 |
+
return False
|
17 |
+
return value.upper() in ENV_VARS_TRUE_VALUES
|
18 |
+
|
19 |
+
|
20 |
+
def _as_int(value: Optional[str]) -> Optional[int]:
|
21 |
+
if value is None:
|
22 |
+
return None
|
23 |
+
return int(value)
|
24 |
+
|
25 |
+
|
26 |
+
# Constants for file downloads
|
27 |
+
|
28 |
+
PYTORCH_WEIGHTS_NAME = "pytorch_model.bin"
|
29 |
+
TF2_WEIGHTS_NAME = "tf_model.h5"
|
30 |
+
TF_WEIGHTS_NAME = "model.ckpt"
|
31 |
+
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
32 |
+
CONFIG_NAME = "config.json"
|
33 |
+
REPOCARD_NAME = "README.md"
|
34 |
+
DEFAULT_ETAG_TIMEOUT = 10
|
35 |
+
DEFAULT_DOWNLOAD_TIMEOUT = 10
|
36 |
+
DEFAULT_REQUEST_TIMEOUT = 10
|
37 |
+
DOWNLOAD_CHUNK_SIZE = 10 * 1024 * 1024
|
38 |
+
HF_TRANSFER_CONCURRENCY = 100
|
39 |
+
|
40 |
+
# Constants for safetensors repos
|
41 |
+
|
42 |
+
SAFETENSORS_SINGLE_FILE = "model.safetensors"
|
43 |
+
SAFETENSORS_INDEX_FILE = "model.safetensors.index.json"
|
44 |
+
SAFETENSORS_MAX_HEADER_LENGTH = 25_000_000
|
45 |
+
|
46 |
+
# Git-related constants
|
47 |
+
|
48 |
+
DEFAULT_REVISION = "main"
|
49 |
+
REGEX_COMMIT_OID = re.compile(r"[A-Fa-f0-9]{5,40}")
|
50 |
+
|
51 |
+
HUGGINGFACE_CO_URL_HOME = "https://huggingface.co/"
|
52 |
+
|
53 |
+
_staging_mode = _is_true(os.environ.get("HUGGINGFACE_CO_STAGING"))
|
54 |
+
|
55 |
+
ENDPOINT = os.getenv("HF_ENDPOINT") or ("https://hub-ci.huggingface.co" if _staging_mode else "https://huggingface.co")
|
56 |
+
|
57 |
+
HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}"
|
58 |
+
HUGGINGFACE_HEADER_X_REPO_COMMIT = "X-Repo-Commit"
|
59 |
+
HUGGINGFACE_HEADER_X_LINKED_ETAG = "X-Linked-Etag"
|
60 |
+
HUGGINGFACE_HEADER_X_LINKED_SIZE = "X-Linked-Size"
|
61 |
+
|
62 |
+
INFERENCE_ENDPOINT = os.environ.get("HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co")
|
63 |
+
|
64 |
+
# See https://huggingface.co/docs/inference-endpoints/index
|
65 |
+
INFERENCE_ENDPOINTS_ENDPOINT = "https://api.endpoints.huggingface.cloud/v2"
|
66 |
+
|
67 |
+
|
68 |
+
REPO_ID_SEPARATOR = "--"
|
69 |
+
# ^ this substring is not allowed in repo_ids on hf.co
|
70 |
+
# and is the canonical one we use for serialization of repo ids elsewhere.
|
71 |
+
|
72 |
+
|
73 |
+
REPO_TYPE_DATASET = "dataset"
|
74 |
+
REPO_TYPE_SPACE = "space"
|
75 |
+
REPO_TYPE_MODEL = "model"
|
76 |
+
REPO_TYPES = [None, REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE]
|
77 |
+
SPACES_SDK_TYPES = ["gradio", "streamlit", "docker", "static"]
|
78 |
+
|
79 |
+
REPO_TYPES_URL_PREFIXES = {
|
80 |
+
REPO_TYPE_DATASET: "datasets/",
|
81 |
+
REPO_TYPE_SPACE: "spaces/",
|
82 |
+
}
|
83 |
+
REPO_TYPES_MAPPING = {
|
84 |
+
"datasets": REPO_TYPE_DATASET,
|
85 |
+
"spaces": REPO_TYPE_SPACE,
|
86 |
+
"models": REPO_TYPE_MODEL,
|
87 |
+
}
|
88 |
+
|
89 |
+
DiscussionTypeFilter = Literal["all", "discussion", "pull_request"]
|
90 |
+
DISCUSSION_TYPES: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionTypeFilter)
|
91 |
+
DiscussionStatusFilter = Literal["all", "open", "closed"]
|
92 |
+
DISCUSSION_STATUS: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionStatusFilter)
|
93 |
+
|
94 |
+
# default cache
|
95 |
+
default_home = os.path.join(os.path.expanduser("~"), ".cache")
|
96 |
+
HF_HOME = os.path.expanduser(
|
97 |
+
os.getenv(
|
98 |
+
"HF_HOME",
|
99 |
+
os.path.join(os.getenv("XDG_CACHE_HOME", default_home), "huggingface"),
|
100 |
+
)
|
101 |
+
)
|
102 |
+
hf_cache_home = HF_HOME # for backward compatibility. TODO: remove this in 1.0.0
|
103 |
+
|
104 |
+
default_cache_path = os.path.join(HF_HOME, "hub")
|
105 |
+
default_assets_cache_path = os.path.join(HF_HOME, "assets")
|
106 |
+
|
107 |
+
# Legacy env variables
|
108 |
+
HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", default_cache_path)
|
109 |
+
HUGGINGFACE_ASSETS_CACHE = os.getenv("HUGGINGFACE_ASSETS_CACHE", default_assets_cache_path)
|
110 |
+
|
111 |
+
# New env variables
|
112 |
+
HF_HUB_CACHE = os.getenv("HF_HUB_CACHE", HUGGINGFACE_HUB_CACHE)
|
113 |
+
HF_ASSETS_CACHE = os.getenv("HF_ASSETS_CACHE", HUGGINGFACE_ASSETS_CACHE)
|
114 |
+
|
115 |
+
HF_HUB_OFFLINE = _is_true(os.environ.get("HF_HUB_OFFLINE") or os.environ.get("TRANSFORMERS_OFFLINE"))
|
116 |
+
|
117 |
+
# Opt-out from telemetry requests
|
118 |
+
HF_HUB_DISABLE_TELEMETRY = (
|
119 |
+
_is_true(os.environ.get("HF_HUB_DISABLE_TELEMETRY")) # HF-specific env variable
|
120 |
+
or _is_true(os.environ.get("DISABLE_TELEMETRY"))
|
121 |
+
or _is_true(os.environ.get("DO_NOT_TRACK")) # https://consoledonottrack.com/
|
122 |
+
)
|
123 |
+
|
124 |
+
# In the past, token was stored in a hardcoded location
|
125 |
+
# `_OLD_HF_TOKEN_PATH` is deprecated and will be removed "at some point".
|
126 |
+
# See https://github.com/huggingface/huggingface_hub/issues/1232
|
127 |
+
_OLD_HF_TOKEN_PATH = os.path.expanduser("~/.huggingface/token")
|
128 |
+
HF_TOKEN_PATH = os.path.join(HF_HOME, "token")
|
129 |
+
|
130 |
+
|
131 |
+
if _staging_mode:
|
132 |
+
# In staging mode, we use a different cache to ensure we don't mix up production and staging data or tokens
|
133 |
+
_staging_home = os.path.join(os.path.expanduser("~"), ".cache", "huggingface_staging")
|
134 |
+
HUGGINGFACE_HUB_CACHE = os.path.join(_staging_home, "hub")
|
135 |
+
_OLD_HF_TOKEN_PATH = os.path.join(_staging_home, "_old_token")
|
136 |
+
HF_TOKEN_PATH = os.path.join(_staging_home, "token")
|
137 |
+
|
138 |
+
# Here, `True` will disable progress bars globally without possibility of enabling it
|
139 |
+
# programmatically. `False` will enable them without possibility of disabling them.
|
140 |
+
# If environment variable is not set (None), then the user is free to enable/disable
|
141 |
+
# them programmatically.
|
142 |
+
# TL;DR: env variable has priority over code
|
143 |
+
__HF_HUB_DISABLE_PROGRESS_BARS = os.environ.get("HF_HUB_DISABLE_PROGRESS_BARS")
|
144 |
+
HF_HUB_DISABLE_PROGRESS_BARS: Optional[bool] = (
|
145 |
+
_is_true(__HF_HUB_DISABLE_PROGRESS_BARS) if __HF_HUB_DISABLE_PROGRESS_BARS is not None else None
|
146 |
+
)
|
147 |
+
|
148 |
+
# Disable warning on machines that do not support symlinks (e.g. Windows non-developer)
|
149 |
+
HF_HUB_DISABLE_SYMLINKS_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_SYMLINKS_WARNING"))
|
150 |
+
|
151 |
+
# Disable warning when using experimental features
|
152 |
+
HF_HUB_DISABLE_EXPERIMENTAL_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_EXPERIMENTAL_WARNING"))
|
153 |
+
|
154 |
+
# Disable sending the cached token by default is all HTTP requests to the Hub
|
155 |
+
HF_HUB_DISABLE_IMPLICIT_TOKEN: bool = _is_true(os.environ.get("HF_HUB_DISABLE_IMPLICIT_TOKEN"))
|
156 |
+
|
157 |
+
# Enable fast-download using external dependency "hf_transfer"
|
158 |
+
# See:
|
159 |
+
# - https://pypi.org/project/hf-transfer/
|
160 |
+
# - https://github.com/huggingface/hf_transfer (private)
|
161 |
+
HF_HUB_ENABLE_HF_TRANSFER: bool = _is_true(os.environ.get("HF_HUB_ENABLE_HF_TRANSFER"))
|
162 |
+
|
163 |
+
|
164 |
+
# Used if download to `local_dir` and `local_dir_use_symlinks="auto"`
|
165 |
+
# Files smaller than 5MB are copy-pasted while bigger files are symlinked. The idea is to save disk-usage by symlinking
|
166 |
+
# huge files (i.e. LFS files most of the time) while allowing small files to be manually edited in local folder.
|
167 |
+
HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD: int = (
|
168 |
+
_as_int(os.environ.get("HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD")) or 5 * 1024 * 1024
|
169 |
+
)
|
170 |
+
|
171 |
+
# Used to override the etag timeout on a system level
|
172 |
+
HF_HUB_ETAG_TIMEOUT: int = _as_int(os.environ.get("HF_HUB_ETAG_TIMEOUT")) or DEFAULT_ETAG_TIMEOUT
|
173 |
+
|
174 |
+
# Used to override the get request timeout on a system level
|
175 |
+
HF_HUB_DOWNLOAD_TIMEOUT: int = _as_int(os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT")) or DEFAULT_DOWNLOAD_TIMEOUT
|
176 |
+
|
177 |
+
# List frameworks that are handled by the InferenceAPI service. Useful to scan endpoints and check which models are
|
178 |
+
# deployed and running. Since 95% of the models are using the top 4 frameworks listed below, we scan only those by
|
179 |
+
# default. We still keep the full list of supported frameworks in case we want to scan all of them.
|
180 |
+
MAIN_INFERENCE_API_FRAMEWORKS = [
|
181 |
+
"diffusers",
|
182 |
+
"sentence-transformers",
|
183 |
+
"text-generation-inference",
|
184 |
+
"transformers",
|
185 |
+
]
|
186 |
+
|
187 |
+
ALL_INFERENCE_API_FRAMEWORKS = MAIN_INFERENCE_API_FRAMEWORKS + [
|
188 |
+
"adapter-transformers",
|
189 |
+
"allennlp",
|
190 |
+
"asteroid",
|
191 |
+
"bertopic",
|
192 |
+
"doctr",
|
193 |
+
"espnet",
|
194 |
+
"fairseq",
|
195 |
+
"fastai",
|
196 |
+
"fasttext",
|
197 |
+
"flair",
|
198 |
+
"generic",
|
199 |
+
"k2",
|
200 |
+
"keras",
|
201 |
+
"mindspore",
|
202 |
+
"nemo",
|
203 |
+
"open_clip",
|
204 |
+
"paddlenlp",
|
205 |
+
"peft",
|
206 |
+
"pyannote-audio",
|
207 |
+
"sklearn",
|
208 |
+
"spacy",
|
209 |
+
"span-marker",
|
210 |
+
"speechbrain",
|
211 |
+
"stanza",
|
212 |
+
"timm",
|
213 |
+
]
|
lib/python3.11/site-packages/huggingface_hub/fastai_utils.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from pickle import DEFAULT_PROTOCOL, PicklingError
|
5 |
+
from typing import Any, Dict, List, Optional, Union
|
6 |
+
|
7 |
+
from packaging import version
|
8 |
+
|
9 |
+
from huggingface_hub import snapshot_download
|
10 |
+
from huggingface_hub.constants import CONFIG_NAME
|
11 |
+
from huggingface_hub.hf_api import HfApi
|
12 |
+
from huggingface_hub.utils import (
|
13 |
+
SoftTemporaryDirectory,
|
14 |
+
get_fastai_version,
|
15 |
+
get_fastcore_version,
|
16 |
+
get_python_version,
|
17 |
+
)
|
18 |
+
|
19 |
+
from .utils import logging, validate_hf_hub_args
|
20 |
+
from .utils._runtime import _PY_VERSION # noqa: F401 # for backward compatibility...
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
def _check_fastai_fastcore_versions(
|
27 |
+
fastai_min_version: str = "2.4",
|
28 |
+
fastcore_min_version: str = "1.3.27",
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Checks that the installed fastai and fastcore versions are compatible for pickle serialization.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
fastai_min_version (`str`, *optional*):
|
35 |
+
The minimum fastai version supported.
|
36 |
+
fastcore_min_version (`str`, *optional*):
|
37 |
+
The minimum fastcore version supported.
|
38 |
+
|
39 |
+
<Tip>
|
40 |
+
Raises the following error:
|
41 |
+
|
42 |
+
- [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
|
43 |
+
if the fastai or fastcore libraries are not available or are of an invalid version.
|
44 |
+
|
45 |
+
</Tip>
|
46 |
+
"""
|
47 |
+
|
48 |
+
if (get_fastcore_version() or get_fastai_version()) == "N/A":
|
49 |
+
raise ImportError(
|
50 |
+
f"fastai>={fastai_min_version} and fastcore>={fastcore_min_version} are"
|
51 |
+
f" required. Currently using fastai=={get_fastai_version()} and"
|
52 |
+
f" fastcore=={get_fastcore_version()}."
|
53 |
+
)
|
54 |
+
|
55 |
+
current_fastai_version = version.Version(get_fastai_version())
|
56 |
+
current_fastcore_version = version.Version(get_fastcore_version())
|
57 |
+
|
58 |
+
if current_fastai_version < version.Version(fastai_min_version):
|
59 |
+
raise ImportError(
|
60 |
+
"`push_to_hub_fastai` and `from_pretrained_fastai` require a"
|
61 |
+
f" fastai>={fastai_min_version} version, but you are using fastai version"
|
62 |
+
f" {get_fastai_version()} which is incompatible. Upgrade with `pip install"
|
63 |
+
" fastai==2.5.6`."
|
64 |
+
)
|
65 |
+
|
66 |
+
if current_fastcore_version < version.Version(fastcore_min_version):
|
67 |
+
raise ImportError(
|
68 |
+
"`push_to_hub_fastai` and `from_pretrained_fastai` require a"
|
69 |
+
f" fastcore>={fastcore_min_version} version, but you are using fastcore"
|
70 |
+
f" version {get_fastcore_version()} which is incompatible. Upgrade with"
|
71 |
+
" `pip install fastcore==1.3.27`."
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
def _check_fastai_fastcore_pyproject_versions(
|
76 |
+
storage_folder: str,
|
77 |
+
fastai_min_version: str = "2.4",
|
78 |
+
fastcore_min_version: str = "1.3.27",
|
79 |
+
):
|
80 |
+
"""
|
81 |
+
Checks that the `pyproject.toml` file in the directory `storage_folder` has fastai and fastcore versions
|
82 |
+
that are compatible with `from_pretrained_fastai` and `push_to_hub_fastai`. If `pyproject.toml` does not exist
|
83 |
+
or does not contain versions for fastai and fastcore, then it logs a warning.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
storage_folder (`str`):
|
87 |
+
Folder to look for the `pyproject.toml` file.
|
88 |
+
fastai_min_version (`str`, *optional*):
|
89 |
+
The minimum fastai version supported.
|
90 |
+
fastcore_min_version (`str`, *optional*):
|
91 |
+
The minimum fastcore version supported.
|
92 |
+
|
93 |
+
<Tip>
|
94 |
+
Raises the following errors:
|
95 |
+
|
96 |
+
- [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
|
97 |
+
if the `toml` module is not installed.
|
98 |
+
- [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
|
99 |
+
if the `pyproject.toml` indicates a lower than minimum supported version of fastai or fastcore.
|
100 |
+
|
101 |
+
</Tip>
|
102 |
+
"""
|
103 |
+
|
104 |
+
try:
|
105 |
+
import toml
|
106 |
+
except ModuleNotFoundError:
|
107 |
+
raise ImportError(
|
108 |
+
"`push_to_hub_fastai` and `from_pretrained_fastai` require the toml module."
|
109 |
+
" Install it with `pip install toml`."
|
110 |
+
)
|
111 |
+
|
112 |
+
# Checks that a `pyproject.toml`, with `build-system` and `requires` sections, exists in the repository. If so, get a list of required packages.
|
113 |
+
if not os.path.isfile(f"{storage_folder}/pyproject.toml"):
|
114 |
+
logger.warning(
|
115 |
+
"There is no `pyproject.toml` in the repository that contains the fastai"
|
116 |
+
" `Learner`. The `pyproject.toml` would allow us to verify that your fastai"
|
117 |
+
" and fastcore versions are compatible with those of the model you want to"
|
118 |
+
" load."
|
119 |
+
)
|
120 |
+
return
|
121 |
+
pyproject_toml = toml.load(f"{storage_folder}/pyproject.toml")
|
122 |
+
|
123 |
+
if "build-system" not in pyproject_toml.keys():
|
124 |
+
logger.warning(
|
125 |
+
"There is no `build-system` section in the pyproject.toml of the repository"
|
126 |
+
" that contains the fastai `Learner`. The `build-system` would allow us to"
|
127 |
+
" verify that your fastai and fastcore versions are compatible with those"
|
128 |
+
" of the model you want to load."
|
129 |
+
)
|
130 |
+
return
|
131 |
+
build_system_toml = pyproject_toml["build-system"]
|
132 |
+
|
133 |
+
if "requires" not in build_system_toml.keys():
|
134 |
+
logger.warning(
|
135 |
+
"There is no `requires` section in the pyproject.toml of the repository"
|
136 |
+
" that contains the fastai `Learner`. The `requires` would allow us to"
|
137 |
+
" verify that your fastai and fastcore versions are compatible with those"
|
138 |
+
" of the model you want to load."
|
139 |
+
)
|
140 |
+
return
|
141 |
+
package_versions = build_system_toml["requires"]
|
142 |
+
|
143 |
+
# Extracts contains fastai and fastcore versions from `pyproject.toml` if available.
|
144 |
+
# If the package is specified but not the version (e.g. "fastai" instead of "fastai=2.4"), the default versions are the highest.
|
145 |
+
fastai_packages = [pck for pck in package_versions if pck.startswith("fastai")]
|
146 |
+
if len(fastai_packages) == 0:
|
147 |
+
logger.warning("The repository does not have a fastai version specified in the `pyproject.toml`.")
|
148 |
+
# fastai_version is an empty string if not specified
|
149 |
+
else:
|
150 |
+
fastai_version = str(fastai_packages[0]).partition("=")[2]
|
151 |
+
if fastai_version != "" and version.Version(fastai_version) < version.Version(fastai_min_version):
|
152 |
+
raise ImportError(
|
153 |
+
"`from_pretrained_fastai` requires"
|
154 |
+
f" fastai>={fastai_min_version} version but the model to load uses"
|
155 |
+
f" {fastai_version} which is incompatible."
|
156 |
+
)
|
157 |
+
|
158 |
+
fastcore_packages = [pck for pck in package_versions if pck.startswith("fastcore")]
|
159 |
+
if len(fastcore_packages) == 0:
|
160 |
+
logger.warning("The repository does not have a fastcore version specified in the `pyproject.toml`.")
|
161 |
+
# fastcore_version is an empty string if not specified
|
162 |
+
else:
|
163 |
+
fastcore_version = str(fastcore_packages[0]).partition("=")[2]
|
164 |
+
if fastcore_version != "" and version.Version(fastcore_version) < version.Version(fastcore_min_version):
|
165 |
+
raise ImportError(
|
166 |
+
"`from_pretrained_fastai` requires"
|
167 |
+
f" fastcore>={fastcore_min_version} version, but you are using fastcore"
|
168 |
+
f" version {fastcore_version} which is incompatible."
|
169 |
+
)
|
170 |
+
|
171 |
+
|
172 |
+
README_TEMPLATE = """---
|
173 |
+
tags:
|
174 |
+
- fastai
|
175 |
+
---
|
176 |
+
|
177 |
+
# Amazing!
|
178 |
+
|
179 |
+
🥳 Congratulations on hosting your fastai model on the Hugging Face Hub!
|
180 |
+
|
181 |
+
# Some next steps
|
182 |
+
1. Fill out this model card with more information (see the template below and the [documentation here](https://huggingface.co/docs/hub/model-repos))!
|
183 |
+
|
184 |
+
2. Create a demo in Gradio or Streamlit using 🤗 Spaces ([documentation here](https://huggingface.co/docs/hub/spaces)).
|
185 |
+
|
186 |
+
3. Join the fastai community on the [Fastai Discord](https://discord.com/invite/YKrxeNn)!
|
187 |
+
|
188 |
+
Greetings fellow fastlearner 🤝! Don't forget to delete this content from your model card.
|
189 |
+
|
190 |
+
|
191 |
+
---
|
192 |
+
|
193 |
+
|
194 |
+
# Model card
|
195 |
+
|
196 |
+
## Model description
|
197 |
+
More information needed
|
198 |
+
|
199 |
+
## Intended uses & limitations
|
200 |
+
More information needed
|
201 |
+
|
202 |
+
## Training and evaluation data
|
203 |
+
More information needed
|
204 |
+
"""
|
205 |
+
|
206 |
+
PYPROJECT_TEMPLATE = f"""[build-system]
|
207 |
+
requires = ["setuptools>=40.8.0", "wheel", "python={get_python_version()}", "fastai={get_fastai_version()}", "fastcore={get_fastcore_version()}"]
|
208 |
+
build-backend = "setuptools.build_meta:__legacy__"
|
209 |
+
"""
|
210 |
+
|
211 |
+
|
212 |
+
def _create_model_card(repo_dir: Path):
|
213 |
+
"""
|
214 |
+
Creates a model card for the repository.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
repo_dir (`Path`):
|
218 |
+
Directory where model card is created.
|
219 |
+
"""
|
220 |
+
readme_path = repo_dir / "README.md"
|
221 |
+
|
222 |
+
if not readme_path.exists():
|
223 |
+
with readme_path.open("w", encoding="utf-8") as f:
|
224 |
+
f.write(README_TEMPLATE)
|
225 |
+
|
226 |
+
|
227 |
+
def _create_model_pyproject(repo_dir: Path):
|
228 |
+
"""
|
229 |
+
Creates a `pyproject.toml` for the repository.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
repo_dir (`Path`):
|
233 |
+
Directory where `pyproject.toml` is created.
|
234 |
+
"""
|
235 |
+
pyproject_path = repo_dir / "pyproject.toml"
|
236 |
+
|
237 |
+
if not pyproject_path.exists():
|
238 |
+
with pyproject_path.open("w", encoding="utf-8") as f:
|
239 |
+
f.write(PYPROJECT_TEMPLATE)
|
240 |
+
|
241 |
+
|
242 |
+
def _save_pretrained_fastai(
|
243 |
+
learner,
|
244 |
+
save_directory: Union[str, Path],
|
245 |
+
config: Optional[Dict[str, Any]] = None,
|
246 |
+
):
|
247 |
+
"""
|
248 |
+
Saves a fastai learner to `save_directory` in pickle format using the default pickle protocol for the version of python used.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
learner (`Learner`):
|
252 |
+
The `fastai.Learner` you'd like to save.
|
253 |
+
save_directory (`str` or `Path`):
|
254 |
+
Specific directory in which you want to save the fastai learner.
|
255 |
+
config (`dict`, *optional*):
|
256 |
+
Configuration object. Will be uploaded as a .json file. Example: 'https://huggingface.co/espejelomar/fastai-pet-breeds-classification/blob/main/config.json'.
|
257 |
+
|
258 |
+
<Tip>
|
259 |
+
|
260 |
+
Raises the following error:
|
261 |
+
|
262 |
+
- [`RuntimeError`](https://docs.python.org/3/library/exceptions.html#RuntimeError)
|
263 |
+
if the config file provided is not a dictionary.
|
264 |
+
|
265 |
+
</Tip>
|
266 |
+
"""
|
267 |
+
_check_fastai_fastcore_versions()
|
268 |
+
|
269 |
+
os.makedirs(save_directory, exist_ok=True)
|
270 |
+
|
271 |
+
# if the user provides config then we update it with the fastai and fastcore versions in CONFIG_TEMPLATE.
|
272 |
+
if config is not None:
|
273 |
+
if not isinstance(config, dict):
|
274 |
+
raise RuntimeError(f"Provided config should be a dict. Got: '{type(config)}'")
|
275 |
+
path = os.path.join(save_directory, CONFIG_NAME)
|
276 |
+
with open(path, "w") as f:
|
277 |
+
json.dump(config, f)
|
278 |
+
|
279 |
+
_create_model_card(Path(save_directory))
|
280 |
+
_create_model_pyproject(Path(save_directory))
|
281 |
+
|
282 |
+
# learner.export saves the model in `self.path`.
|
283 |
+
learner.path = Path(save_directory)
|
284 |
+
os.makedirs(save_directory, exist_ok=True)
|
285 |
+
try:
|
286 |
+
learner.export(
|
287 |
+
fname="model.pkl",
|
288 |
+
pickle_protocol=DEFAULT_PROTOCOL,
|
289 |
+
)
|
290 |
+
except PicklingError:
|
291 |
+
raise PicklingError(
|
292 |
+
"You are using a lambda function, i.e., an anonymous function. `pickle`"
|
293 |
+
" cannot pickle function objects and requires that all functions have"
|
294 |
+
" names. One possible solution is to name the function."
|
295 |
+
)
|
296 |
+
|
297 |
+
|
298 |
+
@validate_hf_hub_args
|
299 |
+
def from_pretrained_fastai(
|
300 |
+
repo_id: str,
|
301 |
+
revision: Optional[str] = None,
|
302 |
+
):
|
303 |
+
"""
|
304 |
+
Load pretrained fastai model from the Hub or from a local directory.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
repo_id (`str`):
|
308 |
+
The location where the pickled fastai.Learner is. It can be either of the two:
|
309 |
+
- Hosted on the Hugging Face Hub. E.g.: 'espejelomar/fatai-pet-breeds-classification' or 'distilgpt2'.
|
310 |
+
You can add a `revision` by appending `@` at the end of `repo_id`. E.g.: `dbmdz/bert-base-german-cased@main`.
|
311 |
+
Revision is the specific model version to use. Since we use a git-based system for storing models and other
|
312 |
+
artifacts on the Hugging Face Hub, it can be a branch name, a tag name, or a commit id.
|
313 |
+
- Hosted locally. `repo_id` would be a directory containing the pickle and a pyproject.toml
|
314 |
+
indicating the fastai and fastcore versions used to build the `fastai.Learner`. E.g.: `./my_model_directory/`.
|
315 |
+
revision (`str`, *optional*):
|
316 |
+
Revision at which the repo's files are downloaded. See documentation of `snapshot_download`.
|
317 |
+
|
318 |
+
Returns:
|
319 |
+
The `fastai.Learner` model in the `repo_id` repo.
|
320 |
+
"""
|
321 |
+
_check_fastai_fastcore_versions()
|
322 |
+
|
323 |
+
# Load the `repo_id` repo.
|
324 |
+
# `snapshot_download` returns the folder where the model was stored.
|
325 |
+
# `cache_dir` will be the default '/root/.cache/huggingface/hub'
|
326 |
+
if not os.path.isdir(repo_id):
|
327 |
+
storage_folder = snapshot_download(
|
328 |
+
repo_id=repo_id,
|
329 |
+
revision=revision,
|
330 |
+
library_name="fastai",
|
331 |
+
library_version=get_fastai_version(),
|
332 |
+
)
|
333 |
+
else:
|
334 |
+
storage_folder = repo_id
|
335 |
+
|
336 |
+
_check_fastai_fastcore_pyproject_versions(storage_folder)
|
337 |
+
|
338 |
+
from fastai.learner import load_learner # type: ignore
|
339 |
+
|
340 |
+
return load_learner(os.path.join(storage_folder, "model.pkl"))
|
341 |
+
|
342 |
+
|
343 |
+
@validate_hf_hub_args
|
344 |
+
def push_to_hub_fastai(
|
345 |
+
learner,
|
346 |
+
*,
|
347 |
+
repo_id: str,
|
348 |
+
commit_message: str = "Push FastAI model using huggingface_hub.",
|
349 |
+
private: bool = False,
|
350 |
+
token: Optional[str] = None,
|
351 |
+
config: Optional[dict] = None,
|
352 |
+
branch: Optional[str] = None,
|
353 |
+
create_pr: Optional[bool] = None,
|
354 |
+
allow_patterns: Optional[Union[List[str], str]] = None,
|
355 |
+
ignore_patterns: Optional[Union[List[str], str]] = None,
|
356 |
+
delete_patterns: Optional[Union[List[str], str]] = None,
|
357 |
+
api_endpoint: Optional[str] = None,
|
358 |
+
):
|
359 |
+
"""
|
360 |
+
Upload learner checkpoint files to the Hub.
|
361 |
+
|
362 |
+
Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
|
363 |
+
`delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
|
364 |
+
details.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
learner (`Learner`):
|
368 |
+
The `fastai.Learner' you'd like to push to the Hub.
|
369 |
+
repo_id (`str`):
|
370 |
+
The repository id for your model in Hub in the format of "namespace/repo_name". The namespace can be your individual account or an organization to which you have write access (for example, 'stanfordnlp/stanza-de').
|
371 |
+
commit_message (`str`, *optional*):
|
372 |
+
Message to commit while pushing. Will default to :obj:`"add model"`.
|
373 |
+
private (`bool`, *optional*, defaults to `False`):
|
374 |
+
Whether or not the repository created should be private.
|
375 |
+
token (`str`, *optional*):
|
376 |
+
The Hugging Face account token to use as HTTP bearer authorization for remote files. If :obj:`None`, the token will be asked by a prompt.
|
377 |
+
config (`dict`, *optional*):
|
378 |
+
Configuration object to be saved alongside the model weights.
|
379 |
+
branch (`str`, *optional*):
|
380 |
+
The git branch on which to push the model. This defaults to
|
381 |
+
the default branch as specified in your repository, which
|
382 |
+
defaults to `"main"`.
|
383 |
+
create_pr (`boolean`, *optional*):
|
384 |
+
Whether or not to create a Pull Request from `branch` with that commit.
|
385 |
+
Defaults to `False`.
|
386 |
+
api_endpoint (`str`, *optional*):
|
387 |
+
The API endpoint to use when pushing the model to the hub.
|
388 |
+
allow_patterns (`List[str]` or `str`, *optional*):
|
389 |
+
If provided, only files matching at least one pattern are pushed.
|
390 |
+
ignore_patterns (`List[str]` or `str`, *optional*):
|
391 |
+
If provided, files matching any of the patterns are not pushed.
|
392 |
+
delete_patterns (`List[str]` or `str`, *optional*):
|
393 |
+
If provided, remote files matching any of the patterns will be deleted from the repo.
|
394 |
+
|
395 |
+
Returns:
|
396 |
+
The url of the commit of your model in the given repository.
|
397 |
+
|
398 |
+
<Tip>
|
399 |
+
|
400 |
+
Raises the following error:
|
401 |
+
|
402 |
+
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
403 |
+
if the user is not log on to the Hugging Face Hub.
|
404 |
+
|
405 |
+
</Tip>
|
406 |
+
"""
|
407 |
+
_check_fastai_fastcore_versions()
|
408 |
+
api = HfApi(endpoint=api_endpoint)
|
409 |
+
repo_id = api.create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True).repo_id
|
410 |
+
|
411 |
+
# Push the files to the repo in a single commit
|
412 |
+
with SoftTemporaryDirectory() as tmp:
|
413 |
+
saved_path = Path(tmp) / repo_id
|
414 |
+
_save_pretrained_fastai(learner, saved_path, config=config)
|
415 |
+
return api.upload_folder(
|
416 |
+
repo_id=repo_id,
|
417 |
+
token=token,
|
418 |
+
folder_path=saved_path,
|
419 |
+
commit_message=commit_message,
|
420 |
+
revision=branch,
|
421 |
+
create_pr=create_pr,
|
422 |
+
allow_patterns=allow_patterns,
|
423 |
+
ignore_patterns=ignore_patterns,
|
424 |
+
delete_patterns=delete_patterns,
|
425 |
+
)
|
lib/python3.11/site-packages/huggingface_hub/file_download.py
ADDED
@@ -0,0 +1,1727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import fnmatch
|
3 |
+
import inspect
|
4 |
+
import io
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import re
|
8 |
+
import shutil
|
9 |
+
import stat
|
10 |
+
import tempfile
|
11 |
+
import time
|
12 |
+
import uuid
|
13 |
+
import warnings
|
14 |
+
from contextlib import contextmanager
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from functools import partial
|
17 |
+
from pathlib import Path
|
18 |
+
from typing import Any, BinaryIO, Dict, Generator, Literal, Optional, Tuple, Union
|
19 |
+
from urllib.parse import quote, urlparse
|
20 |
+
|
21 |
+
import requests
|
22 |
+
from filelock import FileLock
|
23 |
+
|
24 |
+
from huggingface_hub import constants
|
25 |
+
|
26 |
+
from . import __version__ # noqa: F401 # for backward compatibility
|
27 |
+
from .constants import (
|
28 |
+
DEFAULT_ETAG_TIMEOUT,
|
29 |
+
DEFAULT_REQUEST_TIMEOUT,
|
30 |
+
DEFAULT_REVISION,
|
31 |
+
DOWNLOAD_CHUNK_SIZE,
|
32 |
+
ENDPOINT,
|
33 |
+
HF_HUB_CACHE,
|
34 |
+
HF_HUB_DISABLE_SYMLINKS_WARNING,
|
35 |
+
HF_HUB_DOWNLOAD_TIMEOUT,
|
36 |
+
HF_HUB_ENABLE_HF_TRANSFER,
|
37 |
+
HF_HUB_ETAG_TIMEOUT,
|
38 |
+
HF_TRANSFER_CONCURRENCY,
|
39 |
+
HUGGINGFACE_CO_URL_TEMPLATE,
|
40 |
+
HUGGINGFACE_HEADER_X_LINKED_ETAG,
|
41 |
+
HUGGINGFACE_HEADER_X_LINKED_SIZE,
|
42 |
+
HUGGINGFACE_HEADER_X_REPO_COMMIT,
|
43 |
+
HUGGINGFACE_HUB_CACHE, # noqa: F401 # for backward compatibility
|
44 |
+
REPO_ID_SEPARATOR,
|
45 |
+
REPO_TYPES,
|
46 |
+
REPO_TYPES_URL_PREFIXES,
|
47 |
+
)
|
48 |
+
from .utils import (
|
49 |
+
EntryNotFoundError,
|
50 |
+
FileMetadataError,
|
51 |
+
GatedRepoError,
|
52 |
+
LocalEntryNotFoundError,
|
53 |
+
OfflineModeIsEnabled,
|
54 |
+
RepositoryNotFoundError,
|
55 |
+
RevisionNotFoundError,
|
56 |
+
SoftTemporaryDirectory,
|
57 |
+
build_hf_headers,
|
58 |
+
get_fastai_version, # noqa: F401 # for backward compatibility
|
59 |
+
get_fastcore_version, # noqa: F401 # for backward compatibility
|
60 |
+
get_graphviz_version, # noqa: F401 # for backward compatibility
|
61 |
+
get_jinja_version, # noqa: F401 # for backward compatibility
|
62 |
+
get_pydot_version, # noqa: F401 # for backward compatibility
|
63 |
+
get_session,
|
64 |
+
get_tf_version, # noqa: F401 # for backward compatibility
|
65 |
+
get_torch_version, # noqa: F401 # for backward compatibility
|
66 |
+
hf_raise_for_status,
|
67 |
+
is_fastai_available, # noqa: F401 # for backward compatibility
|
68 |
+
is_fastcore_available, # noqa: F401 # for backward compatibility
|
69 |
+
is_graphviz_available, # noqa: F401 # for backward compatibility
|
70 |
+
is_jinja_available, # noqa: F401 # for backward compatibility
|
71 |
+
is_pydot_available, # noqa: F401 # for backward compatibility
|
72 |
+
is_tf_available, # noqa: F401 # for backward compatibility
|
73 |
+
is_torch_available, # noqa: F401 # for backward compatibility
|
74 |
+
logging,
|
75 |
+
reset_sessions,
|
76 |
+
tqdm,
|
77 |
+
validate_hf_hub_args,
|
78 |
+
)
|
79 |
+
from .utils._deprecation import _deprecate_method
|
80 |
+
from .utils._headers import _http_user_agent
|
81 |
+
from .utils._runtime import _PY_VERSION # noqa: F401 # for backward compatibility
|
82 |
+
from .utils._typing import HTTP_METHOD_T
|
83 |
+
from .utils.insecure_hashlib import sha256
|
84 |
+
|
85 |
+
|
86 |
+
logger = logging.get_logger(__name__)
|
87 |
+
|
88 |
+
# Regex to get filename from a "Content-Disposition" header for CDN-served files
|
89 |
+
HEADER_FILENAME_PATTERN = re.compile(r'filename="(?P<filename>.*?)";')
|
90 |
+
|
91 |
+
|
92 |
+
_are_symlinks_supported_in_dir: Dict[str, bool] = {}
|
93 |
+
|
94 |
+
|
95 |
+
def are_symlinks_supported(cache_dir: Union[str, Path, None] = None) -> bool:
|
96 |
+
"""Return whether the symlinks are supported on the machine.
|
97 |
+
|
98 |
+
Since symlinks support can change depending on the mounted disk, we need to check
|
99 |
+
on the precise cache folder. By default, the default HF cache directory is checked.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
cache_dir (`str`, `Path`, *optional*):
|
103 |
+
Path to the folder where cached files are stored.
|
104 |
+
|
105 |
+
Returns: [bool] Whether symlinks are supported in the directory.
|
106 |
+
"""
|
107 |
+
# Defaults to HF cache
|
108 |
+
if cache_dir is None:
|
109 |
+
cache_dir = HF_HUB_CACHE
|
110 |
+
cache_dir = str(Path(cache_dir).expanduser().resolve()) # make it unique
|
111 |
+
|
112 |
+
# Check symlink compatibility only once (per cache directory) at first time use
|
113 |
+
if cache_dir not in _are_symlinks_supported_in_dir:
|
114 |
+
_are_symlinks_supported_in_dir[cache_dir] = True
|
115 |
+
|
116 |
+
os.makedirs(cache_dir, exist_ok=True)
|
117 |
+
with SoftTemporaryDirectory(dir=cache_dir) as tmpdir:
|
118 |
+
src_path = Path(tmpdir) / "dummy_file_src"
|
119 |
+
src_path.touch()
|
120 |
+
dst_path = Path(tmpdir) / "dummy_file_dst"
|
121 |
+
|
122 |
+
# Relative source path as in `_create_symlink``
|
123 |
+
relative_src = os.path.relpath(src_path, start=os.path.dirname(dst_path))
|
124 |
+
try:
|
125 |
+
os.symlink(relative_src, dst_path)
|
126 |
+
except OSError:
|
127 |
+
# Likely running on Windows
|
128 |
+
_are_symlinks_supported_in_dir[cache_dir] = False
|
129 |
+
|
130 |
+
if not HF_HUB_DISABLE_SYMLINKS_WARNING:
|
131 |
+
message = (
|
132 |
+
"`huggingface_hub` cache-system uses symlinks by default to"
|
133 |
+
" efficiently store duplicated files but your machine does not"
|
134 |
+
f" support them in {cache_dir}. Caching files will still work"
|
135 |
+
" but in a degraded version that might require more space on"
|
136 |
+
" your disk. This warning can be disabled by setting the"
|
137 |
+
" `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For"
|
138 |
+
" more details, see"
|
139 |
+
" https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations."
|
140 |
+
)
|
141 |
+
if os.name == "nt":
|
142 |
+
message += (
|
143 |
+
"\nTo support symlinks on Windows, you either need to"
|
144 |
+
" activate Developer Mode or to run Python as an"
|
145 |
+
" administrator. In order to see activate developer mode,"
|
146 |
+
" see this article:"
|
147 |
+
" https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development"
|
148 |
+
)
|
149 |
+
warnings.warn(message)
|
150 |
+
|
151 |
+
return _are_symlinks_supported_in_dir[cache_dir]
|
152 |
+
|
153 |
+
|
154 |
+
# Return value when trying to load a file from cache but the file does not exist in the distant repo.
|
155 |
+
_CACHED_NO_EXIST = object()
|
156 |
+
_CACHED_NO_EXIST_T = Any
|
157 |
+
REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$")
|
158 |
+
|
159 |
+
|
160 |
+
@dataclass(frozen=True)
|
161 |
+
class HfFileMetadata:
|
162 |
+
"""Data structure containing information about a file versioned on the Hub.
|
163 |
+
|
164 |
+
Returned by [`get_hf_file_metadata`] based on a URL.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
commit_hash (`str`, *optional*):
|
168 |
+
The commit_hash related to the file.
|
169 |
+
etag (`str`, *optional*):
|
170 |
+
Etag of the file on the server.
|
171 |
+
location (`str`):
|
172 |
+
Location where to download the file. Can be a Hub url or not (CDN).
|
173 |
+
size (`size`):
|
174 |
+
Size of the file. In case of an LFS file, contains the size of the actual
|
175 |
+
LFS file, not the pointer.
|
176 |
+
"""
|
177 |
+
|
178 |
+
commit_hash: Optional[str]
|
179 |
+
etag: Optional[str]
|
180 |
+
location: str
|
181 |
+
size: Optional[int]
|
182 |
+
|
183 |
+
|
184 |
+
@validate_hf_hub_args
|
185 |
+
def hf_hub_url(
|
186 |
+
repo_id: str,
|
187 |
+
filename: str,
|
188 |
+
*,
|
189 |
+
subfolder: Optional[str] = None,
|
190 |
+
repo_type: Optional[str] = None,
|
191 |
+
revision: Optional[str] = None,
|
192 |
+
endpoint: Optional[str] = None,
|
193 |
+
) -> str:
|
194 |
+
"""Construct the URL of a file from the given information.
|
195 |
+
|
196 |
+
The resolved address can either be a huggingface.co-hosted url, or a link to
|
197 |
+
Cloudfront (a Content Delivery Network, or CDN) for large files which are
|
198 |
+
more than a few MBs.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
repo_id (`str`):
|
202 |
+
A namespace (user or an organization) name and a repo name separated
|
203 |
+
by a `/`.
|
204 |
+
filename (`str`):
|
205 |
+
The name of the file in the repo.
|
206 |
+
subfolder (`str`, *optional*):
|
207 |
+
An optional value corresponding to a folder inside the repo.
|
208 |
+
repo_type (`str`, *optional*):
|
209 |
+
Set to `"dataset"` or `"space"` if downloading from a dataset or space,
|
210 |
+
`None` or `"model"` if downloading from a model. Default is `None`.
|
211 |
+
revision (`str`, *optional*):
|
212 |
+
An optional Git revision id which can be a branch name, a tag, or a
|
213 |
+
commit hash.
|
214 |
+
|
215 |
+
Example:
|
216 |
+
|
217 |
+
```python
|
218 |
+
>>> from huggingface_hub import hf_hub_url
|
219 |
+
|
220 |
+
>>> hf_hub_url(
|
221 |
+
... repo_id="julien-c/EsperBERTo-small", filename="pytorch_model.bin"
|
222 |
+
... )
|
223 |
+
'https://huggingface.co/julien-c/EsperBERTo-small/resolve/main/pytorch_model.bin'
|
224 |
+
```
|
225 |
+
|
226 |
+
<Tip>
|
227 |
+
|
228 |
+
Notes:
|
229 |
+
|
230 |
+
Cloudfront is replicated over the globe so downloads are way faster for
|
231 |
+
the end user (and it also lowers our bandwidth costs).
|
232 |
+
|
233 |
+
Cloudfront aggressively caches files by default (default TTL is 24
|
234 |
+
hours), however this is not an issue here because we implement a
|
235 |
+
git-based versioning system on huggingface.co, which means that we store
|
236 |
+
the files on S3/Cloudfront in a content-addressable way (i.e., the file
|
237 |
+
name is its hash). Using content-addressable filenames means cache can't
|
238 |
+
ever be stale.
|
239 |
+
|
240 |
+
In terms of client-side caching from this library, we base our caching
|
241 |
+
on the objects' entity tag (`ETag`), which is an identifier of a
|
242 |
+
specific version of a resource [1]_. An object's ETag is: its git-sha1
|
243 |
+
if stored in git, or its sha256 if stored in git-lfs.
|
244 |
+
|
245 |
+
</Tip>
|
246 |
+
|
247 |
+
References:
|
248 |
+
|
249 |
+
- [1] https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag
|
250 |
+
"""
|
251 |
+
if subfolder == "":
|
252 |
+
subfolder = None
|
253 |
+
if subfolder is not None:
|
254 |
+
filename = f"{subfolder}/{filename}"
|
255 |
+
|
256 |
+
if repo_type not in REPO_TYPES:
|
257 |
+
raise ValueError("Invalid repo type")
|
258 |
+
|
259 |
+
if repo_type in REPO_TYPES_URL_PREFIXES:
|
260 |
+
repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id
|
261 |
+
|
262 |
+
if revision is None:
|
263 |
+
revision = DEFAULT_REVISION
|
264 |
+
url = HUGGINGFACE_CO_URL_TEMPLATE.format(
|
265 |
+
repo_id=repo_id, revision=quote(revision, safe=""), filename=quote(filename)
|
266 |
+
)
|
267 |
+
# Update endpoint if provided
|
268 |
+
if endpoint is not None and url.startswith(ENDPOINT):
|
269 |
+
url = endpoint + url[len(ENDPOINT) :]
|
270 |
+
return url
|
271 |
+
|
272 |
+
|
273 |
+
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
|
274 |
+
"""Generate a local filename from a url.
|
275 |
+
|
276 |
+
Convert `url` into a hashed filename in a reproducible way. If `etag` is
|
277 |
+
specified, append its hash to the url's, delimited by a period. If the url
|
278 |
+
ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
|
279 |
+
identify it as a HDF5 file (see
|
280 |
+
https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
|
281 |
+
|
282 |
+
Args:
|
283 |
+
url (`str`):
|
284 |
+
The address to the file.
|
285 |
+
etag (`str`, *optional*):
|
286 |
+
The ETag of the file.
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
The generated filename.
|
290 |
+
"""
|
291 |
+
url_bytes = url.encode("utf-8")
|
292 |
+
filename = sha256(url_bytes).hexdigest()
|
293 |
+
|
294 |
+
if etag:
|
295 |
+
etag_bytes = etag.encode("utf-8")
|
296 |
+
filename += "." + sha256(etag_bytes).hexdigest()
|
297 |
+
|
298 |
+
if url.endswith(".h5"):
|
299 |
+
filename += ".h5"
|
300 |
+
|
301 |
+
return filename
|
302 |
+
|
303 |
+
|
304 |
+
def filename_to_url(
|
305 |
+
filename,
|
306 |
+
cache_dir: Optional[str] = None,
|
307 |
+
legacy_cache_layout: bool = False,
|
308 |
+
) -> Tuple[str, str]:
|
309 |
+
"""
|
310 |
+
Return the url and etag (which may be `None`) stored for `filename`. Raise
|
311 |
+
`EnvironmentError` if `filename` or its stored metadata do not exist.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
filename (`str`):
|
315 |
+
The name of the file
|
316 |
+
cache_dir (`str`, *optional*):
|
317 |
+
The cache directory to use instead of the default one.
|
318 |
+
legacy_cache_layout (`bool`, *optional*, defaults to `False`):
|
319 |
+
If `True`, uses the legacy file cache layout i.e. just call `hf_hub_url`
|
320 |
+
then `cached_download`. This is deprecated as the new cache layout is
|
321 |
+
more powerful.
|
322 |
+
"""
|
323 |
+
if not legacy_cache_layout:
|
324 |
+
warnings.warn(
|
325 |
+
"`filename_to_url` uses the legacy way cache file layout",
|
326 |
+
FutureWarning,
|
327 |
+
)
|
328 |
+
|
329 |
+
if cache_dir is None:
|
330 |
+
cache_dir = HF_HUB_CACHE
|
331 |
+
if isinstance(cache_dir, Path):
|
332 |
+
cache_dir = str(cache_dir)
|
333 |
+
|
334 |
+
cache_path = os.path.join(cache_dir, filename)
|
335 |
+
if not os.path.exists(cache_path):
|
336 |
+
raise EnvironmentError(f"file {cache_path} not found")
|
337 |
+
|
338 |
+
meta_path = cache_path + ".json"
|
339 |
+
if not os.path.exists(meta_path):
|
340 |
+
raise EnvironmentError(f"file {meta_path} not found")
|
341 |
+
|
342 |
+
with open(meta_path, encoding="utf-8") as meta_file:
|
343 |
+
metadata = json.load(meta_file)
|
344 |
+
url = metadata["url"]
|
345 |
+
etag = metadata["etag"]
|
346 |
+
|
347 |
+
return url, etag
|
348 |
+
|
349 |
+
|
350 |
+
@_deprecate_method(version="0.22.0", message="Use `huggingface_hub.utils.build_hf_headers` instead.")
|
351 |
+
def http_user_agent(
|
352 |
+
*,
|
353 |
+
library_name: Optional[str] = None,
|
354 |
+
library_version: Optional[str] = None,
|
355 |
+
user_agent: Union[Dict, str, None] = None,
|
356 |
+
) -> str:
|
357 |
+
"""Deprecated in favor of [`build_hf_headers`]."""
|
358 |
+
return _http_user_agent(
|
359 |
+
library_name=library_name,
|
360 |
+
library_version=library_version,
|
361 |
+
user_agent=user_agent,
|
362 |
+
)
|
363 |
+
|
364 |
+
|
365 |
+
def _request_wrapper(
|
366 |
+
method: HTTP_METHOD_T, url: str, *, follow_relative_redirects: bool = False, **params
|
367 |
+
) -> requests.Response:
|
368 |
+
"""Wrapper around requests methods to follow relative redirects if `follow_relative_redirects=True` even when
|
369 |
+
`allow_redirection=False`.
|
370 |
+
|
371 |
+
Args:
|
372 |
+
method (`str`):
|
373 |
+
HTTP method, such as 'GET' or 'HEAD'.
|
374 |
+
url (`str`):
|
375 |
+
The URL of the resource to fetch.
|
376 |
+
follow_relative_redirects (`bool`, *optional*, defaults to `False`)
|
377 |
+
If True, relative redirection (redirection to the same site) will be resolved even when `allow_redirection`
|
378 |
+
kwarg is set to False. Useful when we want to follow a redirection to a renamed repository without
|
379 |
+
following redirection to a CDN.
|
380 |
+
**params (`dict`, *optional*):
|
381 |
+
Params to pass to `requests.request`.
|
382 |
+
"""
|
383 |
+
# Recursively follow relative redirects
|
384 |
+
if follow_relative_redirects:
|
385 |
+
response = _request_wrapper(
|
386 |
+
method=method,
|
387 |
+
url=url,
|
388 |
+
follow_relative_redirects=False,
|
389 |
+
**params,
|
390 |
+
)
|
391 |
+
|
392 |
+
# If redirection, we redirect only relative paths.
|
393 |
+
# This is useful in case of a renamed repository.
|
394 |
+
if 300 <= response.status_code <= 399:
|
395 |
+
parsed_target = urlparse(response.headers["Location"])
|
396 |
+
if parsed_target.netloc == "":
|
397 |
+
# This means it is a relative 'location' headers, as allowed by RFC 7231.
|
398 |
+
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
|
399 |
+
# We want to follow this relative redirect !
|
400 |
+
#
|
401 |
+
# Highly inspired by `resolve_redirects` from requests library.
|
402 |
+
# See https://github.com/psf/requests/blob/main/requests/sessions.py#L159
|
403 |
+
next_url = urlparse(url)._replace(path=parsed_target.path).geturl()
|
404 |
+
return _request_wrapper(method=method, url=next_url, follow_relative_redirects=True, **params)
|
405 |
+
return response
|
406 |
+
|
407 |
+
# Perform request and return if status_code is not in the retry list.
|
408 |
+
response = get_session().request(method=method, url=url, **params)
|
409 |
+
hf_raise_for_status(response)
|
410 |
+
return response
|
411 |
+
|
412 |
+
|
413 |
+
def http_get(
|
414 |
+
url: str,
|
415 |
+
temp_file: BinaryIO,
|
416 |
+
*,
|
417 |
+
proxies=None,
|
418 |
+
resume_size: float = 0,
|
419 |
+
headers: Optional[Dict[str, str]] = None,
|
420 |
+
expected_size: Optional[int] = None,
|
421 |
+
_nb_retries: int = 5,
|
422 |
+
):
|
423 |
+
"""
|
424 |
+
Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
|
425 |
+
|
426 |
+
If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely a
|
427 |
+
transient error (network outage?). We log a warning message and try to resume the download a few times before
|
428 |
+
giving up. The method gives up after 5 attempts if no new data has being received from the server.
|
429 |
+
"""
|
430 |
+
hf_transfer = None
|
431 |
+
if HF_HUB_ENABLE_HF_TRANSFER:
|
432 |
+
if resume_size != 0:
|
433 |
+
warnings.warn("'hf_transfer' does not support `resume_size`: falling back to regular download method")
|
434 |
+
elif proxies is not None:
|
435 |
+
warnings.warn("'hf_transfer' does not support `proxies`: falling back to regular download method")
|
436 |
+
else:
|
437 |
+
try:
|
438 |
+
import hf_transfer # type: ignore[no-redef]
|
439 |
+
except ImportError:
|
440 |
+
raise ValueError(
|
441 |
+
"Fast download using 'hf_transfer' is enabled"
|
442 |
+
" (HF_HUB_ENABLE_HF_TRANSFER=1) but 'hf_transfer' package is not"
|
443 |
+
" available in your environment. Try `pip install hf_transfer`."
|
444 |
+
)
|
445 |
+
|
446 |
+
initial_headers = headers
|
447 |
+
headers = copy.deepcopy(headers) or {}
|
448 |
+
if resume_size > 0:
|
449 |
+
headers["Range"] = "bytes=%d-" % (resume_size,)
|
450 |
+
|
451 |
+
r = _request_wrapper(
|
452 |
+
method="GET", url=url, stream=True, proxies=proxies, headers=headers, timeout=HF_HUB_DOWNLOAD_TIMEOUT
|
453 |
+
)
|
454 |
+
hf_raise_for_status(r)
|
455 |
+
content_length = r.headers.get("Content-Length")
|
456 |
+
|
457 |
+
# NOTE: 'total' is the total number of bytes to download, not the number of bytes in the file.
|
458 |
+
# If the file is compressed, the number of bytes in the saved file will be higher than 'total'.
|
459 |
+
total = resume_size + int(content_length) if content_length is not None else None
|
460 |
+
|
461 |
+
displayed_name = url
|
462 |
+
content_disposition = r.headers.get("Content-Disposition")
|
463 |
+
if content_disposition is not None:
|
464 |
+
match = HEADER_FILENAME_PATTERN.search(content_disposition)
|
465 |
+
if match is not None:
|
466 |
+
# Means file is on CDN
|
467 |
+
displayed_name = match.groupdict()["filename"]
|
468 |
+
|
469 |
+
# Truncate filename if too long to display
|
470 |
+
if len(displayed_name) > 40:
|
471 |
+
displayed_name = f"(…){displayed_name[-40:]}"
|
472 |
+
|
473 |
+
consistency_error_message = (
|
474 |
+
f"Consistency check failed: file should be of size {expected_size} but has size"
|
475 |
+
f" {{actual_size}} ({displayed_name}).\nWe are sorry for the inconvenience. Please retry download and"
|
476 |
+
" pass `force_download=True, resume_download=False` as argument.\nIf the issue persists, please let us"
|
477 |
+
" know by opening an issue on https://github.com/huggingface/huggingface_hub."
|
478 |
+
)
|
479 |
+
|
480 |
+
# Stream file to buffer
|
481 |
+
with tqdm(
|
482 |
+
unit="B",
|
483 |
+
unit_scale=True,
|
484 |
+
total=total,
|
485 |
+
initial=resume_size,
|
486 |
+
desc=displayed_name,
|
487 |
+
disable=bool(logger.getEffectiveLevel() == logging.NOTSET),
|
488 |
+
) as progress:
|
489 |
+
if hf_transfer and total is not None and total > 5 * DOWNLOAD_CHUNK_SIZE:
|
490 |
+
supports_callback = "callback" in inspect.signature(hf_transfer.download).parameters
|
491 |
+
if not supports_callback:
|
492 |
+
warnings.warn(
|
493 |
+
"You are using an outdated version of `hf_transfer`. "
|
494 |
+
"Consider upgrading to latest version to enable progress bars "
|
495 |
+
"using `pip install -U hf_transfer`."
|
496 |
+
)
|
497 |
+
try:
|
498 |
+
hf_transfer.download(
|
499 |
+
url=url,
|
500 |
+
filename=temp_file.name,
|
501 |
+
max_files=HF_TRANSFER_CONCURRENCY,
|
502 |
+
chunk_size=DOWNLOAD_CHUNK_SIZE,
|
503 |
+
headers=headers,
|
504 |
+
parallel_failures=3,
|
505 |
+
max_retries=5,
|
506 |
+
**({"callback": progress.update} if supports_callback else {}),
|
507 |
+
)
|
508 |
+
except Exception as e:
|
509 |
+
raise RuntimeError(
|
510 |
+
"An error occurred while downloading using `hf_transfer`. Consider"
|
511 |
+
" disabling HF_HUB_ENABLE_HF_TRANSFER for better error handling."
|
512 |
+
) from e
|
513 |
+
if not supports_callback:
|
514 |
+
progress.update(total)
|
515 |
+
if expected_size is not None and expected_size != os.path.getsize(temp_file.name):
|
516 |
+
raise EnvironmentError(
|
517 |
+
consistency_error_message.format(
|
518 |
+
actual_size=os.path.getsize(temp_file.name),
|
519 |
+
)
|
520 |
+
)
|
521 |
+
return
|
522 |
+
new_resume_size = resume_size
|
523 |
+
try:
|
524 |
+
for chunk in r.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
|
525 |
+
if chunk: # filter out keep-alive new chunks
|
526 |
+
progress.update(len(chunk))
|
527 |
+
temp_file.write(chunk)
|
528 |
+
new_resume_size += len(chunk)
|
529 |
+
# Some data has been downloaded from the server so we reset the number of retries.
|
530 |
+
_nb_retries = 5
|
531 |
+
except (requests.ConnectionError, requests.ReadTimeout) as e:
|
532 |
+
# If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely
|
533 |
+
# a transient error (network outage?). We log a warning message and try to resume the download a few times
|
534 |
+
# before giving up. Tre retry mechanism is basic but should be enough in most cases.
|
535 |
+
if _nb_retries <= 0:
|
536 |
+
logger.warning("Error while downloading from %s: %s\nMax retries exceeded.", url, str(e))
|
537 |
+
raise
|
538 |
+
logger.warning("Error while downloading from %s: %s\nTrying to resume download...", url, str(e))
|
539 |
+
time.sleep(1)
|
540 |
+
reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects
|
541 |
+
return http_get(
|
542 |
+
url=url,
|
543 |
+
temp_file=temp_file,
|
544 |
+
proxies=proxies,
|
545 |
+
resume_size=new_resume_size,
|
546 |
+
headers=initial_headers,
|
547 |
+
expected_size=expected_size,
|
548 |
+
_nb_retries=_nb_retries - 1,
|
549 |
+
)
|
550 |
+
|
551 |
+
if expected_size is not None and expected_size != temp_file.tell():
|
552 |
+
raise EnvironmentError(
|
553 |
+
consistency_error_message.format(
|
554 |
+
actual_size=temp_file.tell(),
|
555 |
+
)
|
556 |
+
)
|
557 |
+
|
558 |
+
|
559 |
+
@validate_hf_hub_args
|
560 |
+
def cached_download(
|
561 |
+
url: str,
|
562 |
+
*,
|
563 |
+
library_name: Optional[str] = None,
|
564 |
+
library_version: Optional[str] = None,
|
565 |
+
cache_dir: Union[str, Path, None] = None,
|
566 |
+
user_agent: Union[Dict, str, None] = None,
|
567 |
+
force_download: bool = False,
|
568 |
+
force_filename: Optional[str] = None,
|
569 |
+
proxies: Optional[Dict] = None,
|
570 |
+
etag_timeout: float = DEFAULT_ETAG_TIMEOUT,
|
571 |
+
resume_download: bool = False,
|
572 |
+
token: Union[bool, str, None] = None,
|
573 |
+
local_files_only: bool = False,
|
574 |
+
legacy_cache_layout: bool = False,
|
575 |
+
) -> str:
|
576 |
+
"""
|
577 |
+
Download from a given URL and cache it if it's not already present in the
|
578 |
+
local cache.
|
579 |
+
|
580 |
+
Given a URL, this function looks for the corresponding file in the local
|
581 |
+
cache. If it's not there, download it. Then return the path to the cached
|
582 |
+
file.
|
583 |
+
|
584 |
+
Will raise errors tailored to the Hugging Face Hub.
|
585 |
+
|
586 |
+
Args:
|
587 |
+
url (`str`):
|
588 |
+
The path to the file to be downloaded.
|
589 |
+
library_name (`str`, *optional*):
|
590 |
+
The name of the library to which the object corresponds.
|
591 |
+
library_version (`str`, *optional*):
|
592 |
+
The version of the library.
|
593 |
+
cache_dir (`str`, `Path`, *optional*):
|
594 |
+
Path to the folder where cached files are stored.
|
595 |
+
user_agent (`dict`, `str`, *optional*):
|
596 |
+
The user-agent info in the form of a dictionary or a string.
|
597 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
598 |
+
Whether the file should be downloaded even if it already exists in
|
599 |
+
the local cache.
|
600 |
+
force_filename (`str`, *optional*):
|
601 |
+
Use this name instead of a generated file name.
|
602 |
+
proxies (`dict`, *optional*):
|
603 |
+
Dictionary mapping protocol to the URL of the proxy passed to
|
604 |
+
`requests.request`.
|
605 |
+
etag_timeout (`float`, *optional* defaults to `10`):
|
606 |
+
When fetching ETag, how many seconds to wait for the server to send
|
607 |
+
data before giving up which is passed to `requests.request`.
|
608 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
609 |
+
If `True`, resume a previously interrupted download.
|
610 |
+
token (`bool`, `str`, *optional*):
|
611 |
+
A token to be used for the download.
|
612 |
+
- If `True`, the token is read from the HuggingFace config
|
613 |
+
folder.
|
614 |
+
- If a string, it's used as the authentication token.
|
615 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
616 |
+
If `True`, avoid downloading the file and return the path to the
|
617 |
+
local cached file if it exists.
|
618 |
+
legacy_cache_layout (`bool`, *optional*, defaults to `False`):
|
619 |
+
Set this parameter to `True` to mention that you'd like to continue
|
620 |
+
the old cache layout. Putting this to `True` manually will not raise
|
621 |
+
any warning when using `cached_download`. We recommend using
|
622 |
+
`hf_hub_download` to take advantage of the new cache.
|
623 |
+
|
624 |
+
Returns:
|
625 |
+
Local path (string) of file or if networking is off, last version of
|
626 |
+
file cached on disk.
|
627 |
+
|
628 |
+
<Tip>
|
629 |
+
|
630 |
+
Raises the following errors:
|
631 |
+
|
632 |
+
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
633 |
+
if `token=True` and the token cannot be found.
|
634 |
+
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
|
635 |
+
if ETag cannot be determined.
|
636 |
+
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
637 |
+
if some parameter value is invalid
|
638 |
+
- [`~utils.RepositoryNotFoundError`]
|
639 |
+
If the repository to download from cannot be found. This may be because it doesn't exist,
|
640 |
+
or because it is set to `private` and you do not have access.
|
641 |
+
- [`~utils.RevisionNotFoundError`]
|
642 |
+
If the revision to download from cannot be found.
|
643 |
+
- [`~utils.EntryNotFoundError`]
|
644 |
+
If the file to download cannot be found.
|
645 |
+
- [`~utils.LocalEntryNotFoundError`]
|
646 |
+
If network is disabled or unavailable and file is not found in cache.
|
647 |
+
|
648 |
+
</Tip>
|
649 |
+
"""
|
650 |
+
if HF_HUB_ETAG_TIMEOUT != DEFAULT_ETAG_TIMEOUT:
|
651 |
+
# Respect environment variable above user value
|
652 |
+
etag_timeout = HF_HUB_ETAG_TIMEOUT
|
653 |
+
|
654 |
+
if not legacy_cache_layout:
|
655 |
+
warnings.warn(
|
656 |
+
"'cached_download' is the legacy way to download files from the HF hub, please consider upgrading to"
|
657 |
+
" 'hf_hub_download'",
|
658 |
+
FutureWarning,
|
659 |
+
)
|
660 |
+
|
661 |
+
if cache_dir is None:
|
662 |
+
cache_dir = HF_HUB_CACHE
|
663 |
+
if isinstance(cache_dir, Path):
|
664 |
+
cache_dir = str(cache_dir)
|
665 |
+
|
666 |
+
os.makedirs(cache_dir, exist_ok=True)
|
667 |
+
|
668 |
+
headers = build_hf_headers(
|
669 |
+
token=token,
|
670 |
+
library_name=library_name,
|
671 |
+
library_version=library_version,
|
672 |
+
user_agent=user_agent,
|
673 |
+
)
|
674 |
+
|
675 |
+
url_to_download = url
|
676 |
+
etag = None
|
677 |
+
expected_size = None
|
678 |
+
if not local_files_only:
|
679 |
+
try:
|
680 |
+
# Temporary header: we want the full (decompressed) content size returned to be able to check the
|
681 |
+
# downloaded file size
|
682 |
+
headers["Accept-Encoding"] = "identity"
|
683 |
+
r = _request_wrapper(
|
684 |
+
method="HEAD",
|
685 |
+
url=url,
|
686 |
+
headers=headers,
|
687 |
+
allow_redirects=False,
|
688 |
+
follow_relative_redirects=True,
|
689 |
+
proxies=proxies,
|
690 |
+
timeout=etag_timeout,
|
691 |
+
)
|
692 |
+
headers.pop("Accept-Encoding", None)
|
693 |
+
hf_raise_for_status(r)
|
694 |
+
etag = r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag")
|
695 |
+
# We favor a custom header indicating the etag of the linked resource, and
|
696 |
+
# we fallback to the regular etag header.
|
697 |
+
# If we don't have any of those, raise an error.
|
698 |
+
if etag is None:
|
699 |
+
raise FileMetadataError(
|
700 |
+
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
|
701 |
+
)
|
702 |
+
# We get the expected size of the file, to check the download went well.
|
703 |
+
expected_size = _int_or_none(r.headers.get("Content-Length"))
|
704 |
+
# In case of a redirect, save an extra redirect on the request.get call,
|
705 |
+
# and ensure we download the exact atomic version even if it changed
|
706 |
+
# between the HEAD and the GET (unlikely, but hey).
|
707 |
+
# Useful for lfs blobs that are stored on a CDN.
|
708 |
+
if 300 <= r.status_code <= 399:
|
709 |
+
url_to_download = r.headers["Location"]
|
710 |
+
headers.pop("authorization", None)
|
711 |
+
expected_size = None # redirected -> can't know the expected size
|
712 |
+
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
|
713 |
+
# Actually raise for those subclasses of ConnectionError
|
714 |
+
raise
|
715 |
+
except (
|
716 |
+
requests.exceptions.ConnectionError,
|
717 |
+
requests.exceptions.Timeout,
|
718 |
+
OfflineModeIsEnabled,
|
719 |
+
):
|
720 |
+
# Otherwise, our Internet connection is down.
|
721 |
+
# etag is None
|
722 |
+
pass
|
723 |
+
|
724 |
+
filename = force_filename if force_filename is not None else url_to_filename(url, etag)
|
725 |
+
|
726 |
+
# get cache path to put the file
|
727 |
+
cache_path = os.path.join(cache_dir, filename)
|
728 |
+
|
729 |
+
# etag is None == we don't have a connection or we passed local_files_only.
|
730 |
+
# try to get the last downloaded one
|
731 |
+
if etag is None:
|
732 |
+
if os.path.exists(cache_path) and not force_download:
|
733 |
+
return cache_path
|
734 |
+
else:
|
735 |
+
matching_files = [
|
736 |
+
file
|
737 |
+
for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
|
738 |
+
if not file.endswith(".json") and not file.endswith(".lock")
|
739 |
+
]
|
740 |
+
if len(matching_files) > 0 and not force_download and force_filename is None:
|
741 |
+
return os.path.join(cache_dir, matching_files[-1])
|
742 |
+
else:
|
743 |
+
# If files cannot be found and local_files_only=True,
|
744 |
+
# the models might've been found if local_files_only=False
|
745 |
+
# Notify the user about that
|
746 |
+
if local_files_only:
|
747 |
+
raise LocalEntryNotFoundError(
|
748 |
+
"Cannot find the requested files in the cached path and"
|
749 |
+
" outgoing traffic has been disabled. To enable model look-ups"
|
750 |
+
" and downloads online, set 'local_files_only' to False."
|
751 |
+
)
|
752 |
+
else:
|
753 |
+
raise LocalEntryNotFoundError(
|
754 |
+
"Connection error, and we cannot find the requested files in"
|
755 |
+
" the cached path. Please try again or make sure your Internet"
|
756 |
+
" connection is on."
|
757 |
+
)
|
758 |
+
|
759 |
+
# From now on, etag is not None.
|
760 |
+
if os.path.exists(cache_path) and not force_download:
|
761 |
+
return cache_path
|
762 |
+
|
763 |
+
# Prevent parallel downloads of the same file with a lock.
|
764 |
+
lock_path = cache_path + ".lock"
|
765 |
+
|
766 |
+
# Some Windows versions do not allow for paths longer than 255 characters.
|
767 |
+
# In this case, we must specify it is an extended path by using the "\\?\" prefix.
|
768 |
+
if os.name == "nt" and len(os.path.abspath(lock_path)) > 255:
|
769 |
+
lock_path = "\\\\?\\" + os.path.abspath(lock_path)
|
770 |
+
|
771 |
+
if os.name == "nt" and len(os.path.abspath(cache_path)) > 255:
|
772 |
+
cache_path = "\\\\?\\" + os.path.abspath(cache_path)
|
773 |
+
|
774 |
+
with FileLock(lock_path):
|
775 |
+
# If the download just completed while the lock was activated.
|
776 |
+
if os.path.exists(cache_path) and not force_download:
|
777 |
+
# Even if returning early like here, the lock will be released.
|
778 |
+
return cache_path
|
779 |
+
|
780 |
+
if resume_download:
|
781 |
+
incomplete_path = cache_path + ".incomplete"
|
782 |
+
|
783 |
+
@contextmanager
|
784 |
+
def _resumable_file_manager() -> Generator[io.BufferedWriter, None, None]:
|
785 |
+
with open(incomplete_path, "ab") as f:
|
786 |
+
yield f
|
787 |
+
|
788 |
+
temp_file_manager = _resumable_file_manager
|
789 |
+
if os.path.exists(incomplete_path):
|
790 |
+
resume_size = os.stat(incomplete_path).st_size
|
791 |
+
else:
|
792 |
+
resume_size = 0
|
793 |
+
else:
|
794 |
+
temp_file_manager = partial( # type: ignore
|
795 |
+
tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False
|
796 |
+
)
|
797 |
+
resume_size = 0
|
798 |
+
|
799 |
+
# Download to temporary file, then copy to cache dir once finished.
|
800 |
+
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
801 |
+
with temp_file_manager() as temp_file:
|
802 |
+
logger.info("downloading %s to %s", url, temp_file.name)
|
803 |
+
|
804 |
+
http_get(
|
805 |
+
url_to_download,
|
806 |
+
temp_file,
|
807 |
+
proxies=proxies,
|
808 |
+
resume_size=resume_size,
|
809 |
+
headers=headers,
|
810 |
+
expected_size=expected_size,
|
811 |
+
)
|
812 |
+
|
813 |
+
logger.info("storing %s in cache at %s", url, cache_path)
|
814 |
+
_chmod_and_replace(temp_file.name, cache_path)
|
815 |
+
|
816 |
+
if force_filename is None:
|
817 |
+
logger.info("creating metadata file for %s", cache_path)
|
818 |
+
meta = {"url": url, "etag": etag}
|
819 |
+
meta_path = cache_path + ".json"
|
820 |
+
with open(meta_path, "w") as meta_file:
|
821 |
+
json.dump(meta, meta_file)
|
822 |
+
|
823 |
+
return cache_path
|
824 |
+
|
825 |
+
|
826 |
+
def _normalize_etag(etag: Optional[str]) -> Optional[str]:
|
827 |
+
"""Normalize ETag HTTP header, so it can be used to create nice filepaths.
|
828 |
+
|
829 |
+
The HTTP spec allows two forms of ETag:
|
830 |
+
ETag: W/"<etag_value>"
|
831 |
+
ETag: "<etag_value>"
|
832 |
+
|
833 |
+
For now, we only expect the second form from the server, but we want to be future-proof so we support both. For
|
834 |
+
more context, see `TestNormalizeEtag` tests and https://github.com/huggingface/huggingface_hub/pull/1428.
|
835 |
+
|
836 |
+
Args:
|
837 |
+
etag (`str`, *optional*): HTTP header
|
838 |
+
|
839 |
+
Returns:
|
840 |
+
`str` or `None`: string that can be used as a nice directory name.
|
841 |
+
Returns `None` if input is None.
|
842 |
+
"""
|
843 |
+
if etag is None:
|
844 |
+
return None
|
845 |
+
return etag.lstrip("W/").strip('"')
|
846 |
+
|
847 |
+
|
848 |
+
def _create_relative_symlink(src: str, dst: str, new_blob: bool = False) -> None:
|
849 |
+
"""Alias method used in `transformers` conversion script."""
|
850 |
+
return _create_symlink(src=src, dst=dst, new_blob=new_blob)
|
851 |
+
|
852 |
+
|
853 |
+
def _create_symlink(src: str, dst: str, new_blob: bool = False) -> None:
|
854 |
+
"""Create a symbolic link named dst pointing to src.
|
855 |
+
|
856 |
+
By default, it will try to create a symlink using a relative path. Relative paths have 2 advantages:
|
857 |
+
- If the cache_folder is moved (example: back-up on a shared drive), relative paths within the cache folder will
|
858 |
+
not brake.
|
859 |
+
- Relative paths seems to be better handled on Windows. Issue was reported 3 times in less than a week when
|
860 |
+
changing from relative to absolute paths. See https://github.com/huggingface/huggingface_hub/issues/1398,
|
861 |
+
https://github.com/huggingface/diffusers/issues/2729 and https://github.com/huggingface/transformers/pull/22228.
|
862 |
+
NOTE: The issue with absolute paths doesn't happen on admin mode.
|
863 |
+
When creating a symlink from the cache to a local folder, it is possible that a relative path cannot be created.
|
864 |
+
This happens when paths are not on the same volume. In that case, we use absolute paths.
|
865 |
+
|
866 |
+
|
867 |
+
The result layout looks something like
|
868 |
+
└── [ 128] snapshots
|
869 |
+
├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f
|
870 |
+
│ ├── [ 52] README.md -> ../../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812
|
871 |
+
│ └── [ 76] pytorch_model.bin -> ../../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
|
872 |
+
|
873 |
+
If symlinks cannot be created on this platform (most likely to be Windows), the workaround is to avoid symlinks by
|
874 |
+
having the actual file in `dst`. If it is a new file (`new_blob=True`), we move it to `dst`. If it is not a new file
|
875 |
+
(`new_blob=False`), we don't know if the blob file is already referenced elsewhere. To avoid breaking existing
|
876 |
+
cache, the file is duplicated on the disk.
|
877 |
+
|
878 |
+
In case symlinks are not supported, a warning message is displayed to the user once when loading `huggingface_hub`.
|
879 |
+
The warning message can be disable with the `DISABLE_SYMLINKS_WARNING` environment variable.
|
880 |
+
"""
|
881 |
+
try:
|
882 |
+
os.remove(dst)
|
883 |
+
except OSError:
|
884 |
+
pass
|
885 |
+
|
886 |
+
abs_src = os.path.abspath(os.path.expanduser(src))
|
887 |
+
abs_dst = os.path.abspath(os.path.expanduser(dst))
|
888 |
+
abs_dst_folder = os.path.dirname(abs_dst)
|
889 |
+
|
890 |
+
# Use relative_dst in priority
|
891 |
+
try:
|
892 |
+
relative_src = os.path.relpath(abs_src, abs_dst_folder)
|
893 |
+
except ValueError:
|
894 |
+
# Raised on Windows if src and dst are not on the same volume. This is the case when creating a symlink to a
|
895 |
+
# local_dir instead of within the cache directory.
|
896 |
+
# See https://docs.python.org/3/library/os.path.html#os.path.relpath
|
897 |
+
relative_src = None
|
898 |
+
|
899 |
+
try:
|
900 |
+
commonpath = os.path.commonpath([abs_src, abs_dst])
|
901 |
+
_support_symlinks = are_symlinks_supported(commonpath)
|
902 |
+
except ValueError:
|
903 |
+
# Raised if src and dst are not on the same volume. Symlinks will still work on Linux/Macos.
|
904 |
+
# See https://docs.python.org/3/library/os.path.html#os.path.commonpath
|
905 |
+
_support_symlinks = os.name != "nt"
|
906 |
+
except PermissionError:
|
907 |
+
# Permission error means src and dst are not in the same volume (e.g. destination path has been provided
|
908 |
+
# by the user via `local_dir`. Let's test symlink support there)
|
909 |
+
_support_symlinks = are_symlinks_supported(abs_dst_folder)
|
910 |
+
|
911 |
+
# Symlinks are supported => let's create a symlink.
|
912 |
+
if _support_symlinks:
|
913 |
+
src_rel_or_abs = relative_src or abs_src
|
914 |
+
logger.debug(f"Creating pointer from {src_rel_or_abs} to {abs_dst}")
|
915 |
+
try:
|
916 |
+
os.symlink(src_rel_or_abs, abs_dst)
|
917 |
+
return
|
918 |
+
except FileExistsError:
|
919 |
+
if os.path.islink(abs_dst) and os.path.realpath(abs_dst) == os.path.realpath(abs_src):
|
920 |
+
# `abs_dst` already exists and is a symlink to the `abs_src` blob. It is most likely that the file has
|
921 |
+
# been cached twice concurrently (exactly between `os.remove` and `os.symlink`). Do nothing.
|
922 |
+
return
|
923 |
+
else:
|
924 |
+
# Very unlikely to happen. Means a file `dst` has been created exactly between `os.remove` and
|
925 |
+
# `os.symlink` and is not a symlink to the `abs_src` blob file. Raise exception.
|
926 |
+
raise
|
927 |
+
except PermissionError:
|
928 |
+
# Permission error means src and dst are not in the same volume (e.g. download to local dir) and symlink
|
929 |
+
# is supported on both volumes but not between them. Let's just make a hard copy in that case.
|
930 |
+
pass
|
931 |
+
|
932 |
+
# Symlinks are not supported => let's move or copy the file.
|
933 |
+
if new_blob:
|
934 |
+
logger.info(f"Symlink not supported. Moving file from {abs_src} to {abs_dst}")
|
935 |
+
shutil.move(abs_src, abs_dst)
|
936 |
+
else:
|
937 |
+
logger.info(f"Symlink not supported. Copying file from {abs_src} to {abs_dst}")
|
938 |
+
shutil.copyfile(abs_src, abs_dst)
|
939 |
+
|
940 |
+
|
941 |
+
def _cache_commit_hash_for_specific_revision(storage_folder: str, revision: str, commit_hash: str) -> None:
|
942 |
+
"""Cache reference between a revision (tag, branch or truncated commit hash) and the corresponding commit hash.
|
943 |
+
|
944 |
+
Does nothing if `revision` is already a proper `commit_hash` or reference is already cached.
|
945 |
+
"""
|
946 |
+
if revision != commit_hash:
|
947 |
+
ref_path = Path(storage_folder) / "refs" / revision
|
948 |
+
ref_path.parent.mkdir(parents=True, exist_ok=True)
|
949 |
+
if not ref_path.exists() or commit_hash != ref_path.read_text():
|
950 |
+
# Update ref only if has been updated. Could cause useless error in case
|
951 |
+
# repo is already cached and user doesn't have write access to cache folder.
|
952 |
+
# See https://github.com/huggingface/huggingface_hub/issues/1216.
|
953 |
+
ref_path.write_text(commit_hash)
|
954 |
+
|
955 |
+
|
956 |
+
@validate_hf_hub_args
|
957 |
+
def repo_folder_name(*, repo_id: str, repo_type: str) -> str:
|
958 |
+
"""Return a serialized version of a hf.co repo name and type, safe for disk storage
|
959 |
+
as a single non-nested folder.
|
960 |
+
|
961 |
+
Example: models--julien-c--EsperBERTo-small
|
962 |
+
"""
|
963 |
+
# remove all `/` occurrences to correctly convert repo to directory name
|
964 |
+
parts = [f"{repo_type}s", *repo_id.split("/")]
|
965 |
+
return REPO_ID_SEPARATOR.join(parts)
|
966 |
+
|
967 |
+
|
968 |
+
def _check_disk_space(expected_size: int, target_dir: Union[str, Path]) -> None:
|
969 |
+
"""Check disk usage and log a warning if there is not enough disk space to download the file.
|
970 |
+
|
971 |
+
Args:
|
972 |
+
expected_size (`int`):
|
973 |
+
The expected size of the file in bytes.
|
974 |
+
target_dir (`str`):
|
975 |
+
The directory where the file will be stored after downloading.
|
976 |
+
"""
|
977 |
+
|
978 |
+
target_dir = Path(target_dir) # format as `Path`
|
979 |
+
for path in [target_dir] + list(target_dir.parents): # first check target_dir, then each parents one by one
|
980 |
+
try:
|
981 |
+
target_dir_free = shutil.disk_usage(path).free
|
982 |
+
if target_dir_free < expected_size:
|
983 |
+
warnings.warn(
|
984 |
+
"Not enough free disk space to download the file. "
|
985 |
+
f"The expected file size is: {expected_size / 1e6:.2f} MB. "
|
986 |
+
f"The target location {target_dir} only has {target_dir_free / 1e6:.2f} MB free disk space."
|
987 |
+
)
|
988 |
+
return
|
989 |
+
except OSError: # raise on anything: file does not exist or space disk cannot be checked
|
990 |
+
pass
|
991 |
+
|
992 |
+
|
993 |
+
@validate_hf_hub_args
|
994 |
+
def hf_hub_download(
|
995 |
+
repo_id: str,
|
996 |
+
filename: str,
|
997 |
+
*,
|
998 |
+
subfolder: Optional[str] = None,
|
999 |
+
repo_type: Optional[str] = None,
|
1000 |
+
revision: Optional[str] = None,
|
1001 |
+
library_name: Optional[str] = None,
|
1002 |
+
library_version: Optional[str] = None,
|
1003 |
+
cache_dir: Union[str, Path, None] = None,
|
1004 |
+
local_dir: Union[str, Path, None] = None,
|
1005 |
+
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
|
1006 |
+
user_agent: Union[Dict, str, None] = None,
|
1007 |
+
force_download: bool = False,
|
1008 |
+
force_filename: Optional[str] = None,
|
1009 |
+
proxies: Optional[Dict] = None,
|
1010 |
+
etag_timeout: float = DEFAULT_ETAG_TIMEOUT,
|
1011 |
+
resume_download: bool = False,
|
1012 |
+
token: Union[bool, str, None] = None,
|
1013 |
+
local_files_only: bool = False,
|
1014 |
+
legacy_cache_layout: bool = False,
|
1015 |
+
endpoint: Optional[str] = None,
|
1016 |
+
) -> str:
|
1017 |
+
"""Download a given file if it's not already present in the local cache.
|
1018 |
+
|
1019 |
+
The new cache file layout looks like this:
|
1020 |
+
- The cache directory contains one subfolder per repo_id (namespaced by repo type)
|
1021 |
+
- inside each repo folder:
|
1022 |
+
- refs is a list of the latest known revision => commit_hash pairs
|
1023 |
+
- blobs contains the actual file blobs (identified by their git-sha or sha256, depending on
|
1024 |
+
whether they're LFS files or not)
|
1025 |
+
- snapshots contains one subfolder per commit, each "commit" contains the subset of the files
|
1026 |
+
that have been resolved at that particular commit. Each filename is a symlink to the blob
|
1027 |
+
at that particular commit.
|
1028 |
+
|
1029 |
+
If `local_dir` is provided, the file structure from the repo will be replicated in this location. You can configure
|
1030 |
+
how you want to move those files:
|
1031 |
+
- If `local_dir_use_symlinks="auto"` (default), files are downloaded and stored in the cache directory as blob
|
1032 |
+
files. Small files (<5MB) are duplicated in `local_dir` while a symlink is created for bigger files. The goal
|
1033 |
+
is to be able to manually edit and save small files without corrupting the cache while saving disk space for
|
1034 |
+
binary files. The 5MB threshold can be configured with the `HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD`
|
1035 |
+
environment variable.
|
1036 |
+
- If `local_dir_use_symlinks=True`, files are downloaded, stored in the cache directory and symlinked in `local_dir`.
|
1037 |
+
This is optimal in term of disk usage but files must not be manually edited.
|
1038 |
+
- If `local_dir_use_symlinks=False` and the blob files exist in the cache directory, they are duplicated in the
|
1039 |
+
local dir. This means disk usage is not optimized.
|
1040 |
+
- Finally, if `local_dir_use_symlinks=False` and the blob files do not exist in the cache directory, then the
|
1041 |
+
files are downloaded and directly placed under `local_dir`. This means if you need to download them again later,
|
1042 |
+
they will be re-downloaded entirely.
|
1043 |
+
|
1044 |
+
```
|
1045 |
+
[ 96] .
|
1046 |
+
└── [ 160] models--julien-c--EsperBERTo-small
|
1047 |
+
├── [ 160] blobs
|
1048 |
+
│ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
|
1049 |
+
│ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e
|
1050 |
+
│ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812
|
1051 |
+
├── [ 96] refs
|
1052 |
+
│ └── [ 40] main
|
1053 |
+
└── [ 128] snapshots
|
1054 |
+
├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f
|
1055 |
+
│ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812
|
1056 |
+
│ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
|
1057 |
+
└── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48
|
1058 |
+
├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e
|
1059 |
+
└── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
|
1060 |
+
```
|
1061 |
+
|
1062 |
+
Args:
|
1063 |
+
repo_id (`str`):
|
1064 |
+
A user or an organization name and a repo name separated by a `/`.
|
1065 |
+
filename (`str`):
|
1066 |
+
The name of the file in the repo.
|
1067 |
+
subfolder (`str`, *optional*):
|
1068 |
+
An optional value corresponding to a folder inside the model repo.
|
1069 |
+
repo_type (`str`, *optional*):
|
1070 |
+
Set to `"dataset"` or `"space"` if downloading from a dataset or space,
|
1071 |
+
`None` or `"model"` if downloading from a model. Default is `None`.
|
1072 |
+
revision (`str`, *optional*):
|
1073 |
+
An optional Git revision id which can be a branch name, a tag, or a
|
1074 |
+
commit hash.
|
1075 |
+
library_name (`str`, *optional*):
|
1076 |
+
The name of the library to which the object corresponds.
|
1077 |
+
library_version (`str`, *optional*):
|
1078 |
+
The version of the library.
|
1079 |
+
cache_dir (`str`, `Path`, *optional*):
|
1080 |
+
Path to the folder where cached files are stored.
|
1081 |
+
local_dir (`str` or `Path`, *optional*):
|
1082 |
+
If provided, the downloaded file will be placed under this directory, either as a symlink (default) or
|
1083 |
+
a regular file (see description for more details).
|
1084 |
+
local_dir_use_symlinks (`"auto"` or `bool`, defaults to `"auto"`):
|
1085 |
+
To be used with `local_dir`. If set to "auto", the cache directory will be used and the file will be either
|
1086 |
+
duplicated or symlinked to the local directory depending on its size. It set to `True`, a symlink will be
|
1087 |
+
created, no matter the file size. If set to `False`, the file will either be duplicated from cache (if
|
1088 |
+
already exists) or downloaded from the Hub and not cached. See description for more details.
|
1089 |
+
user_agent (`dict`, `str`, *optional*):
|
1090 |
+
The user-agent info in the form of a dictionary or a string.
|
1091 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
1092 |
+
Whether the file should be downloaded even if it already exists in
|
1093 |
+
the local cache.
|
1094 |
+
proxies (`dict`, *optional*):
|
1095 |
+
Dictionary mapping protocol to the URL of the proxy passed to
|
1096 |
+
`requests.request`.
|
1097 |
+
etag_timeout (`float`, *optional*, defaults to `10`):
|
1098 |
+
When fetching ETag, how many seconds to wait for the server to send
|
1099 |
+
data before giving up which is passed to `requests.request`.
|
1100 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
1101 |
+
If `True`, resume a previously interrupted download.
|
1102 |
+
token (`str`, `bool`, *optional*):
|
1103 |
+
A token to be used for the download.
|
1104 |
+
- If `True`, the token is read from the HuggingFace config
|
1105 |
+
folder.
|
1106 |
+
- If a string, it's used as the authentication token.
|
1107 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
1108 |
+
If `True`, avoid downloading the file and return the path to the
|
1109 |
+
local cached file if it exists.
|
1110 |
+
legacy_cache_layout (`bool`, *optional*, defaults to `False`):
|
1111 |
+
If `True`, uses the legacy file cache layout i.e. just call [`hf_hub_url`]
|
1112 |
+
then `cached_download`. This is deprecated as the new cache layout is
|
1113 |
+
more powerful.
|
1114 |
+
|
1115 |
+
Returns:
|
1116 |
+
Local path (string) of file or if networking is off, last version of
|
1117 |
+
file cached on disk.
|
1118 |
+
|
1119 |
+
<Tip>
|
1120 |
+
|
1121 |
+
Raises the following errors:
|
1122 |
+
|
1123 |
+
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
1124 |
+
if `token=True` and the token cannot be found.
|
1125 |
+
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
|
1126 |
+
if ETag cannot be determined.
|
1127 |
+
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
1128 |
+
if some parameter value is invalid
|
1129 |
+
- [`~utils.RepositoryNotFoundError`]
|
1130 |
+
If the repository to download from cannot be found. This may be because it doesn't exist,
|
1131 |
+
or because it is set to `private` and you do not have access.
|
1132 |
+
- [`~utils.RevisionNotFoundError`]
|
1133 |
+
If the revision to download from cannot be found.
|
1134 |
+
- [`~utils.EntryNotFoundError`]
|
1135 |
+
If the file to download cannot be found.
|
1136 |
+
- [`~utils.LocalEntryNotFoundError`]
|
1137 |
+
If network is disabled or unavailable and file is not found in cache.
|
1138 |
+
|
1139 |
+
</Tip>
|
1140 |
+
"""
|
1141 |
+
if HF_HUB_ETAG_TIMEOUT != DEFAULT_ETAG_TIMEOUT:
|
1142 |
+
# Respect environment variable above user value
|
1143 |
+
etag_timeout = HF_HUB_ETAG_TIMEOUT
|
1144 |
+
|
1145 |
+
if force_filename is not None:
|
1146 |
+
warnings.warn(
|
1147 |
+
"The `force_filename` parameter is deprecated as a new caching system, "
|
1148 |
+
"which keeps the filenames as they are on the Hub, is now in place.",
|
1149 |
+
FutureWarning,
|
1150 |
+
)
|
1151 |
+
legacy_cache_layout = True
|
1152 |
+
|
1153 |
+
if legacy_cache_layout:
|
1154 |
+
url = hf_hub_url(
|
1155 |
+
repo_id,
|
1156 |
+
filename,
|
1157 |
+
subfolder=subfolder,
|
1158 |
+
repo_type=repo_type,
|
1159 |
+
revision=revision,
|
1160 |
+
endpoint=endpoint,
|
1161 |
+
)
|
1162 |
+
|
1163 |
+
return cached_download(
|
1164 |
+
url,
|
1165 |
+
library_name=library_name,
|
1166 |
+
library_version=library_version,
|
1167 |
+
cache_dir=cache_dir,
|
1168 |
+
user_agent=user_agent,
|
1169 |
+
force_download=force_download,
|
1170 |
+
force_filename=force_filename,
|
1171 |
+
proxies=proxies,
|
1172 |
+
etag_timeout=etag_timeout,
|
1173 |
+
resume_download=resume_download,
|
1174 |
+
token=token,
|
1175 |
+
local_files_only=local_files_only,
|
1176 |
+
legacy_cache_layout=legacy_cache_layout,
|
1177 |
+
)
|
1178 |
+
|
1179 |
+
if cache_dir is None:
|
1180 |
+
cache_dir = HF_HUB_CACHE
|
1181 |
+
if revision is None:
|
1182 |
+
revision = DEFAULT_REVISION
|
1183 |
+
if isinstance(cache_dir, Path):
|
1184 |
+
cache_dir = str(cache_dir)
|
1185 |
+
if isinstance(local_dir, Path):
|
1186 |
+
local_dir = str(local_dir)
|
1187 |
+
locks_dir = os.path.join(cache_dir, ".locks")
|
1188 |
+
|
1189 |
+
if subfolder == "":
|
1190 |
+
subfolder = None
|
1191 |
+
if subfolder is not None:
|
1192 |
+
# This is used to create a URL, and not a local path, hence the forward slash.
|
1193 |
+
filename = f"{subfolder}/{filename}"
|
1194 |
+
|
1195 |
+
if repo_type is None:
|
1196 |
+
repo_type = "model"
|
1197 |
+
if repo_type not in REPO_TYPES:
|
1198 |
+
raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}")
|
1199 |
+
|
1200 |
+
storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
|
1201 |
+
os.makedirs(storage_folder, exist_ok=True)
|
1202 |
+
|
1203 |
+
# cross platform transcription of filename, to be used as a local file path.
|
1204 |
+
relative_filename = os.path.join(*filename.split("/"))
|
1205 |
+
if os.name == "nt":
|
1206 |
+
if relative_filename.startswith("..\\") or "\\..\\" in relative_filename:
|
1207 |
+
raise ValueError(
|
1208 |
+
f"Invalid filename: cannot handle filename '{relative_filename}' on Windows. Please ask the repository"
|
1209 |
+
" owner to rename this file."
|
1210 |
+
)
|
1211 |
+
|
1212 |
+
# if user provides a commit_hash and they already have the file on disk,
|
1213 |
+
# shortcut everything.
|
1214 |
+
if REGEX_COMMIT_HASH.match(revision):
|
1215 |
+
pointer_path = _get_pointer_path(storage_folder, revision, relative_filename)
|
1216 |
+
if os.path.exists(pointer_path):
|
1217 |
+
if local_dir is not None:
|
1218 |
+
return _to_local_dir(pointer_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks)
|
1219 |
+
return pointer_path
|
1220 |
+
|
1221 |
+
url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision, endpoint=endpoint)
|
1222 |
+
|
1223 |
+
headers = build_hf_headers(
|
1224 |
+
token=token,
|
1225 |
+
library_name=library_name,
|
1226 |
+
library_version=library_version,
|
1227 |
+
user_agent=user_agent,
|
1228 |
+
)
|
1229 |
+
|
1230 |
+
url_to_download = url
|
1231 |
+
etag = None
|
1232 |
+
commit_hash = None
|
1233 |
+
expected_size = None
|
1234 |
+
head_call_error: Optional[Exception] = None
|
1235 |
+
if not local_files_only:
|
1236 |
+
try:
|
1237 |
+
try:
|
1238 |
+
metadata = get_hf_file_metadata(
|
1239 |
+
url=url,
|
1240 |
+
token=token,
|
1241 |
+
proxies=proxies,
|
1242 |
+
timeout=etag_timeout,
|
1243 |
+
library_name=library_name,
|
1244 |
+
library_version=library_version,
|
1245 |
+
user_agent=user_agent,
|
1246 |
+
)
|
1247 |
+
except EntryNotFoundError as http_error:
|
1248 |
+
# Cache the non-existence of the file and raise
|
1249 |
+
commit_hash = http_error.response.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT)
|
1250 |
+
if commit_hash is not None and not legacy_cache_layout:
|
1251 |
+
no_exist_file_path = Path(storage_folder) / ".no_exist" / commit_hash / relative_filename
|
1252 |
+
no_exist_file_path.parent.mkdir(parents=True, exist_ok=True)
|
1253 |
+
no_exist_file_path.touch()
|
1254 |
+
_cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash)
|
1255 |
+
raise
|
1256 |
+
|
1257 |
+
# Commit hash must exist
|
1258 |
+
commit_hash = metadata.commit_hash
|
1259 |
+
if commit_hash is None:
|
1260 |
+
raise FileMetadataError(
|
1261 |
+
"Distant resource does not seem to be on huggingface.co. It is possible that a configuration issue"
|
1262 |
+
" prevents you from downloading resources from https://huggingface.co. Please check your firewall"
|
1263 |
+
" and proxy settings and make sure your SSL certificates are updated."
|
1264 |
+
)
|
1265 |
+
|
1266 |
+
# Etag must exist
|
1267 |
+
etag = metadata.etag
|
1268 |
+
# We favor a custom header indicating the etag of the linked resource, and
|
1269 |
+
# we fallback to the regular etag header.
|
1270 |
+
# If we don't have any of those, raise an error.
|
1271 |
+
if etag is None:
|
1272 |
+
raise FileMetadataError(
|
1273 |
+
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
|
1274 |
+
)
|
1275 |
+
|
1276 |
+
# Expected (uncompressed) size
|
1277 |
+
expected_size = metadata.size
|
1278 |
+
|
1279 |
+
# In case of a redirect, save an extra redirect on the request.get call,
|
1280 |
+
# and ensure we download the exact atomic version even if it changed
|
1281 |
+
# between the HEAD and the GET (unlikely, but hey).
|
1282 |
+
# Useful for lfs blobs that are stored on a CDN.
|
1283 |
+
if metadata.location != url:
|
1284 |
+
url_to_download = metadata.location
|
1285 |
+
# Remove authorization header when downloading a LFS blob
|
1286 |
+
headers.pop("authorization", None)
|
1287 |
+
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
|
1288 |
+
# Actually raise for those subclasses of ConnectionError
|
1289 |
+
raise
|
1290 |
+
except (
|
1291 |
+
requests.exceptions.ConnectionError,
|
1292 |
+
requests.exceptions.Timeout,
|
1293 |
+
OfflineModeIsEnabled,
|
1294 |
+
) as error:
|
1295 |
+
# Otherwise, our Internet connection is down.
|
1296 |
+
# etag is None
|
1297 |
+
head_call_error = error
|
1298 |
+
pass
|
1299 |
+
except (RevisionNotFoundError, EntryNotFoundError):
|
1300 |
+
# The repo was found but the revision or entry doesn't exist on the Hub (never existed or got deleted)
|
1301 |
+
raise
|
1302 |
+
except requests.HTTPError as error:
|
1303 |
+
# Multiple reasons for an http error:
|
1304 |
+
# - Repository is private and invalid/missing token sent
|
1305 |
+
# - Repository is gated and invalid/missing token sent
|
1306 |
+
# - Hub is down (error 500 or 504)
|
1307 |
+
# => let's switch to 'local_files_only=True' to check if the files are already cached.
|
1308 |
+
# (if it's not the case, the error will be re-raised)
|
1309 |
+
head_call_error = error
|
1310 |
+
pass
|
1311 |
+
except FileMetadataError as error:
|
1312 |
+
# Multiple reasons for a FileMetadataError:
|
1313 |
+
# - Wrong network configuration (proxy, firewall, SSL certificates)
|
1314 |
+
# - Inconsistency on the Hub
|
1315 |
+
# => let's switch to 'local_files_only=True' to check if the files are already cached.
|
1316 |
+
# (if it's not the case, the error will be re-raised)
|
1317 |
+
head_call_error = error
|
1318 |
+
pass
|
1319 |
+
|
1320 |
+
# etag can be None for several reasons:
|
1321 |
+
# 1. we passed local_files_only.
|
1322 |
+
# 2. we don't have a connection
|
1323 |
+
# 3. Hub is down (HTTP 500 or 504)
|
1324 |
+
# 4. repo is not found -for example private or gated- and invalid/missing token sent
|
1325 |
+
# 5. Hub is blocked by a firewall or proxy is not set correctly.
|
1326 |
+
# => Try to get the last downloaded one from the specified revision.
|
1327 |
+
#
|
1328 |
+
# If the specified revision is a commit hash, look inside "snapshots".
|
1329 |
+
# If the specified revision is a branch or tag, look inside "refs".
|
1330 |
+
if etag is None:
|
1331 |
+
# In those cases, we cannot force download.
|
1332 |
+
if force_download:
|
1333 |
+
raise ValueError(
|
1334 |
+
"We have no connection or you passed local_files_only, so force_download is not an accepted option."
|
1335 |
+
)
|
1336 |
+
|
1337 |
+
# Try to get "commit_hash" from "revision"
|
1338 |
+
commit_hash = None
|
1339 |
+
if REGEX_COMMIT_HASH.match(revision):
|
1340 |
+
commit_hash = revision
|
1341 |
+
else:
|
1342 |
+
ref_path = os.path.join(storage_folder, "refs", revision)
|
1343 |
+
if os.path.isfile(ref_path):
|
1344 |
+
with open(ref_path) as f:
|
1345 |
+
commit_hash = f.read()
|
1346 |
+
|
1347 |
+
# Return pointer file if exists
|
1348 |
+
if commit_hash is not None:
|
1349 |
+
pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)
|
1350 |
+
if os.path.exists(pointer_path):
|
1351 |
+
if local_dir is not None:
|
1352 |
+
return _to_local_dir(
|
1353 |
+
pointer_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks
|
1354 |
+
)
|
1355 |
+
return pointer_path
|
1356 |
+
|
1357 |
+
# If we couldn't find an appropriate file on disk, raise an error.
|
1358 |
+
# If files cannot be found and local_files_only=True,
|
1359 |
+
# the models might've been found if local_files_only=False
|
1360 |
+
# Notify the user about that
|
1361 |
+
if local_files_only:
|
1362 |
+
raise LocalEntryNotFoundError(
|
1363 |
+
"Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable"
|
1364 |
+
" hf.co look-ups and downloads online, set 'local_files_only' to False."
|
1365 |
+
)
|
1366 |
+
elif isinstance(head_call_error, RepositoryNotFoundError) or isinstance(head_call_error, GatedRepoError):
|
1367 |
+
# Repo not found => let's raise the actual error
|
1368 |
+
raise head_call_error
|
1369 |
+
else:
|
1370 |
+
# Otherwise: most likely a connection issue or Hub downtime => let's warn the user
|
1371 |
+
raise LocalEntryNotFoundError(
|
1372 |
+
"An error happened while trying to locate the file on the Hub and we cannot find the requested files"
|
1373 |
+
" in the local cache. Please check your connection and try again or make sure your Internet connection"
|
1374 |
+
" is on."
|
1375 |
+
) from head_call_error
|
1376 |
+
|
1377 |
+
# From now on, etag and commit_hash are not None.
|
1378 |
+
assert etag is not None, "etag must have been retrieved from server"
|
1379 |
+
assert commit_hash is not None, "commit_hash must have been retrieved from server"
|
1380 |
+
blob_path = os.path.join(storage_folder, "blobs", etag)
|
1381 |
+
pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)
|
1382 |
+
|
1383 |
+
os.makedirs(os.path.dirname(blob_path), exist_ok=True)
|
1384 |
+
os.makedirs(os.path.dirname(pointer_path), exist_ok=True)
|
1385 |
+
# if passed revision is not identical to commit_hash
|
1386 |
+
# then revision has to be a branch name or tag name.
|
1387 |
+
# In that case store a ref.
|
1388 |
+
_cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash)
|
1389 |
+
|
1390 |
+
if os.path.exists(pointer_path) and not force_download:
|
1391 |
+
if local_dir is not None:
|
1392 |
+
return _to_local_dir(pointer_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks)
|
1393 |
+
return pointer_path
|
1394 |
+
|
1395 |
+
if os.path.exists(blob_path) and not force_download:
|
1396 |
+
# we have the blob already, but not the pointer
|
1397 |
+
if local_dir is not None: # to local dir
|
1398 |
+
return _to_local_dir(blob_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks)
|
1399 |
+
else: # or in snapshot cache
|
1400 |
+
_create_symlink(blob_path, pointer_path, new_blob=False)
|
1401 |
+
return pointer_path
|
1402 |
+
|
1403 |
+
# Prevent parallel downloads of the same file with a lock.
|
1404 |
+
# etag could be duplicated across repos,
|
1405 |
+
lock_path = os.path.join(locks_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type), f"{etag}.lock")
|
1406 |
+
|
1407 |
+
# Some Windows versions do not allow for paths longer than 255 characters.
|
1408 |
+
# In this case, we must specify it is an extended path by using the "\\?\" prefix.
|
1409 |
+
if os.name == "nt" and len(os.path.abspath(lock_path)) > 255:
|
1410 |
+
lock_path = "\\\\?\\" + os.path.abspath(lock_path)
|
1411 |
+
|
1412 |
+
if os.name == "nt" and len(os.path.abspath(blob_path)) > 255:
|
1413 |
+
blob_path = "\\\\?\\" + os.path.abspath(blob_path)
|
1414 |
+
|
1415 |
+
Path(lock_path).parent.mkdir(parents=True, exist_ok=True)
|
1416 |
+
with FileLock(lock_path):
|
1417 |
+
# If the download just completed while the lock was activated.
|
1418 |
+
if os.path.exists(pointer_path) and not force_download:
|
1419 |
+
# Even if returning early like here, the lock will be released.
|
1420 |
+
if local_dir is not None:
|
1421 |
+
return _to_local_dir(pointer_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks)
|
1422 |
+
return pointer_path
|
1423 |
+
|
1424 |
+
if resume_download:
|
1425 |
+
incomplete_path = blob_path + ".incomplete"
|
1426 |
+
|
1427 |
+
@contextmanager
|
1428 |
+
def _resumable_file_manager() -> Generator[io.BufferedWriter, None, None]:
|
1429 |
+
with open(incomplete_path, "ab") as f:
|
1430 |
+
yield f
|
1431 |
+
|
1432 |
+
temp_file_manager = _resumable_file_manager
|
1433 |
+
if os.path.exists(incomplete_path):
|
1434 |
+
resume_size = os.stat(incomplete_path).st_size
|
1435 |
+
else:
|
1436 |
+
resume_size = 0
|
1437 |
+
else:
|
1438 |
+
temp_file_manager = partial( # type: ignore
|
1439 |
+
tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False
|
1440 |
+
)
|
1441 |
+
resume_size = 0
|
1442 |
+
|
1443 |
+
# Download to temporary file, then copy to cache dir once finished.
|
1444 |
+
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
1445 |
+
with temp_file_manager() as temp_file:
|
1446 |
+
logger.info("downloading %s to %s", url, temp_file.name)
|
1447 |
+
|
1448 |
+
if expected_size is not None: # might be None if HTTP header not set correctly
|
1449 |
+
# Check tmp path
|
1450 |
+
_check_disk_space(expected_size, os.path.dirname(temp_file.name))
|
1451 |
+
|
1452 |
+
# Check destination
|
1453 |
+
_check_disk_space(expected_size, os.path.dirname(blob_path))
|
1454 |
+
if local_dir is not None:
|
1455 |
+
_check_disk_space(expected_size, local_dir)
|
1456 |
+
|
1457 |
+
http_get(
|
1458 |
+
url_to_download,
|
1459 |
+
temp_file,
|
1460 |
+
proxies=proxies,
|
1461 |
+
resume_size=resume_size,
|
1462 |
+
headers=headers,
|
1463 |
+
expected_size=expected_size,
|
1464 |
+
)
|
1465 |
+
|
1466 |
+
if local_dir is None:
|
1467 |
+
logger.debug(f"Storing {url} in cache at {blob_path}")
|
1468 |
+
_chmod_and_replace(temp_file.name, blob_path)
|
1469 |
+
_create_symlink(blob_path, pointer_path, new_blob=True)
|
1470 |
+
else:
|
1471 |
+
local_dir_filepath = os.path.join(local_dir, relative_filename)
|
1472 |
+
os.makedirs(os.path.dirname(local_dir_filepath), exist_ok=True)
|
1473 |
+
|
1474 |
+
# If "auto" (default) copy-paste small files to ease manual editing but symlink big files to save disk
|
1475 |
+
# In both cases, blob file is cached.
|
1476 |
+
is_big_file = os.stat(temp_file.name).st_size > constants.HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD
|
1477 |
+
if local_dir_use_symlinks is True or (local_dir_use_symlinks == "auto" and is_big_file):
|
1478 |
+
logger.debug(f"Storing {url} in cache at {blob_path}")
|
1479 |
+
_chmod_and_replace(temp_file.name, blob_path)
|
1480 |
+
logger.debug("Create symlink to local dir")
|
1481 |
+
_create_symlink(blob_path, local_dir_filepath, new_blob=False)
|
1482 |
+
elif local_dir_use_symlinks == "auto" and not is_big_file:
|
1483 |
+
logger.debug(f"Storing {url} in cache at {blob_path}")
|
1484 |
+
_chmod_and_replace(temp_file.name, blob_path)
|
1485 |
+
logger.debug("Duplicate in local dir (small file and use_symlink set to 'auto')")
|
1486 |
+
shutil.copyfile(blob_path, local_dir_filepath)
|
1487 |
+
else:
|
1488 |
+
logger.debug(f"Storing {url} in local_dir at {local_dir_filepath} (not cached).")
|
1489 |
+
_chmod_and_replace(temp_file.name, local_dir_filepath)
|
1490 |
+
pointer_path = local_dir_filepath # for return value
|
1491 |
+
|
1492 |
+
return pointer_path
|
1493 |
+
|
1494 |
+
|
1495 |
+
@validate_hf_hub_args
|
1496 |
+
def try_to_load_from_cache(
|
1497 |
+
repo_id: str,
|
1498 |
+
filename: str,
|
1499 |
+
cache_dir: Union[str, Path, None] = None,
|
1500 |
+
revision: Optional[str] = None,
|
1501 |
+
repo_type: Optional[str] = None,
|
1502 |
+
) -> Union[str, _CACHED_NO_EXIST_T, None]:
|
1503 |
+
"""
|
1504 |
+
Explores the cache to return the latest cached file for a given revision if found.
|
1505 |
+
|
1506 |
+
This function will not raise any exception if the file in not cached.
|
1507 |
+
|
1508 |
+
Args:
|
1509 |
+
cache_dir (`str` or `os.PathLike`):
|
1510 |
+
The folder where the cached files lie.
|
1511 |
+
repo_id (`str`):
|
1512 |
+
The ID of the repo on huggingface.co.
|
1513 |
+
filename (`str`):
|
1514 |
+
The filename to look for inside `repo_id`.
|
1515 |
+
revision (`str`, *optional*):
|
1516 |
+
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
|
1517 |
+
provided either.
|
1518 |
+
repo_type (`str`, *optional*):
|
1519 |
+
The type of the repository. Will default to `"model"`.
|
1520 |
+
|
1521 |
+
Returns:
|
1522 |
+
`Optional[str]` or `_CACHED_NO_EXIST`:
|
1523 |
+
Will return `None` if the file was not cached. Otherwise:
|
1524 |
+
- The exact path to the cached file if it's found in the cache
|
1525 |
+
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
|
1526 |
+
cached.
|
1527 |
+
|
1528 |
+
Example:
|
1529 |
+
|
1530 |
+
```python
|
1531 |
+
from huggingface_hub import try_to_load_from_cache, _CACHED_NO_EXIST
|
1532 |
+
|
1533 |
+
filepath = try_to_load_from_cache()
|
1534 |
+
if isinstance(filepath, str):
|
1535 |
+
# file exists and is cached
|
1536 |
+
...
|
1537 |
+
elif filepath is _CACHED_NO_EXIST:
|
1538 |
+
# non-existence of file is cached
|
1539 |
+
...
|
1540 |
+
else:
|
1541 |
+
# file is not cached
|
1542 |
+
...
|
1543 |
+
```
|
1544 |
+
"""
|
1545 |
+
if revision is None:
|
1546 |
+
revision = "main"
|
1547 |
+
if repo_type is None:
|
1548 |
+
repo_type = "model"
|
1549 |
+
if repo_type not in REPO_TYPES:
|
1550 |
+
raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}")
|
1551 |
+
if cache_dir is None:
|
1552 |
+
cache_dir = HF_HUB_CACHE
|
1553 |
+
|
1554 |
+
object_id = repo_id.replace("/", "--")
|
1555 |
+
repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
|
1556 |
+
if not os.path.isdir(repo_cache):
|
1557 |
+
# No cache for this model
|
1558 |
+
return None
|
1559 |
+
|
1560 |
+
refs_dir = os.path.join(repo_cache, "refs")
|
1561 |
+
snapshots_dir = os.path.join(repo_cache, "snapshots")
|
1562 |
+
no_exist_dir = os.path.join(repo_cache, ".no_exist")
|
1563 |
+
|
1564 |
+
# Resolve refs (for instance to convert main to the associated commit sha)
|
1565 |
+
if os.path.isdir(refs_dir):
|
1566 |
+
revision_file = os.path.join(refs_dir, revision)
|
1567 |
+
if os.path.isfile(revision_file):
|
1568 |
+
with open(revision_file) as f:
|
1569 |
+
revision = f.read()
|
1570 |
+
|
1571 |
+
# Check if file is cached as "no_exist"
|
1572 |
+
if os.path.isfile(os.path.join(no_exist_dir, revision, filename)):
|
1573 |
+
return _CACHED_NO_EXIST
|
1574 |
+
|
1575 |
+
# Check if revision folder exists
|
1576 |
+
if not os.path.exists(snapshots_dir):
|
1577 |
+
return None
|
1578 |
+
cached_shas = os.listdir(snapshots_dir)
|
1579 |
+
if revision not in cached_shas:
|
1580 |
+
# No cache for this revision and we won't try to return a random revision
|
1581 |
+
return None
|
1582 |
+
|
1583 |
+
# Check if file exists in cache
|
1584 |
+
cached_file = os.path.join(snapshots_dir, revision, filename)
|
1585 |
+
return cached_file if os.path.isfile(cached_file) else None
|
1586 |
+
|
1587 |
+
|
1588 |
+
@validate_hf_hub_args
|
1589 |
+
def get_hf_file_metadata(
|
1590 |
+
url: str,
|
1591 |
+
token: Union[bool, str, None] = None,
|
1592 |
+
proxies: Optional[Dict] = None,
|
1593 |
+
timeout: Optional[float] = DEFAULT_REQUEST_TIMEOUT,
|
1594 |
+
library_name: Optional[str] = None,
|
1595 |
+
library_version: Optional[str] = None,
|
1596 |
+
user_agent: Union[Dict, str, None] = None,
|
1597 |
+
) -> HfFileMetadata:
|
1598 |
+
"""Fetch metadata of a file versioned on the Hub for a given url.
|
1599 |
+
|
1600 |
+
Args:
|
1601 |
+
url (`str`):
|
1602 |
+
File url, for example returned by [`hf_hub_url`].
|
1603 |
+
token (`str` or `bool`, *optional*):
|
1604 |
+
A token to be used for the download.
|
1605 |
+
- If `True`, the token is read from the HuggingFace config
|
1606 |
+
folder.
|
1607 |
+
- If `False` or `None`, no token is provided.
|
1608 |
+
- If a string, it's used as the authentication token.
|
1609 |
+
proxies (`dict`, *optional*):
|
1610 |
+
Dictionary mapping protocol to the URL of the proxy passed to
|
1611 |
+
`requests.request`.
|
1612 |
+
timeout (`float`, *optional*, defaults to 10):
|
1613 |
+
How many seconds to wait for the server to send metadata before giving up.
|
1614 |
+
library_name (`str`, *optional*):
|
1615 |
+
The name of the library to which the object corresponds.
|
1616 |
+
library_version (`str`, *optional*):
|
1617 |
+
The version of the library.
|
1618 |
+
user_agent (`dict`, `str`, *optional*):
|
1619 |
+
The user-agent info in the form of a dictionary or a string.
|
1620 |
+
|
1621 |
+
Returns:
|
1622 |
+
A [`HfFileMetadata`] object containing metadata such as location, etag, size and
|
1623 |
+
commit_hash.
|
1624 |
+
"""
|
1625 |
+
headers = build_hf_headers(
|
1626 |
+
token=token, library_name=library_name, library_version=library_version, user_agent=user_agent
|
1627 |
+
)
|
1628 |
+
headers["Accept-Encoding"] = "identity" # prevent any compression => we want to know the real size of the file
|
1629 |
+
|
1630 |
+
# Retrieve metadata
|
1631 |
+
r = _request_wrapper(
|
1632 |
+
method="HEAD",
|
1633 |
+
url=url,
|
1634 |
+
headers=headers,
|
1635 |
+
allow_redirects=False,
|
1636 |
+
follow_relative_redirects=True,
|
1637 |
+
proxies=proxies,
|
1638 |
+
timeout=timeout,
|
1639 |
+
)
|
1640 |
+
hf_raise_for_status(r)
|
1641 |
+
|
1642 |
+
# Return
|
1643 |
+
return HfFileMetadata(
|
1644 |
+
commit_hash=r.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT),
|
1645 |
+
# We favor a custom header indicating the etag of the linked resource, and
|
1646 |
+
# we fallback to the regular etag header.
|
1647 |
+
etag=_normalize_etag(r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag")),
|
1648 |
+
# Either from response headers (if redirected) or defaults to request url
|
1649 |
+
# Do not use directly `url`, as `_request_wrapper` might have followed relative
|
1650 |
+
# redirects.
|
1651 |
+
location=r.headers.get("Location") or r.request.url, # type: ignore
|
1652 |
+
size=_int_or_none(r.headers.get(HUGGINGFACE_HEADER_X_LINKED_SIZE) or r.headers.get("Content-Length")),
|
1653 |
+
)
|
1654 |
+
|
1655 |
+
|
1656 |
+
def _int_or_none(value: Optional[str]) -> Optional[int]:
|
1657 |
+
try:
|
1658 |
+
return int(value) # type: ignore
|
1659 |
+
except (TypeError, ValueError):
|
1660 |
+
return None
|
1661 |
+
|
1662 |
+
|
1663 |
+
def _chmod_and_replace(src: str, dst: str) -> None:
|
1664 |
+
"""Set correct permission before moving a blob from tmp directory to cache dir.
|
1665 |
+
|
1666 |
+
Do not take into account the `umask` from the process as there is no convenient way
|
1667 |
+
to get it that is thread-safe.
|
1668 |
+
|
1669 |
+
See:
|
1670 |
+
- About umask: https://docs.python.org/3/library/os.html#os.umask
|
1671 |
+
- Thread-safety: https://stackoverflow.com/a/70343066
|
1672 |
+
- About solution: https://github.com/huggingface/huggingface_hub/pull/1220#issuecomment-1326211591
|
1673 |
+
- Fix issue: https://github.com/huggingface/huggingface_hub/issues/1141
|
1674 |
+
- Fix issue: https://github.com/huggingface/huggingface_hub/issues/1215
|
1675 |
+
"""
|
1676 |
+
# Get umask by creating a temporary file in the cached repo folder.
|
1677 |
+
tmp_file = Path(dst).parent.parent / f"tmp_{uuid.uuid4()}"
|
1678 |
+
try:
|
1679 |
+
tmp_file.touch()
|
1680 |
+
cache_dir_mode = Path(tmp_file).stat().st_mode
|
1681 |
+
os.chmod(src, stat.S_IMODE(cache_dir_mode))
|
1682 |
+
finally:
|
1683 |
+
tmp_file.unlink()
|
1684 |
+
|
1685 |
+
shutil.move(src, dst)
|
1686 |
+
|
1687 |
+
|
1688 |
+
def _get_pointer_path(storage_folder: str, revision: str, relative_filename: str) -> str:
|
1689 |
+
# Using `os.path.abspath` instead of `Path.resolve()` to avoid resolving symlinks
|
1690 |
+
snapshot_path = os.path.join(storage_folder, "snapshots")
|
1691 |
+
pointer_path = os.path.join(snapshot_path, revision, relative_filename)
|
1692 |
+
if Path(os.path.abspath(snapshot_path)) not in Path(os.path.abspath(pointer_path)).parents:
|
1693 |
+
raise ValueError(
|
1694 |
+
"Invalid pointer path: cannot create pointer path in snapshot folder if"
|
1695 |
+
f" `storage_folder='{storage_folder}'`, `revision='{revision}'` and"
|
1696 |
+
f" `relative_filename='{relative_filename}'`."
|
1697 |
+
)
|
1698 |
+
return pointer_path
|
1699 |
+
|
1700 |
+
|
1701 |
+
def _to_local_dir(
|
1702 |
+
path: str, local_dir: str, relative_filename: str, use_symlinks: Union[bool, Literal["auto"]]
|
1703 |
+
) -> str:
|
1704 |
+
"""Place a file in a local dir (different than cache_dir).
|
1705 |
+
|
1706 |
+
Either symlink to blob file in cache or duplicate file depending on `use_symlinks` and file size.
|
1707 |
+
"""
|
1708 |
+
# Using `os.path.abspath` instead of `Path.resolve()` to avoid resolving symlinks
|
1709 |
+
local_dir_filepath = os.path.join(local_dir, relative_filename)
|
1710 |
+
if Path(os.path.abspath(local_dir)) not in Path(os.path.abspath(local_dir_filepath)).parents:
|
1711 |
+
raise ValueError(
|
1712 |
+
f"Cannot copy file '{relative_filename}' to local dir '{local_dir}': file would not be in the local"
|
1713 |
+
" directory."
|
1714 |
+
)
|
1715 |
+
|
1716 |
+
os.makedirs(os.path.dirname(local_dir_filepath), exist_ok=True)
|
1717 |
+
real_blob_path = os.path.realpath(path)
|
1718 |
+
|
1719 |
+
# If "auto" (default) copy-paste small files to ease manual editing but symlink big files to save disk
|
1720 |
+
if use_symlinks == "auto":
|
1721 |
+
use_symlinks = os.stat(real_blob_path).st_size > constants.HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD
|
1722 |
+
|
1723 |
+
if use_symlinks:
|
1724 |
+
_create_symlink(real_blob_path, local_dir_filepath, new_blob=False)
|
1725 |
+
else:
|
1726 |
+
shutil.copyfile(real_blob_path, local_dir_filepath)
|
1727 |
+
return local_dir_filepath
|
lib/python3.11/site-packages/huggingface_hub/hf_api.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
lib/python3.11/site-packages/huggingface_hub/hf_file_system.py
ADDED
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import tempfile
|
5 |
+
from collections import deque
|
6 |
+
from dataclasses import dataclass, field
|
7 |
+
from datetime import datetime
|
8 |
+
from itertools import chain
|
9 |
+
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
|
10 |
+
from urllib.parse import quote, unquote
|
11 |
+
|
12 |
+
import fsspec
|
13 |
+
|
14 |
+
from ._commit_api import CommitOperationCopy, CommitOperationDelete
|
15 |
+
from .constants import DEFAULT_REVISION, ENDPOINT, REPO_TYPE_MODEL, REPO_TYPES_MAPPING, REPO_TYPES_URL_PREFIXES
|
16 |
+
from .file_download import hf_hub_url
|
17 |
+
from .hf_api import HfApi, LastCommitInfo, RepoFile
|
18 |
+
from .utils import (
|
19 |
+
EntryNotFoundError,
|
20 |
+
HFValidationError,
|
21 |
+
RepositoryNotFoundError,
|
22 |
+
RevisionNotFoundError,
|
23 |
+
hf_raise_for_status,
|
24 |
+
http_backoff,
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
# Regex used to match special revisions with "/" in them (see #1710)
|
29 |
+
SPECIAL_REFS_REVISION_REGEX = re.compile(
|
30 |
+
r"""
|
31 |
+
(^refs\/convert\/\w+) # `refs/convert/parquet` revisions
|
32 |
+
|
|
33 |
+
(^refs\/pr\/\d+) # PR revisions
|
34 |
+
""",
|
35 |
+
re.VERBOSE,
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
@dataclass
|
40 |
+
class HfFileSystemResolvedPath:
|
41 |
+
"""Data structure containing information about a resolved Hugging Face file system path."""
|
42 |
+
|
43 |
+
repo_type: str
|
44 |
+
repo_id: str
|
45 |
+
revision: str
|
46 |
+
path_in_repo: str
|
47 |
+
# The part placed after '@' in the initial path. It can be a quoted or unquoted refs revision.
|
48 |
+
# Used to reconstruct the unresolved path to return to the user.
|
49 |
+
_raw_revision: Optional[str] = field(default=None, repr=False)
|
50 |
+
|
51 |
+
def unresolve(self) -> str:
|
52 |
+
repo_path = REPO_TYPES_URL_PREFIXES.get(self.repo_type, "") + self.repo_id
|
53 |
+
if self._raw_revision:
|
54 |
+
return f"{repo_path}@{self._raw_revision}/{self.path_in_repo}".rstrip("/")
|
55 |
+
elif self.revision != DEFAULT_REVISION:
|
56 |
+
return f"{repo_path}@{safe_revision(self.revision)}/{self.path_in_repo}".rstrip("/")
|
57 |
+
else:
|
58 |
+
return f"{repo_path}/{self.path_in_repo}".rstrip("/")
|
59 |
+
|
60 |
+
|
61 |
+
class HfFileSystem(fsspec.AbstractFileSystem):
|
62 |
+
"""
|
63 |
+
Access a remote Hugging Face Hub repository as if were a local file system.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
token (`str`, *optional*):
|
67 |
+
Authentication token, obtained with [`HfApi.login`] method. Will default to the stored token.
|
68 |
+
|
69 |
+
Usage:
|
70 |
+
|
71 |
+
```python
|
72 |
+
>>> from huggingface_hub import HfFileSystem
|
73 |
+
|
74 |
+
>>> fs = HfFileSystem()
|
75 |
+
|
76 |
+
>>> # List files
|
77 |
+
>>> fs.glob("my-username/my-model/*.bin")
|
78 |
+
['my-username/my-model/pytorch_model.bin']
|
79 |
+
>>> fs.ls("datasets/my-username/my-dataset", detail=False)
|
80 |
+
['datasets/my-username/my-dataset/.gitattributes', 'datasets/my-username/my-dataset/README.md', 'datasets/my-username/my-dataset/data.json']
|
81 |
+
|
82 |
+
>>> # Read/write files
|
83 |
+
>>> with fs.open("my-username/my-model/pytorch_model.bin") as f:
|
84 |
+
... data = f.read()
|
85 |
+
>>> with fs.open("my-username/my-model/pytorch_model.bin", "wb") as f:
|
86 |
+
... f.write(data)
|
87 |
+
```
|
88 |
+
"""
|
89 |
+
|
90 |
+
root_marker = ""
|
91 |
+
protocol = "hf"
|
92 |
+
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
*args,
|
96 |
+
endpoint: Optional[str] = None,
|
97 |
+
token: Optional[str] = None,
|
98 |
+
**storage_options,
|
99 |
+
):
|
100 |
+
super().__init__(*args, **storage_options)
|
101 |
+
self.endpoint = endpoint or ENDPOINT
|
102 |
+
self.token = token
|
103 |
+
self._api = HfApi(endpoint=endpoint, token=token)
|
104 |
+
# Maps (repo_type, repo_id, revision) to a 2-tuple with:
|
105 |
+
# * the 1st element indicating whether the repositoy and the revision exist
|
106 |
+
# * the 2nd element being the exception raised if the repository or revision doesn't exist
|
107 |
+
self._repo_and_revision_exists_cache: Dict[
|
108 |
+
Tuple[str, str, Optional[str]], Tuple[bool, Optional[Exception]]
|
109 |
+
] = {}
|
110 |
+
|
111 |
+
def _repo_and_revision_exist(
|
112 |
+
self, repo_type: str, repo_id: str, revision: Optional[str]
|
113 |
+
) -> Tuple[bool, Optional[Exception]]:
|
114 |
+
if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache:
|
115 |
+
try:
|
116 |
+
self._api.repo_info(repo_id, revision=revision, repo_type=repo_type)
|
117 |
+
except (RepositoryNotFoundError, HFValidationError) as e:
|
118 |
+
self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e
|
119 |
+
self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = False, e
|
120 |
+
except RevisionNotFoundError as e:
|
121 |
+
self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e
|
122 |
+
self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None
|
123 |
+
else:
|
124 |
+
self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = True, None
|
125 |
+
self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None
|
126 |
+
return self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)]
|
127 |
+
|
128 |
+
def resolve_path(self, path: str, revision: Optional[str] = None) -> HfFileSystemResolvedPath:
|
129 |
+
def _align_revision_in_path_with_revision(
|
130 |
+
revision_in_path: Optional[str], revision: Optional[str]
|
131 |
+
) -> Optional[str]:
|
132 |
+
if revision is not None:
|
133 |
+
if revision_in_path is not None and revision_in_path != revision:
|
134 |
+
raise ValueError(
|
135 |
+
f'Revision specified in path ("{revision_in_path}") and in `revision` argument ("{revision}")'
|
136 |
+
" are not the same."
|
137 |
+
)
|
138 |
+
else:
|
139 |
+
revision = revision_in_path
|
140 |
+
return revision
|
141 |
+
|
142 |
+
path = self._strip_protocol(path)
|
143 |
+
if not path:
|
144 |
+
# can't list repositories at root
|
145 |
+
raise NotImplementedError("Access to repositories lists is not implemented.")
|
146 |
+
elif path.split("/")[0] + "/" in REPO_TYPES_URL_PREFIXES.values():
|
147 |
+
if "/" not in path:
|
148 |
+
# can't list repositories at the repository type level
|
149 |
+
raise NotImplementedError("Access to repositories lists is not implemented.")
|
150 |
+
repo_type, path = path.split("/", 1)
|
151 |
+
repo_type = REPO_TYPES_MAPPING[repo_type]
|
152 |
+
else:
|
153 |
+
repo_type = REPO_TYPE_MODEL
|
154 |
+
if path.count("/") > 0:
|
155 |
+
if "@" in path:
|
156 |
+
repo_id, revision_in_path = path.split("@", 1)
|
157 |
+
if "/" in revision_in_path:
|
158 |
+
match = SPECIAL_REFS_REVISION_REGEX.search(revision_in_path)
|
159 |
+
if match is not None and revision in (None, match.group()):
|
160 |
+
# Handle `refs/convert/parquet` and PR revisions separately
|
161 |
+
path_in_repo = SPECIAL_REFS_REVISION_REGEX.sub("", revision_in_path).lstrip("/")
|
162 |
+
revision_in_path = match.group()
|
163 |
+
else:
|
164 |
+
revision_in_path, path_in_repo = revision_in_path.split("/", 1)
|
165 |
+
else:
|
166 |
+
path_in_repo = ""
|
167 |
+
revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision)
|
168 |
+
repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision)
|
169 |
+
if not repo_and_revision_exist:
|
170 |
+
_raise_file_not_found(path, err)
|
171 |
+
else:
|
172 |
+
revision_in_path = None
|
173 |
+
repo_id_with_namespace = "/".join(path.split("/")[:2])
|
174 |
+
path_in_repo_with_namespace = "/".join(path.split("/")[2:])
|
175 |
+
repo_id_without_namespace = path.split("/")[0]
|
176 |
+
path_in_repo_without_namespace = "/".join(path.split("/")[1:])
|
177 |
+
repo_id = repo_id_with_namespace
|
178 |
+
path_in_repo = path_in_repo_with_namespace
|
179 |
+
repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision)
|
180 |
+
if not repo_and_revision_exist:
|
181 |
+
if isinstance(err, (RepositoryNotFoundError, HFValidationError)):
|
182 |
+
repo_id = repo_id_without_namespace
|
183 |
+
path_in_repo = path_in_repo_without_namespace
|
184 |
+
repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision)
|
185 |
+
if not repo_and_revision_exist:
|
186 |
+
_raise_file_not_found(path, err)
|
187 |
+
else:
|
188 |
+
_raise_file_not_found(path, err)
|
189 |
+
else:
|
190 |
+
repo_id = path
|
191 |
+
path_in_repo = ""
|
192 |
+
if "@" in path:
|
193 |
+
repo_id, revision_in_path = path.split("@", 1)
|
194 |
+
revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision)
|
195 |
+
else:
|
196 |
+
revision_in_path = None
|
197 |
+
repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision)
|
198 |
+
if not repo_and_revision_exist:
|
199 |
+
raise NotImplementedError("Access to repositories lists is not implemented.")
|
200 |
+
|
201 |
+
revision = revision if revision is not None else DEFAULT_REVISION
|
202 |
+
return HfFileSystemResolvedPath(repo_type, repo_id, revision, path_in_repo, _raw_revision=revision_in_path)
|
203 |
+
|
204 |
+
def invalidate_cache(self, path: Optional[str] = None) -> None:
|
205 |
+
if not path:
|
206 |
+
self.dircache.clear()
|
207 |
+
self._repo_and_revision_exists_cache.clear()
|
208 |
+
else:
|
209 |
+
path = self.resolve_path(path).unresolve()
|
210 |
+
while path:
|
211 |
+
self.dircache.pop(path, None)
|
212 |
+
path = self._parent(path)
|
213 |
+
|
214 |
+
def _open(
|
215 |
+
self,
|
216 |
+
path: str,
|
217 |
+
mode: str = "rb",
|
218 |
+
revision: Optional[str] = None,
|
219 |
+
**kwargs,
|
220 |
+
) -> "HfFileSystemFile":
|
221 |
+
if "a" in mode:
|
222 |
+
raise NotImplementedError("Appending to remote files is not yet supported.")
|
223 |
+
return HfFileSystemFile(self, path, mode=mode, revision=revision, **kwargs)
|
224 |
+
|
225 |
+
def _rm(self, path: str, revision: Optional[str] = None, **kwargs) -> None:
|
226 |
+
resolved_path = self.resolve_path(path, revision=revision)
|
227 |
+
self._api.delete_file(
|
228 |
+
path_in_repo=resolved_path.path_in_repo,
|
229 |
+
repo_id=resolved_path.repo_id,
|
230 |
+
token=self.token,
|
231 |
+
repo_type=resolved_path.repo_type,
|
232 |
+
revision=resolved_path.revision,
|
233 |
+
commit_message=kwargs.get("commit_message"),
|
234 |
+
commit_description=kwargs.get("commit_description"),
|
235 |
+
)
|
236 |
+
self.invalidate_cache(path=resolved_path.unresolve())
|
237 |
+
|
238 |
+
def rm(
|
239 |
+
self,
|
240 |
+
path: str,
|
241 |
+
recursive: bool = False,
|
242 |
+
maxdepth: Optional[int] = None,
|
243 |
+
revision: Optional[str] = None,
|
244 |
+
**kwargs,
|
245 |
+
) -> None:
|
246 |
+
resolved_path = self.resolve_path(path, revision=revision)
|
247 |
+
root_path = REPO_TYPES_URL_PREFIXES.get(resolved_path.repo_type, "") + resolved_path.repo_id
|
248 |
+
paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth, revision=revision)
|
249 |
+
paths_in_repo = [path[len(root_path) + 1 :] for path in paths if not self.isdir(path)]
|
250 |
+
operations = [CommitOperationDelete(path_in_repo=path_in_repo) for path_in_repo in paths_in_repo]
|
251 |
+
commit_message = f"Delete {path} "
|
252 |
+
commit_message += "recursively " if recursive else ""
|
253 |
+
commit_message += f"up to depth {maxdepth} " if maxdepth is not None else ""
|
254 |
+
# TODO: use `commit_description` to list all the deleted paths?
|
255 |
+
self._api.create_commit(
|
256 |
+
repo_id=resolved_path.repo_id,
|
257 |
+
repo_type=resolved_path.repo_type,
|
258 |
+
token=self.token,
|
259 |
+
operations=operations,
|
260 |
+
revision=resolved_path.revision,
|
261 |
+
commit_message=kwargs.get("commit_message", commit_message),
|
262 |
+
commit_description=kwargs.get("commit_description"),
|
263 |
+
)
|
264 |
+
self.invalidate_cache(path=resolved_path.unresolve())
|
265 |
+
|
266 |
+
def ls(
|
267 |
+
self, path: str, detail: bool = True, refresh: bool = False, revision: Optional[str] = None, **kwargs
|
268 |
+
) -> List[Union[str, Dict[str, Any]]]:
|
269 |
+
"""List the contents of a directory."""
|
270 |
+
resolved_path = self.resolve_path(path, revision=revision)
|
271 |
+
path = resolved_path.unresolve()
|
272 |
+
kwargs = {"expand_info": detail, **kwargs}
|
273 |
+
try:
|
274 |
+
out = self._ls_tree(path, refresh=refresh, revision=revision, **kwargs)
|
275 |
+
except EntryNotFoundError:
|
276 |
+
# Path could be a file
|
277 |
+
if not resolved_path.path_in_repo:
|
278 |
+
_raise_file_not_found(path, None)
|
279 |
+
out = self._ls_tree(self._parent(path), refresh=refresh, revision=revision, **kwargs)
|
280 |
+
out = [o for o in out if o["name"] == path]
|
281 |
+
if len(out) == 0:
|
282 |
+
_raise_file_not_found(path, None)
|
283 |
+
return out if detail else [o["name"] for o in out]
|
284 |
+
|
285 |
+
def _ls_tree(
|
286 |
+
self,
|
287 |
+
path: str,
|
288 |
+
recursive: bool = False,
|
289 |
+
refresh: bool = False,
|
290 |
+
revision: Optional[str] = None,
|
291 |
+
expand_info: bool = True,
|
292 |
+
):
|
293 |
+
resolved_path = self.resolve_path(path, revision=revision)
|
294 |
+
path = resolved_path.unresolve()
|
295 |
+
root_path = HfFileSystemResolvedPath(
|
296 |
+
resolved_path.repo_type,
|
297 |
+
resolved_path.repo_id,
|
298 |
+
resolved_path.revision,
|
299 |
+
path_in_repo="",
|
300 |
+
_raw_revision=resolved_path._raw_revision,
|
301 |
+
).unresolve()
|
302 |
+
|
303 |
+
out = []
|
304 |
+
if path in self.dircache and not refresh:
|
305 |
+
cached_path_infos = self.dircache[path]
|
306 |
+
out.extend(cached_path_infos)
|
307 |
+
dirs_not_in_dircache = []
|
308 |
+
if recursive:
|
309 |
+
# Use BFS to traverse the cache and build the "recursive "output
|
310 |
+
# (The Hub uses a so-called "tree first" strategy for the tree endpoint but we sort the output to follow the spec so the result is (eventually) the same)
|
311 |
+
dirs_to_visit = deque(
|
312 |
+
[path_info for path_info in cached_path_infos if path_info["type"] == "directory"]
|
313 |
+
)
|
314 |
+
while dirs_to_visit:
|
315 |
+
dir_info = dirs_to_visit.popleft()
|
316 |
+
if dir_info["name"] not in self.dircache:
|
317 |
+
dirs_not_in_dircache.append(dir_info["name"])
|
318 |
+
else:
|
319 |
+
cached_path_infos = self.dircache[dir_info["name"]]
|
320 |
+
out.extend(cached_path_infos)
|
321 |
+
dirs_to_visit.extend(
|
322 |
+
[path_info for path_info in cached_path_infos if path_info["type"] == "directory"]
|
323 |
+
)
|
324 |
+
|
325 |
+
dirs_not_expanded = []
|
326 |
+
if expand_info:
|
327 |
+
# Check if there are directories with non-expanded entries
|
328 |
+
dirs_not_expanded = [self._parent(o["name"]) for o in out if o["last_commit"] is None]
|
329 |
+
|
330 |
+
if (recursive and dirs_not_in_dircache) or (expand_info and dirs_not_expanded):
|
331 |
+
# If the dircache is incomplete, find the common path of the missing and non-expanded entries
|
332 |
+
# and extend the output with the result of `_ls_tree(common_path, recursive=True)`
|
333 |
+
common_prefix = os.path.commonprefix(dirs_not_in_dircache + dirs_not_expanded)
|
334 |
+
# Get the parent directory if the common prefix itself is not a directory
|
335 |
+
common_path = (
|
336 |
+
common_prefix.rstrip("/")
|
337 |
+
if common_prefix.endswith("/")
|
338 |
+
or common_prefix == root_path
|
339 |
+
or common_prefix in chain(dirs_not_in_dircache, dirs_not_expanded)
|
340 |
+
else self._parent(common_prefix)
|
341 |
+
)
|
342 |
+
out = [o for o in out if not o["name"].startswith(common_path + "/")]
|
343 |
+
for cached_path in self.dircache:
|
344 |
+
if cached_path.startswith(common_path + "/"):
|
345 |
+
self.dircache.pop(cached_path, None)
|
346 |
+
self.dircache.pop(common_path, None)
|
347 |
+
out.extend(
|
348 |
+
self._ls_tree(
|
349 |
+
common_path,
|
350 |
+
recursive=recursive,
|
351 |
+
refresh=True,
|
352 |
+
revision=revision,
|
353 |
+
expand_info=expand_info,
|
354 |
+
)
|
355 |
+
)
|
356 |
+
else:
|
357 |
+
tree = self._api.list_repo_tree(
|
358 |
+
resolved_path.repo_id,
|
359 |
+
resolved_path.path_in_repo,
|
360 |
+
recursive=recursive,
|
361 |
+
expand=expand_info,
|
362 |
+
revision=resolved_path.revision,
|
363 |
+
repo_type=resolved_path.repo_type,
|
364 |
+
)
|
365 |
+
for path_info in tree:
|
366 |
+
if isinstance(path_info, RepoFile):
|
367 |
+
cache_path_info = {
|
368 |
+
"name": root_path + "/" + path_info.path,
|
369 |
+
"size": path_info.size,
|
370 |
+
"type": "file",
|
371 |
+
"blob_id": path_info.blob_id,
|
372 |
+
"lfs": path_info.lfs,
|
373 |
+
"last_commit": path_info.last_commit,
|
374 |
+
"security": path_info.security,
|
375 |
+
}
|
376 |
+
else:
|
377 |
+
cache_path_info = {
|
378 |
+
"name": root_path + "/" + path_info.path,
|
379 |
+
"size": 0,
|
380 |
+
"type": "directory",
|
381 |
+
"tree_id": path_info.tree_id,
|
382 |
+
"last_commit": path_info.last_commit,
|
383 |
+
}
|
384 |
+
parent_path = self._parent(cache_path_info["name"])
|
385 |
+
self.dircache.setdefault(parent_path, []).append(cache_path_info)
|
386 |
+
out.append(cache_path_info)
|
387 |
+
return copy.deepcopy(out) # copy to not let users modify the dircache
|
388 |
+
|
389 |
+
def glob(self, path, **kwargs):
|
390 |
+
# Set expand_info=False by default to get a x10 speed boost
|
391 |
+
kwargs = {"expand_info": kwargs.get("detail", False), **kwargs}
|
392 |
+
path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
|
393 |
+
return super().glob(path, **kwargs)
|
394 |
+
|
395 |
+
def find(
|
396 |
+
self,
|
397 |
+
path: str,
|
398 |
+
maxdepth: Optional[int] = None,
|
399 |
+
withdirs: bool = False,
|
400 |
+
detail: bool = False,
|
401 |
+
refresh: bool = False,
|
402 |
+
revision: Optional[str] = None,
|
403 |
+
**kwargs,
|
404 |
+
) -> Union[List[str], Dict[str, Dict[str, Any]]]:
|
405 |
+
if maxdepth:
|
406 |
+
return super().find(
|
407 |
+
path, maxdepth=maxdepth, withdirs=withdirs, detail=detail, refresh=refresh, revision=revision, **kwargs
|
408 |
+
)
|
409 |
+
resolved_path = self.resolve_path(path, revision=revision)
|
410 |
+
path = resolved_path.unresolve()
|
411 |
+
kwargs = {"expand_info": detail, **kwargs}
|
412 |
+
try:
|
413 |
+
out = self._ls_tree(path, recursive=True, refresh=refresh, revision=resolved_path.revision, **kwargs)
|
414 |
+
except EntryNotFoundError:
|
415 |
+
# Path could be a file
|
416 |
+
if self.info(path, revision=revision, **kwargs)["type"] == "file":
|
417 |
+
out = {path: {}}
|
418 |
+
else:
|
419 |
+
out = {}
|
420 |
+
else:
|
421 |
+
if not withdirs:
|
422 |
+
out = [o for o in out if o["type"] != "directory"]
|
423 |
+
else:
|
424 |
+
# If `withdirs=True`, include the directory itself to be consistent with the spec
|
425 |
+
path_info = self.info(path, revision=resolved_path.revision, **kwargs)
|
426 |
+
out = [path_info] + out if path_info["type"] == "directory" else out
|
427 |
+
out = {o["name"]: o for o in out}
|
428 |
+
names = sorted(out)
|
429 |
+
if not detail:
|
430 |
+
return names
|
431 |
+
else:
|
432 |
+
return {name: out[name] for name in names}
|
433 |
+
|
434 |
+
def cp_file(self, path1: str, path2: str, revision: Optional[str] = None, **kwargs) -> None:
|
435 |
+
resolved_path1 = self.resolve_path(path1, revision=revision)
|
436 |
+
resolved_path2 = self.resolve_path(path2, revision=revision)
|
437 |
+
|
438 |
+
same_repo = (
|
439 |
+
resolved_path1.repo_type == resolved_path2.repo_type and resolved_path1.repo_id == resolved_path2.repo_id
|
440 |
+
)
|
441 |
+
|
442 |
+
if same_repo and self.info(path1, revision=resolved_path1.revision)["lfs"] is not None:
|
443 |
+
commit_message = f"Copy {path1} to {path2}"
|
444 |
+
self._api.create_commit(
|
445 |
+
repo_id=resolved_path1.repo_id,
|
446 |
+
repo_type=resolved_path1.repo_type,
|
447 |
+
revision=resolved_path2.revision,
|
448 |
+
commit_message=kwargs.get("commit_message", commit_message),
|
449 |
+
commit_description=kwargs.get("commit_description", ""),
|
450 |
+
operations=[
|
451 |
+
CommitOperationCopy(
|
452 |
+
src_path_in_repo=resolved_path1.path_in_repo,
|
453 |
+
path_in_repo=resolved_path2.path_in_repo,
|
454 |
+
src_revision=resolved_path1.revision,
|
455 |
+
)
|
456 |
+
],
|
457 |
+
)
|
458 |
+
else:
|
459 |
+
with self.open(path1, "rb", revision=resolved_path1.revision) as f:
|
460 |
+
content = f.read()
|
461 |
+
commit_message = f"Copy {path1} to {path2}"
|
462 |
+
self._api.upload_file(
|
463 |
+
path_or_fileobj=content,
|
464 |
+
path_in_repo=resolved_path2.path_in_repo,
|
465 |
+
repo_id=resolved_path2.repo_id,
|
466 |
+
token=self.token,
|
467 |
+
repo_type=resolved_path2.repo_type,
|
468 |
+
revision=resolved_path2.revision,
|
469 |
+
commit_message=kwargs.get("commit_message", commit_message),
|
470 |
+
commit_description=kwargs.get("commit_description"),
|
471 |
+
)
|
472 |
+
self.invalidate_cache(path=resolved_path1.unresolve())
|
473 |
+
self.invalidate_cache(path=resolved_path2.unresolve())
|
474 |
+
|
475 |
+
def modified(self, path: str, **kwargs) -> datetime:
|
476 |
+
info = self.info(path, **kwargs)
|
477 |
+
return info["last_commit"]["date"]
|
478 |
+
|
479 |
+
def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) -> Dict[str, Any]:
|
480 |
+
resolved_path = self.resolve_path(path, revision=revision)
|
481 |
+
path = resolved_path.unresolve()
|
482 |
+
expand_info = kwargs.get(
|
483 |
+
"expand_info", True
|
484 |
+
) # don't expose it as a parameter in the public API to follow the spec
|
485 |
+
if not resolved_path.path_in_repo:
|
486 |
+
# Path is the root directory
|
487 |
+
out = {
|
488 |
+
"name": path,
|
489 |
+
"size": 0,
|
490 |
+
"type": "directory",
|
491 |
+
}
|
492 |
+
if expand_info:
|
493 |
+
last_commit = self._api.list_repo_commits(
|
494 |
+
resolved_path.repo_id, repo_type=resolved_path.repo_type, revision=resolved_path.revision
|
495 |
+
)[-1]
|
496 |
+
out = {
|
497 |
+
**out,
|
498 |
+
"tree_id": None, # TODO: tree_id of the root directory?
|
499 |
+
"last_commit": LastCommitInfo(
|
500 |
+
oid=last_commit.commit_id, title=last_commit.title, date=last_commit.created_at
|
501 |
+
),
|
502 |
+
}
|
503 |
+
else:
|
504 |
+
out = None
|
505 |
+
parent_path = self._parent(path)
|
506 |
+
if parent_path in self.dircache:
|
507 |
+
# Check if the path is in the cache
|
508 |
+
out1 = [o for o in self.dircache[parent_path] if o["name"] == path]
|
509 |
+
if not out1:
|
510 |
+
_raise_file_not_found(path, None)
|
511 |
+
out = out1[0]
|
512 |
+
if refresh or out is None or (expand_info and out and out["last_commit"] is None):
|
513 |
+
paths_info = self._api.get_paths_info(
|
514 |
+
resolved_path.repo_id,
|
515 |
+
resolved_path.path_in_repo,
|
516 |
+
expand=expand_info,
|
517 |
+
revision=resolved_path.revision,
|
518 |
+
repo_type=resolved_path.repo_type,
|
519 |
+
)
|
520 |
+
if not paths_info:
|
521 |
+
_raise_file_not_found(path, None)
|
522 |
+
path_info = paths_info[0]
|
523 |
+
root_path = HfFileSystemResolvedPath(
|
524 |
+
resolved_path.repo_type,
|
525 |
+
resolved_path.repo_id,
|
526 |
+
resolved_path.revision,
|
527 |
+
path_in_repo="",
|
528 |
+
_raw_revision=resolved_path._raw_revision,
|
529 |
+
).unresolve()
|
530 |
+
if isinstance(path_info, RepoFile):
|
531 |
+
out = {
|
532 |
+
"name": root_path + "/" + path_info.path,
|
533 |
+
"size": path_info.size,
|
534 |
+
"type": "file",
|
535 |
+
"blob_id": path_info.blob_id,
|
536 |
+
"lfs": path_info.lfs,
|
537 |
+
"last_commit": path_info.last_commit,
|
538 |
+
"security": path_info.security,
|
539 |
+
}
|
540 |
+
else:
|
541 |
+
out = {
|
542 |
+
"name": root_path + "/" + path_info.path,
|
543 |
+
"size": 0,
|
544 |
+
"type": "directory",
|
545 |
+
"tree_id": path_info.tree_id,
|
546 |
+
"last_commit": path_info.last_commit,
|
547 |
+
}
|
548 |
+
if not expand_info:
|
549 |
+
out = {k: out[k] for k in ["name", "size", "type"]}
|
550 |
+
assert out is not None
|
551 |
+
return copy.deepcopy(out) # copy to not let users modify the dircache
|
552 |
+
|
553 |
+
def exists(self, path, **kwargs):
|
554 |
+
"""Is there a file at the given path"""
|
555 |
+
try:
|
556 |
+
self.info(path, expand_info=False, **kwargs)
|
557 |
+
return True
|
558 |
+
except: # noqa: E722
|
559 |
+
# any exception allowed bar FileNotFoundError?
|
560 |
+
return False
|
561 |
+
|
562 |
+
def isdir(self, path):
|
563 |
+
"""Is this entry directory-like?"""
|
564 |
+
try:
|
565 |
+
return self.info(path, expand_info=False)["type"] == "directory"
|
566 |
+
except OSError:
|
567 |
+
return False
|
568 |
+
|
569 |
+
def isfile(self, path):
|
570 |
+
"""Is this entry file-like?"""
|
571 |
+
try:
|
572 |
+
return self.info(path, expand_info=False)["type"] == "file"
|
573 |
+
except: # noqa: E722
|
574 |
+
return False
|
575 |
+
|
576 |
+
@property
|
577 |
+
def transaction(self):
|
578 |
+
"""A context within which files are committed together upon exit
|
579 |
+
|
580 |
+
Requires the file class to implement `.commit()` and `.discard()`
|
581 |
+
for the normal and exception cases.
|
582 |
+
"""
|
583 |
+
# Taken from https://github.com/fsspec/filesystem_spec/blob/3fbb6fee33b46cccb015607630843dea049d3243/fsspec/spec.py#L231
|
584 |
+
# See https://github.com/huggingface/huggingface_hub/issues/1733
|
585 |
+
raise NotImplementedError("Transactional commits are not supported.")
|
586 |
+
|
587 |
+
def start_transaction(self):
|
588 |
+
"""Begin write transaction for deferring files, non-context version"""
|
589 |
+
# Taken from https://github.com/fsspec/filesystem_spec/blob/3fbb6fee33b46cccb015607630843dea049d3243/fsspec/spec.py#L241
|
590 |
+
# See https://github.com/huggingface/huggingface_hub/issues/1733
|
591 |
+
raise NotImplementedError("Transactional commits are not supported.")
|
592 |
+
|
593 |
+
|
594 |
+
class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
|
595 |
+
def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs):
|
596 |
+
super().__init__(fs, path, **kwargs)
|
597 |
+
self.fs: HfFileSystem
|
598 |
+
|
599 |
+
try:
|
600 |
+
self.resolved_path = fs.resolve_path(path, revision=revision)
|
601 |
+
except FileNotFoundError as e:
|
602 |
+
if "w" in kwargs.get("mode", ""):
|
603 |
+
raise FileNotFoundError(
|
604 |
+
f"{e}.\nMake sure the repository and revision exist before writing data."
|
605 |
+
) from e
|
606 |
+
|
607 |
+
def __del__(self):
|
608 |
+
if not hasattr(self, "resolved_path"):
|
609 |
+
# Means that the constructor failed. Nothing to do.
|
610 |
+
return
|
611 |
+
return super().__del__()
|
612 |
+
|
613 |
+
def _fetch_range(self, start: int, end: int) -> bytes:
|
614 |
+
headers = {
|
615 |
+
"range": f"bytes={start}-{end - 1}",
|
616 |
+
**self.fs._api._build_hf_headers(),
|
617 |
+
}
|
618 |
+
url = hf_hub_url(
|
619 |
+
repo_id=self.resolved_path.repo_id,
|
620 |
+
revision=self.resolved_path.revision,
|
621 |
+
filename=self.resolved_path.path_in_repo,
|
622 |
+
repo_type=self.resolved_path.repo_type,
|
623 |
+
endpoint=self.fs.endpoint,
|
624 |
+
)
|
625 |
+
r = http_backoff("GET", url, headers=headers)
|
626 |
+
hf_raise_for_status(r)
|
627 |
+
return r.content
|
628 |
+
|
629 |
+
def _initiate_upload(self) -> None:
|
630 |
+
self.temp_file = tempfile.NamedTemporaryFile(prefix="hffs-", delete=False)
|
631 |
+
|
632 |
+
def _upload_chunk(self, final: bool = False) -> None:
|
633 |
+
self.buffer.seek(0)
|
634 |
+
block = self.buffer.read()
|
635 |
+
self.temp_file.write(block)
|
636 |
+
if final:
|
637 |
+
self.temp_file.close()
|
638 |
+
self.fs._api.upload_file(
|
639 |
+
path_or_fileobj=self.temp_file.name,
|
640 |
+
path_in_repo=self.resolved_path.path_in_repo,
|
641 |
+
repo_id=self.resolved_path.repo_id,
|
642 |
+
token=self.fs.token,
|
643 |
+
repo_type=self.resolved_path.repo_type,
|
644 |
+
revision=self.resolved_path.revision,
|
645 |
+
commit_message=self.kwargs.get("commit_message"),
|
646 |
+
commit_description=self.kwargs.get("commit_description"),
|
647 |
+
)
|
648 |
+
os.remove(self.temp_file.name)
|
649 |
+
self.fs.invalidate_cache(
|
650 |
+
path=self.resolved_path.unresolve(),
|
651 |
+
)
|
652 |
+
|
653 |
+
|
654 |
+
def safe_revision(revision: str) -> str:
|
655 |
+
return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision)
|
656 |
+
|
657 |
+
|
658 |
+
def safe_quote(s: str) -> str:
|
659 |
+
return quote(s, safe="")
|
660 |
+
|
661 |
+
|
662 |
+
def _raise_file_not_found(path: str, err: Optional[Exception]) -> NoReturn:
|
663 |
+
msg = path
|
664 |
+
if isinstance(err, RepositoryNotFoundError):
|
665 |
+
msg = f"{path} (repository not found)"
|
666 |
+
elif isinstance(err, RevisionNotFoundError):
|
667 |
+
msg = f"{path} (revision not found)"
|
668 |
+
elif isinstance(err, HFValidationError):
|
669 |
+
msg = f"{path} (invalid repository id)"
|
670 |
+
raise FileNotFoundError(msg) from err
|
lib/python3.11/site-packages/huggingface_hub/hub_mixin.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Dict, List, Optional, Type, TypeVar, Union
|
5 |
+
|
6 |
+
from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
|
7 |
+
from .file_download import hf_hub_download, is_torch_available
|
8 |
+
from .hf_api import HfApi
|
9 |
+
from .utils import HfHubHTTPError, SoftTemporaryDirectory, logging, validate_hf_hub_args
|
10 |
+
|
11 |
+
|
12 |
+
if is_torch_available():
|
13 |
+
import torch # type: ignore
|
14 |
+
|
15 |
+
logger = logging.get_logger(__name__)
|
16 |
+
|
17 |
+
# Generic variable that is either ModelHubMixin or a subclass thereof
|
18 |
+
T = TypeVar("T", bound="ModelHubMixin")
|
19 |
+
|
20 |
+
|
21 |
+
class ModelHubMixin:
|
22 |
+
"""
|
23 |
+
A generic mixin to integrate ANY machine learning framework with the Hub.
|
24 |
+
|
25 |
+
To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models
|
26 |
+
have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example
|
27 |
+
of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def save_pretrained(
|
31 |
+
self,
|
32 |
+
save_directory: Union[str, Path],
|
33 |
+
*,
|
34 |
+
config: Optional[dict] = None,
|
35 |
+
repo_id: Optional[str] = None,
|
36 |
+
push_to_hub: bool = False,
|
37 |
+
**kwargs,
|
38 |
+
) -> Optional[str]:
|
39 |
+
"""
|
40 |
+
Save weights in local directory.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
save_directory (`str` or `Path`):
|
44 |
+
Path to directory in which the model weights and configuration will be saved.
|
45 |
+
config (`dict`, *optional*):
|
46 |
+
Model configuration specified as a key/value dictionary.
|
47 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
48 |
+
Whether or not to push your model to the Huggingface Hub after saving it.
|
49 |
+
repo_id (`str`, *optional*):
|
50 |
+
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
|
51 |
+
not provided.
|
52 |
+
kwargs:
|
53 |
+
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
|
54 |
+
"""
|
55 |
+
save_directory = Path(save_directory)
|
56 |
+
save_directory.mkdir(parents=True, exist_ok=True)
|
57 |
+
|
58 |
+
# saving model weights/files
|
59 |
+
self._save_pretrained(save_directory)
|
60 |
+
|
61 |
+
# saving config
|
62 |
+
if isinstance(config, dict):
|
63 |
+
(save_directory / CONFIG_NAME).write_text(json.dumps(config))
|
64 |
+
|
65 |
+
if push_to_hub:
|
66 |
+
kwargs = kwargs.copy() # soft-copy to avoid mutating input
|
67 |
+
if config is not None: # kwarg for `push_to_hub`
|
68 |
+
kwargs["config"] = config
|
69 |
+
if repo_id is None:
|
70 |
+
repo_id = save_directory.name # Defaults to `save_directory` name
|
71 |
+
return self.push_to_hub(repo_id=repo_id, **kwargs)
|
72 |
+
return None
|
73 |
+
|
74 |
+
def _save_pretrained(self, save_directory: Path) -> None:
|
75 |
+
"""
|
76 |
+
Overwrite this method in subclass to define how to save your model.
|
77 |
+
Check out our [integration guide](../guides/integrations) for instructions.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
save_directory (`str` or `Path`):
|
81 |
+
Path to directory in which the model weights and configuration will be saved.
|
82 |
+
"""
|
83 |
+
raise NotImplementedError
|
84 |
+
|
85 |
+
@classmethod
|
86 |
+
@validate_hf_hub_args
|
87 |
+
def from_pretrained(
|
88 |
+
cls: Type[T],
|
89 |
+
pretrained_model_name_or_path: Union[str, Path],
|
90 |
+
*,
|
91 |
+
force_download: bool = False,
|
92 |
+
resume_download: bool = False,
|
93 |
+
proxies: Optional[Dict] = None,
|
94 |
+
token: Optional[Union[str, bool]] = None,
|
95 |
+
cache_dir: Optional[Union[str, Path]] = None,
|
96 |
+
local_files_only: bool = False,
|
97 |
+
revision: Optional[str] = None,
|
98 |
+
**model_kwargs,
|
99 |
+
) -> T:
|
100 |
+
"""
|
101 |
+
Download a model from the Huggingface Hub and instantiate it.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
pretrained_model_name_or_path (`str`, `Path`):
|
105 |
+
- Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`.
|
106 |
+
- Or a path to a `directory` containing model weights saved using
|
107 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`.
|
108 |
+
revision (`str`, *optional*):
|
109 |
+
Revision of the model on the Hub. Can be a branch name, a git tag or any commit id.
|
110 |
+
Defaults to the latest commit on `main` branch.
|
111 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
112 |
+
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
|
113 |
+
the existing cache.
|
114 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
115 |
+
Whether to delete incompletely received files. Will attempt to resume the download if such a file exists.
|
116 |
+
proxies (`Dict[str, str]`, *optional*):
|
117 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
118 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
|
119 |
+
token (`str` or `bool`, *optional*):
|
120 |
+
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
121 |
+
cached when running `huggingface-cli login`.
|
122 |
+
cache_dir (`str`, `Path`, *optional*):
|
123 |
+
Path to the folder where cached files are stored.
|
124 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
125 |
+
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
126 |
+
model_kwargs (`Dict`, *optional*):
|
127 |
+
Additional kwargs to pass to the model during initialization.
|
128 |
+
"""
|
129 |
+
model_id = pretrained_model_name_or_path
|
130 |
+
config_file: Optional[str] = None
|
131 |
+
if os.path.isdir(model_id):
|
132 |
+
if CONFIG_NAME in os.listdir(model_id):
|
133 |
+
config_file = os.path.join(model_id, CONFIG_NAME)
|
134 |
+
else:
|
135 |
+
logger.warning(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
136 |
+
elif isinstance(model_id, str):
|
137 |
+
try:
|
138 |
+
config_file = hf_hub_download(
|
139 |
+
repo_id=str(model_id),
|
140 |
+
filename=CONFIG_NAME,
|
141 |
+
revision=revision,
|
142 |
+
cache_dir=cache_dir,
|
143 |
+
force_download=force_download,
|
144 |
+
proxies=proxies,
|
145 |
+
resume_download=resume_download,
|
146 |
+
token=token,
|
147 |
+
local_files_only=local_files_only,
|
148 |
+
)
|
149 |
+
except HfHubHTTPError:
|
150 |
+
logger.info(f"{CONFIG_NAME} not found in HuggingFace Hub.")
|
151 |
+
|
152 |
+
if config_file is not None:
|
153 |
+
with open(config_file, "r", encoding="utf-8") as f:
|
154 |
+
config = json.load(f)
|
155 |
+
model_kwargs.update({"config": config})
|
156 |
+
|
157 |
+
return cls._from_pretrained(
|
158 |
+
model_id=str(model_id),
|
159 |
+
revision=revision,
|
160 |
+
cache_dir=cache_dir,
|
161 |
+
force_download=force_download,
|
162 |
+
proxies=proxies,
|
163 |
+
resume_download=resume_download,
|
164 |
+
local_files_only=local_files_only,
|
165 |
+
token=token,
|
166 |
+
**model_kwargs,
|
167 |
+
)
|
168 |
+
|
169 |
+
@classmethod
|
170 |
+
def _from_pretrained(
|
171 |
+
cls: Type[T],
|
172 |
+
*,
|
173 |
+
model_id: str,
|
174 |
+
revision: Optional[str],
|
175 |
+
cache_dir: Optional[Union[str, Path]],
|
176 |
+
force_download: bool,
|
177 |
+
proxies: Optional[Dict],
|
178 |
+
resume_download: bool,
|
179 |
+
local_files_only: bool,
|
180 |
+
token: Optional[Union[str, bool]],
|
181 |
+
**model_kwargs,
|
182 |
+
) -> T:
|
183 |
+
"""Overwrite this method in subclass to define how to load your model from pretrained.
|
184 |
+
|
185 |
+
Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most
|
186 |
+
args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this
|
187 |
+
method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location`
|
188 |
+
parameter to set on which device the model should be loaded.
|
189 |
+
|
190 |
+
Check out our [integration guide](../guides/integrations) for more instructions.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
model_id (`str`):
|
194 |
+
ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`).
|
195 |
+
revision (`str`, *optional*):
|
196 |
+
Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the
|
197 |
+
latest commit on `main` branch.
|
198 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
199 |
+
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
|
200 |
+
the existing cache.
|
201 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
202 |
+
Whether to delete incompletely received files. Will attempt to resume the download if such a file exists.
|
203 |
+
proxies (`Dict[str, str]`, *optional*):
|
204 |
+
A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
|
205 |
+
'http://hostname': 'foo.bar:4012'}`).
|
206 |
+
token (`str` or `bool`, *optional*):
|
207 |
+
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
208 |
+
cached when running `huggingface-cli login`.
|
209 |
+
cache_dir (`str`, `Path`, *optional*):
|
210 |
+
Path to the folder where cached files are stored.
|
211 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
212 |
+
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
213 |
+
model_kwargs:
|
214 |
+
Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method.
|
215 |
+
"""
|
216 |
+
raise NotImplementedError
|
217 |
+
|
218 |
+
@validate_hf_hub_args
|
219 |
+
def push_to_hub(
|
220 |
+
self,
|
221 |
+
repo_id: str,
|
222 |
+
*,
|
223 |
+
config: Optional[dict] = None,
|
224 |
+
commit_message: str = "Push model using huggingface_hub.",
|
225 |
+
private: bool = False,
|
226 |
+
api_endpoint: Optional[str] = None,
|
227 |
+
token: Optional[str] = None,
|
228 |
+
branch: Optional[str] = None,
|
229 |
+
create_pr: Optional[bool] = None,
|
230 |
+
allow_patterns: Optional[Union[List[str], str]] = None,
|
231 |
+
ignore_patterns: Optional[Union[List[str], str]] = None,
|
232 |
+
delete_patterns: Optional[Union[List[str], str]] = None,
|
233 |
+
) -> str:
|
234 |
+
"""
|
235 |
+
Upload model checkpoint to the Hub.
|
236 |
+
|
237 |
+
Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
|
238 |
+
`delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
|
239 |
+
details.
|
240 |
+
|
241 |
+
|
242 |
+
Args:
|
243 |
+
repo_id (`str`):
|
244 |
+
ID of the repository to push to (example: `"username/my-model"`).
|
245 |
+
config (`dict`, *optional*):
|
246 |
+
Configuration object to be saved alongside the model weights.
|
247 |
+
commit_message (`str`, *optional*):
|
248 |
+
Message to commit while pushing.
|
249 |
+
private (`bool`, *optional*, defaults to `False`):
|
250 |
+
Whether the repository created should be private.
|
251 |
+
api_endpoint (`str`, *optional*):
|
252 |
+
The API endpoint to use when pushing the model to the hub.
|
253 |
+
token (`str`, *optional*):
|
254 |
+
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
255 |
+
cached when running `huggingface-cli login`.
|
256 |
+
branch (`str`, *optional*):
|
257 |
+
The git branch on which to push the model. This defaults to `"main"`.
|
258 |
+
create_pr (`boolean`, *optional*):
|
259 |
+
Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
|
260 |
+
allow_patterns (`List[str]` or `str`, *optional*):
|
261 |
+
If provided, only files matching at least one pattern are pushed.
|
262 |
+
ignore_patterns (`List[str]` or `str`, *optional*):
|
263 |
+
If provided, files matching any of the patterns are not pushed.
|
264 |
+
delete_patterns (`List[str]` or `str`, *optional*):
|
265 |
+
If provided, remote files matching any of the patterns will be deleted from the repo.
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
The url of the commit of your model in the given repository.
|
269 |
+
"""
|
270 |
+
api = HfApi(endpoint=api_endpoint, token=token)
|
271 |
+
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
|
272 |
+
|
273 |
+
# Push the files to the repo in a single commit
|
274 |
+
with SoftTemporaryDirectory() as tmp:
|
275 |
+
saved_path = Path(tmp) / repo_id
|
276 |
+
self.save_pretrained(saved_path, config=config)
|
277 |
+
return api.upload_folder(
|
278 |
+
repo_id=repo_id,
|
279 |
+
repo_type="model",
|
280 |
+
folder_path=saved_path,
|
281 |
+
commit_message=commit_message,
|
282 |
+
revision=branch,
|
283 |
+
create_pr=create_pr,
|
284 |
+
allow_patterns=allow_patterns,
|
285 |
+
ignore_patterns=ignore_patterns,
|
286 |
+
delete_patterns=delete_patterns,
|
287 |
+
)
|
288 |
+
|
289 |
+
|
290 |
+
class PyTorchModelHubMixin(ModelHubMixin):
|
291 |
+
"""
|
292 |
+
Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model
|
293 |
+
is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model,
|
294 |
+
you should first set it back in training mode with `model.train()`.
|
295 |
+
|
296 |
+
Example:
|
297 |
+
|
298 |
+
```python
|
299 |
+
>>> import torch
|
300 |
+
>>> import torch.nn as nn
|
301 |
+
>>> from huggingface_hub import PyTorchModelHubMixin
|
302 |
+
|
303 |
+
|
304 |
+
>>> class MyModel(nn.Module, PyTorchModelHubMixin):
|
305 |
+
... def __init__(self):
|
306 |
+
... super().__init__()
|
307 |
+
... self.param = nn.Parameter(torch.rand(3, 4))
|
308 |
+
... self.linear = nn.Linear(4, 5)
|
309 |
+
|
310 |
+
... def forward(self, x):
|
311 |
+
... return self.linear(x + self.param)
|
312 |
+
>>> model = MyModel()
|
313 |
+
|
314 |
+
# Save model weights to local directory
|
315 |
+
>>> model.save_pretrained("my-awesome-model")
|
316 |
+
|
317 |
+
# Push model weights to the Hub
|
318 |
+
>>> model.push_to_hub("my-awesome-model")
|
319 |
+
|
320 |
+
# Download and initialize weights from the Hub
|
321 |
+
>>> model = MyModel.from_pretrained("username/my-awesome-model")
|
322 |
+
```
|
323 |
+
"""
|
324 |
+
|
325 |
+
def _save_pretrained(self, save_directory: Path) -> None:
|
326 |
+
"""Save weights from a Pytorch model to a local directory."""
|
327 |
+
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
|
328 |
+
torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME)
|
329 |
+
|
330 |
+
@classmethod
|
331 |
+
def _from_pretrained(
|
332 |
+
cls,
|
333 |
+
*,
|
334 |
+
model_id: str,
|
335 |
+
revision: Optional[str],
|
336 |
+
cache_dir: Optional[Union[str, Path]],
|
337 |
+
force_download: bool,
|
338 |
+
proxies: Optional[Dict],
|
339 |
+
resume_download: bool,
|
340 |
+
local_files_only: bool,
|
341 |
+
token: Union[str, bool, None],
|
342 |
+
map_location: str = "cpu",
|
343 |
+
strict: bool = False,
|
344 |
+
**model_kwargs,
|
345 |
+
):
|
346 |
+
"""Load Pytorch pretrained weights and return the loaded model."""
|
347 |
+
if os.path.isdir(model_id):
|
348 |
+
print("Loading weights from local directory")
|
349 |
+
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
|
350 |
+
else:
|
351 |
+
model_file = hf_hub_download(
|
352 |
+
repo_id=model_id,
|
353 |
+
filename=PYTORCH_WEIGHTS_NAME,
|
354 |
+
revision=revision,
|
355 |
+
cache_dir=cache_dir,
|
356 |
+
force_download=force_download,
|
357 |
+
proxies=proxies,
|
358 |
+
resume_download=resume_download,
|
359 |
+
token=token,
|
360 |
+
local_files_only=local_files_only,
|
361 |
+
)
|
362 |
+
model = cls(**model_kwargs)
|
363 |
+
|
364 |
+
state_dict = torch.load(model_file, map_location=torch.device(map_location))
|
365 |
+
model.load_state_dict(state_dict, strict=strict) # type: ignore
|
366 |
+
model.eval() # type: ignore
|
367 |
+
|
368 |
+
return model
|
lib/python3.11/site-packages/huggingface_hub/inference/__init__.py
ADDED
File without changes
|
lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (240 Bytes). View file
|
|
lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_client.cpython-311.pyc
ADDED
Binary file (90.8 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_common.cpython-311.pyc
ADDED
Binary file (14.6 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_text_generation.cpython-311.pyc
ADDED
Binary file (23.9 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_types.cpython-311.pyc
ADDED
Binary file (7.6 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/inference/_client.py
ADDED
@@ -0,0 +1,1990 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023-present, the HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
# Related resources:
|
17 |
+
# https://huggingface.co/tasks
|
18 |
+
# https://huggingface.co/docs/huggingface.js/inference/README
|
19 |
+
# https://github.com/huggingface/huggingface.js/tree/main/packages/inference/src
|
20 |
+
# https://github.com/huggingface/text-generation-inference/tree/main/clients/python
|
21 |
+
# https://github.com/huggingface/text-generation-inference/blob/main/clients/python/text_generation/client.py
|
22 |
+
# https://huggingface.slack.com/archives/C03E4DQ9LAJ/p1680169099087869
|
23 |
+
# https://github.com/huggingface/unity-api#tasks
|
24 |
+
#
|
25 |
+
# Some TODO:
|
26 |
+
# - validate inputs/options/parameters? with Pydantic for instance? or only optionally?
|
27 |
+
# - add all tasks
|
28 |
+
#
|
29 |
+
# NOTE: the philosophy of this client is "let's make it as easy as possible to use it, even if less optimized". Some
|
30 |
+
# examples of how it translates:
|
31 |
+
# - Timeout / Server unavailable is handled by the client in a single "timeout" parameter.
|
32 |
+
# - Files can be provided as bytes, file paths, or URLs and the client will try to "guess" the type.
|
33 |
+
# - Images are parsed as PIL.Image for easier manipulation.
|
34 |
+
# - Provides a "recommended model" for each task => suboptimal but user-wise quicker to get a first script running.
|
35 |
+
# - Only the main parameters are publicly exposed. Power users can always read the docs for more options.
|
36 |
+
import logging
|
37 |
+
import time
|
38 |
+
import warnings
|
39 |
+
from dataclasses import asdict
|
40 |
+
from typing import (
|
41 |
+
TYPE_CHECKING,
|
42 |
+
Any,
|
43 |
+
Dict,
|
44 |
+
Iterable,
|
45 |
+
List,
|
46 |
+
Literal,
|
47 |
+
Optional,
|
48 |
+
Union,
|
49 |
+
overload,
|
50 |
+
)
|
51 |
+
|
52 |
+
from requests import HTTPError
|
53 |
+
from requests.structures import CaseInsensitiveDict
|
54 |
+
|
55 |
+
from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
|
56 |
+
from huggingface_hub.inference._common import (
|
57 |
+
TASKS_EXPECTING_IMAGES,
|
58 |
+
ContentT,
|
59 |
+
InferenceTimeoutError,
|
60 |
+
ModelStatus,
|
61 |
+
_b64_encode,
|
62 |
+
_b64_to_image,
|
63 |
+
_bytes_to_dict,
|
64 |
+
_bytes_to_image,
|
65 |
+
_bytes_to_list,
|
66 |
+
_fetch_recommended_models,
|
67 |
+
_import_numpy,
|
68 |
+
_is_tgi_server,
|
69 |
+
_open_as_binary,
|
70 |
+
_set_as_non_tgi,
|
71 |
+
_stream_text_generation_response,
|
72 |
+
)
|
73 |
+
from huggingface_hub.inference._text_generation import (
|
74 |
+
TextGenerationParameters,
|
75 |
+
TextGenerationRequest,
|
76 |
+
TextGenerationResponse,
|
77 |
+
TextGenerationStreamResponse,
|
78 |
+
raise_text_generation_error,
|
79 |
+
)
|
80 |
+
from huggingface_hub.inference._types import (
|
81 |
+
ClassificationOutput,
|
82 |
+
ConversationalOutput,
|
83 |
+
FillMaskOutput,
|
84 |
+
ImageSegmentationOutput,
|
85 |
+
ObjectDetectionOutput,
|
86 |
+
QuestionAnsweringOutput,
|
87 |
+
TableQuestionAnsweringOutput,
|
88 |
+
TokenClassificationOutput,
|
89 |
+
)
|
90 |
+
from huggingface_hub.utils import (
|
91 |
+
BadRequestError,
|
92 |
+
build_hf_headers,
|
93 |
+
get_session,
|
94 |
+
hf_raise_for_status,
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
if TYPE_CHECKING:
|
99 |
+
import numpy as np
|
100 |
+
from PIL import Image
|
101 |
+
|
102 |
+
logger = logging.getLogger(__name__)
|
103 |
+
|
104 |
+
|
105 |
+
class InferenceClient:
|
106 |
+
"""
|
107 |
+
Initialize a new Inference Client.
|
108 |
+
|
109 |
+
[`InferenceClient`] aims to provide a unified experience to perform inference. The client can be used
|
110 |
+
seamlessly with either the (free) Inference API or self-hosted Inference Endpoints.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
model (`str`, `optional`):
|
114 |
+
The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `bigcode/starcoder`
|
115 |
+
or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
|
116 |
+
automatically selected for the task.
|
117 |
+
token (`str`, *optional*):
|
118 |
+
Hugging Face token. Will default to the locally saved token. Pass `token=False` if you don't want to send
|
119 |
+
your token to the server.
|
120 |
+
timeout (`float`, `optional`):
|
121 |
+
The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
|
122 |
+
API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
|
123 |
+
headers (`Dict[str, str]`, `optional`):
|
124 |
+
Additional headers to send to the server. By default only the authorization and user-agent headers are sent.
|
125 |
+
Values in this dictionary will override the default values.
|
126 |
+
cookies (`Dict[str, str]`, `optional`):
|
127 |
+
Additional cookies to send to the server.
|
128 |
+
"""
|
129 |
+
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
model: Optional[str] = None,
|
133 |
+
token: Union[str, bool, None] = None,
|
134 |
+
timeout: Optional[float] = None,
|
135 |
+
headers: Optional[Dict[str, str]] = None,
|
136 |
+
cookies: Optional[Dict[str, str]] = None,
|
137 |
+
) -> None:
|
138 |
+
self.model: Optional[str] = model
|
139 |
+
self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) # contains 'authorization' + 'user-agent'
|
140 |
+
if headers is not None:
|
141 |
+
self.headers.update(headers)
|
142 |
+
self.cookies = cookies
|
143 |
+
self.timeout = timeout
|
144 |
+
|
145 |
+
def __repr__(self):
|
146 |
+
return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"
|
147 |
+
|
148 |
+
@overload
|
149 |
+
def post( # type: ignore[misc]
|
150 |
+
self,
|
151 |
+
*,
|
152 |
+
json: Optional[Union[str, Dict, List]] = None,
|
153 |
+
data: Optional[ContentT] = None,
|
154 |
+
model: Optional[str] = None,
|
155 |
+
task: Optional[str] = None,
|
156 |
+
stream: Literal[False] = ...,
|
157 |
+
) -> bytes:
|
158 |
+
pass
|
159 |
+
|
160 |
+
@overload
|
161 |
+
def post(
|
162 |
+
self,
|
163 |
+
*,
|
164 |
+
json: Optional[Union[str, Dict, List]] = None,
|
165 |
+
data: Optional[ContentT] = None,
|
166 |
+
model: Optional[str] = None,
|
167 |
+
task: Optional[str] = None,
|
168 |
+
stream: Literal[True] = ...,
|
169 |
+
) -> Iterable[bytes]:
|
170 |
+
pass
|
171 |
+
|
172 |
+
def post(
|
173 |
+
self,
|
174 |
+
*,
|
175 |
+
json: Optional[Union[str, Dict, List]] = None,
|
176 |
+
data: Optional[ContentT] = None,
|
177 |
+
model: Optional[str] = None,
|
178 |
+
task: Optional[str] = None,
|
179 |
+
stream: bool = False,
|
180 |
+
) -> Union[bytes, Iterable[bytes]]:
|
181 |
+
"""
|
182 |
+
Make a POST request to the inference server.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
json (`Union[str, Dict, List]`, *optional*):
|
186 |
+
The JSON data to send in the request body, specific to each task. Defaults to None.
|
187 |
+
data (`Union[str, Path, bytes, BinaryIO]`, *optional*):
|
188 |
+
The content to send in the request body, specific to each task.
|
189 |
+
It can be raw bytes, a pointer to an opened file, a local file path,
|
190 |
+
or a URL to an online resource (image, audio file,...). If both `json` and `data` are passed,
|
191 |
+
`data` will take precedence. At least `json` or `data` must be provided. Defaults to None.
|
192 |
+
model (`str`, *optional*):
|
193 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
194 |
+
Inference Endpoint. Will override the model defined at the instance level. Defaults to None.
|
195 |
+
task (`str`, *optional*):
|
196 |
+
The task to perform on the inference. All available tasks can be found
|
197 |
+
[here](https://huggingface.co/tasks). Used only to default to a recommended model if `model` is not
|
198 |
+
provided. At least `model` or `task` must be provided. Defaults to None.
|
199 |
+
stream (`bool`, *optional*):
|
200 |
+
Whether to iterate over streaming APIs.
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
bytes: The raw bytes returned by the server.
|
204 |
+
|
205 |
+
Raises:
|
206 |
+
[`InferenceTimeoutError`]:
|
207 |
+
If the model is unavailable or the request times out.
|
208 |
+
`HTTPError`:
|
209 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
210 |
+
"""
|
211 |
+
url = self._resolve_url(model, task)
|
212 |
+
|
213 |
+
if data is not None and json is not None:
|
214 |
+
warnings.warn("Ignoring `json` as `data` is passed as binary.")
|
215 |
+
|
216 |
+
# Set Accept header if relevant
|
217 |
+
headers = self.headers.copy()
|
218 |
+
if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers:
|
219 |
+
headers["Accept"] = "image/png"
|
220 |
+
|
221 |
+
t0 = time.time()
|
222 |
+
timeout = self.timeout
|
223 |
+
while True:
|
224 |
+
with _open_as_binary(data) as data_as_binary:
|
225 |
+
try:
|
226 |
+
response = get_session().post(
|
227 |
+
url,
|
228 |
+
json=json,
|
229 |
+
data=data_as_binary,
|
230 |
+
headers=headers,
|
231 |
+
cookies=self.cookies,
|
232 |
+
timeout=self.timeout,
|
233 |
+
stream=stream,
|
234 |
+
)
|
235 |
+
except TimeoutError as error:
|
236 |
+
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
237 |
+
raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore
|
238 |
+
|
239 |
+
try:
|
240 |
+
hf_raise_for_status(response)
|
241 |
+
return response.iter_lines() if stream else response.content
|
242 |
+
except HTTPError as error:
|
243 |
+
if error.response.status_code == 422 and task is not None:
|
244 |
+
error.args = (
|
245 |
+
f"{error.args[0]}\nMake sure '{task}' task is supported by the model.",
|
246 |
+
) + error.args[1:]
|
247 |
+
if error.response.status_code == 503:
|
248 |
+
# If Model is unavailable, either raise a TimeoutError...
|
249 |
+
if timeout is not None and time.time() - t0 > timeout:
|
250 |
+
raise InferenceTimeoutError(
|
251 |
+
f"Model not loaded on the server: {url}. Please retry with a higher timeout (current:"
|
252 |
+
f" {self.timeout}).",
|
253 |
+
request=error.request,
|
254 |
+
response=error.response,
|
255 |
+
) from error
|
256 |
+
# ...or wait 1s and retry
|
257 |
+
logger.info(f"Waiting for model to be loaded on the server: {error}")
|
258 |
+
time.sleep(1)
|
259 |
+
if timeout is not None:
|
260 |
+
timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
|
261 |
+
continue
|
262 |
+
raise
|
263 |
+
|
264 |
+
def audio_classification(
|
265 |
+
self,
|
266 |
+
audio: ContentT,
|
267 |
+
*,
|
268 |
+
model: Optional[str] = None,
|
269 |
+
) -> List[ClassificationOutput]:
|
270 |
+
"""
|
271 |
+
Perform audio classification on the provided audio content.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
audio (Union[str, Path, bytes, BinaryIO]):
|
275 |
+
The audio content to classify. It can be raw audio bytes, a local audio file, or a URL pointing to an
|
276 |
+
audio file.
|
277 |
+
model (`str`, *optional*):
|
278 |
+
The model to use for audio classification. Can be a model ID hosted on the Hugging Face Hub
|
279 |
+
or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for
|
280 |
+
audio classification will be used.
|
281 |
+
|
282 |
+
Returns:
|
283 |
+
`List[Dict]`: The classification output containing the predicted label and its confidence.
|
284 |
+
|
285 |
+
Raises:
|
286 |
+
[`InferenceTimeoutError`]:
|
287 |
+
If the model is unavailable or the request times out.
|
288 |
+
`HTTPError`:
|
289 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
290 |
+
|
291 |
+
Example:
|
292 |
+
```py
|
293 |
+
>>> from huggingface_hub import InferenceClient
|
294 |
+
>>> client = InferenceClient()
|
295 |
+
>>> client.audio_classification("audio.flac")
|
296 |
+
[{'score': 0.4976358711719513, 'label': 'hap'}, {'score': 0.3677836060523987, 'label': 'neu'},...]
|
297 |
+
```
|
298 |
+
"""
|
299 |
+
response = self.post(data=audio, model=model, task="audio-classification")
|
300 |
+
return _bytes_to_list(response)
|
301 |
+
|
302 |
+
def automatic_speech_recognition(
|
303 |
+
self,
|
304 |
+
audio: ContentT,
|
305 |
+
*,
|
306 |
+
model: Optional[str] = None,
|
307 |
+
) -> str:
|
308 |
+
"""
|
309 |
+
Perform automatic speech recognition (ASR or audio-to-text) on the given audio content.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
audio (Union[str, Path, bytes, BinaryIO]):
|
313 |
+
The content to transcribe. It can be raw audio bytes, local audio file, or a URL to an audio file.
|
314 |
+
model (`str`, *optional*):
|
315 |
+
The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
316 |
+
Inference Endpoint. If not provided, the default recommended model for ASR will be used.
|
317 |
+
|
318 |
+
Returns:
|
319 |
+
str: The transcribed text.
|
320 |
+
|
321 |
+
Raises:
|
322 |
+
[`InferenceTimeoutError`]:
|
323 |
+
If the model is unavailable or the request times out.
|
324 |
+
`HTTPError`:
|
325 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
326 |
+
|
327 |
+
Example:
|
328 |
+
```py
|
329 |
+
>>> from huggingface_hub import InferenceClient
|
330 |
+
>>> client = InferenceClient()
|
331 |
+
>>> client.automatic_speech_recognition("hello_world.flac")
|
332 |
+
"hello world"
|
333 |
+
```
|
334 |
+
"""
|
335 |
+
response = self.post(data=audio, model=model, task="automatic-speech-recognition")
|
336 |
+
return _bytes_to_dict(response)["text"]
|
337 |
+
|
338 |
+
def conversational(
|
339 |
+
self,
|
340 |
+
text: str,
|
341 |
+
generated_responses: Optional[List[str]] = None,
|
342 |
+
past_user_inputs: Optional[List[str]] = None,
|
343 |
+
*,
|
344 |
+
parameters: Optional[Dict[str, Any]] = None,
|
345 |
+
model: Optional[str] = None,
|
346 |
+
) -> ConversationalOutput:
|
347 |
+
"""
|
348 |
+
Generate conversational responses based on the given input text (i.e. chat with the API).
|
349 |
+
|
350 |
+
Args:
|
351 |
+
text (`str`):
|
352 |
+
The last input from the user in the conversation.
|
353 |
+
generated_responses (`List[str]`, *optional*):
|
354 |
+
A list of strings corresponding to the earlier replies from the model. Defaults to None.
|
355 |
+
past_user_inputs (`List[str]`, *optional*):
|
356 |
+
A list of strings corresponding to the earlier replies from the user. Should be the same length as
|
357 |
+
`generated_responses`. Defaults to None.
|
358 |
+
parameters (`Dict[str, Any]`, *optional*):
|
359 |
+
Additional parameters for the conversational task. Defaults to None. For more details about the available
|
360 |
+
parameters, please refer to [this page](https://huggingface.co/docs/api-inference/detailed_parameters#conversational-task)
|
361 |
+
model (`str`, *optional*):
|
362 |
+
The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
363 |
+
a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
|
364 |
+
Defaults to None.
|
365 |
+
|
366 |
+
Returns:
|
367 |
+
`Dict`: The generated conversational output.
|
368 |
+
|
369 |
+
Raises:
|
370 |
+
[`InferenceTimeoutError`]:
|
371 |
+
If the model is unavailable or the request times out.
|
372 |
+
`HTTPError`:
|
373 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
374 |
+
|
375 |
+
Example:
|
376 |
+
```py
|
377 |
+
>>> from huggingface_hub import InferenceClient
|
378 |
+
>>> client = InferenceClient()
|
379 |
+
>>> output = client.conversational("Hi, who are you?")
|
380 |
+
>>> output
|
381 |
+
{'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.']}
|
382 |
+
>>> client.conversational(
|
383 |
+
... "Wow, that's scary!",
|
384 |
+
... generated_responses=output["conversation"]["generated_responses"],
|
385 |
+
... past_user_inputs=output["conversation"]["past_user_inputs"],
|
386 |
+
... )
|
387 |
+
```
|
388 |
+
"""
|
389 |
+
payload: Dict[str, Any] = {"inputs": {"text": text}}
|
390 |
+
if generated_responses is not None:
|
391 |
+
payload["inputs"]["generated_responses"] = generated_responses
|
392 |
+
if past_user_inputs is not None:
|
393 |
+
payload["inputs"]["past_user_inputs"] = past_user_inputs
|
394 |
+
if parameters is not None:
|
395 |
+
payload["parameters"] = parameters
|
396 |
+
response = self.post(json=payload, model=model, task="conversational")
|
397 |
+
return _bytes_to_dict(response) # type: ignore
|
398 |
+
|
399 |
+
def visual_question_answering(
|
400 |
+
self,
|
401 |
+
image: ContentT,
|
402 |
+
question: str,
|
403 |
+
*,
|
404 |
+
model: Optional[str] = None,
|
405 |
+
) -> List[str]:
|
406 |
+
"""
|
407 |
+
Answering open-ended questions based on an image.
|
408 |
+
|
409 |
+
Args:
|
410 |
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
411 |
+
The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
|
412 |
+
question (`str`):
|
413 |
+
Question to be answered.
|
414 |
+
model (`str`, *optional*):
|
415 |
+
The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
416 |
+
a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used.
|
417 |
+
Defaults to None.
|
418 |
+
|
419 |
+
Returns:
|
420 |
+
`List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
|
421 |
+
|
422 |
+
Raises:
|
423 |
+
`InferenceTimeoutError`:
|
424 |
+
If the model is unavailable or the request times out.
|
425 |
+
`HTTPError`:
|
426 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
427 |
+
|
428 |
+
Example:
|
429 |
+
```py
|
430 |
+
>>> from huggingface_hub import InferenceClient
|
431 |
+
>>> client = InferenceClient()
|
432 |
+
>>> client.visual_question_answering(
|
433 |
+
... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg",
|
434 |
+
... question="What is the animal doing?"
|
435 |
+
... )
|
436 |
+
[{'score': 0.778609573841095, 'answer': 'laying down'},{'score': 0.6957435607910156, 'answer': 'sitting'}, ...]
|
437 |
+
```
|
438 |
+
"""
|
439 |
+
payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
|
440 |
+
response = self.post(json=payload, model=model, task="visual-question-answering")
|
441 |
+
return _bytes_to_list(response)
|
442 |
+
|
443 |
+
def document_question_answering(
|
444 |
+
self,
|
445 |
+
image: ContentT,
|
446 |
+
question: str,
|
447 |
+
*,
|
448 |
+
model: Optional[str] = None,
|
449 |
+
) -> List[QuestionAnsweringOutput]:
|
450 |
+
"""
|
451 |
+
Answer questions on document images.
|
452 |
+
|
453 |
+
Args:
|
454 |
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
455 |
+
The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
|
456 |
+
question (`str`):
|
457 |
+
Question to be answered.
|
458 |
+
model (`str`, *optional*):
|
459 |
+
The model to use for the document question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
460 |
+
a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used.
|
461 |
+
Defaults to None.
|
462 |
+
|
463 |
+
Returns:
|
464 |
+
`List[Dict]`: a list of dictionaries containing the predicted label, associated probability, word ids, and page number.
|
465 |
+
|
466 |
+
Raises:
|
467 |
+
[`InferenceTimeoutError`]:
|
468 |
+
If the model is unavailable or the request times out.
|
469 |
+
`HTTPError`:
|
470 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
471 |
+
|
472 |
+
Example:
|
473 |
+
```py
|
474 |
+
>>> from huggingface_hub import InferenceClient
|
475 |
+
>>> client = InferenceClient()
|
476 |
+
>>> client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?")
|
477 |
+
[{'score': 0.42515629529953003, 'answer': 'us-001', 'start': 16, 'end': 16}]
|
478 |
+
```
|
479 |
+
"""
|
480 |
+
payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
|
481 |
+
response = self.post(json=payload, model=model, task="document-question-answering")
|
482 |
+
return _bytes_to_list(response)
|
483 |
+
|
484 |
+
def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray":
|
485 |
+
"""
|
486 |
+
Generate embeddings for a given text.
|
487 |
+
|
488 |
+
Args:
|
489 |
+
text (`str`):
|
490 |
+
The text to embed.
|
491 |
+
model (`str`, *optional*):
|
492 |
+
The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
493 |
+
a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
|
494 |
+
Defaults to None.
|
495 |
+
|
496 |
+
Returns:
|
497 |
+
`np.ndarray`: The embedding representing the input text as a float32 numpy array.
|
498 |
+
|
499 |
+
Raises:
|
500 |
+
[`InferenceTimeoutError`]:
|
501 |
+
If the model is unavailable or the request times out.
|
502 |
+
`HTTPError`:
|
503 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
504 |
+
|
505 |
+
Example:
|
506 |
+
```py
|
507 |
+
>>> from huggingface_hub import InferenceClient
|
508 |
+
>>> client = InferenceClient()
|
509 |
+
>>> client.feature_extraction("Hi, who are you?")
|
510 |
+
array([[ 2.424802 , 2.93384 , 1.1750331 , ..., 1.240499, -0.13776633, -0.7889173 ],
|
511 |
+
[-0.42943227, -0.6364878 , -1.693462 , ..., 0.41978157, -2.4336355 , 0.6162071 ],
|
512 |
+
...,
|
513 |
+
[ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
|
514 |
+
```
|
515 |
+
"""
|
516 |
+
response = self.post(json={"inputs": text}, model=model, task="feature-extraction")
|
517 |
+
np = _import_numpy()
|
518 |
+
return np.array(_bytes_to_dict(response), dtype="float32")
|
519 |
+
|
520 |
+
def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[FillMaskOutput]:
|
521 |
+
"""
|
522 |
+
Fill in a hole with a missing word (token to be precise).
|
523 |
+
|
524 |
+
Args:
|
525 |
+
text (`str`):
|
526 |
+
a string to be filled from, must contain the [MASK] token (check model card for exact name of the mask).
|
527 |
+
model (`str`, *optional*):
|
528 |
+
The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
529 |
+
a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used.
|
530 |
+
Defaults to None.
|
531 |
+
|
532 |
+
Returns:
|
533 |
+
`List[Dict]`: a list of fill mask output dictionaries containing the predicted label, associated
|
534 |
+
probability, token reference, and completed text.
|
535 |
+
|
536 |
+
Raises:
|
537 |
+
[`InferenceTimeoutError`]:
|
538 |
+
If the model is unavailable or the request times out.
|
539 |
+
`HTTPError`:
|
540 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
541 |
+
|
542 |
+
Example:
|
543 |
+
```py
|
544 |
+
>>> from huggingface_hub import InferenceClient
|
545 |
+
>>> client = InferenceClient()
|
546 |
+
>>> client.fill_mask("The goal of life is <mask>.")
|
547 |
+
[{'score': 0.06897063553333282,
|
548 |
+
'token': 11098,
|
549 |
+
'token_str': ' happiness',
|
550 |
+
'sequence': 'The goal of life is happiness.'},
|
551 |
+
{'score': 0.06554922461509705,
|
552 |
+
'token': 45075,
|
553 |
+
'token_str': ' immortality',
|
554 |
+
'sequence': 'The goal of life is immortality.'}]
|
555 |
+
```
|
556 |
+
"""
|
557 |
+
response = self.post(json={"inputs": text}, model=model, task="fill-mask")
|
558 |
+
return _bytes_to_list(response)
|
559 |
+
|
560 |
+
def image_classification(
|
561 |
+
self,
|
562 |
+
image: ContentT,
|
563 |
+
*,
|
564 |
+
model: Optional[str] = None,
|
565 |
+
) -> List[ClassificationOutput]:
|
566 |
+
"""
|
567 |
+
Perform image classification on the given image using the specified model.
|
568 |
+
|
569 |
+
Args:
|
570 |
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
571 |
+
The image to classify. It can be raw bytes, an image file, or a URL to an online image.
|
572 |
+
model (`str`, *optional*):
|
573 |
+
The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
574 |
+
deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used.
|
575 |
+
|
576 |
+
Returns:
|
577 |
+
`List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
|
578 |
+
|
579 |
+
Raises:
|
580 |
+
[`InferenceTimeoutError`]:
|
581 |
+
If the model is unavailable or the request times out.
|
582 |
+
`HTTPError`:
|
583 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
584 |
+
|
585 |
+
Example:
|
586 |
+
```py
|
587 |
+
>>> from huggingface_hub import InferenceClient
|
588 |
+
>>> client = InferenceClient()
|
589 |
+
>>> client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg")
|
590 |
+
[{'score': 0.9779096841812134, 'label': 'Blenheim spaniel'}, ...]
|
591 |
+
```
|
592 |
+
"""
|
593 |
+
response = self.post(data=image, model=model, task="image-classification")
|
594 |
+
return _bytes_to_list(response)
|
595 |
+
|
596 |
+
def image_segmentation(
|
597 |
+
self,
|
598 |
+
image: ContentT,
|
599 |
+
*,
|
600 |
+
model: Optional[str] = None,
|
601 |
+
) -> List[ImageSegmentationOutput]:
|
602 |
+
"""
|
603 |
+
Perform image segmentation on the given image using the specified model.
|
604 |
+
|
605 |
+
<Tip warning={true}>
|
606 |
+
|
607 |
+
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
608 |
+
|
609 |
+
</Tip>
|
610 |
+
|
611 |
+
Args:
|
612 |
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
613 |
+
The image to segment. It can be raw bytes, an image file, or a URL to an online image.
|
614 |
+
model (`str`, *optional*):
|
615 |
+
The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
616 |
+
deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used.
|
617 |
+
|
618 |
+
Returns:
|
619 |
+
`List[Dict]`: A list of dictionaries containing the segmented masks and associated attributes.
|
620 |
+
|
621 |
+
Raises:
|
622 |
+
[`InferenceTimeoutError`]:
|
623 |
+
If the model is unavailable or the request times out.
|
624 |
+
`HTTPError`:
|
625 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
626 |
+
|
627 |
+
Example:
|
628 |
+
```py
|
629 |
+
>>> from huggingface_hub import InferenceClient
|
630 |
+
>>> client = InferenceClient()
|
631 |
+
>>> client.image_segmentation("cat.jpg"):
|
632 |
+
[{'score': 0.989008, 'label': 'LABEL_184', 'mask': <PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>}, ...]
|
633 |
+
```
|
634 |
+
"""
|
635 |
+
|
636 |
+
# Segment
|
637 |
+
response = self.post(data=image, model=model, task="image-segmentation")
|
638 |
+
output = _bytes_to_dict(response)
|
639 |
+
|
640 |
+
# Parse masks as PIL Image
|
641 |
+
if not isinstance(output, list):
|
642 |
+
raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
|
643 |
+
for item in output:
|
644 |
+
item["mask"] = _b64_to_image(item["mask"])
|
645 |
+
return output
|
646 |
+
|
647 |
+
def image_to_image(
|
648 |
+
self,
|
649 |
+
image: ContentT,
|
650 |
+
prompt: Optional[str] = None,
|
651 |
+
*,
|
652 |
+
negative_prompt: Optional[str] = None,
|
653 |
+
height: Optional[int] = None,
|
654 |
+
width: Optional[int] = None,
|
655 |
+
num_inference_steps: Optional[int] = None,
|
656 |
+
guidance_scale: Optional[float] = None,
|
657 |
+
model: Optional[str] = None,
|
658 |
+
**kwargs,
|
659 |
+
) -> "Image":
|
660 |
+
"""
|
661 |
+
Perform image-to-image translation using a specified model.
|
662 |
+
|
663 |
+
<Tip warning={true}>
|
664 |
+
|
665 |
+
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
666 |
+
|
667 |
+
</Tip>
|
668 |
+
|
669 |
+
Args:
|
670 |
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
671 |
+
The input image for translation. It can be raw bytes, an image file, or a URL to an online image.
|
672 |
+
prompt (`str`, *optional*):
|
673 |
+
The text prompt to guide the image generation.
|
674 |
+
negative_prompt (`str`, *optional*):
|
675 |
+
A negative prompt to guide the translation process.
|
676 |
+
height (`int`, *optional*):
|
677 |
+
The height in pixels of the generated image.
|
678 |
+
width (`int`, *optional*):
|
679 |
+
The width in pixels of the generated image.
|
680 |
+
num_inference_steps (`int`, *optional*):
|
681 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
682 |
+
expense of slower inference.
|
683 |
+
guidance_scale (`float`, *optional*):
|
684 |
+
Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
685 |
+
usually at the expense of lower image quality.
|
686 |
+
model (`str`, *optional*):
|
687 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
688 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
689 |
+
|
690 |
+
Returns:
|
691 |
+
`Image`: The translated image.
|
692 |
+
|
693 |
+
Raises:
|
694 |
+
[`InferenceTimeoutError`]:
|
695 |
+
If the model is unavailable or the request times out.
|
696 |
+
`HTTPError`:
|
697 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
698 |
+
|
699 |
+
Example:
|
700 |
+
```py
|
701 |
+
>>> from huggingface_hub import InferenceClient
|
702 |
+
>>> client = InferenceClient()
|
703 |
+
>>> image = client.image_to_image("cat.jpg", prompt="turn the cat into a tiger")
|
704 |
+
>>> image.save("tiger.jpg")
|
705 |
+
```
|
706 |
+
"""
|
707 |
+
parameters = {
|
708 |
+
"prompt": prompt,
|
709 |
+
"negative_prompt": negative_prompt,
|
710 |
+
"height": height,
|
711 |
+
"width": width,
|
712 |
+
"num_inference_steps": num_inference_steps,
|
713 |
+
"guidance_scale": guidance_scale,
|
714 |
+
**kwargs,
|
715 |
+
}
|
716 |
+
if all(parameter is None for parameter in parameters.values()):
|
717 |
+
# Either only an image to send => send as raw bytes
|
718 |
+
data = image
|
719 |
+
payload: Optional[Dict[str, Any]] = None
|
720 |
+
else:
|
721 |
+
# Or an image + some parameters => use base64 encoding
|
722 |
+
data = None
|
723 |
+
payload = {"inputs": _b64_encode(image)}
|
724 |
+
for key, value in parameters.items():
|
725 |
+
if value is not None:
|
726 |
+
payload.setdefault("parameters", {})[key] = value
|
727 |
+
|
728 |
+
response = self.post(json=payload, data=data, model=model, task="image-to-image")
|
729 |
+
return _bytes_to_image(response)
|
730 |
+
|
731 |
+
def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> str:
|
732 |
+
"""
|
733 |
+
Takes an input image and return text.
|
734 |
+
|
735 |
+
Models can have very different outputs depending on your use case (image captioning, optical character recognition
|
736 |
+
(OCR), Pix2Struct, etc). Please have a look to the model card to learn more about a model's specificities.
|
737 |
+
|
738 |
+
Args:
|
739 |
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
740 |
+
The input image to caption. It can be raw bytes, an image file, or a URL to an online image..
|
741 |
+
model (`str`, *optional*):
|
742 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
743 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
744 |
+
|
745 |
+
Returns:
|
746 |
+
`str`: The generated text.
|
747 |
+
|
748 |
+
Raises:
|
749 |
+
[`InferenceTimeoutError`]:
|
750 |
+
If the model is unavailable or the request times out.
|
751 |
+
`HTTPError`:
|
752 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
753 |
+
|
754 |
+
Example:
|
755 |
+
```py
|
756 |
+
>>> from huggingface_hub import InferenceClient
|
757 |
+
>>> client = InferenceClient()
|
758 |
+
>>> client.image_to_text("cat.jpg")
|
759 |
+
'a cat standing in a grassy field '
|
760 |
+
>>> client.image_to_text("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg")
|
761 |
+
'a dog laying on the grass next to a flower pot '
|
762 |
+
```
|
763 |
+
"""
|
764 |
+
response = self.post(data=image, model=model, task="image-to-text")
|
765 |
+
return _bytes_to_dict(response)[0]["generated_text"]
|
766 |
+
|
767 |
+
def list_deployed_models(
|
768 |
+
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
|
769 |
+
) -> Dict[str, List[str]]:
|
770 |
+
"""
|
771 |
+
List models currently deployed on the Inference API service.
|
772 |
+
|
773 |
+
This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
|
774 |
+
are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
|
775 |
+
specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
|
776 |
+
in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
|
777 |
+
frameworks are checked, the more time it will take.
|
778 |
+
|
779 |
+
<Tip>
|
780 |
+
|
781 |
+
This endpoint is mostly useful for discoverability. If you already know which model you want to use and want to
|
782 |
+
check its availability, you can directly use [`~InferenceClient.get_model_status`].
|
783 |
+
|
784 |
+
</Tip>
|
785 |
+
|
786 |
+
Args:
|
787 |
+
frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
|
788 |
+
The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
|
789 |
+
"all", all available frameworks will be tested. It is also possible to provide a single framework or a
|
790 |
+
custom set of frameworks to check.
|
791 |
+
|
792 |
+
Returns:
|
793 |
+
`Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
|
794 |
+
|
795 |
+
Example:
|
796 |
+
```python
|
797 |
+
>>> from huggingface_hub import InferenceClient
|
798 |
+
>>> client = InferenceClient()
|
799 |
+
|
800 |
+
# Discover zero-shot-classification models currently deployed
|
801 |
+
>>> models = client.list_deployed_models()
|
802 |
+
>>> models["zero-shot-classification"]
|
803 |
+
['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
|
804 |
+
|
805 |
+
# List from only 1 framework
|
806 |
+
>>> client.list_deployed_models("text-generation-inference")
|
807 |
+
{'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
|
808 |
+
```
|
809 |
+
"""
|
810 |
+
# Resolve which frameworks to check
|
811 |
+
if frameworks is None:
|
812 |
+
frameworks = MAIN_INFERENCE_API_FRAMEWORKS
|
813 |
+
elif frameworks == "all":
|
814 |
+
frameworks = ALL_INFERENCE_API_FRAMEWORKS
|
815 |
+
elif isinstance(frameworks, str):
|
816 |
+
frameworks = [frameworks]
|
817 |
+
frameworks = list(set(frameworks))
|
818 |
+
|
819 |
+
# Fetch them iteratively
|
820 |
+
models_by_task: Dict[str, List[str]] = {}
|
821 |
+
|
822 |
+
def _unpack_response(framework: str, items: List[Dict]) -> None:
|
823 |
+
for model in items:
|
824 |
+
if framework == "sentence-transformers":
|
825 |
+
# Model running with the `sentence-transformers` framework can work with both tasks even if not
|
826 |
+
# branded as such in the API response
|
827 |
+
models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
|
828 |
+
models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
|
829 |
+
else:
|
830 |
+
models_by_task.setdefault(model["task"], []).append(model["model_id"])
|
831 |
+
|
832 |
+
for framework in frameworks:
|
833 |
+
response = get_session().get(f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=self.headers)
|
834 |
+
hf_raise_for_status(response)
|
835 |
+
_unpack_response(framework, response.json())
|
836 |
+
|
837 |
+
# Sort alphabetically for discoverability and return
|
838 |
+
for task, models in models_by_task.items():
|
839 |
+
models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
|
840 |
+
return models_by_task
|
841 |
+
|
842 |
+
def object_detection(
|
843 |
+
self,
|
844 |
+
image: ContentT,
|
845 |
+
*,
|
846 |
+
model: Optional[str] = None,
|
847 |
+
) -> List[ObjectDetectionOutput]:
|
848 |
+
"""
|
849 |
+
Perform object detection on the given image using the specified model.
|
850 |
+
|
851 |
+
<Tip warning={true}>
|
852 |
+
|
853 |
+
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
854 |
+
|
855 |
+
</Tip>
|
856 |
+
|
857 |
+
Args:
|
858 |
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
859 |
+
The image to detect objects on. It can be raw bytes, an image file, or a URL to an online image.
|
860 |
+
model (`str`, *optional*):
|
861 |
+
The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
862 |
+
deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used.
|
863 |
+
|
864 |
+
Returns:
|
865 |
+
`List[ObjectDetectionOutput]`: A list of dictionaries containing the bounding boxes and associated attributes.
|
866 |
+
|
867 |
+
Raises:
|
868 |
+
[`InferenceTimeoutError`]:
|
869 |
+
If the model is unavailable or the request times out.
|
870 |
+
`HTTPError`:
|
871 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
872 |
+
`ValueError`:
|
873 |
+
If the request output is not a List.
|
874 |
+
|
875 |
+
Example:
|
876 |
+
```py
|
877 |
+
>>> from huggingface_hub import InferenceClient
|
878 |
+
>>> client = InferenceClient()
|
879 |
+
>>> client.object_detection("people.jpg"):
|
880 |
+
[{"score":0.9486683011054993,"label":"person","box":{"xmin":59,"ymin":39,"xmax":420,"ymax":510}}, ... ]
|
881 |
+
```
|
882 |
+
"""
|
883 |
+
# detect objects
|
884 |
+
response = self.post(data=image, model=model, task="object-detection")
|
885 |
+
output = _bytes_to_dict(response)
|
886 |
+
if not isinstance(output, list):
|
887 |
+
raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
|
888 |
+
return output
|
889 |
+
|
890 |
+
def question_answering(
|
891 |
+
self, question: str, context: str, *, model: Optional[str] = None
|
892 |
+
) -> QuestionAnsweringOutput:
|
893 |
+
"""
|
894 |
+
Retrieve the answer to a question from a given text.
|
895 |
+
|
896 |
+
Args:
|
897 |
+
question (`str`):
|
898 |
+
Question to be answered.
|
899 |
+
context (`str`):
|
900 |
+
The context of the question.
|
901 |
+
model (`str`):
|
902 |
+
The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
903 |
+
a deployed Inference Endpoint.
|
904 |
+
|
905 |
+
Returns:
|
906 |
+
`Dict`: a dictionary of question answering output containing the score, start index, end index, and answer.
|
907 |
+
|
908 |
+
Raises:
|
909 |
+
[`InferenceTimeoutError`]:
|
910 |
+
If the model is unavailable or the request times out.
|
911 |
+
`HTTPError`:
|
912 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
913 |
+
|
914 |
+
Example:
|
915 |
+
```py
|
916 |
+
>>> from huggingface_hub import InferenceClient
|
917 |
+
>>> client = InferenceClient()
|
918 |
+
>>> client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.")
|
919 |
+
{'score': 0.9326562285423279, 'start': 11, 'end': 16, 'answer': 'Clara'}
|
920 |
+
```
|
921 |
+
"""
|
922 |
+
|
923 |
+
payload: Dict[str, Any] = {"question": question, "context": context}
|
924 |
+
response = self.post(
|
925 |
+
json=payload,
|
926 |
+
model=model,
|
927 |
+
task="question-answering",
|
928 |
+
)
|
929 |
+
return _bytes_to_dict(response) # type: ignore
|
930 |
+
|
931 |
+
def sentence_similarity(
|
932 |
+
self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None
|
933 |
+
) -> List[float]:
|
934 |
+
"""
|
935 |
+
Compute the semantic similarity between a sentence and a list of other sentences by comparing their embeddings.
|
936 |
+
|
937 |
+
Args:
|
938 |
+
sentence (`str`):
|
939 |
+
The main sentence to compare to others.
|
940 |
+
other_sentences (`List[str]`):
|
941 |
+
The list of sentences to compare to.
|
942 |
+
model (`str`, *optional*):
|
943 |
+
The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
944 |
+
a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
|
945 |
+
Defaults to None.
|
946 |
+
|
947 |
+
Returns:
|
948 |
+
`List[float]`: The embedding representing the input text.
|
949 |
+
|
950 |
+
Raises:
|
951 |
+
[`InferenceTimeoutError`]:
|
952 |
+
If the model is unavailable or the request times out.
|
953 |
+
`HTTPError`:
|
954 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
955 |
+
|
956 |
+
Example:
|
957 |
+
```py
|
958 |
+
>>> from huggingface_hub import InferenceClient
|
959 |
+
>>> client = InferenceClient()
|
960 |
+
>>> client.sentence_similarity(
|
961 |
+
... "Machine learning is so easy.",
|
962 |
+
... other_sentences=[
|
963 |
+
... "Deep learning is so straightforward.",
|
964 |
+
... "This is so difficult, like rocket science.",
|
965 |
+
... "I can't believe how much I struggled with this.",
|
966 |
+
... ],
|
967 |
+
... )
|
968 |
+
[0.7785726189613342, 0.45876261591911316, 0.2906220555305481]
|
969 |
+
```
|
970 |
+
"""
|
971 |
+
response = self.post(
|
972 |
+
json={"inputs": {"source_sentence": sentence, "sentences": other_sentences}},
|
973 |
+
model=model,
|
974 |
+
task="sentence-similarity",
|
975 |
+
)
|
976 |
+
return _bytes_to_list(response)
|
977 |
+
|
978 |
+
def summarization(
|
979 |
+
self,
|
980 |
+
text: str,
|
981 |
+
*,
|
982 |
+
parameters: Optional[Dict[str, Any]] = None,
|
983 |
+
model: Optional[str] = None,
|
984 |
+
) -> str:
|
985 |
+
"""
|
986 |
+
Generate a summary of a given text using a specified model.
|
987 |
+
|
988 |
+
Args:
|
989 |
+
text (`str`):
|
990 |
+
The input text to summarize.
|
991 |
+
parameters (`Dict[str, Any]`, *optional*):
|
992 |
+
Additional parameters for summarization. Check out this [page](https://huggingface.co/docs/api-inference/detailed_parameters#summarization-task)
|
993 |
+
for more details.
|
994 |
+
model (`str`, *optional*):
|
995 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
996 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
997 |
+
|
998 |
+
Returns:
|
999 |
+
`str`: The generated summary text.
|
1000 |
+
|
1001 |
+
Raises:
|
1002 |
+
[`InferenceTimeoutError`]:
|
1003 |
+
If the model is unavailable or the request times out.
|
1004 |
+
`HTTPError`:
|
1005 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1006 |
+
|
1007 |
+
Example:
|
1008 |
+
```py
|
1009 |
+
>>> from huggingface_hub import InferenceClient
|
1010 |
+
>>> client = InferenceClient()
|
1011 |
+
>>> client.summarization("The Eiffel tower...")
|
1012 |
+
'The Eiffel tower is one of the most famous landmarks in the world....'
|
1013 |
+
```
|
1014 |
+
"""
|
1015 |
+
payload: Dict[str, Any] = {"inputs": text}
|
1016 |
+
if parameters is not None:
|
1017 |
+
payload["parameters"] = parameters
|
1018 |
+
response = self.post(json=payload, model=model, task="summarization")
|
1019 |
+
return _bytes_to_dict(response)[0]["summary_text"]
|
1020 |
+
|
1021 |
+
def table_question_answering(
|
1022 |
+
self, table: Dict[str, Any], query: str, *, model: Optional[str] = None
|
1023 |
+
) -> TableQuestionAnsweringOutput:
|
1024 |
+
"""
|
1025 |
+
Retrieve the answer to a question from information given in a table.
|
1026 |
+
|
1027 |
+
Args:
|
1028 |
+
table (`str`):
|
1029 |
+
A table of data represented as a dict of lists where entries are headers and the lists are all the
|
1030 |
+
values, all lists must have the same size.
|
1031 |
+
query (`str`):
|
1032 |
+
The query in plain text that you want to ask the table.
|
1033 |
+
model (`str`):
|
1034 |
+
The model to use for the table-question-answering task. Can be a model ID hosted on the Hugging Face
|
1035 |
+
Hub or a URL to a deployed Inference Endpoint.
|
1036 |
+
|
1037 |
+
Returns:
|
1038 |
+
`Dict`: a dictionary of table question answering output containing the answer, coordinates, cells and the aggregator used.
|
1039 |
+
|
1040 |
+
Raises:
|
1041 |
+
[`InferenceTimeoutError`]:
|
1042 |
+
If the model is unavailable or the request times out.
|
1043 |
+
`HTTPError`:
|
1044 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1045 |
+
|
1046 |
+
Example:
|
1047 |
+
```py
|
1048 |
+
>>> from huggingface_hub import InferenceClient
|
1049 |
+
>>> client = InferenceClient()
|
1050 |
+
>>> query = "How many stars does the transformers repository have?"
|
1051 |
+
>>> table = {"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"]}
|
1052 |
+
>>> client.table_question_answering(table, query, model="google/tapas-base-finetuned-wtq")
|
1053 |
+
{'answer': 'AVERAGE > 36542', 'coordinates': [[0, 1]], 'cells': ['36542'], 'aggregator': 'AVERAGE'}
|
1054 |
+
```
|
1055 |
+
"""
|
1056 |
+
response = self.post(
|
1057 |
+
json={
|
1058 |
+
"query": query,
|
1059 |
+
"table": table,
|
1060 |
+
},
|
1061 |
+
model=model,
|
1062 |
+
task="table-question-answering",
|
1063 |
+
)
|
1064 |
+
return _bytes_to_dict(response) # type: ignore
|
1065 |
+
|
1066 |
+
def tabular_classification(self, table: Dict[str, Any], *, model: str) -> List[str]:
|
1067 |
+
"""
|
1068 |
+
Classifying a target category (a group) based on a set of attributes.
|
1069 |
+
|
1070 |
+
Args:
|
1071 |
+
table (`Dict[str, Any]`):
|
1072 |
+
Set of attributes to classify.
|
1073 |
+
model (`str`):
|
1074 |
+
The model to use for the tabular-classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
1075 |
+
a deployed Inference Endpoint.
|
1076 |
+
|
1077 |
+
Returns:
|
1078 |
+
`List`: a list of labels, one per row in the initial table.
|
1079 |
+
|
1080 |
+
Raises:
|
1081 |
+
[`InferenceTimeoutError`]:
|
1082 |
+
If the model is unavailable or the request times out.
|
1083 |
+
`HTTPError`:
|
1084 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1085 |
+
|
1086 |
+
Example:
|
1087 |
+
```py
|
1088 |
+
>>> from huggingface_hub import InferenceClient
|
1089 |
+
>>> client = InferenceClient()
|
1090 |
+
>>> table = {
|
1091 |
+
... "fixed_acidity": ["7.4", "7.8", "10.3"],
|
1092 |
+
... "volatile_acidity": ["0.7", "0.88", "0.32"],
|
1093 |
+
... "citric_acid": ["0", "0", "0.45"],
|
1094 |
+
... "residual_sugar": ["1.9", "2.6", "6.4"],
|
1095 |
+
... "chlorides": ["0.076", "0.098", "0.073"],
|
1096 |
+
... "free_sulfur_dioxide": ["11", "25", "5"],
|
1097 |
+
... "total_sulfur_dioxide": ["34", "67", "13"],
|
1098 |
+
... "density": ["0.9978", "0.9968", "0.9976"],
|
1099 |
+
... "pH": ["3.51", "3.2", "3.23"],
|
1100 |
+
... "sulphates": ["0.56", "0.68", "0.82"],
|
1101 |
+
... "alcohol": ["9.4", "9.8", "12.6"],
|
1102 |
+
... }
|
1103 |
+
>>> client.tabular_classification(table=table, model="julien-c/wine-quality")
|
1104 |
+
["5", "5", "5"]
|
1105 |
+
```
|
1106 |
+
"""
|
1107 |
+
response = self.post(json={"table": table}, model=model, task="tabular-classification")
|
1108 |
+
return _bytes_to_list(response)
|
1109 |
+
|
1110 |
+
def tabular_regression(self, table: Dict[str, Any], *, model: str) -> List[float]:
|
1111 |
+
"""
|
1112 |
+
Predicting a numerical target value given a set of attributes/features in a table.
|
1113 |
+
|
1114 |
+
Args:
|
1115 |
+
table (`Dict[str, Any]`):
|
1116 |
+
Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical.
|
1117 |
+
model (`str`):
|
1118 |
+
The model to use for the tabular-regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
1119 |
+
a deployed Inference Endpoint.
|
1120 |
+
|
1121 |
+
Returns:
|
1122 |
+
`List`: a list of predicted numerical target values.
|
1123 |
+
|
1124 |
+
Raises:
|
1125 |
+
[`InferenceTimeoutError`]:
|
1126 |
+
If the model is unavailable or the request times out.
|
1127 |
+
`HTTPError`:
|
1128 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1129 |
+
|
1130 |
+
Example:
|
1131 |
+
```py
|
1132 |
+
>>> from huggingface_hub import InferenceClient
|
1133 |
+
>>> client = InferenceClient()
|
1134 |
+
>>> table = {
|
1135 |
+
... "Height": ["11.52", "12.48", "12.3778"],
|
1136 |
+
... "Length1": ["23.2", "24", "23.9"],
|
1137 |
+
... "Length2": ["25.4", "26.3", "26.5"],
|
1138 |
+
... "Length3": ["30", "31.2", "31.1"],
|
1139 |
+
... "Species": ["Bream", "Bream", "Bream"],
|
1140 |
+
... "Width": ["4.02", "4.3056", "4.6961"],
|
1141 |
+
... }
|
1142 |
+
>>> client.tabular_regression(table, model="scikit-learn/Fish-Weight")
|
1143 |
+
[110, 120, 130]
|
1144 |
+
```
|
1145 |
+
"""
|
1146 |
+
response = self.post(json={"table": table}, model=model, task="tabular-regression")
|
1147 |
+
return _bytes_to_list(response)
|
1148 |
+
|
1149 |
+
def text_classification(self, text: str, *, model: Optional[str] = None) -> List[ClassificationOutput]:
|
1150 |
+
"""
|
1151 |
+
Perform text classification (e.g. sentiment-analysis) on the given text.
|
1152 |
+
|
1153 |
+
Args:
|
1154 |
+
text (`str`):
|
1155 |
+
A string to be classified.
|
1156 |
+
model (`str`, *optional*):
|
1157 |
+
The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
1158 |
+
a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used.
|
1159 |
+
Defaults to None.
|
1160 |
+
|
1161 |
+
Returns:
|
1162 |
+
`List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
|
1163 |
+
|
1164 |
+
Raises:
|
1165 |
+
[`InferenceTimeoutError`]:
|
1166 |
+
If the model is unavailable or the request times out.
|
1167 |
+
`HTTPError`:
|
1168 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1169 |
+
|
1170 |
+
Example:
|
1171 |
+
```py
|
1172 |
+
>>> from huggingface_hub import InferenceClient
|
1173 |
+
>>> client = InferenceClient()
|
1174 |
+
>>> client.text_classification("I like you")
|
1175 |
+
[{'label': 'POSITIVE', 'score': 0.9998695850372314}, {'label': 'NEGATIVE', 'score': 0.0001304351753788069}]
|
1176 |
+
```
|
1177 |
+
"""
|
1178 |
+
response = self.post(json={"inputs": text}, model=model, task="text-classification")
|
1179 |
+
return _bytes_to_list(response)[0]
|
1180 |
+
|
1181 |
+
@overload
|
1182 |
+
def text_generation( # type: ignore
|
1183 |
+
self,
|
1184 |
+
prompt: str,
|
1185 |
+
*,
|
1186 |
+
details: Literal[False] = ...,
|
1187 |
+
stream: Literal[False] = ...,
|
1188 |
+
model: Optional[str] = None,
|
1189 |
+
do_sample: bool = False,
|
1190 |
+
max_new_tokens: int = 20,
|
1191 |
+
best_of: Optional[int] = None,
|
1192 |
+
repetition_penalty: Optional[float] = None,
|
1193 |
+
return_full_text: bool = False,
|
1194 |
+
seed: Optional[int] = None,
|
1195 |
+
stop_sequences: Optional[List[str]] = None,
|
1196 |
+
temperature: Optional[float] = None,
|
1197 |
+
top_k: Optional[int] = None,
|
1198 |
+
top_p: Optional[float] = None,
|
1199 |
+
truncate: Optional[int] = None,
|
1200 |
+
typical_p: Optional[float] = None,
|
1201 |
+
watermark: bool = False,
|
1202 |
+
) -> str:
|
1203 |
+
...
|
1204 |
+
|
1205 |
+
@overload
|
1206 |
+
def text_generation( # type: ignore
|
1207 |
+
self,
|
1208 |
+
prompt: str,
|
1209 |
+
*,
|
1210 |
+
details: Literal[True] = ...,
|
1211 |
+
stream: Literal[False] = ...,
|
1212 |
+
model: Optional[str] = None,
|
1213 |
+
do_sample: bool = False,
|
1214 |
+
max_new_tokens: int = 20,
|
1215 |
+
best_of: Optional[int] = None,
|
1216 |
+
repetition_penalty: Optional[float] = None,
|
1217 |
+
return_full_text: bool = False,
|
1218 |
+
seed: Optional[int] = None,
|
1219 |
+
stop_sequences: Optional[List[str]] = None,
|
1220 |
+
temperature: Optional[float] = None,
|
1221 |
+
top_k: Optional[int] = None,
|
1222 |
+
top_p: Optional[float] = None,
|
1223 |
+
truncate: Optional[int] = None,
|
1224 |
+
typical_p: Optional[float] = None,
|
1225 |
+
watermark: bool = False,
|
1226 |
+
) -> TextGenerationResponse:
|
1227 |
+
...
|
1228 |
+
|
1229 |
+
@overload
|
1230 |
+
def text_generation( # type: ignore
|
1231 |
+
self,
|
1232 |
+
prompt: str,
|
1233 |
+
*,
|
1234 |
+
details: Literal[False] = ...,
|
1235 |
+
stream: Literal[True] = ...,
|
1236 |
+
model: Optional[str] = None,
|
1237 |
+
do_sample: bool = False,
|
1238 |
+
max_new_tokens: int = 20,
|
1239 |
+
best_of: Optional[int] = None,
|
1240 |
+
repetition_penalty: Optional[float] = None,
|
1241 |
+
return_full_text: bool = False,
|
1242 |
+
seed: Optional[int] = None,
|
1243 |
+
stop_sequences: Optional[List[str]] = None,
|
1244 |
+
temperature: Optional[float] = None,
|
1245 |
+
top_k: Optional[int] = None,
|
1246 |
+
top_p: Optional[float] = None,
|
1247 |
+
truncate: Optional[int] = None,
|
1248 |
+
typical_p: Optional[float] = None,
|
1249 |
+
watermark: bool = False,
|
1250 |
+
) -> Iterable[str]:
|
1251 |
+
...
|
1252 |
+
|
1253 |
+
@overload
|
1254 |
+
def text_generation(
|
1255 |
+
self,
|
1256 |
+
prompt: str,
|
1257 |
+
*,
|
1258 |
+
details: Literal[True] = ...,
|
1259 |
+
stream: Literal[True] = ...,
|
1260 |
+
model: Optional[str] = None,
|
1261 |
+
do_sample: bool = False,
|
1262 |
+
max_new_tokens: int = 20,
|
1263 |
+
best_of: Optional[int] = None,
|
1264 |
+
repetition_penalty: Optional[float] = None,
|
1265 |
+
return_full_text: bool = False,
|
1266 |
+
seed: Optional[int] = None,
|
1267 |
+
stop_sequences: Optional[List[str]] = None,
|
1268 |
+
temperature: Optional[float] = None,
|
1269 |
+
top_k: Optional[int] = None,
|
1270 |
+
top_p: Optional[float] = None,
|
1271 |
+
truncate: Optional[int] = None,
|
1272 |
+
typical_p: Optional[float] = None,
|
1273 |
+
watermark: bool = False,
|
1274 |
+
) -> Iterable[TextGenerationStreamResponse]:
|
1275 |
+
...
|
1276 |
+
|
1277 |
+
def text_generation(
|
1278 |
+
self,
|
1279 |
+
prompt: str,
|
1280 |
+
*,
|
1281 |
+
details: bool = False,
|
1282 |
+
stream: bool = False,
|
1283 |
+
model: Optional[str] = None,
|
1284 |
+
do_sample: bool = False,
|
1285 |
+
max_new_tokens: int = 20,
|
1286 |
+
best_of: Optional[int] = None,
|
1287 |
+
repetition_penalty: Optional[float] = None,
|
1288 |
+
return_full_text: bool = False,
|
1289 |
+
seed: Optional[int] = None,
|
1290 |
+
stop_sequences: Optional[List[str]] = None,
|
1291 |
+
temperature: Optional[float] = None,
|
1292 |
+
top_k: Optional[int] = None,
|
1293 |
+
top_p: Optional[float] = None,
|
1294 |
+
truncate: Optional[int] = None,
|
1295 |
+
typical_p: Optional[float] = None,
|
1296 |
+
watermark: bool = False,
|
1297 |
+
decoder_input_details: bool = False,
|
1298 |
+
) -> Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]:
|
1299 |
+
"""
|
1300 |
+
Given a prompt, generate the following text.
|
1301 |
+
|
1302 |
+
It is recommended to have Pydantic installed in order to get inputs validated. This is preferable as it allow
|
1303 |
+
early failures.
|
1304 |
+
|
1305 |
+
API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the
|
1306 |
+
go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the
|
1307 |
+
default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but
|
1308 |
+
not exactly the same. This method is compatible with both approaches but some parameters are only available for
|
1309 |
+
`text-generation-inference`. If some parameters are ignored, a warning message is triggered but the process
|
1310 |
+
continues correctly.
|
1311 |
+
|
1312 |
+
To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
|
1313 |
+
|
1314 |
+
Args:
|
1315 |
+
prompt (`str`):
|
1316 |
+
Input text.
|
1317 |
+
details (`bool`, *optional*):
|
1318 |
+
By default, text_generation returns a string. Pass `details=True` if you want a detailed output (tokens,
|
1319 |
+
probabilities, seed, finish reason, etc.). Only available for models running on with the
|
1320 |
+
`text-generation-inference` backend.
|
1321 |
+
stream (`bool`, *optional*):
|
1322 |
+
By default, text_generation returns the full generated text. Pass `stream=True` if you want a stream of
|
1323 |
+
tokens to be returned. Only available for models running on with the `text-generation-inference`
|
1324 |
+
backend.
|
1325 |
+
model (`str`, *optional*):
|
1326 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
1327 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
1328 |
+
do_sample (`bool`):
|
1329 |
+
Activate logits sampling
|
1330 |
+
max_new_tokens (`int`):
|
1331 |
+
Maximum number of generated tokens
|
1332 |
+
best_of (`int`):
|
1333 |
+
Generate best_of sequences and return the one if the highest token logprobs
|
1334 |
+
repetition_penalty (`float`):
|
1335 |
+
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
1336 |
+
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
1337 |
+
return_full_text (`bool`):
|
1338 |
+
Whether to prepend the prompt to the generated text
|
1339 |
+
seed (`int`):
|
1340 |
+
Random sampling seed
|
1341 |
+
stop_sequences (`List[str]`):
|
1342 |
+
Stop generating tokens if a member of `stop_sequences` is generated
|
1343 |
+
temperature (`float`):
|
1344 |
+
The value used to module the logits distribution.
|
1345 |
+
top_k (`int`):
|
1346 |
+
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
1347 |
+
top_p (`float`):
|
1348 |
+
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
1349 |
+
higher are kept for generation.
|
1350 |
+
truncate (`int`):
|
1351 |
+
Truncate inputs tokens to the given size
|
1352 |
+
typical_p (`float`):
|
1353 |
+
Typical Decoding mass
|
1354 |
+
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
1355 |
+
watermark (`bool`):
|
1356 |
+
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
1357 |
+
decoder_input_details (`bool`):
|
1358 |
+
Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken
|
1359 |
+
into account. Defaults to `False`.
|
1360 |
+
|
1361 |
+
Returns:
|
1362 |
+
`Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]`:
|
1363 |
+
Generated text returned from the server:
|
1364 |
+
- if `stream=False` and `details=False`, the generated text is returned as a `str` (default)
|
1365 |
+
- if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]`
|
1366 |
+
- if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.inference._text_generation.TextGenerationResponse`]
|
1367 |
+
- if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.inference._text_generation.TextGenerationStreamResponse`]
|
1368 |
+
|
1369 |
+
Raises:
|
1370 |
+
`ValidationError`:
|
1371 |
+
If input values are not valid. No HTTP call is made to the server.
|
1372 |
+
[`InferenceTimeoutError`]:
|
1373 |
+
If the model is unavailable or the request times out.
|
1374 |
+
`HTTPError`:
|
1375 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1376 |
+
|
1377 |
+
Example:
|
1378 |
+
```py
|
1379 |
+
>>> from huggingface_hub import InferenceClient
|
1380 |
+
>>> client = InferenceClient()
|
1381 |
+
|
1382 |
+
# Case 1: generate text
|
1383 |
+
>>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12)
|
1384 |
+
'100% open source and built to be easy to use.'
|
1385 |
+
|
1386 |
+
# Case 2: iterate over the generated tokens. Useful for large generation.
|
1387 |
+
>>> for token in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True):
|
1388 |
+
... print(token)
|
1389 |
+
100
|
1390 |
+
%
|
1391 |
+
open
|
1392 |
+
source
|
1393 |
+
and
|
1394 |
+
built
|
1395 |
+
to
|
1396 |
+
be
|
1397 |
+
easy
|
1398 |
+
to
|
1399 |
+
use
|
1400 |
+
.
|
1401 |
+
|
1402 |
+
# Case 3: get more details about the generation process.
|
1403 |
+
>>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True)
|
1404 |
+
TextGenerationResponse(
|
1405 |
+
generated_text='100% open source and built to be easy to use.',
|
1406 |
+
details=Details(
|
1407 |
+
finish_reason=<FinishReason.Length: 'length'>,
|
1408 |
+
generated_tokens=12,
|
1409 |
+
seed=None,
|
1410 |
+
prefill=[
|
1411 |
+
InputToken(id=487, text='The', logprob=None),
|
1412 |
+
InputToken(id=53789, text=' hugging', logprob=-13.171875),
|
1413 |
+
(...)
|
1414 |
+
InputToken(id=204, text=' ', logprob=-7.0390625)
|
1415 |
+
],
|
1416 |
+
tokens=[
|
1417 |
+
Token(id=1425, text='100', logprob=-1.0175781, special=False),
|
1418 |
+
Token(id=16, text='%', logprob=-0.0463562, special=False),
|
1419 |
+
(...)
|
1420 |
+
Token(id=25, text='.', logprob=-0.5703125, special=False)
|
1421 |
+
],
|
1422 |
+
best_of_sequences=None
|
1423 |
+
)
|
1424 |
+
)
|
1425 |
+
|
1426 |
+
# Case 4: iterate over the generated tokens with more details.
|
1427 |
+
# Last object is more complete, containing the full generated text and the finish reason.
|
1428 |
+
>>> for details in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True):
|
1429 |
+
... print(details)
|
1430 |
+
...
|
1431 |
+
TextGenerationStreamResponse(token=Token(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None)
|
1432 |
+
TextGenerationStreamResponse(token=Token(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None)
|
1433 |
+
TextGenerationStreamResponse(token=Token(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None)
|
1434 |
+
TextGenerationStreamResponse(token=Token(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None)
|
1435 |
+
TextGenerationStreamResponse(token=Token(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None)
|
1436 |
+
TextGenerationStreamResponse(token=Token(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None)
|
1437 |
+
TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None)
|
1438 |
+
TextGenerationStreamResponse(token=Token(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None)
|
1439 |
+
TextGenerationStreamResponse(token=Token(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None)
|
1440 |
+
TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None)
|
1441 |
+
TextGenerationStreamResponse(token=Token(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None)
|
1442 |
+
TextGenerationStreamResponse(token=Token(
|
1443 |
+
id=25,
|
1444 |
+
text='.',
|
1445 |
+
logprob=-0.5703125,
|
1446 |
+
special=False),
|
1447 |
+
generated_text='100% open source and built to be easy to use.',
|
1448 |
+
details=StreamDetails(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=12, seed=None)
|
1449 |
+
)
|
1450 |
+
```
|
1451 |
+
"""
|
1452 |
+
# NOTE: Text-generation integration is taken from the text-generation-inference project. It has more features
|
1453 |
+
# like input/output validation (if Pydantic is installed). See `_text_generation.py` header for more details.
|
1454 |
+
|
1455 |
+
if decoder_input_details and not details:
|
1456 |
+
warnings.warn(
|
1457 |
+
"`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that"
|
1458 |
+
" the output from the server will be truncated."
|
1459 |
+
)
|
1460 |
+
decoder_input_details = False
|
1461 |
+
|
1462 |
+
# Validate parameters
|
1463 |
+
parameters = TextGenerationParameters(
|
1464 |
+
best_of=best_of,
|
1465 |
+
details=details,
|
1466 |
+
do_sample=do_sample,
|
1467 |
+
max_new_tokens=max_new_tokens,
|
1468 |
+
repetition_penalty=repetition_penalty,
|
1469 |
+
return_full_text=return_full_text,
|
1470 |
+
seed=seed,
|
1471 |
+
stop=stop_sequences if stop_sequences is not None else [],
|
1472 |
+
temperature=temperature,
|
1473 |
+
top_k=top_k,
|
1474 |
+
top_p=top_p,
|
1475 |
+
truncate=truncate,
|
1476 |
+
typical_p=typical_p,
|
1477 |
+
watermark=watermark,
|
1478 |
+
decoder_input_details=decoder_input_details,
|
1479 |
+
)
|
1480 |
+
request = TextGenerationRequest(inputs=prompt, stream=stream, parameters=parameters)
|
1481 |
+
payload = asdict(request)
|
1482 |
+
|
1483 |
+
# Remove some parameters if not a TGI server
|
1484 |
+
if not _is_tgi_server(model):
|
1485 |
+
ignored_parameters = []
|
1486 |
+
for key in "watermark", "stop", "details", "decoder_input_details":
|
1487 |
+
if payload["parameters"][key] is not None:
|
1488 |
+
ignored_parameters.append(key)
|
1489 |
+
del payload["parameters"][key]
|
1490 |
+
if len(ignored_parameters) > 0:
|
1491 |
+
warnings.warn(
|
1492 |
+
"API endpoint/model for text-generation is not served via TGI. Ignoring parameters"
|
1493 |
+
f" {ignored_parameters}.",
|
1494 |
+
UserWarning,
|
1495 |
+
)
|
1496 |
+
if details:
|
1497 |
+
warnings.warn(
|
1498 |
+
"API endpoint/model for text-generation is not served via TGI. Parameter `details=True` will"
|
1499 |
+
" be ignored meaning only the generated text will be returned.",
|
1500 |
+
UserWarning,
|
1501 |
+
)
|
1502 |
+
details = False
|
1503 |
+
if stream:
|
1504 |
+
raise ValueError(
|
1505 |
+
"API endpoint/model for text-generation is not served via TGI. Cannot return output as a stream."
|
1506 |
+
" Please pass `stream=False` as input."
|
1507 |
+
)
|
1508 |
+
|
1509 |
+
# Handle errors separately for more precise error messages
|
1510 |
+
try:
|
1511 |
+
bytes_output = self.post(json=payload, model=model, task="text-generation", stream=stream) # type: ignore
|
1512 |
+
except HTTPError as e:
|
1513 |
+
if isinstance(e, BadRequestError) and "The following `model_kwargs` are not used by the model" in str(e):
|
1514 |
+
_set_as_non_tgi(model)
|
1515 |
+
return self.text_generation( # type: ignore
|
1516 |
+
prompt=prompt,
|
1517 |
+
details=details,
|
1518 |
+
stream=stream,
|
1519 |
+
model=model,
|
1520 |
+
do_sample=do_sample,
|
1521 |
+
max_new_tokens=max_new_tokens,
|
1522 |
+
best_of=best_of,
|
1523 |
+
repetition_penalty=repetition_penalty,
|
1524 |
+
return_full_text=return_full_text,
|
1525 |
+
seed=seed,
|
1526 |
+
stop_sequences=stop_sequences,
|
1527 |
+
temperature=temperature,
|
1528 |
+
top_k=top_k,
|
1529 |
+
top_p=top_p,
|
1530 |
+
truncate=truncate,
|
1531 |
+
typical_p=typical_p,
|
1532 |
+
watermark=watermark,
|
1533 |
+
decoder_input_details=decoder_input_details,
|
1534 |
+
)
|
1535 |
+
raise_text_generation_error(e)
|
1536 |
+
|
1537 |
+
# Parse output
|
1538 |
+
if stream:
|
1539 |
+
return _stream_text_generation_response(bytes_output, details) # type: ignore
|
1540 |
+
|
1541 |
+
data = _bytes_to_dict(bytes_output)[0]
|
1542 |
+
return TextGenerationResponse(**data) if details else data["generated_text"]
|
1543 |
+
|
1544 |
+
def text_to_image(
|
1545 |
+
self,
|
1546 |
+
prompt: str,
|
1547 |
+
*,
|
1548 |
+
negative_prompt: Optional[str] = None,
|
1549 |
+
height: Optional[float] = None,
|
1550 |
+
width: Optional[float] = None,
|
1551 |
+
num_inference_steps: Optional[float] = None,
|
1552 |
+
guidance_scale: Optional[float] = None,
|
1553 |
+
model: Optional[str] = None,
|
1554 |
+
**kwargs,
|
1555 |
+
) -> "Image":
|
1556 |
+
"""
|
1557 |
+
Generate an image based on a given text using a specified model.
|
1558 |
+
|
1559 |
+
<Tip warning={true}>
|
1560 |
+
|
1561 |
+
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
1562 |
+
|
1563 |
+
</Tip>
|
1564 |
+
|
1565 |
+
Args:
|
1566 |
+
prompt (`str`):
|
1567 |
+
The prompt to generate an image from.
|
1568 |
+
negative_prompt (`str`, *optional*):
|
1569 |
+
An optional negative prompt for the image generation.
|
1570 |
+
height (`float`, *optional*):
|
1571 |
+
The height in pixels of the image to generate.
|
1572 |
+
width (`float`, *optional*):
|
1573 |
+
The width in pixels of the image to generate.
|
1574 |
+
num_inference_steps (`int`, *optional*):
|
1575 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1576 |
+
expense of slower inference.
|
1577 |
+
guidance_scale (`float`, *optional*):
|
1578 |
+
Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1579 |
+
usually at the expense of lower image quality.
|
1580 |
+
model (`str`, *optional*):
|
1581 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
1582 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
1583 |
+
|
1584 |
+
Returns:
|
1585 |
+
`Image`: The generated image.
|
1586 |
+
|
1587 |
+
Raises:
|
1588 |
+
[`InferenceTimeoutError`]:
|
1589 |
+
If the model is unavailable or the request times out.
|
1590 |
+
`HTTPError`:
|
1591 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1592 |
+
|
1593 |
+
Example:
|
1594 |
+
```py
|
1595 |
+
>>> from huggingface_hub import InferenceClient
|
1596 |
+
>>> client = InferenceClient()
|
1597 |
+
|
1598 |
+
>>> image = client.text_to_image("An astronaut riding a horse on the moon.")
|
1599 |
+
>>> image.save("astronaut.png")
|
1600 |
+
|
1601 |
+
>>> image = client.text_to_image(
|
1602 |
+
... "An astronaut riding a horse on the moon.",
|
1603 |
+
... negative_prompt="low resolution, blurry",
|
1604 |
+
... model="stabilityai/stable-diffusion-2-1",
|
1605 |
+
... )
|
1606 |
+
>>> image.save("better_astronaut.png")
|
1607 |
+
```
|
1608 |
+
"""
|
1609 |
+
payload = {"inputs": prompt}
|
1610 |
+
parameters = {
|
1611 |
+
"negative_prompt": negative_prompt,
|
1612 |
+
"height": height,
|
1613 |
+
"width": width,
|
1614 |
+
"num_inference_steps": num_inference_steps,
|
1615 |
+
"guidance_scale": guidance_scale,
|
1616 |
+
**kwargs,
|
1617 |
+
}
|
1618 |
+
for key, value in parameters.items():
|
1619 |
+
if value is not None:
|
1620 |
+
payload.setdefault("parameters", {})[key] = value # type: ignore
|
1621 |
+
response = self.post(json=payload, model=model, task="text-to-image")
|
1622 |
+
return _bytes_to_image(response)
|
1623 |
+
|
1624 |
+
def text_to_speech(self, text: str, *, model: Optional[str] = None) -> bytes:
|
1625 |
+
"""
|
1626 |
+
Synthesize an audio of a voice pronouncing a given text.
|
1627 |
+
|
1628 |
+
Args:
|
1629 |
+
text (`str`):
|
1630 |
+
The text to synthesize.
|
1631 |
+
model (`str`, *optional*):
|
1632 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
1633 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
1634 |
+
|
1635 |
+
Returns:
|
1636 |
+
`bytes`: The generated audio.
|
1637 |
+
|
1638 |
+
Raises:
|
1639 |
+
[`InferenceTimeoutError`]:
|
1640 |
+
If the model is unavailable or the request times out.
|
1641 |
+
`HTTPError`:
|
1642 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1643 |
+
|
1644 |
+
Example:
|
1645 |
+
```py
|
1646 |
+
>>> from pathlib import Path
|
1647 |
+
>>> from huggingface_hub import InferenceClient
|
1648 |
+
>>> client = InferenceClient()
|
1649 |
+
|
1650 |
+
>>> audio = client.text_to_speech("Hello world")
|
1651 |
+
>>> Path("hello_world.flac").write_bytes(audio)
|
1652 |
+
```
|
1653 |
+
"""
|
1654 |
+
return self.post(json={"inputs": text}, model=model, task="text-to-speech")
|
1655 |
+
|
1656 |
+
def token_classification(self, text: str, *, model: Optional[str] = None) -> List[TokenClassificationOutput]:
|
1657 |
+
"""
|
1658 |
+
Perform token classification on the given text.
|
1659 |
+
Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text.
|
1660 |
+
|
1661 |
+
Args:
|
1662 |
+
text (`str`):
|
1663 |
+
A string to be classified.
|
1664 |
+
model (`str`, *optional*):
|
1665 |
+
The model to use for the token classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
1666 |
+
a deployed Inference Endpoint. If not provided, the default recommended token classification model will be used.
|
1667 |
+
Defaults to None.
|
1668 |
+
|
1669 |
+
Returns:
|
1670 |
+
`List[Dict]`: List of token classification outputs containing the entity group, confidence score, word, start and end index.
|
1671 |
+
|
1672 |
+
Raises:
|
1673 |
+
[`InferenceTimeoutError`]:
|
1674 |
+
If the model is unavailable or the request times out.
|
1675 |
+
`HTTPError`:
|
1676 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1677 |
+
|
1678 |
+
Example:
|
1679 |
+
```py
|
1680 |
+
>>> from huggingface_hub import InferenceClient
|
1681 |
+
>>> client = InferenceClient()
|
1682 |
+
>>> client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica")
|
1683 |
+
[{'entity_group': 'PER',
|
1684 |
+
'score': 0.9971321225166321,
|
1685 |
+
'word': 'Sarah Jessica Parker',
|
1686 |
+
'start': 11,
|
1687 |
+
'end': 31},
|
1688 |
+
{'entity_group': 'PER',
|
1689 |
+
'score': 0.9773476123809814,
|
1690 |
+
'word': 'Jessica',
|
1691 |
+
'start': 52,
|
1692 |
+
'end': 59}]
|
1693 |
+
```
|
1694 |
+
"""
|
1695 |
+
payload: Dict[str, Any] = {"inputs": text}
|
1696 |
+
response = self.post(
|
1697 |
+
json=payload,
|
1698 |
+
model=model,
|
1699 |
+
task="token-classification",
|
1700 |
+
)
|
1701 |
+
return _bytes_to_list(response)
|
1702 |
+
|
1703 |
+
def translation(
|
1704 |
+
self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None
|
1705 |
+
) -> str:
|
1706 |
+
"""
|
1707 |
+
Convert text from one language to another.
|
1708 |
+
|
1709 |
+
Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for
|
1710 |
+
your specific use case. Source and target languages usually depend on the model.
|
1711 |
+
However, it is possible to specify source and target languages for certain models. If you are working with one of these models,
|
1712 |
+
you can use `src_lang` and `tgt_lang` arguments to pass the relevant information.
|
1713 |
+
You can find this information in the model card.
|
1714 |
+
|
1715 |
+
Args:
|
1716 |
+
text (`str`):
|
1717 |
+
A string to be translated.
|
1718 |
+
model (`str`, *optional*):
|
1719 |
+
The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
1720 |
+
a deployed Inference Endpoint. If not provided, the default recommended translation model will be used.
|
1721 |
+
Defaults to None.
|
1722 |
+
src_lang (`str`, *optional*):
|
1723 |
+
Source language of the translation task, i.e. input language. Cannot be passed without `tgt_lang`.
|
1724 |
+
tgt_lang (`str`, *optional*):
|
1725 |
+
Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`.
|
1726 |
+
|
1727 |
+
Returns:
|
1728 |
+
`str`: The generated translated text.
|
1729 |
+
|
1730 |
+
Raises:
|
1731 |
+
[`InferenceTimeoutError`]:
|
1732 |
+
If the model is unavailable or the request times out.
|
1733 |
+
`HTTPError`:
|
1734 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1735 |
+
`ValueError`:
|
1736 |
+
If only one of the `src_lang` and `tgt_lang` arguments are provided.
|
1737 |
+
|
1738 |
+
Example:
|
1739 |
+
```py
|
1740 |
+
>>> from huggingface_hub import InferenceClient
|
1741 |
+
>>> client = InferenceClient()
|
1742 |
+
>>> client.translation("My name is Wolfgang and I live in Berlin")
|
1743 |
+
'Mein Name ist Wolfgang und ich lebe in Berlin.'
|
1744 |
+
>>> client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr")
|
1745 |
+
"Je m'appelle Wolfgang et je vis à Berlin."
|
1746 |
+
```
|
1747 |
+
|
1748 |
+
Specifying languages:
|
1749 |
+
```py
|
1750 |
+
>>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX")
|
1751 |
+
"Mon nom est Sarah Jessica Parker mais vous pouvez m\'appeler Jessica"
|
1752 |
+
```
|
1753 |
+
"""
|
1754 |
+
# Throw error if only one of `src_lang` and `tgt_lang` was given
|
1755 |
+
if src_lang is not None and tgt_lang is None:
|
1756 |
+
raise ValueError("You cannot specify `src_lang` without specifying `tgt_lang`.")
|
1757 |
+
|
1758 |
+
if src_lang is None and tgt_lang is not None:
|
1759 |
+
raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.")
|
1760 |
+
|
1761 |
+
# If both `src_lang` and `tgt_lang` are given, pass them to the request body
|
1762 |
+
payload: Dict = {"inputs": text}
|
1763 |
+
if src_lang and tgt_lang:
|
1764 |
+
payload["parameters"] = {"src_lang": src_lang, "tgt_lang": tgt_lang}
|
1765 |
+
response = self.post(json=payload, model=model, task="translation")
|
1766 |
+
return _bytes_to_dict(response)[0]["translation_text"]
|
1767 |
+
|
1768 |
+
def zero_shot_classification(
|
1769 |
+
self, text: str, labels: List[str], *, multi_label: bool = False, model: Optional[str] = None
|
1770 |
+
) -> List[ClassificationOutput]:
|
1771 |
+
"""
|
1772 |
+
Provide as input a text and a set of candidate labels to classify the input text.
|
1773 |
+
|
1774 |
+
Args:
|
1775 |
+
text (`str`):
|
1776 |
+
The input text to classify.
|
1777 |
+
labels (`List[str]`):
|
1778 |
+
List of string possible labels. There must be at least 2 labels.
|
1779 |
+
multi_label (`bool`):
|
1780 |
+
Boolean that is set to True if classes can overlap.
|
1781 |
+
model (`str`, *optional*):
|
1782 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
1783 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
1784 |
+
|
1785 |
+
Returns:
|
1786 |
+
`List[Dict]`: List of classification outputs containing the predicted labels and their confidence.
|
1787 |
+
|
1788 |
+
Raises:
|
1789 |
+
[`InferenceTimeoutError`]:
|
1790 |
+
If the model is unavailable or the request times out.
|
1791 |
+
`HTTPError`:
|
1792 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1793 |
+
|
1794 |
+
Example:
|
1795 |
+
```py
|
1796 |
+
>>> from huggingface_hub import InferenceClient
|
1797 |
+
>>> client = InferenceClient()
|
1798 |
+
>>> text = (
|
1799 |
+
... "A new model offers an explanation for how the Galilean satellites formed around the solar system's"
|
1800 |
+
... "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling"
|
1801 |
+
... " mysteries when he went for a run up a hill in Nice, France."
|
1802 |
+
... )
|
1803 |
+
>>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"]
|
1804 |
+
>>> client.zero_shot_classification(text, labels)
|
1805 |
+
[
|
1806 |
+
{"label": "scientific discovery", "score": 0.7961668968200684},
|
1807 |
+
{"label": "space & cosmos", "score": 0.18570658564567566},
|
1808 |
+
{"label": "microbiology", "score": 0.00730885099619627},
|
1809 |
+
{"label": "archeology", "score": 0.006258360575884581},
|
1810 |
+
{"label": "robots", "score": 0.004559356719255447},
|
1811 |
+
]
|
1812 |
+
>>> client.zero_shot_classification(text, labels, multi_label=True)
|
1813 |
+
[
|
1814 |
+
{"label": "scientific discovery", "score": 0.9829297661781311},
|
1815 |
+
{"label": "space & cosmos", "score": 0.755190908908844},
|
1816 |
+
{"label": "microbiology", "score": 0.0005462635890580714},
|
1817 |
+
{"label": "archeology", "score": 0.00047131875180639327},
|
1818 |
+
{"label": "robots", "score": 0.00030448526376858354},
|
1819 |
+
]
|
1820 |
+
```
|
1821 |
+
"""
|
1822 |
+
# Raise ValueError if input is less than 2 labels
|
1823 |
+
if len(labels) < 2:
|
1824 |
+
raise ValueError("You must specify at least 2 classes to compare.")
|
1825 |
+
|
1826 |
+
response = self.post(
|
1827 |
+
json={
|
1828 |
+
"inputs": text,
|
1829 |
+
"parameters": {
|
1830 |
+
"candidate_labels": ",".join(labels),
|
1831 |
+
"multi_label": multi_label,
|
1832 |
+
},
|
1833 |
+
},
|
1834 |
+
model=model,
|
1835 |
+
task="zero-shot-classification",
|
1836 |
+
)
|
1837 |
+
output = _bytes_to_dict(response)
|
1838 |
+
return [{"label": label, "score": score} for label, score in zip(output["labels"], output["scores"])]
|
1839 |
+
|
1840 |
+
def zero_shot_image_classification(
|
1841 |
+
self, image: ContentT, labels: List[str], *, model: Optional[str] = None
|
1842 |
+
) -> List[ClassificationOutput]:
|
1843 |
+
"""
|
1844 |
+
Provide input image and text labels to predict text labels for the image.
|
1845 |
+
|
1846 |
+
Args:
|
1847 |
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
1848 |
+
The input image to caption. It can be raw bytes, an image file, or a URL to an online image.
|
1849 |
+
labels (`List[str]`):
|
1850 |
+
List of string possible labels. There must be at least 2 labels.
|
1851 |
+
model (`str`, *optional*):
|
1852 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
1853 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
1854 |
+
|
1855 |
+
Returns:
|
1856 |
+
`List[Dict]`: List of classification outputs containing the predicted labels and their confidence.
|
1857 |
+
|
1858 |
+
Raises:
|
1859 |
+
[`InferenceTimeoutError`]:
|
1860 |
+
If the model is unavailable or the request times out.
|
1861 |
+
`HTTPError`:
|
1862 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1863 |
+
|
1864 |
+
Example:
|
1865 |
+
```py
|
1866 |
+
>>> from huggingface_hub import InferenceClient
|
1867 |
+
>>> client = InferenceClient()
|
1868 |
+
|
1869 |
+
>>> client.zero_shot_image_classification(
|
1870 |
+
... "https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg",
|
1871 |
+
... labels=["dog", "cat", "horse"],
|
1872 |
+
... )
|
1873 |
+
[{"label": "dog", "score": 0.956}, ...]
|
1874 |
+
```
|
1875 |
+
"""
|
1876 |
+
# Raise ValueError if input is less than 2 labels
|
1877 |
+
if len(labels) < 2:
|
1878 |
+
raise ValueError("You must specify at least 2 classes to compare.")
|
1879 |
+
|
1880 |
+
response = self.post(
|
1881 |
+
json={"image": _b64_encode(image), "parameters": {"candidate_labels": ",".join(labels)}},
|
1882 |
+
model=model,
|
1883 |
+
task="zero-shot-image-classification",
|
1884 |
+
)
|
1885 |
+
return _bytes_to_list(response)
|
1886 |
+
|
1887 |
+
def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
|
1888 |
+
model = model or self.model
|
1889 |
+
|
1890 |
+
# If model is already a URL, ignore `task` and return directly
|
1891 |
+
if model is not None and (model.startswith("http://") or model.startswith("https://")):
|
1892 |
+
return model
|
1893 |
+
|
1894 |
+
# # If no model but task is set => fetch the recommended one for this task
|
1895 |
+
if model is None:
|
1896 |
+
if task is None:
|
1897 |
+
raise ValueError(
|
1898 |
+
"You must specify at least a model (repo_id or URL) or a task, either when instantiating"
|
1899 |
+
" `InferenceClient` or when making a request."
|
1900 |
+
)
|
1901 |
+
model = self.get_recommended_model(task)
|
1902 |
+
logger.info(
|
1903 |
+
f"Using recommended model {model} for task {task}. Note that it is"
|
1904 |
+
f" encouraged to explicitly set `model='{model}'` as the recommended"
|
1905 |
+
" models list might get updated without prior notice."
|
1906 |
+
)
|
1907 |
+
|
1908 |
+
# Compute InferenceAPI url
|
1909 |
+
return (
|
1910 |
+
# Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks.
|
1911 |
+
f"{INFERENCE_ENDPOINT}/pipeline/{task}/{model}"
|
1912 |
+
if task in ("feature-extraction", "sentence-similarity")
|
1913 |
+
# Otherwise, we use the default endpoint
|
1914 |
+
else f"{INFERENCE_ENDPOINT}/models/{model}"
|
1915 |
+
)
|
1916 |
+
|
1917 |
+
@staticmethod
|
1918 |
+
def get_recommended_model(task: str) -> str:
|
1919 |
+
"""
|
1920 |
+
Get the model Hugging Face recommends for the input task.
|
1921 |
+
|
1922 |
+
Args:
|
1923 |
+
task (`str`):
|
1924 |
+
The Hugging Face task to get which model Hugging Face recommends.
|
1925 |
+
All available tasks can be found [here](https://huggingface.co/tasks).
|
1926 |
+
|
1927 |
+
Returns:
|
1928 |
+
`str`: Name of the model recommended for the input task.
|
1929 |
+
|
1930 |
+
Raises:
|
1931 |
+
`ValueError`: If Hugging Face has no recommendation for the input task.
|
1932 |
+
"""
|
1933 |
+
model = _fetch_recommended_models().get(task)
|
1934 |
+
if model is None:
|
1935 |
+
raise ValueError(
|
1936 |
+
f"Task {task} has no recommended model. Please specify a model"
|
1937 |
+
" explicitly. Visit https://huggingface.co/tasks for more info."
|
1938 |
+
)
|
1939 |
+
return model
|
1940 |
+
|
1941 |
+
def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
|
1942 |
+
"""
|
1943 |
+
Get the status of a model hosted on the Inference API.
|
1944 |
+
|
1945 |
+
<Tip>
|
1946 |
+
|
1947 |
+
This endpoint is mostly useful when you already know which model you want to use and want to check its
|
1948 |
+
availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`].
|
1949 |
+
|
1950 |
+
</Tip>
|
1951 |
+
|
1952 |
+
Args:
|
1953 |
+
model (`str`, *optional*):
|
1954 |
+
Identifier of the model for witch the status gonna be checked. If model is not provided,
|
1955 |
+
the model associated with this instance of [`InferenceClient`] will be used. Only InferenceAPI service can be checked so the
|
1956 |
+
identifier cannot be a URL.
|
1957 |
+
|
1958 |
+
|
1959 |
+
Returns:
|
1960 |
+
[`ModelStatus`]: An instance of ModelStatus dataclass, containing information,
|
1961 |
+
about the state of the model: load, state, compute type and framework.
|
1962 |
+
|
1963 |
+
Example:
|
1964 |
+
```py
|
1965 |
+
>>> from huggingface_hub import InferenceClient
|
1966 |
+
>>> client = InferenceClient()
|
1967 |
+
>>> client.get_model_status("bigcode/starcoder")
|
1968 |
+
ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
|
1969 |
+
```
|
1970 |
+
"""
|
1971 |
+
model = model or self.model
|
1972 |
+
if model is None:
|
1973 |
+
raise ValueError("Model id not provided.")
|
1974 |
+
if model.startswith("https://"):
|
1975 |
+
raise NotImplementedError("Model status is only available for Inference API endpoints.")
|
1976 |
+
url = f"{INFERENCE_ENDPOINT}/status/{model}"
|
1977 |
+
|
1978 |
+
response = get_session().get(url, headers=self.headers)
|
1979 |
+
hf_raise_for_status(response)
|
1980 |
+
response_data = response.json()
|
1981 |
+
|
1982 |
+
if "error" in response_data:
|
1983 |
+
raise ValueError(response_data["error"])
|
1984 |
+
|
1985 |
+
return ModelStatus(
|
1986 |
+
loaded=response_data["loaded"],
|
1987 |
+
state=response_data["state"],
|
1988 |
+
compute_type=response_data["compute_type"],
|
1989 |
+
framework=response_data["framework"],
|
1990 |
+
)
|
lib/python3.11/site-packages/huggingface_hub/inference/_common.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023-present, the HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Contains utilities used by both the sync and async inference clients."""
|
16 |
+
import base64
|
17 |
+
import io
|
18 |
+
import json
|
19 |
+
import logging
|
20 |
+
from contextlib import contextmanager
|
21 |
+
from dataclasses import dataclass
|
22 |
+
from pathlib import Path
|
23 |
+
from typing import (
|
24 |
+
TYPE_CHECKING,
|
25 |
+
Any,
|
26 |
+
AsyncIterable,
|
27 |
+
BinaryIO,
|
28 |
+
ContextManager,
|
29 |
+
Dict,
|
30 |
+
Generator,
|
31 |
+
Iterable,
|
32 |
+
List,
|
33 |
+
Literal,
|
34 |
+
Optional,
|
35 |
+
Set,
|
36 |
+
Union,
|
37 |
+
overload,
|
38 |
+
)
|
39 |
+
|
40 |
+
from requests import HTTPError
|
41 |
+
|
42 |
+
from ..constants import ENDPOINT
|
43 |
+
from ..utils import (
|
44 |
+
build_hf_headers,
|
45 |
+
get_session,
|
46 |
+
hf_raise_for_status,
|
47 |
+
is_aiohttp_available,
|
48 |
+
is_numpy_available,
|
49 |
+
is_pillow_available,
|
50 |
+
)
|
51 |
+
from ._text_generation import TextGenerationStreamResponse, _parse_text_generation_error
|
52 |
+
|
53 |
+
|
54 |
+
if TYPE_CHECKING:
|
55 |
+
from aiohttp import ClientResponse, ClientSession
|
56 |
+
from PIL import Image
|
57 |
+
|
58 |
+
# TYPES
|
59 |
+
UrlT = str
|
60 |
+
PathT = Union[str, Path]
|
61 |
+
BinaryT = Union[bytes, BinaryIO]
|
62 |
+
ContentT = Union[BinaryT, PathT, UrlT]
|
63 |
+
|
64 |
+
# Use to set a Accept: image/png header
|
65 |
+
TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"}
|
66 |
+
|
67 |
+
logger = logging.getLogger(__name__)
|
68 |
+
|
69 |
+
|
70 |
+
# Add dataclass for ModelStatus. We use this dataclass in get_model_status function.
|
71 |
+
@dataclass
|
72 |
+
class ModelStatus:
|
73 |
+
"""
|
74 |
+
This Dataclass represents the the model status in the Hugging Face Inference API.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
loaded (`bool`):
|
78 |
+
If the model is currently loaded into Hugging Face's InferenceAPI. Models
|
79 |
+
are loaded on-demand, leading to the user's first request taking longer.
|
80 |
+
If a model is loaded, you can be assured that it is in a healthy state.
|
81 |
+
state (`str`):
|
82 |
+
The current state of the model. This can be 'Loaded', 'Loadable', 'TooBig'.
|
83 |
+
If a model's state is 'Loadable', it's not too big and has a supported
|
84 |
+
backend. Loadable models are automatically loaded when the user first
|
85 |
+
requests inference on the endpoint. This means it is transparent for the
|
86 |
+
user to load a model, except that the first call takes longer to complete.
|
87 |
+
compute_type (`str`):
|
88 |
+
The type of compute resource the model is using or will use, such as 'gpu' or 'cpu'.
|
89 |
+
framework (`str`):
|
90 |
+
The name of the framework that the model was built with, such as 'transformers'
|
91 |
+
or 'text-generation-inference'.
|
92 |
+
"""
|
93 |
+
|
94 |
+
loaded: bool
|
95 |
+
state: str
|
96 |
+
compute_type: str
|
97 |
+
framework: str
|
98 |
+
|
99 |
+
|
100 |
+
class InferenceTimeoutError(HTTPError, TimeoutError):
|
101 |
+
"""Error raised when a model is unavailable or the request times out."""
|
102 |
+
|
103 |
+
|
104 |
+
## IMPORT UTILS
|
105 |
+
|
106 |
+
|
107 |
+
def _import_aiohttp():
|
108 |
+
# Make sure `aiohttp` is installed on the machine.
|
109 |
+
if not is_aiohttp_available():
|
110 |
+
raise ImportError("Please install aiohttp to use `AsyncInferenceClient` (`pip install aiohttp`).")
|
111 |
+
import aiohttp
|
112 |
+
|
113 |
+
return aiohttp
|
114 |
+
|
115 |
+
|
116 |
+
def _import_numpy():
|
117 |
+
"""Make sure `numpy` is installed on the machine."""
|
118 |
+
if not is_numpy_available():
|
119 |
+
raise ImportError("Please install numpy to use deal with embeddings (`pip install numpy`).")
|
120 |
+
import numpy
|
121 |
+
|
122 |
+
return numpy
|
123 |
+
|
124 |
+
|
125 |
+
def _import_pil_image():
|
126 |
+
"""Make sure `PIL` is installed on the machine."""
|
127 |
+
if not is_pillow_available():
|
128 |
+
raise ImportError(
|
129 |
+
"Please install Pillow to use deal with images (`pip install Pillow`). If you don't want the image to be"
|
130 |
+
" post-processed, use `client.post(...)` and get the raw response from the server."
|
131 |
+
)
|
132 |
+
from PIL import Image
|
133 |
+
|
134 |
+
return Image
|
135 |
+
|
136 |
+
|
137 |
+
## RECOMMENDED MODELS
|
138 |
+
|
139 |
+
# Will be globally fetched only once (see '_fetch_recommended_models')
|
140 |
+
_RECOMMENDED_MODELS: Optional[Dict[str, Optional[str]]] = None
|
141 |
+
|
142 |
+
|
143 |
+
def _fetch_recommended_models() -> Dict[str, Optional[str]]:
|
144 |
+
global _RECOMMENDED_MODELS
|
145 |
+
if _RECOMMENDED_MODELS is None:
|
146 |
+
response = get_session().get(f"{ENDPOINT}/api/tasks", headers=build_hf_headers())
|
147 |
+
hf_raise_for_status(response)
|
148 |
+
_RECOMMENDED_MODELS = {
|
149 |
+
task: _first_or_none(details["widgetModels"]) for task, details in response.json().items()
|
150 |
+
}
|
151 |
+
return _RECOMMENDED_MODELS
|
152 |
+
|
153 |
+
|
154 |
+
def _first_or_none(items: List[Any]) -> Optional[Any]:
|
155 |
+
try:
|
156 |
+
return items[0] or None
|
157 |
+
except IndexError:
|
158 |
+
return None
|
159 |
+
|
160 |
+
|
161 |
+
## ENCODING / DECODING UTILS
|
162 |
+
|
163 |
+
|
164 |
+
@overload
|
165 |
+
def _open_as_binary(content: ContentT) -> ContextManager[BinaryT]:
|
166 |
+
... # means "if input is not None, output is not None"
|
167 |
+
|
168 |
+
|
169 |
+
@overload
|
170 |
+
def _open_as_binary(content: Literal[None]) -> ContextManager[Literal[None]]:
|
171 |
+
... # means "if input is None, output is None"
|
172 |
+
|
173 |
+
|
174 |
+
@contextmanager # type: ignore
|
175 |
+
def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT], None, None]:
|
176 |
+
"""Open `content` as a binary file, either from a URL, a local path, or raw bytes.
|
177 |
+
|
178 |
+
Do nothing if `content` is None,
|
179 |
+
|
180 |
+
TODO: handle a PIL.Image as input
|
181 |
+
TODO: handle base64 as input
|
182 |
+
"""
|
183 |
+
# If content is a string => must be either a URL or a path
|
184 |
+
if isinstance(content, str):
|
185 |
+
if content.startswith("https://") or content.startswith("http://"):
|
186 |
+
logger.debug(f"Downloading content from {content}")
|
187 |
+
yield get_session().get(content).content # TODO: retrieve as stream and pipe to post request ?
|
188 |
+
return
|
189 |
+
content = Path(content)
|
190 |
+
if not content.exists():
|
191 |
+
raise FileNotFoundError(
|
192 |
+
f"File not found at {content}. If `data` is a string, it must either be a URL or a path to a local"
|
193 |
+
" file. To pass raw content, please encode it as bytes first."
|
194 |
+
)
|
195 |
+
|
196 |
+
# If content is a Path => open it
|
197 |
+
if isinstance(content, Path):
|
198 |
+
logger.debug(f"Opening content from {content}")
|
199 |
+
with content.open("rb") as f:
|
200 |
+
yield f
|
201 |
+
else:
|
202 |
+
# Otherwise: already a file-like object or None
|
203 |
+
yield content
|
204 |
+
|
205 |
+
|
206 |
+
def _b64_encode(content: ContentT) -> str:
|
207 |
+
"""Encode a raw file (image, audio) into base64. Can be byes, an opened file, a path or a URL."""
|
208 |
+
with _open_as_binary(content) as data:
|
209 |
+
data_as_bytes = data if isinstance(data, bytes) else data.read()
|
210 |
+
return base64.b64encode(data_as_bytes).decode()
|
211 |
+
|
212 |
+
|
213 |
+
def _b64_to_image(encoded_image: str) -> "Image":
|
214 |
+
"""Parse a base64-encoded string into a PIL Image."""
|
215 |
+
Image = _import_pil_image()
|
216 |
+
return Image.open(io.BytesIO(base64.b64decode(encoded_image)))
|
217 |
+
|
218 |
+
|
219 |
+
def _bytes_to_list(content: bytes) -> List:
|
220 |
+
"""Parse bytes from a Response object into a Python list.
|
221 |
+
|
222 |
+
Expects the response body to be JSON-encoded data.
|
223 |
+
|
224 |
+
NOTE: This is exactly the same implementation as `_bytes_to_dict` and will not complain if the returned data is a
|
225 |
+
dictionary. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect.
|
226 |
+
"""
|
227 |
+
return json.loads(content.decode())
|
228 |
+
|
229 |
+
|
230 |
+
def _bytes_to_dict(content: bytes) -> Dict:
|
231 |
+
"""Parse bytes from a Response object into a Python dictionary.
|
232 |
+
|
233 |
+
Expects the response body to be JSON-encoded data.
|
234 |
+
|
235 |
+
NOTE: This is exactly the same implementation as `_bytes_to_list` and will not complain if the returned data is a
|
236 |
+
list. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect.
|
237 |
+
"""
|
238 |
+
return json.loads(content.decode())
|
239 |
+
|
240 |
+
|
241 |
+
def _bytes_to_image(content: bytes) -> "Image":
|
242 |
+
"""Parse bytes from a Response object into a PIL Image.
|
243 |
+
|
244 |
+
Expects the response body to be raw bytes. To deal with b64 encoded images, use `_b64_to_image` instead.
|
245 |
+
"""
|
246 |
+
Image = _import_pil_image()
|
247 |
+
return Image.open(io.BytesIO(content))
|
248 |
+
|
249 |
+
|
250 |
+
## STREAMING UTILS
|
251 |
+
|
252 |
+
|
253 |
+
def _stream_text_generation_response(
|
254 |
+
bytes_output_as_lines: Iterable[bytes], details: bool
|
255 |
+
) -> Union[Iterable[str], Iterable[TextGenerationStreamResponse]]:
|
256 |
+
# Parse ServerSentEvents
|
257 |
+
for byte_payload in bytes_output_as_lines:
|
258 |
+
# Skip line
|
259 |
+
if byte_payload == b"\n":
|
260 |
+
continue
|
261 |
+
|
262 |
+
payload = byte_payload.decode("utf-8")
|
263 |
+
|
264 |
+
# Event data
|
265 |
+
if payload.startswith("data:"):
|
266 |
+
# Decode payload
|
267 |
+
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
268 |
+
# Either an error as being returned
|
269 |
+
if json_payload.get("error") is not None:
|
270 |
+
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
|
271 |
+
# Or parse token payload
|
272 |
+
output = TextGenerationStreamResponse(**json_payload)
|
273 |
+
yield output.token.text if not details else output
|
274 |
+
|
275 |
+
|
276 |
+
async def _async_stream_text_generation_response(
|
277 |
+
bytes_output_as_lines: AsyncIterable[bytes], details: bool
|
278 |
+
) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamResponse]]:
|
279 |
+
# Parse ServerSentEvents
|
280 |
+
async for byte_payload in bytes_output_as_lines:
|
281 |
+
# Skip line
|
282 |
+
if byte_payload == b"\n":
|
283 |
+
continue
|
284 |
+
|
285 |
+
payload = byte_payload.decode("utf-8")
|
286 |
+
|
287 |
+
# Event data
|
288 |
+
if payload.startswith("data:"):
|
289 |
+
# Decode payload
|
290 |
+
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
291 |
+
# Either an error as being returned
|
292 |
+
if json_payload.get("error") is not None:
|
293 |
+
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
|
294 |
+
# Or parse token payload
|
295 |
+
output = TextGenerationStreamResponse(**json_payload)
|
296 |
+
yield output.token.text if not details else output
|
297 |
+
|
298 |
+
|
299 |
+
async def _async_yield_from(client: "ClientSession", response: "ClientResponse") -> AsyncIterable[bytes]:
|
300 |
+
async for byte_payload in response.content:
|
301 |
+
yield byte_payload
|
302 |
+
await client.close()
|
303 |
+
|
304 |
+
|
305 |
+
# "TGI servers" are servers running with the `text-generation-inference` backend.
|
306 |
+
# This backend is the go-to solution to run large language models at scale. However,
|
307 |
+
# for some smaller models (e.g. "gpt2") the default `transformers` + `api-inference`
|
308 |
+
# solution is still in use.
|
309 |
+
#
|
310 |
+
# Both approaches have very similar APIs, but not exactly the same. What we do first in
|
311 |
+
# the `text_generation` method is to assume the model is served via TGI. If we realize
|
312 |
+
# it's not the case (i.e. we receive an HTTP 400 Bad Request), we fallback to the
|
313 |
+
# default API with a warning message. We remember for each model if it's a TGI server
|
314 |
+
# or not using `_NON_TGI_SERVERS` global variable.
|
315 |
+
#
|
316 |
+
# For more details, see https://github.com/huggingface/text-generation-inference and
|
317 |
+
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task.
|
318 |
+
|
319 |
+
_NON_TGI_SERVERS: Set[Optional[str]] = set()
|
320 |
+
|
321 |
+
|
322 |
+
def _set_as_non_tgi(model: Optional[str]) -> None:
|
323 |
+
_NON_TGI_SERVERS.add(model)
|
324 |
+
|
325 |
+
|
326 |
+
def _is_tgi_server(model: Optional[str]) -> bool:
|
327 |
+
return model not in _NON_TGI_SERVERS
|
lib/python3.11/site-packages/huggingface_hub/inference/_generated/__init__.py
ADDED
File without changes
|
lib/python3.11/site-packages/huggingface_hub/inference/_generated/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (251 Bytes). View file
|
|
lib/python3.11/site-packages/huggingface_hub/inference/_generated/__pycache__/_async_client.cpython-311.pyc
ADDED
Binary file (96.9 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/inference/_generated/_async_client.py
ADDED
@@ -0,0 +1,2020 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023-present, the HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
# WARNING
|
17 |
+
# This entire file has been adapted from the sync-client code in `src/huggingface_hub/inference/_client.py`.
|
18 |
+
# Any change in InferenceClient will be automatically reflected in AsyncInferenceClient.
|
19 |
+
# To re-generate the code, run `make style` or `python ./utils/generate_async_inference_client.py --update`.
|
20 |
+
# WARNING
|
21 |
+
import asyncio
|
22 |
+
import logging
|
23 |
+
import time
|
24 |
+
import warnings
|
25 |
+
from dataclasses import asdict
|
26 |
+
from typing import (
|
27 |
+
TYPE_CHECKING,
|
28 |
+
Any,
|
29 |
+
AsyncIterable,
|
30 |
+
Dict,
|
31 |
+
List,
|
32 |
+
Literal,
|
33 |
+
Optional,
|
34 |
+
Union,
|
35 |
+
overload,
|
36 |
+
)
|
37 |
+
|
38 |
+
from requests.structures import CaseInsensitiveDict
|
39 |
+
|
40 |
+
from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
|
41 |
+
from huggingface_hub.inference._common import (
|
42 |
+
TASKS_EXPECTING_IMAGES,
|
43 |
+
ContentT,
|
44 |
+
InferenceTimeoutError,
|
45 |
+
ModelStatus,
|
46 |
+
_async_stream_text_generation_response,
|
47 |
+
_b64_encode,
|
48 |
+
_b64_to_image,
|
49 |
+
_bytes_to_dict,
|
50 |
+
_bytes_to_image,
|
51 |
+
_bytes_to_list,
|
52 |
+
_fetch_recommended_models,
|
53 |
+
_import_numpy,
|
54 |
+
_is_tgi_server,
|
55 |
+
_open_as_binary,
|
56 |
+
_set_as_non_tgi,
|
57 |
+
)
|
58 |
+
from huggingface_hub.inference._text_generation import (
|
59 |
+
TextGenerationParameters,
|
60 |
+
TextGenerationRequest,
|
61 |
+
TextGenerationResponse,
|
62 |
+
TextGenerationStreamResponse,
|
63 |
+
raise_text_generation_error,
|
64 |
+
)
|
65 |
+
from huggingface_hub.inference._types import (
|
66 |
+
ClassificationOutput,
|
67 |
+
ConversationalOutput,
|
68 |
+
FillMaskOutput,
|
69 |
+
ImageSegmentationOutput,
|
70 |
+
ObjectDetectionOutput,
|
71 |
+
QuestionAnsweringOutput,
|
72 |
+
TableQuestionAnsweringOutput,
|
73 |
+
TokenClassificationOutput,
|
74 |
+
)
|
75 |
+
from huggingface_hub.utils import (
|
76 |
+
build_hf_headers,
|
77 |
+
)
|
78 |
+
|
79 |
+
from .._common import _async_yield_from, _import_aiohttp
|
80 |
+
|
81 |
+
|
82 |
+
if TYPE_CHECKING:
|
83 |
+
import numpy as np
|
84 |
+
from PIL import Image
|
85 |
+
|
86 |
+
logger = logging.getLogger(__name__)
|
87 |
+
|
88 |
+
|
89 |
+
class AsyncInferenceClient:
|
90 |
+
"""
|
91 |
+
Initialize a new Inference Client.
|
92 |
+
|
93 |
+
[`InferenceClient`] aims to provide a unified experience to perform inference. The client can be used
|
94 |
+
seamlessly with either the (free) Inference API or self-hosted Inference Endpoints.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
model (`str`, `optional`):
|
98 |
+
The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `bigcode/starcoder`
|
99 |
+
or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
|
100 |
+
automatically selected for the task.
|
101 |
+
token (`str`, *optional*):
|
102 |
+
Hugging Face token. Will default to the locally saved token. Pass `token=False` if you don't want to send
|
103 |
+
your token to the server.
|
104 |
+
timeout (`float`, `optional`):
|
105 |
+
The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
|
106 |
+
API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
|
107 |
+
headers (`Dict[str, str]`, `optional`):
|
108 |
+
Additional headers to send to the server. By default only the authorization and user-agent headers are sent.
|
109 |
+
Values in this dictionary will override the default values.
|
110 |
+
cookies (`Dict[str, str]`, `optional`):
|
111 |
+
Additional cookies to send to the server.
|
112 |
+
"""
|
113 |
+
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
model: Optional[str] = None,
|
117 |
+
token: Union[str, bool, None] = None,
|
118 |
+
timeout: Optional[float] = None,
|
119 |
+
headers: Optional[Dict[str, str]] = None,
|
120 |
+
cookies: Optional[Dict[str, str]] = None,
|
121 |
+
) -> None:
|
122 |
+
self.model: Optional[str] = model
|
123 |
+
self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) # contains 'authorization' + 'user-agent'
|
124 |
+
if headers is not None:
|
125 |
+
self.headers.update(headers)
|
126 |
+
self.cookies = cookies
|
127 |
+
self.timeout = timeout
|
128 |
+
|
129 |
+
def __repr__(self):
|
130 |
+
return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"
|
131 |
+
|
132 |
+
@overload
|
133 |
+
async def post( # type: ignore[misc]
|
134 |
+
self,
|
135 |
+
*,
|
136 |
+
json: Optional[Union[str, Dict, List]] = None,
|
137 |
+
data: Optional[ContentT] = None,
|
138 |
+
model: Optional[str] = None,
|
139 |
+
task: Optional[str] = None,
|
140 |
+
stream: Literal[False] = ...,
|
141 |
+
) -> bytes:
|
142 |
+
pass
|
143 |
+
|
144 |
+
@overload
|
145 |
+
async def post(
|
146 |
+
self,
|
147 |
+
*,
|
148 |
+
json: Optional[Union[str, Dict, List]] = None,
|
149 |
+
data: Optional[ContentT] = None,
|
150 |
+
model: Optional[str] = None,
|
151 |
+
task: Optional[str] = None,
|
152 |
+
stream: Literal[True] = ...,
|
153 |
+
) -> AsyncIterable[bytes]:
|
154 |
+
pass
|
155 |
+
|
156 |
+
async def post(
|
157 |
+
self,
|
158 |
+
*,
|
159 |
+
json: Optional[Union[str, Dict, List]] = None,
|
160 |
+
data: Optional[ContentT] = None,
|
161 |
+
model: Optional[str] = None,
|
162 |
+
task: Optional[str] = None,
|
163 |
+
stream: bool = False,
|
164 |
+
) -> Union[bytes, AsyncIterable[bytes]]:
|
165 |
+
"""
|
166 |
+
Make a POST request to the inference server.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
json (`Union[str, Dict, List]`, *optional*):
|
170 |
+
The JSON data to send in the request body, specific to each task. Defaults to None.
|
171 |
+
data (`Union[str, Path, bytes, BinaryIO]`, *optional*):
|
172 |
+
The content to send in the request body, specific to each task.
|
173 |
+
It can be raw bytes, a pointer to an opened file, a local file path,
|
174 |
+
or a URL to an online resource (image, audio file,...). If both `json` and `data` are passed,
|
175 |
+
`data` will take precedence. At least `json` or `data` must be provided. Defaults to None.
|
176 |
+
model (`str`, *optional*):
|
177 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
178 |
+
Inference Endpoint. Will override the model defined at the instance level. Defaults to None.
|
179 |
+
task (`str`, *optional*):
|
180 |
+
The task to perform on the inference. All available tasks can be found
|
181 |
+
[here](https://huggingface.co/tasks). Used only to default to a recommended model if `model` is not
|
182 |
+
provided. At least `model` or `task` must be provided. Defaults to None.
|
183 |
+
stream (`bool`, *optional*):
|
184 |
+
Whether to iterate over streaming APIs.
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
bytes: The raw bytes returned by the server.
|
188 |
+
|
189 |
+
Raises:
|
190 |
+
[`InferenceTimeoutError`]:
|
191 |
+
If the model is unavailable or the request times out.
|
192 |
+
`aiohttp.ClientResponseError`:
|
193 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
194 |
+
"""
|
195 |
+
|
196 |
+
aiohttp = _import_aiohttp()
|
197 |
+
|
198 |
+
url = self._resolve_url(model, task)
|
199 |
+
|
200 |
+
if data is not None and json is not None:
|
201 |
+
warnings.warn("Ignoring `json` as `data` is passed as binary.")
|
202 |
+
|
203 |
+
# Set Accept header if relevant
|
204 |
+
headers = self.headers.copy()
|
205 |
+
if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers:
|
206 |
+
headers["Accept"] = "image/png"
|
207 |
+
|
208 |
+
t0 = time.time()
|
209 |
+
timeout = self.timeout
|
210 |
+
while True:
|
211 |
+
with _open_as_binary(data) as data_as_binary:
|
212 |
+
# Do not use context manager as we don't want to close the connection immediately when returning
|
213 |
+
# a stream
|
214 |
+
client = aiohttp.ClientSession(
|
215 |
+
headers=headers, cookies=self.cookies, timeout=aiohttp.ClientTimeout(self.timeout)
|
216 |
+
)
|
217 |
+
|
218 |
+
try:
|
219 |
+
response = await client.post(url, json=json, data=data_as_binary)
|
220 |
+
response_error_payload = None
|
221 |
+
if response.status != 200:
|
222 |
+
try:
|
223 |
+
response_error_payload = await response.json() # get payload before connection closed
|
224 |
+
except Exception:
|
225 |
+
pass
|
226 |
+
response.raise_for_status()
|
227 |
+
if stream:
|
228 |
+
return _async_yield_from(client, response)
|
229 |
+
else:
|
230 |
+
content = await response.read()
|
231 |
+
await client.close()
|
232 |
+
return content
|
233 |
+
except asyncio.TimeoutError as error:
|
234 |
+
await client.close()
|
235 |
+
# Convert any `TimeoutError` to a `InferenceTimeoutError`
|
236 |
+
raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore
|
237 |
+
except aiohttp.ClientResponseError as error:
|
238 |
+
error.response_error_payload = response_error_payload
|
239 |
+
await client.close()
|
240 |
+
if response.status == 422 and task is not None:
|
241 |
+
error.message += f". Make sure '{task}' task is supported by the model."
|
242 |
+
if response.status == 503:
|
243 |
+
# If Model is unavailable, either raise a TimeoutError...
|
244 |
+
if timeout is not None and time.time() - t0 > timeout:
|
245 |
+
raise InferenceTimeoutError(
|
246 |
+
f"Model not loaded on the server: {url}. Please retry with a higher timeout"
|
247 |
+
f" (current: {self.timeout}).",
|
248 |
+
request=error.request,
|
249 |
+
response=error.response,
|
250 |
+
) from error
|
251 |
+
# ...or wait 1s and retry
|
252 |
+
logger.info(f"Waiting for model to be loaded on the server: {error}")
|
253 |
+
time.sleep(1)
|
254 |
+
if timeout is not None:
|
255 |
+
timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
|
256 |
+
continue
|
257 |
+
raise error
|
258 |
+
|
259 |
+
async def audio_classification(
|
260 |
+
self,
|
261 |
+
audio: ContentT,
|
262 |
+
*,
|
263 |
+
model: Optional[str] = None,
|
264 |
+
) -> List[ClassificationOutput]:
|
265 |
+
"""
|
266 |
+
Perform audio classification on the provided audio content.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
audio (Union[str, Path, bytes, BinaryIO]):
|
270 |
+
The audio content to classify. It can be raw audio bytes, a local audio file, or a URL pointing to an
|
271 |
+
audio file.
|
272 |
+
model (`str`, *optional*):
|
273 |
+
The model to use for audio classification. Can be a model ID hosted on the Hugging Face Hub
|
274 |
+
or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for
|
275 |
+
audio classification will be used.
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
`List[Dict]`: The classification output containing the predicted label and its confidence.
|
279 |
+
|
280 |
+
Raises:
|
281 |
+
[`InferenceTimeoutError`]:
|
282 |
+
If the model is unavailable or the request times out.
|
283 |
+
`aiohttp.ClientResponseError`:
|
284 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
285 |
+
|
286 |
+
Example:
|
287 |
+
```py
|
288 |
+
# Must be run in an async context
|
289 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
290 |
+
>>> client = AsyncInferenceClient()
|
291 |
+
>>> await client.audio_classification("audio.flac")
|
292 |
+
[{'score': 0.4976358711719513, 'label': 'hap'}, {'score': 0.3677836060523987, 'label': 'neu'},...]
|
293 |
+
```
|
294 |
+
"""
|
295 |
+
response = await self.post(data=audio, model=model, task="audio-classification")
|
296 |
+
return _bytes_to_list(response)
|
297 |
+
|
298 |
+
async def automatic_speech_recognition(
|
299 |
+
self,
|
300 |
+
audio: ContentT,
|
301 |
+
*,
|
302 |
+
model: Optional[str] = None,
|
303 |
+
) -> str:
|
304 |
+
"""
|
305 |
+
Perform automatic speech recognition (ASR or audio-to-text) on the given audio content.
|
306 |
+
|
307 |
+
Args:
|
308 |
+
audio (Union[str, Path, bytes, BinaryIO]):
|
309 |
+
The content to transcribe. It can be raw audio bytes, local audio file, or a URL to an audio file.
|
310 |
+
model (`str`, *optional*):
|
311 |
+
The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
312 |
+
Inference Endpoint. If not provided, the default recommended model for ASR will be used.
|
313 |
+
|
314 |
+
Returns:
|
315 |
+
str: The transcribed text.
|
316 |
+
|
317 |
+
Raises:
|
318 |
+
[`InferenceTimeoutError`]:
|
319 |
+
If the model is unavailable or the request times out.
|
320 |
+
`aiohttp.ClientResponseError`:
|
321 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
322 |
+
|
323 |
+
Example:
|
324 |
+
```py
|
325 |
+
# Must be run in an async context
|
326 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
327 |
+
>>> client = AsyncInferenceClient()
|
328 |
+
>>> await client.automatic_speech_recognition("hello_world.flac")
|
329 |
+
"hello world"
|
330 |
+
```
|
331 |
+
"""
|
332 |
+
response = await self.post(data=audio, model=model, task="automatic-speech-recognition")
|
333 |
+
return _bytes_to_dict(response)["text"]
|
334 |
+
|
335 |
+
async def conversational(
|
336 |
+
self,
|
337 |
+
text: str,
|
338 |
+
generated_responses: Optional[List[str]] = None,
|
339 |
+
past_user_inputs: Optional[List[str]] = None,
|
340 |
+
*,
|
341 |
+
parameters: Optional[Dict[str, Any]] = None,
|
342 |
+
model: Optional[str] = None,
|
343 |
+
) -> ConversationalOutput:
|
344 |
+
"""
|
345 |
+
Generate conversational responses based on the given input text (i.e. chat with the API).
|
346 |
+
|
347 |
+
Args:
|
348 |
+
text (`str`):
|
349 |
+
The last input from the user in the conversation.
|
350 |
+
generated_responses (`List[str]`, *optional*):
|
351 |
+
A list of strings corresponding to the earlier replies from the model. Defaults to None.
|
352 |
+
past_user_inputs (`List[str]`, *optional*):
|
353 |
+
A list of strings corresponding to the earlier replies from the user. Should be the same length as
|
354 |
+
`generated_responses`. Defaults to None.
|
355 |
+
parameters (`Dict[str, Any]`, *optional*):
|
356 |
+
Additional parameters for the conversational task. Defaults to None. For more details about the available
|
357 |
+
parameters, please refer to [this page](https://huggingface.co/docs/api-inference/detailed_parameters#conversational-task)
|
358 |
+
model (`str`, *optional*):
|
359 |
+
The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
360 |
+
a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
|
361 |
+
Defaults to None.
|
362 |
+
|
363 |
+
Returns:
|
364 |
+
`Dict`: The generated conversational output.
|
365 |
+
|
366 |
+
Raises:
|
367 |
+
[`InferenceTimeoutError`]:
|
368 |
+
If the model is unavailable or the request times out.
|
369 |
+
`aiohttp.ClientResponseError`:
|
370 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
371 |
+
|
372 |
+
Example:
|
373 |
+
```py
|
374 |
+
# Must be run in an async context
|
375 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
376 |
+
>>> client = AsyncInferenceClient()
|
377 |
+
>>> output = await client.conversational("Hi, who are you?")
|
378 |
+
>>> output
|
379 |
+
{'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256 async for open-end generation.']}
|
380 |
+
>>> await client.conversational(
|
381 |
+
... "Wow, that's scary!",
|
382 |
+
... generated_responses=output["conversation"]["generated_responses"],
|
383 |
+
... past_user_inputs=output["conversation"]["past_user_inputs"],
|
384 |
+
... )
|
385 |
+
```
|
386 |
+
"""
|
387 |
+
payload: Dict[str, Any] = {"inputs": {"text": text}}
|
388 |
+
if generated_responses is not None:
|
389 |
+
payload["inputs"]["generated_responses"] = generated_responses
|
390 |
+
if past_user_inputs is not None:
|
391 |
+
payload["inputs"]["past_user_inputs"] = past_user_inputs
|
392 |
+
if parameters is not None:
|
393 |
+
payload["parameters"] = parameters
|
394 |
+
response = await self.post(json=payload, model=model, task="conversational")
|
395 |
+
return _bytes_to_dict(response) # type: ignore
|
396 |
+
|
397 |
+
async def visual_question_answering(
|
398 |
+
self,
|
399 |
+
image: ContentT,
|
400 |
+
question: str,
|
401 |
+
*,
|
402 |
+
model: Optional[str] = None,
|
403 |
+
) -> List[str]:
|
404 |
+
"""
|
405 |
+
Answering open-ended questions based on an image.
|
406 |
+
|
407 |
+
Args:
|
408 |
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
409 |
+
The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
|
410 |
+
question (`str`):
|
411 |
+
Question to be answered.
|
412 |
+
model (`str`, *optional*):
|
413 |
+
The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
414 |
+
a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used.
|
415 |
+
Defaults to None.
|
416 |
+
|
417 |
+
Returns:
|
418 |
+
`List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
|
419 |
+
|
420 |
+
Raises:
|
421 |
+
`InferenceTimeoutError`:
|
422 |
+
If the model is unavailable or the request times out.
|
423 |
+
`aiohttp.ClientResponseError`:
|
424 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
425 |
+
|
426 |
+
Example:
|
427 |
+
```py
|
428 |
+
# Must be run in an async context
|
429 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
430 |
+
>>> client = AsyncInferenceClient()
|
431 |
+
>>> await client.visual_question_answering(
|
432 |
+
... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg",
|
433 |
+
... question="What is the animal doing?"
|
434 |
+
... )
|
435 |
+
[{'score': 0.778609573841095, 'answer': 'laying down'},{'score': 0.6957435607910156, 'answer': 'sitting'}, ...]
|
436 |
+
```
|
437 |
+
"""
|
438 |
+
payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
|
439 |
+
response = await self.post(json=payload, model=model, task="visual-question-answering")
|
440 |
+
return _bytes_to_list(response)
|
441 |
+
|
442 |
+
async def document_question_answering(
|
443 |
+
self,
|
444 |
+
image: ContentT,
|
445 |
+
question: str,
|
446 |
+
*,
|
447 |
+
model: Optional[str] = None,
|
448 |
+
) -> List[QuestionAnsweringOutput]:
|
449 |
+
"""
|
450 |
+
Answer questions on document images.
|
451 |
+
|
452 |
+
Args:
|
453 |
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
454 |
+
The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
|
455 |
+
question (`str`):
|
456 |
+
Question to be answered.
|
457 |
+
model (`str`, *optional*):
|
458 |
+
The model to use for the document question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
459 |
+
a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used.
|
460 |
+
Defaults to None.
|
461 |
+
|
462 |
+
Returns:
|
463 |
+
`List[Dict]`: a list of dictionaries containing the predicted label, associated probability, word ids, and page number.
|
464 |
+
|
465 |
+
Raises:
|
466 |
+
[`InferenceTimeoutError`]:
|
467 |
+
If the model is unavailable or the request times out.
|
468 |
+
`aiohttp.ClientResponseError`:
|
469 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
470 |
+
|
471 |
+
Example:
|
472 |
+
```py
|
473 |
+
# Must be run in an async context
|
474 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
475 |
+
>>> client = AsyncInferenceClient()
|
476 |
+
>>> await client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?")
|
477 |
+
[{'score': 0.42515629529953003, 'answer': 'us-001', 'start': 16, 'end': 16}]
|
478 |
+
```
|
479 |
+
"""
|
480 |
+
payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
|
481 |
+
response = await self.post(json=payload, model=model, task="document-question-answering")
|
482 |
+
return _bytes_to_list(response)
|
483 |
+
|
484 |
+
async def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray":
|
485 |
+
"""
|
486 |
+
Generate embeddings for a given text.
|
487 |
+
|
488 |
+
Args:
|
489 |
+
text (`str`):
|
490 |
+
The text to embed.
|
491 |
+
model (`str`, *optional*):
|
492 |
+
The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
493 |
+
a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
|
494 |
+
Defaults to None.
|
495 |
+
|
496 |
+
Returns:
|
497 |
+
`np.ndarray`: The embedding representing the input text as a float32 numpy array.
|
498 |
+
|
499 |
+
Raises:
|
500 |
+
[`InferenceTimeoutError`]:
|
501 |
+
If the model is unavailable or the request times out.
|
502 |
+
`aiohttp.ClientResponseError`:
|
503 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
504 |
+
|
505 |
+
Example:
|
506 |
+
```py
|
507 |
+
# Must be run in an async context
|
508 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
509 |
+
>>> client = AsyncInferenceClient()
|
510 |
+
>>> await client.feature_extraction("Hi, who are you?")
|
511 |
+
array([[ 2.424802 , 2.93384 , 1.1750331 , ..., 1.240499, -0.13776633, -0.7889173 ],
|
512 |
+
[-0.42943227, -0.6364878 , -1.693462 , ..., 0.41978157, -2.4336355 , 0.6162071 ],
|
513 |
+
...,
|
514 |
+
[ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
|
515 |
+
```
|
516 |
+
"""
|
517 |
+
response = await self.post(json={"inputs": text}, model=model, task="feature-extraction")
|
518 |
+
np = _import_numpy()
|
519 |
+
return np.array(_bytes_to_dict(response), dtype="float32")
|
520 |
+
|
521 |
+
async def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[FillMaskOutput]:
|
522 |
+
"""
|
523 |
+
Fill in a hole with a missing word (token to be precise).
|
524 |
+
|
525 |
+
Args:
|
526 |
+
text (`str`):
|
527 |
+
a string to be filled from, must contain the [MASK] token (check model card for exact name of the mask).
|
528 |
+
model (`str`, *optional*):
|
529 |
+
The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
530 |
+
a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used.
|
531 |
+
Defaults to None.
|
532 |
+
|
533 |
+
Returns:
|
534 |
+
`List[Dict]`: a list of fill mask output dictionaries containing the predicted label, associated
|
535 |
+
probability, token reference, and completed text.
|
536 |
+
|
537 |
+
Raises:
|
538 |
+
[`InferenceTimeoutError`]:
|
539 |
+
If the model is unavailable or the request times out.
|
540 |
+
`aiohttp.ClientResponseError`:
|
541 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
542 |
+
|
543 |
+
Example:
|
544 |
+
```py
|
545 |
+
# Must be run in an async context
|
546 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
547 |
+
>>> client = AsyncInferenceClient()
|
548 |
+
>>> await client.fill_mask("The goal of life is <mask>.")
|
549 |
+
[{'score': 0.06897063553333282,
|
550 |
+
'token': 11098,
|
551 |
+
'token_str': ' happiness',
|
552 |
+
'sequence': 'The goal of life is happiness.'},
|
553 |
+
{'score': 0.06554922461509705,
|
554 |
+
'token': 45075,
|
555 |
+
'token_str': ' immortality',
|
556 |
+
'sequence': 'The goal of life is immortality.'}]
|
557 |
+
```
|
558 |
+
"""
|
559 |
+
response = await self.post(json={"inputs": text}, model=model, task="fill-mask")
|
560 |
+
return _bytes_to_list(response)
|
561 |
+
|
562 |
+
async def image_classification(
|
563 |
+
self,
|
564 |
+
image: ContentT,
|
565 |
+
*,
|
566 |
+
model: Optional[str] = None,
|
567 |
+
) -> List[ClassificationOutput]:
|
568 |
+
"""
|
569 |
+
Perform image classification on the given image using the specified model.
|
570 |
+
|
571 |
+
Args:
|
572 |
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
573 |
+
The image to classify. It can be raw bytes, an image file, or a URL to an online image.
|
574 |
+
model (`str`, *optional*):
|
575 |
+
The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
576 |
+
deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used.
|
577 |
+
|
578 |
+
Returns:
|
579 |
+
`List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
|
580 |
+
|
581 |
+
Raises:
|
582 |
+
[`InferenceTimeoutError`]:
|
583 |
+
If the model is unavailable or the request times out.
|
584 |
+
`aiohttp.ClientResponseError`:
|
585 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
586 |
+
|
587 |
+
Example:
|
588 |
+
```py
|
589 |
+
# Must be run in an async context
|
590 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
591 |
+
>>> client = AsyncInferenceClient()
|
592 |
+
>>> await client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg")
|
593 |
+
[{'score': 0.9779096841812134, 'label': 'Blenheim spaniel'}, ...]
|
594 |
+
```
|
595 |
+
"""
|
596 |
+
response = await self.post(data=image, model=model, task="image-classification")
|
597 |
+
return _bytes_to_list(response)
|
598 |
+
|
599 |
+
async def image_segmentation(
|
600 |
+
self,
|
601 |
+
image: ContentT,
|
602 |
+
*,
|
603 |
+
model: Optional[str] = None,
|
604 |
+
) -> List[ImageSegmentationOutput]:
|
605 |
+
"""
|
606 |
+
Perform image segmentation on the given image using the specified model.
|
607 |
+
|
608 |
+
<Tip warning={true}>
|
609 |
+
|
610 |
+
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
611 |
+
|
612 |
+
</Tip>
|
613 |
+
|
614 |
+
Args:
|
615 |
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
616 |
+
The image to segment. It can be raw bytes, an image file, or a URL to an online image.
|
617 |
+
model (`str`, *optional*):
|
618 |
+
The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
619 |
+
deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used.
|
620 |
+
|
621 |
+
Returns:
|
622 |
+
`List[Dict]`: A list of dictionaries containing the segmented masks and associated attributes.
|
623 |
+
|
624 |
+
Raises:
|
625 |
+
[`InferenceTimeoutError`]:
|
626 |
+
If the model is unavailable or the request times out.
|
627 |
+
`aiohttp.ClientResponseError`:
|
628 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
629 |
+
|
630 |
+
Example:
|
631 |
+
```py
|
632 |
+
# Must be run in an async context
|
633 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
634 |
+
>>> client = AsyncInferenceClient()
|
635 |
+
>>> await client.image_segmentation("cat.jpg"):
|
636 |
+
[{'score': 0.989008, 'label': 'LABEL_184', 'mask': <PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>}, ...]
|
637 |
+
```
|
638 |
+
"""
|
639 |
+
|
640 |
+
# Segment
|
641 |
+
response = await self.post(data=image, model=model, task="image-segmentation")
|
642 |
+
output = _bytes_to_dict(response)
|
643 |
+
|
644 |
+
# Parse masks as PIL Image
|
645 |
+
if not isinstance(output, list):
|
646 |
+
raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
|
647 |
+
for item in output:
|
648 |
+
item["mask"] = _b64_to_image(item["mask"])
|
649 |
+
return output
|
650 |
+
|
651 |
+
async def image_to_image(
|
652 |
+
self,
|
653 |
+
image: ContentT,
|
654 |
+
prompt: Optional[str] = None,
|
655 |
+
*,
|
656 |
+
negative_prompt: Optional[str] = None,
|
657 |
+
height: Optional[int] = None,
|
658 |
+
width: Optional[int] = None,
|
659 |
+
num_inference_steps: Optional[int] = None,
|
660 |
+
guidance_scale: Optional[float] = None,
|
661 |
+
model: Optional[str] = None,
|
662 |
+
**kwargs,
|
663 |
+
) -> "Image":
|
664 |
+
"""
|
665 |
+
Perform image-to-image translation using a specified model.
|
666 |
+
|
667 |
+
<Tip warning={true}>
|
668 |
+
|
669 |
+
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
670 |
+
|
671 |
+
</Tip>
|
672 |
+
|
673 |
+
Args:
|
674 |
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
675 |
+
The input image for translation. It can be raw bytes, an image file, or a URL to an online image.
|
676 |
+
prompt (`str`, *optional*):
|
677 |
+
The text prompt to guide the image generation.
|
678 |
+
negative_prompt (`str`, *optional*):
|
679 |
+
A negative prompt to guide the translation process.
|
680 |
+
height (`int`, *optional*):
|
681 |
+
The height in pixels of the generated image.
|
682 |
+
width (`int`, *optional*):
|
683 |
+
The width in pixels of the generated image.
|
684 |
+
num_inference_steps (`int`, *optional*):
|
685 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
686 |
+
expense of slower inference.
|
687 |
+
guidance_scale (`float`, *optional*):
|
688 |
+
Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
689 |
+
usually at the expense of lower image quality.
|
690 |
+
model (`str`, *optional*):
|
691 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
692 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
693 |
+
|
694 |
+
Returns:
|
695 |
+
`Image`: The translated image.
|
696 |
+
|
697 |
+
Raises:
|
698 |
+
[`InferenceTimeoutError`]:
|
699 |
+
If the model is unavailable or the request times out.
|
700 |
+
`aiohttp.ClientResponseError`:
|
701 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
702 |
+
|
703 |
+
Example:
|
704 |
+
```py
|
705 |
+
# Must be run in an async context
|
706 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
707 |
+
>>> client = AsyncInferenceClient()
|
708 |
+
>>> image = await client.image_to_image("cat.jpg", prompt="turn the cat into a tiger")
|
709 |
+
>>> image.save("tiger.jpg")
|
710 |
+
```
|
711 |
+
"""
|
712 |
+
parameters = {
|
713 |
+
"prompt": prompt,
|
714 |
+
"negative_prompt": negative_prompt,
|
715 |
+
"height": height,
|
716 |
+
"width": width,
|
717 |
+
"num_inference_steps": num_inference_steps,
|
718 |
+
"guidance_scale": guidance_scale,
|
719 |
+
**kwargs,
|
720 |
+
}
|
721 |
+
if all(parameter is None for parameter in parameters.values()):
|
722 |
+
# Either only an image to send => send as raw bytes
|
723 |
+
data = image
|
724 |
+
payload: Optional[Dict[str, Any]] = None
|
725 |
+
else:
|
726 |
+
# Or an image + some parameters => use base64 encoding
|
727 |
+
data = None
|
728 |
+
payload = {"inputs": _b64_encode(image)}
|
729 |
+
for key, value in parameters.items():
|
730 |
+
if value is not None:
|
731 |
+
payload.setdefault("parameters", {})[key] = value
|
732 |
+
|
733 |
+
response = await self.post(json=payload, data=data, model=model, task="image-to-image")
|
734 |
+
return _bytes_to_image(response)
|
735 |
+
|
736 |
+
async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> str:
|
737 |
+
"""
|
738 |
+
Takes an input image and return text.
|
739 |
+
|
740 |
+
Models can have very different outputs depending on your use case (image captioning, optical character recognition
|
741 |
+
(OCR), Pix2Struct, etc). Please have a look to the model card to learn more about a model's specificities.
|
742 |
+
|
743 |
+
Args:
|
744 |
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
745 |
+
The input image to caption. It can be raw bytes, an image file, or a URL to an online image..
|
746 |
+
model (`str`, *optional*):
|
747 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
748 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
749 |
+
|
750 |
+
Returns:
|
751 |
+
`str`: The generated text.
|
752 |
+
|
753 |
+
Raises:
|
754 |
+
[`InferenceTimeoutError`]:
|
755 |
+
If the model is unavailable or the request times out.
|
756 |
+
`aiohttp.ClientResponseError`:
|
757 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
758 |
+
|
759 |
+
Example:
|
760 |
+
```py
|
761 |
+
# Must be run in an async context
|
762 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
763 |
+
>>> client = AsyncInferenceClient()
|
764 |
+
>>> await client.image_to_text("cat.jpg")
|
765 |
+
'a cat standing in a grassy field '
|
766 |
+
>>> await client.image_to_text("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg")
|
767 |
+
'a dog laying on the grass next to a flower pot '
|
768 |
+
```
|
769 |
+
"""
|
770 |
+
response = await self.post(data=image, model=model, task="image-to-text")
|
771 |
+
return _bytes_to_dict(response)[0]["generated_text"]
|
772 |
+
|
773 |
+
async def list_deployed_models(
|
774 |
+
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
|
775 |
+
) -> Dict[str, List[str]]:
|
776 |
+
"""
|
777 |
+
List models currently deployed on the Inference API service.
|
778 |
+
|
779 |
+
This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
|
780 |
+
are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
|
781 |
+
specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
|
782 |
+
in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
|
783 |
+
frameworks are checked, the more time it will take.
|
784 |
+
|
785 |
+
<Tip>
|
786 |
+
|
787 |
+
This endpoint is mostly useful for discoverability. If you already know which model you want to use and want to
|
788 |
+
check its availability, you can directly use [`~InferenceClient.get_model_status`].
|
789 |
+
|
790 |
+
</Tip>
|
791 |
+
|
792 |
+
Args:
|
793 |
+
frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
|
794 |
+
The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
|
795 |
+
"all", all available frameworks will be tested. It is also possible to provide a single framework or a
|
796 |
+
custom set of frameworks to check.
|
797 |
+
|
798 |
+
Returns:
|
799 |
+
`Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
|
800 |
+
|
801 |
+
Example:
|
802 |
+
```py
|
803 |
+
# Must be run in an async contextthon
|
804 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
805 |
+
>>> client = AsyncInferenceClient()
|
806 |
+
|
807 |
+
# Discover zero-shot-classification models currently deployed
|
808 |
+
>>> models = await client.list_deployed_models()
|
809 |
+
>>> models["zero-shot-classification"]
|
810 |
+
['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
|
811 |
+
|
812 |
+
# List from only 1 framework
|
813 |
+
>>> await client.list_deployed_models("text-generation-inference")
|
814 |
+
{'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
|
815 |
+
```
|
816 |
+
"""
|
817 |
+
# Resolve which frameworks to check
|
818 |
+
if frameworks is None:
|
819 |
+
frameworks = MAIN_INFERENCE_API_FRAMEWORKS
|
820 |
+
elif frameworks == "all":
|
821 |
+
frameworks = ALL_INFERENCE_API_FRAMEWORKS
|
822 |
+
elif isinstance(frameworks, str):
|
823 |
+
frameworks = [frameworks]
|
824 |
+
frameworks = list(set(frameworks))
|
825 |
+
|
826 |
+
# Fetch them iteratively
|
827 |
+
models_by_task: Dict[str, List[str]] = {}
|
828 |
+
|
829 |
+
def _unpack_response(framework: str, items: List[Dict]) -> None:
|
830 |
+
for model in items:
|
831 |
+
if framework == "sentence-transformers":
|
832 |
+
# Model running with the `sentence-transformers` framework can work with both tasks even if not
|
833 |
+
# branded as such in the API response
|
834 |
+
models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
|
835 |
+
models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
|
836 |
+
else:
|
837 |
+
models_by_task.setdefault(model["task"], []).append(model["model_id"])
|
838 |
+
|
839 |
+
async def _fetch_framework(framework: str) -> None:
|
840 |
+
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
|
841 |
+
response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}")
|
842 |
+
response.raise_for_status()
|
843 |
+
_unpack_response(framework, await response.json())
|
844 |
+
|
845 |
+
import asyncio
|
846 |
+
|
847 |
+
await asyncio.gather(*[_fetch_framework(framework) for framework in frameworks])
|
848 |
+
|
849 |
+
# Sort alphabetically for discoverability and return
|
850 |
+
for task, models in models_by_task.items():
|
851 |
+
models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
|
852 |
+
return models_by_task
|
853 |
+
|
854 |
+
async def object_detection(
|
855 |
+
self,
|
856 |
+
image: ContentT,
|
857 |
+
*,
|
858 |
+
model: Optional[str] = None,
|
859 |
+
) -> List[ObjectDetectionOutput]:
|
860 |
+
"""
|
861 |
+
Perform object detection on the given image using the specified model.
|
862 |
+
|
863 |
+
<Tip warning={true}>
|
864 |
+
|
865 |
+
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
866 |
+
|
867 |
+
</Tip>
|
868 |
+
|
869 |
+
Args:
|
870 |
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
871 |
+
The image to detect objects on. It can be raw bytes, an image file, or a URL to an online image.
|
872 |
+
model (`str`, *optional*):
|
873 |
+
The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a
|
874 |
+
deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used.
|
875 |
+
|
876 |
+
Returns:
|
877 |
+
`List[ObjectDetectionOutput]`: A list of dictionaries containing the bounding boxes and associated attributes.
|
878 |
+
|
879 |
+
Raises:
|
880 |
+
[`InferenceTimeoutError`]:
|
881 |
+
If the model is unavailable or the request times out.
|
882 |
+
`aiohttp.ClientResponseError`:
|
883 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
884 |
+
`ValueError`:
|
885 |
+
If the request output is not a List.
|
886 |
+
|
887 |
+
Example:
|
888 |
+
```py
|
889 |
+
# Must be run in an async context
|
890 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
891 |
+
>>> client = AsyncInferenceClient()
|
892 |
+
>>> await client.object_detection("people.jpg"):
|
893 |
+
[{"score":0.9486683011054993,"label":"person","box":{"xmin":59,"ymin":39,"xmax":420,"ymax":510}}, ... ]
|
894 |
+
```
|
895 |
+
"""
|
896 |
+
# detect objects
|
897 |
+
response = await self.post(data=image, model=model, task="object-detection")
|
898 |
+
output = _bytes_to_dict(response)
|
899 |
+
if not isinstance(output, list):
|
900 |
+
raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
|
901 |
+
return output
|
902 |
+
|
903 |
+
async def question_answering(
|
904 |
+
self, question: str, context: str, *, model: Optional[str] = None
|
905 |
+
) -> QuestionAnsweringOutput:
|
906 |
+
"""
|
907 |
+
Retrieve the answer to a question from a given text.
|
908 |
+
|
909 |
+
Args:
|
910 |
+
question (`str`):
|
911 |
+
Question to be answered.
|
912 |
+
context (`str`):
|
913 |
+
The context of the question.
|
914 |
+
model (`str`):
|
915 |
+
The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
916 |
+
a deployed Inference Endpoint.
|
917 |
+
|
918 |
+
Returns:
|
919 |
+
`Dict`: a dictionary of question answering output containing the score, start index, end index, and answer.
|
920 |
+
|
921 |
+
Raises:
|
922 |
+
[`InferenceTimeoutError`]:
|
923 |
+
If the model is unavailable or the request times out.
|
924 |
+
`aiohttp.ClientResponseError`:
|
925 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
926 |
+
|
927 |
+
Example:
|
928 |
+
```py
|
929 |
+
# Must be run in an async context
|
930 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
931 |
+
>>> client = AsyncInferenceClient()
|
932 |
+
>>> await client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.")
|
933 |
+
{'score': 0.9326562285423279, 'start': 11, 'end': 16, 'answer': 'Clara'}
|
934 |
+
```
|
935 |
+
"""
|
936 |
+
|
937 |
+
payload: Dict[str, Any] = {"question": question, "context": context}
|
938 |
+
response = await self.post(
|
939 |
+
json=payload,
|
940 |
+
model=model,
|
941 |
+
task="question-answering",
|
942 |
+
)
|
943 |
+
return _bytes_to_dict(response) # type: ignore
|
944 |
+
|
945 |
+
async def sentence_similarity(
|
946 |
+
self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None
|
947 |
+
) -> List[float]:
|
948 |
+
"""
|
949 |
+
Compute the semantic similarity between a sentence and a list of other sentences by comparing their embeddings.
|
950 |
+
|
951 |
+
Args:
|
952 |
+
sentence (`str`):
|
953 |
+
The main sentence to compare to others.
|
954 |
+
other_sentences (`List[str]`):
|
955 |
+
The list of sentences to compare to.
|
956 |
+
model (`str`, *optional*):
|
957 |
+
The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
958 |
+
a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
|
959 |
+
Defaults to None.
|
960 |
+
|
961 |
+
Returns:
|
962 |
+
`List[float]`: The embedding representing the input text.
|
963 |
+
|
964 |
+
Raises:
|
965 |
+
[`InferenceTimeoutError`]:
|
966 |
+
If the model is unavailable or the request times out.
|
967 |
+
`aiohttp.ClientResponseError`:
|
968 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
969 |
+
|
970 |
+
Example:
|
971 |
+
```py
|
972 |
+
# Must be run in an async context
|
973 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
974 |
+
>>> client = AsyncInferenceClient()
|
975 |
+
>>> await client.sentence_similarity(
|
976 |
+
... "Machine learning is so easy.",
|
977 |
+
... other_sentences=[
|
978 |
+
... "Deep learning is so straightforward.",
|
979 |
+
... "This is so difficult, like rocket science.",
|
980 |
+
... "I can't believe how much I struggled with this.",
|
981 |
+
... ],
|
982 |
+
... )
|
983 |
+
[0.7785726189613342, 0.45876261591911316, 0.2906220555305481]
|
984 |
+
```
|
985 |
+
"""
|
986 |
+
response = await self.post(
|
987 |
+
json={"inputs": {"source_sentence": sentence, "sentences": other_sentences}},
|
988 |
+
model=model,
|
989 |
+
task="sentence-similarity",
|
990 |
+
)
|
991 |
+
return _bytes_to_list(response)
|
992 |
+
|
993 |
+
async def summarization(
|
994 |
+
self,
|
995 |
+
text: str,
|
996 |
+
*,
|
997 |
+
parameters: Optional[Dict[str, Any]] = None,
|
998 |
+
model: Optional[str] = None,
|
999 |
+
) -> str:
|
1000 |
+
"""
|
1001 |
+
Generate a summary of a given text using a specified model.
|
1002 |
+
|
1003 |
+
Args:
|
1004 |
+
text (`str`):
|
1005 |
+
The input text to summarize.
|
1006 |
+
parameters (`Dict[str, Any]`, *optional*):
|
1007 |
+
Additional parameters for summarization. Check out this [page](https://huggingface.co/docs/api-inference/detailed_parameters#summarization-task)
|
1008 |
+
for more details.
|
1009 |
+
model (`str`, *optional*):
|
1010 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
1011 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
1012 |
+
|
1013 |
+
Returns:
|
1014 |
+
`str`: The generated summary text.
|
1015 |
+
|
1016 |
+
Raises:
|
1017 |
+
[`InferenceTimeoutError`]:
|
1018 |
+
If the model is unavailable or the request times out.
|
1019 |
+
`aiohttp.ClientResponseError`:
|
1020 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1021 |
+
|
1022 |
+
Example:
|
1023 |
+
```py
|
1024 |
+
# Must be run in an async context
|
1025 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
1026 |
+
>>> client = AsyncInferenceClient()
|
1027 |
+
>>> await client.summarization("The Eiffel tower...")
|
1028 |
+
'The Eiffel tower is one of the most famous landmarks in the world....'
|
1029 |
+
```
|
1030 |
+
"""
|
1031 |
+
payload: Dict[str, Any] = {"inputs": text}
|
1032 |
+
if parameters is not None:
|
1033 |
+
payload["parameters"] = parameters
|
1034 |
+
response = await self.post(json=payload, model=model, task="summarization")
|
1035 |
+
return _bytes_to_dict(response)[0]["summary_text"]
|
1036 |
+
|
1037 |
+
async def table_question_answering(
|
1038 |
+
self, table: Dict[str, Any], query: str, *, model: Optional[str] = None
|
1039 |
+
) -> TableQuestionAnsweringOutput:
|
1040 |
+
"""
|
1041 |
+
Retrieve the answer to a question from information given in a table.
|
1042 |
+
|
1043 |
+
Args:
|
1044 |
+
table (`str`):
|
1045 |
+
A table of data represented as a dict of lists where entries are headers and the lists are all the
|
1046 |
+
values, all lists must have the same size.
|
1047 |
+
query (`str`):
|
1048 |
+
The query in plain text that you want to ask the table.
|
1049 |
+
model (`str`):
|
1050 |
+
The model to use for the table-question-answering task. Can be a model ID hosted on the Hugging Face
|
1051 |
+
Hub or a URL to a deployed Inference Endpoint.
|
1052 |
+
|
1053 |
+
Returns:
|
1054 |
+
`Dict`: a dictionary of table question answering output containing the answer, coordinates, cells and the aggregator used.
|
1055 |
+
|
1056 |
+
Raises:
|
1057 |
+
[`InferenceTimeoutError`]:
|
1058 |
+
If the model is unavailable or the request times out.
|
1059 |
+
`aiohttp.ClientResponseError`:
|
1060 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1061 |
+
|
1062 |
+
Example:
|
1063 |
+
```py
|
1064 |
+
# Must be run in an async context
|
1065 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
1066 |
+
>>> client = AsyncInferenceClient()
|
1067 |
+
>>> query = "How many stars does the transformers repository have?"
|
1068 |
+
>>> table = {"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"]}
|
1069 |
+
>>> await client.table_question_answering(table, query, model="google/tapas-base-finetuned-wtq")
|
1070 |
+
{'answer': 'AVERAGE > 36542', 'coordinates': [[0, 1]], 'cells': ['36542'], 'aggregator': 'AVERAGE'}
|
1071 |
+
```
|
1072 |
+
"""
|
1073 |
+
response = await self.post(
|
1074 |
+
json={
|
1075 |
+
"query": query,
|
1076 |
+
"table": table,
|
1077 |
+
},
|
1078 |
+
model=model,
|
1079 |
+
task="table-question-answering",
|
1080 |
+
)
|
1081 |
+
return _bytes_to_dict(response) # type: ignore
|
1082 |
+
|
1083 |
+
async def tabular_classification(self, table: Dict[str, Any], *, model: str) -> List[str]:
|
1084 |
+
"""
|
1085 |
+
Classifying a target category (a group) based on a set of attributes.
|
1086 |
+
|
1087 |
+
Args:
|
1088 |
+
table (`Dict[str, Any]`):
|
1089 |
+
Set of attributes to classify.
|
1090 |
+
model (`str`):
|
1091 |
+
The model to use for the tabular-classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
1092 |
+
a deployed Inference Endpoint.
|
1093 |
+
|
1094 |
+
Returns:
|
1095 |
+
`List`: a list of labels, one per row in the initial table.
|
1096 |
+
|
1097 |
+
Raises:
|
1098 |
+
[`InferenceTimeoutError`]:
|
1099 |
+
If the model is unavailable or the request times out.
|
1100 |
+
`aiohttp.ClientResponseError`:
|
1101 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1102 |
+
|
1103 |
+
Example:
|
1104 |
+
```py
|
1105 |
+
# Must be run in an async context
|
1106 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
1107 |
+
>>> client = AsyncInferenceClient()
|
1108 |
+
>>> table = {
|
1109 |
+
... "fixed_acidity": ["7.4", "7.8", "10.3"],
|
1110 |
+
... "volatile_acidity": ["0.7", "0.88", "0.32"],
|
1111 |
+
... "citric_acid": ["0", "0", "0.45"],
|
1112 |
+
... "residual_sugar": ["1.9", "2.6", "6.4"],
|
1113 |
+
... "chlorides": ["0.076", "0.098", "0.073"],
|
1114 |
+
... "free_sulfur_dioxide": ["11", "25", "5"],
|
1115 |
+
... "total_sulfur_dioxide": ["34", "67", "13"],
|
1116 |
+
... "density": ["0.9978", "0.9968", "0.9976"],
|
1117 |
+
... "pH": ["3.51", "3.2", "3.23"],
|
1118 |
+
... "sulphates": ["0.56", "0.68", "0.82"],
|
1119 |
+
... "alcohol": ["9.4", "9.8", "12.6"],
|
1120 |
+
... }
|
1121 |
+
>>> await client.tabular_classification(table=table, model="julien-c/wine-quality")
|
1122 |
+
["5", "5", "5"]
|
1123 |
+
```
|
1124 |
+
"""
|
1125 |
+
response = await self.post(json={"table": table}, model=model, task="tabular-classification")
|
1126 |
+
return _bytes_to_list(response)
|
1127 |
+
|
1128 |
+
async def tabular_regression(self, table: Dict[str, Any], *, model: str) -> List[float]:
|
1129 |
+
"""
|
1130 |
+
Predicting a numerical target value given a set of attributes/features in a table.
|
1131 |
+
|
1132 |
+
Args:
|
1133 |
+
table (`Dict[str, Any]`):
|
1134 |
+
Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical.
|
1135 |
+
model (`str`):
|
1136 |
+
The model to use for the tabular-regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
1137 |
+
a deployed Inference Endpoint.
|
1138 |
+
|
1139 |
+
Returns:
|
1140 |
+
`List`: a list of predicted numerical target values.
|
1141 |
+
|
1142 |
+
Raises:
|
1143 |
+
[`InferenceTimeoutError`]:
|
1144 |
+
If the model is unavailable or the request times out.
|
1145 |
+
`aiohttp.ClientResponseError`:
|
1146 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1147 |
+
|
1148 |
+
Example:
|
1149 |
+
```py
|
1150 |
+
# Must be run in an async context
|
1151 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
1152 |
+
>>> client = AsyncInferenceClient()
|
1153 |
+
>>> table = {
|
1154 |
+
... "Height": ["11.52", "12.48", "12.3778"],
|
1155 |
+
... "Length1": ["23.2", "24", "23.9"],
|
1156 |
+
... "Length2": ["25.4", "26.3", "26.5"],
|
1157 |
+
... "Length3": ["30", "31.2", "31.1"],
|
1158 |
+
... "Species": ["Bream", "Bream", "Bream"],
|
1159 |
+
... "Width": ["4.02", "4.3056", "4.6961"],
|
1160 |
+
... }
|
1161 |
+
>>> await client.tabular_regression(table, model="scikit-learn/Fish-Weight")
|
1162 |
+
[110, 120, 130]
|
1163 |
+
```
|
1164 |
+
"""
|
1165 |
+
response = await self.post(json={"table": table}, model=model, task="tabular-regression")
|
1166 |
+
return _bytes_to_list(response)
|
1167 |
+
|
1168 |
+
async def text_classification(self, text: str, *, model: Optional[str] = None) -> List[ClassificationOutput]:
|
1169 |
+
"""
|
1170 |
+
Perform text classification (e.g. sentiment-analysis) on the given text.
|
1171 |
+
|
1172 |
+
Args:
|
1173 |
+
text (`str`):
|
1174 |
+
A string to be classified.
|
1175 |
+
model (`str`, *optional*):
|
1176 |
+
The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
1177 |
+
a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used.
|
1178 |
+
Defaults to None.
|
1179 |
+
|
1180 |
+
Returns:
|
1181 |
+
`List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
|
1182 |
+
|
1183 |
+
Raises:
|
1184 |
+
[`InferenceTimeoutError`]:
|
1185 |
+
If the model is unavailable or the request times out.
|
1186 |
+
`aiohttp.ClientResponseError`:
|
1187 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1188 |
+
|
1189 |
+
Example:
|
1190 |
+
```py
|
1191 |
+
# Must be run in an async context
|
1192 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
1193 |
+
>>> client = AsyncInferenceClient()
|
1194 |
+
>>> await client.text_classification("I like you")
|
1195 |
+
[{'label': 'POSITIVE', 'score': 0.9998695850372314}, {'label': 'NEGATIVE', 'score': 0.0001304351753788069}]
|
1196 |
+
```
|
1197 |
+
"""
|
1198 |
+
response = await self.post(json={"inputs": text}, model=model, task="text-classification")
|
1199 |
+
return _bytes_to_list(response)[0]
|
1200 |
+
|
1201 |
+
@overload
|
1202 |
+
async def text_generation( # type: ignore
|
1203 |
+
self,
|
1204 |
+
prompt: str,
|
1205 |
+
*,
|
1206 |
+
details: Literal[False] = ...,
|
1207 |
+
stream: Literal[False] = ...,
|
1208 |
+
model: Optional[str] = None,
|
1209 |
+
do_sample: bool = False,
|
1210 |
+
max_new_tokens: int = 20,
|
1211 |
+
best_of: Optional[int] = None,
|
1212 |
+
repetition_penalty: Optional[float] = None,
|
1213 |
+
return_full_text: bool = False,
|
1214 |
+
seed: Optional[int] = None,
|
1215 |
+
stop_sequences: Optional[List[str]] = None,
|
1216 |
+
temperature: Optional[float] = None,
|
1217 |
+
top_k: Optional[int] = None,
|
1218 |
+
top_p: Optional[float] = None,
|
1219 |
+
truncate: Optional[int] = None,
|
1220 |
+
typical_p: Optional[float] = None,
|
1221 |
+
watermark: bool = False,
|
1222 |
+
) -> str:
|
1223 |
+
...
|
1224 |
+
|
1225 |
+
@overload
|
1226 |
+
async def text_generation( # type: ignore
|
1227 |
+
self,
|
1228 |
+
prompt: str,
|
1229 |
+
*,
|
1230 |
+
details: Literal[True] = ...,
|
1231 |
+
stream: Literal[False] = ...,
|
1232 |
+
model: Optional[str] = None,
|
1233 |
+
do_sample: bool = False,
|
1234 |
+
max_new_tokens: int = 20,
|
1235 |
+
best_of: Optional[int] = None,
|
1236 |
+
repetition_penalty: Optional[float] = None,
|
1237 |
+
return_full_text: bool = False,
|
1238 |
+
seed: Optional[int] = None,
|
1239 |
+
stop_sequences: Optional[List[str]] = None,
|
1240 |
+
temperature: Optional[float] = None,
|
1241 |
+
top_k: Optional[int] = None,
|
1242 |
+
top_p: Optional[float] = None,
|
1243 |
+
truncate: Optional[int] = None,
|
1244 |
+
typical_p: Optional[float] = None,
|
1245 |
+
watermark: bool = False,
|
1246 |
+
) -> TextGenerationResponse:
|
1247 |
+
...
|
1248 |
+
|
1249 |
+
@overload
|
1250 |
+
async def text_generation( # type: ignore
|
1251 |
+
self,
|
1252 |
+
prompt: str,
|
1253 |
+
*,
|
1254 |
+
details: Literal[False] = ...,
|
1255 |
+
stream: Literal[True] = ...,
|
1256 |
+
model: Optional[str] = None,
|
1257 |
+
do_sample: bool = False,
|
1258 |
+
max_new_tokens: int = 20,
|
1259 |
+
best_of: Optional[int] = None,
|
1260 |
+
repetition_penalty: Optional[float] = None,
|
1261 |
+
return_full_text: bool = False,
|
1262 |
+
seed: Optional[int] = None,
|
1263 |
+
stop_sequences: Optional[List[str]] = None,
|
1264 |
+
temperature: Optional[float] = None,
|
1265 |
+
top_k: Optional[int] = None,
|
1266 |
+
top_p: Optional[float] = None,
|
1267 |
+
truncate: Optional[int] = None,
|
1268 |
+
typical_p: Optional[float] = None,
|
1269 |
+
watermark: bool = False,
|
1270 |
+
) -> AsyncIterable[str]:
|
1271 |
+
...
|
1272 |
+
|
1273 |
+
@overload
|
1274 |
+
async def text_generation(
|
1275 |
+
self,
|
1276 |
+
prompt: str,
|
1277 |
+
*,
|
1278 |
+
details: Literal[True] = ...,
|
1279 |
+
stream: Literal[True] = ...,
|
1280 |
+
model: Optional[str] = None,
|
1281 |
+
do_sample: bool = False,
|
1282 |
+
max_new_tokens: int = 20,
|
1283 |
+
best_of: Optional[int] = None,
|
1284 |
+
repetition_penalty: Optional[float] = None,
|
1285 |
+
return_full_text: bool = False,
|
1286 |
+
seed: Optional[int] = None,
|
1287 |
+
stop_sequences: Optional[List[str]] = None,
|
1288 |
+
temperature: Optional[float] = None,
|
1289 |
+
top_k: Optional[int] = None,
|
1290 |
+
top_p: Optional[float] = None,
|
1291 |
+
truncate: Optional[int] = None,
|
1292 |
+
typical_p: Optional[float] = None,
|
1293 |
+
watermark: bool = False,
|
1294 |
+
) -> AsyncIterable[TextGenerationStreamResponse]:
|
1295 |
+
...
|
1296 |
+
|
1297 |
+
async def text_generation(
|
1298 |
+
self,
|
1299 |
+
prompt: str,
|
1300 |
+
*,
|
1301 |
+
details: bool = False,
|
1302 |
+
stream: bool = False,
|
1303 |
+
model: Optional[str] = None,
|
1304 |
+
do_sample: bool = False,
|
1305 |
+
max_new_tokens: int = 20,
|
1306 |
+
best_of: Optional[int] = None,
|
1307 |
+
repetition_penalty: Optional[float] = None,
|
1308 |
+
return_full_text: bool = False,
|
1309 |
+
seed: Optional[int] = None,
|
1310 |
+
stop_sequences: Optional[List[str]] = None,
|
1311 |
+
temperature: Optional[float] = None,
|
1312 |
+
top_k: Optional[int] = None,
|
1313 |
+
top_p: Optional[float] = None,
|
1314 |
+
truncate: Optional[int] = None,
|
1315 |
+
typical_p: Optional[float] = None,
|
1316 |
+
watermark: bool = False,
|
1317 |
+
decoder_input_details: bool = False,
|
1318 |
+
) -> Union[str, TextGenerationResponse, AsyncIterable[str], AsyncIterable[TextGenerationStreamResponse]]:
|
1319 |
+
"""
|
1320 |
+
Given a prompt, generate the following text.
|
1321 |
+
|
1322 |
+
It is recommended to have Pydantic installed in order to get inputs validated. This is preferable as it allow
|
1323 |
+
early failures.
|
1324 |
+
|
1325 |
+
API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the
|
1326 |
+
go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the
|
1327 |
+
default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but
|
1328 |
+
not exactly the same. This method is compatible with both approaches but some parameters are only available for
|
1329 |
+
`text-generation-inference`. If some parameters are ignored, a warning message is triggered but the process
|
1330 |
+
continues correctly.
|
1331 |
+
|
1332 |
+
To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
|
1333 |
+
|
1334 |
+
Args:
|
1335 |
+
prompt (`str`):
|
1336 |
+
Input text.
|
1337 |
+
details (`bool`, *optional*):
|
1338 |
+
By default, text_generation returns a string. Pass `details=True` if you want a detailed output (tokens,
|
1339 |
+
probabilities, seed, finish reason, etc.). Only available for models running on with the
|
1340 |
+
`text-generation-inference` backend.
|
1341 |
+
stream (`bool`, *optional*):
|
1342 |
+
By default, text_generation returns the full generated text. Pass `stream=True` if you want a stream of
|
1343 |
+
tokens to be returned. Only available for models running on with the `text-generation-inference`
|
1344 |
+
backend.
|
1345 |
+
model (`str`, *optional*):
|
1346 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
1347 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
1348 |
+
do_sample (`bool`):
|
1349 |
+
Activate logits sampling
|
1350 |
+
max_new_tokens (`int`):
|
1351 |
+
Maximum number of generated tokens
|
1352 |
+
best_of (`int`):
|
1353 |
+
Generate best_of sequences and return the one if the highest token logprobs
|
1354 |
+
repetition_penalty (`float`):
|
1355 |
+
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
1356 |
+
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
1357 |
+
return_full_text (`bool`):
|
1358 |
+
Whether to prepend the prompt to the generated text
|
1359 |
+
seed (`int`):
|
1360 |
+
Random sampling seed
|
1361 |
+
stop_sequences (`List[str]`):
|
1362 |
+
Stop generating tokens if a member of `stop_sequences` is generated
|
1363 |
+
temperature (`float`):
|
1364 |
+
The value used to module the logits distribution.
|
1365 |
+
top_k (`int`):
|
1366 |
+
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
1367 |
+
top_p (`float`):
|
1368 |
+
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
1369 |
+
higher are kept for generation.
|
1370 |
+
truncate (`int`):
|
1371 |
+
Truncate inputs tokens to the given size
|
1372 |
+
typical_p (`float`):
|
1373 |
+
Typical Decoding mass
|
1374 |
+
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
1375 |
+
watermark (`bool`):
|
1376 |
+
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
1377 |
+
decoder_input_details (`bool`):
|
1378 |
+
Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken
|
1379 |
+
into account. Defaults to `False`.
|
1380 |
+
|
1381 |
+
Returns:
|
1382 |
+
`Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]`:
|
1383 |
+
Generated text returned from the server:
|
1384 |
+
- if `stream=False` and `details=False`, the generated text is returned as a `str` (default)
|
1385 |
+
- if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]`
|
1386 |
+
- if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.inference._text_generation.TextGenerationResponse`]
|
1387 |
+
- if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.inference._text_generation.TextGenerationStreamResponse`]
|
1388 |
+
|
1389 |
+
Raises:
|
1390 |
+
`ValidationError`:
|
1391 |
+
If input values are not valid. No HTTP call is made to the server.
|
1392 |
+
[`InferenceTimeoutError`]:
|
1393 |
+
If the model is unavailable or the request times out.
|
1394 |
+
`aiohttp.ClientResponseError`:
|
1395 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1396 |
+
|
1397 |
+
Example:
|
1398 |
+
```py
|
1399 |
+
# Must be run in an async context
|
1400 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
1401 |
+
>>> client = AsyncInferenceClient()
|
1402 |
+
|
1403 |
+
# Case 1: generate text
|
1404 |
+
>>> await client.text_generation("The huggingface_hub library is ", max_new_tokens=12)
|
1405 |
+
'100% open source and built to be easy to use.'
|
1406 |
+
|
1407 |
+
# Case 2: iterate over the generated tokens. Useful async for large generation.
|
1408 |
+
>>> async for token in await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True):
|
1409 |
+
... print(token)
|
1410 |
+
100
|
1411 |
+
%
|
1412 |
+
open
|
1413 |
+
source
|
1414 |
+
and
|
1415 |
+
built
|
1416 |
+
to
|
1417 |
+
be
|
1418 |
+
easy
|
1419 |
+
to
|
1420 |
+
use
|
1421 |
+
.
|
1422 |
+
|
1423 |
+
# Case 3: get more details about the generation process.
|
1424 |
+
>>> await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True)
|
1425 |
+
TextGenerationResponse(
|
1426 |
+
generated_text='100% open source and built to be easy to use.',
|
1427 |
+
details=Details(
|
1428 |
+
finish_reason=<FinishReason.Length: 'length'>,
|
1429 |
+
generated_tokens=12,
|
1430 |
+
seed=None,
|
1431 |
+
prefill=[
|
1432 |
+
InputToken(id=487, text='The', logprob=None),
|
1433 |
+
InputToken(id=53789, text=' hugging', logprob=-13.171875),
|
1434 |
+
(...)
|
1435 |
+
InputToken(id=204, text=' ', logprob=-7.0390625)
|
1436 |
+
],
|
1437 |
+
tokens=[
|
1438 |
+
Token(id=1425, text='100', logprob=-1.0175781, special=False),
|
1439 |
+
Token(id=16, text='%', logprob=-0.0463562, special=False),
|
1440 |
+
(...)
|
1441 |
+
Token(id=25, text='.', logprob=-0.5703125, special=False)
|
1442 |
+
],
|
1443 |
+
best_of_sequences=None
|
1444 |
+
)
|
1445 |
+
)
|
1446 |
+
|
1447 |
+
# Case 4: iterate over the generated tokens with more details.
|
1448 |
+
# Last object is more complete, containing the full generated text and the finish reason.
|
1449 |
+
>>> async for details in await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True):
|
1450 |
+
... print(details)
|
1451 |
+
...
|
1452 |
+
TextGenerationStreamResponse(token=Token(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None)
|
1453 |
+
TextGenerationStreamResponse(token=Token(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None)
|
1454 |
+
TextGenerationStreamResponse(token=Token(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None)
|
1455 |
+
TextGenerationStreamResponse(token=Token(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None)
|
1456 |
+
TextGenerationStreamResponse(token=Token(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None)
|
1457 |
+
TextGenerationStreamResponse(token=Token(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None)
|
1458 |
+
TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None)
|
1459 |
+
TextGenerationStreamResponse(token=Token(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None)
|
1460 |
+
TextGenerationStreamResponse(token=Token(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None)
|
1461 |
+
TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None)
|
1462 |
+
TextGenerationStreamResponse(token=Token(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None)
|
1463 |
+
TextGenerationStreamResponse(token=Token(
|
1464 |
+
id=25,
|
1465 |
+
text='.',
|
1466 |
+
logprob=-0.5703125,
|
1467 |
+
special=False),
|
1468 |
+
generated_text='100% open source and built to be easy to use.',
|
1469 |
+
details=StreamDetails(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=12, seed=None)
|
1470 |
+
)
|
1471 |
+
```
|
1472 |
+
"""
|
1473 |
+
# NOTE: Text-generation integration is taken from the text-generation-inference project. It has more features
|
1474 |
+
# like input/output validation (if Pydantic is installed). See `_text_generation.py` header for more details.
|
1475 |
+
|
1476 |
+
if decoder_input_details and not details:
|
1477 |
+
warnings.warn(
|
1478 |
+
"`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that"
|
1479 |
+
" the output from the server will be truncated."
|
1480 |
+
)
|
1481 |
+
decoder_input_details = False
|
1482 |
+
|
1483 |
+
# Validate parameters
|
1484 |
+
parameters = TextGenerationParameters(
|
1485 |
+
best_of=best_of,
|
1486 |
+
details=details,
|
1487 |
+
do_sample=do_sample,
|
1488 |
+
max_new_tokens=max_new_tokens,
|
1489 |
+
repetition_penalty=repetition_penalty,
|
1490 |
+
return_full_text=return_full_text,
|
1491 |
+
seed=seed,
|
1492 |
+
stop=stop_sequences if stop_sequences is not None else [],
|
1493 |
+
temperature=temperature,
|
1494 |
+
top_k=top_k,
|
1495 |
+
top_p=top_p,
|
1496 |
+
truncate=truncate,
|
1497 |
+
typical_p=typical_p,
|
1498 |
+
watermark=watermark,
|
1499 |
+
decoder_input_details=decoder_input_details,
|
1500 |
+
)
|
1501 |
+
request = TextGenerationRequest(inputs=prompt, stream=stream, parameters=parameters)
|
1502 |
+
payload = asdict(request)
|
1503 |
+
|
1504 |
+
# Remove some parameters if not a TGI server
|
1505 |
+
if not _is_tgi_server(model):
|
1506 |
+
ignored_parameters = []
|
1507 |
+
for key in "watermark", "stop", "details", "decoder_input_details":
|
1508 |
+
if payload["parameters"][key] is not None:
|
1509 |
+
ignored_parameters.append(key)
|
1510 |
+
del payload["parameters"][key]
|
1511 |
+
if len(ignored_parameters) > 0:
|
1512 |
+
warnings.warn(
|
1513 |
+
"API endpoint/model for text-generation is not served via TGI. Ignoring parameters"
|
1514 |
+
f" {ignored_parameters}.",
|
1515 |
+
UserWarning,
|
1516 |
+
)
|
1517 |
+
if details:
|
1518 |
+
warnings.warn(
|
1519 |
+
"API endpoint/model for text-generation is not served via TGI. Parameter `details=True` will"
|
1520 |
+
" be ignored meaning only the generated text will be returned.",
|
1521 |
+
UserWarning,
|
1522 |
+
)
|
1523 |
+
details = False
|
1524 |
+
if stream:
|
1525 |
+
raise ValueError(
|
1526 |
+
"API endpoint/model for text-generation is not served via TGI. Cannot return output as a stream."
|
1527 |
+
" Please pass `stream=False` as input."
|
1528 |
+
)
|
1529 |
+
|
1530 |
+
# Handle errors separately for more precise error messages
|
1531 |
+
try:
|
1532 |
+
bytes_output = await self.post(json=payload, model=model, task="text-generation", stream=stream) # type: ignore
|
1533 |
+
except _import_aiohttp().ClientResponseError as e:
|
1534 |
+
error_message = getattr(e, "response_error_payload", {}).get("error", "")
|
1535 |
+
if e.code == 400 and "The following `model_kwargs` are not used by the model" in error_message:
|
1536 |
+
_set_as_non_tgi(model)
|
1537 |
+
return await self.text_generation( # type: ignore
|
1538 |
+
prompt=prompt,
|
1539 |
+
details=details,
|
1540 |
+
stream=stream,
|
1541 |
+
model=model,
|
1542 |
+
do_sample=do_sample,
|
1543 |
+
max_new_tokens=max_new_tokens,
|
1544 |
+
best_of=best_of,
|
1545 |
+
repetition_penalty=repetition_penalty,
|
1546 |
+
return_full_text=return_full_text,
|
1547 |
+
seed=seed,
|
1548 |
+
stop_sequences=stop_sequences,
|
1549 |
+
temperature=temperature,
|
1550 |
+
top_k=top_k,
|
1551 |
+
top_p=top_p,
|
1552 |
+
truncate=truncate,
|
1553 |
+
typical_p=typical_p,
|
1554 |
+
watermark=watermark,
|
1555 |
+
decoder_input_details=decoder_input_details,
|
1556 |
+
)
|
1557 |
+
raise_text_generation_error(e)
|
1558 |
+
|
1559 |
+
# Parse output
|
1560 |
+
if stream:
|
1561 |
+
return _async_stream_text_generation_response(bytes_output, details) # type: ignore
|
1562 |
+
|
1563 |
+
data = _bytes_to_dict(bytes_output)[0]
|
1564 |
+
return TextGenerationResponse(**data) if details else data["generated_text"]
|
1565 |
+
|
1566 |
+
async def text_to_image(
|
1567 |
+
self,
|
1568 |
+
prompt: str,
|
1569 |
+
*,
|
1570 |
+
negative_prompt: Optional[str] = None,
|
1571 |
+
height: Optional[float] = None,
|
1572 |
+
width: Optional[float] = None,
|
1573 |
+
num_inference_steps: Optional[float] = None,
|
1574 |
+
guidance_scale: Optional[float] = None,
|
1575 |
+
model: Optional[str] = None,
|
1576 |
+
**kwargs,
|
1577 |
+
) -> "Image":
|
1578 |
+
"""
|
1579 |
+
Generate an image based on a given text using a specified model.
|
1580 |
+
|
1581 |
+
<Tip warning={true}>
|
1582 |
+
|
1583 |
+
You must have `PIL` installed if you want to work with images (`pip install Pillow`).
|
1584 |
+
|
1585 |
+
</Tip>
|
1586 |
+
|
1587 |
+
Args:
|
1588 |
+
prompt (`str`):
|
1589 |
+
The prompt to generate an image from.
|
1590 |
+
negative_prompt (`str`, *optional*):
|
1591 |
+
An optional negative prompt for the image generation.
|
1592 |
+
height (`float`, *optional*):
|
1593 |
+
The height in pixels of the image to generate.
|
1594 |
+
width (`float`, *optional*):
|
1595 |
+
The width in pixels of the image to generate.
|
1596 |
+
num_inference_steps (`int`, *optional*):
|
1597 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1598 |
+
expense of slower inference.
|
1599 |
+
guidance_scale (`float`, *optional*):
|
1600 |
+
Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1601 |
+
usually at the expense of lower image quality.
|
1602 |
+
model (`str`, *optional*):
|
1603 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
1604 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
1605 |
+
|
1606 |
+
Returns:
|
1607 |
+
`Image`: The generated image.
|
1608 |
+
|
1609 |
+
Raises:
|
1610 |
+
[`InferenceTimeoutError`]:
|
1611 |
+
If the model is unavailable or the request times out.
|
1612 |
+
`aiohttp.ClientResponseError`:
|
1613 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1614 |
+
|
1615 |
+
Example:
|
1616 |
+
```py
|
1617 |
+
# Must be run in an async context
|
1618 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
1619 |
+
>>> client = AsyncInferenceClient()
|
1620 |
+
|
1621 |
+
>>> image = await client.text_to_image("An astronaut riding a horse on the moon.")
|
1622 |
+
>>> image.save("astronaut.png")
|
1623 |
+
|
1624 |
+
>>> image = await client.text_to_image(
|
1625 |
+
... "An astronaut riding a horse on the moon.",
|
1626 |
+
... negative_prompt="low resolution, blurry",
|
1627 |
+
... model="stabilityai/stable-diffusion-2-1",
|
1628 |
+
... )
|
1629 |
+
>>> image.save("better_astronaut.png")
|
1630 |
+
```
|
1631 |
+
"""
|
1632 |
+
payload = {"inputs": prompt}
|
1633 |
+
parameters = {
|
1634 |
+
"negative_prompt": negative_prompt,
|
1635 |
+
"height": height,
|
1636 |
+
"width": width,
|
1637 |
+
"num_inference_steps": num_inference_steps,
|
1638 |
+
"guidance_scale": guidance_scale,
|
1639 |
+
**kwargs,
|
1640 |
+
}
|
1641 |
+
for key, value in parameters.items():
|
1642 |
+
if value is not None:
|
1643 |
+
payload.setdefault("parameters", {})[key] = value # type: ignore
|
1644 |
+
response = await self.post(json=payload, model=model, task="text-to-image")
|
1645 |
+
return _bytes_to_image(response)
|
1646 |
+
|
1647 |
+
async def text_to_speech(self, text: str, *, model: Optional[str] = None) -> bytes:
|
1648 |
+
"""
|
1649 |
+
Synthesize an audio of a voice pronouncing a given text.
|
1650 |
+
|
1651 |
+
Args:
|
1652 |
+
text (`str`):
|
1653 |
+
The text to synthesize.
|
1654 |
+
model (`str`, *optional*):
|
1655 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
1656 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
1657 |
+
|
1658 |
+
Returns:
|
1659 |
+
`bytes`: The generated audio.
|
1660 |
+
|
1661 |
+
Raises:
|
1662 |
+
[`InferenceTimeoutError`]:
|
1663 |
+
If the model is unavailable or the request times out.
|
1664 |
+
`aiohttp.ClientResponseError`:
|
1665 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1666 |
+
|
1667 |
+
Example:
|
1668 |
+
```py
|
1669 |
+
# Must be run in an async context
|
1670 |
+
>>> from pathlib import Path
|
1671 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
1672 |
+
>>> client = AsyncInferenceClient()
|
1673 |
+
|
1674 |
+
>>> audio = await client.text_to_speech("Hello world")
|
1675 |
+
>>> Path("hello_world.flac").write_bytes(audio)
|
1676 |
+
```
|
1677 |
+
"""
|
1678 |
+
return await self.post(json={"inputs": text}, model=model, task="text-to-speech")
|
1679 |
+
|
1680 |
+
async def token_classification(self, text: str, *, model: Optional[str] = None) -> List[TokenClassificationOutput]:
|
1681 |
+
"""
|
1682 |
+
Perform token classification on the given text.
|
1683 |
+
Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text.
|
1684 |
+
|
1685 |
+
Args:
|
1686 |
+
text (`str`):
|
1687 |
+
A string to be classified.
|
1688 |
+
model (`str`, *optional*):
|
1689 |
+
The model to use for the token classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
1690 |
+
a deployed Inference Endpoint. If not provided, the default recommended token classification model will be used.
|
1691 |
+
Defaults to None.
|
1692 |
+
|
1693 |
+
Returns:
|
1694 |
+
`List[Dict]`: List of token classification outputs containing the entity group, confidence score, word, start and end index.
|
1695 |
+
|
1696 |
+
Raises:
|
1697 |
+
[`InferenceTimeoutError`]:
|
1698 |
+
If the model is unavailable or the request times out.
|
1699 |
+
`aiohttp.ClientResponseError`:
|
1700 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1701 |
+
|
1702 |
+
Example:
|
1703 |
+
```py
|
1704 |
+
# Must be run in an async context
|
1705 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
1706 |
+
>>> client = AsyncInferenceClient()
|
1707 |
+
>>> await client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica")
|
1708 |
+
[{'entity_group': 'PER',
|
1709 |
+
'score': 0.9971321225166321,
|
1710 |
+
'word': 'Sarah Jessica Parker',
|
1711 |
+
'start': 11,
|
1712 |
+
'end': 31},
|
1713 |
+
{'entity_group': 'PER',
|
1714 |
+
'score': 0.9773476123809814,
|
1715 |
+
'word': 'Jessica',
|
1716 |
+
'start': 52,
|
1717 |
+
'end': 59}]
|
1718 |
+
```
|
1719 |
+
"""
|
1720 |
+
payload: Dict[str, Any] = {"inputs": text}
|
1721 |
+
response = await self.post(
|
1722 |
+
json=payload,
|
1723 |
+
model=model,
|
1724 |
+
task="token-classification",
|
1725 |
+
)
|
1726 |
+
return _bytes_to_list(response)
|
1727 |
+
|
1728 |
+
async def translation(
|
1729 |
+
self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None
|
1730 |
+
) -> str:
|
1731 |
+
"""
|
1732 |
+
Convert text from one language to another.
|
1733 |
+
|
1734 |
+
Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for
|
1735 |
+
your specific use case. Source and target languages usually depend on the model.
|
1736 |
+
However, it is possible to specify source and target languages for certain models. If you are working with one of these models,
|
1737 |
+
you can use `src_lang` and `tgt_lang` arguments to pass the relevant information.
|
1738 |
+
You can find this information in the model card.
|
1739 |
+
|
1740 |
+
Args:
|
1741 |
+
text (`str`):
|
1742 |
+
A string to be translated.
|
1743 |
+
model (`str`, *optional*):
|
1744 |
+
The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to
|
1745 |
+
a deployed Inference Endpoint. If not provided, the default recommended translation model will be used.
|
1746 |
+
Defaults to None.
|
1747 |
+
src_lang (`str`, *optional*):
|
1748 |
+
Source language of the translation task, i.e. input language. Cannot be passed without `tgt_lang`.
|
1749 |
+
tgt_lang (`str`, *optional*):
|
1750 |
+
Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`.
|
1751 |
+
|
1752 |
+
Returns:
|
1753 |
+
`str`: The generated translated text.
|
1754 |
+
|
1755 |
+
Raises:
|
1756 |
+
[`InferenceTimeoutError`]:
|
1757 |
+
If the model is unavailable or the request times out.
|
1758 |
+
`aiohttp.ClientResponseError`:
|
1759 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1760 |
+
`ValueError`:
|
1761 |
+
If only one of the `src_lang` and `tgt_lang` arguments are provided.
|
1762 |
+
|
1763 |
+
Example:
|
1764 |
+
```py
|
1765 |
+
# Must be run in an async context
|
1766 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
1767 |
+
>>> client = AsyncInferenceClient()
|
1768 |
+
>>> await client.translation("My name is Wolfgang and I live in Berlin")
|
1769 |
+
'Mein Name ist Wolfgang und ich lebe in Berlin.'
|
1770 |
+
>>> await client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr")
|
1771 |
+
"Je m'appelle Wolfgang et je vis à Berlin."
|
1772 |
+
```
|
1773 |
+
|
1774 |
+
Specifying languages:
|
1775 |
+
```py
|
1776 |
+
>>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX")
|
1777 |
+
"Mon nom est Sarah Jessica Parker mais vous pouvez m\'appeler Jessica"
|
1778 |
+
```
|
1779 |
+
"""
|
1780 |
+
# Throw error if only one of `src_lang` and `tgt_lang` was given
|
1781 |
+
if src_lang is not None and tgt_lang is None:
|
1782 |
+
raise ValueError("You cannot specify `src_lang` without specifying `tgt_lang`.")
|
1783 |
+
|
1784 |
+
if src_lang is None and tgt_lang is not None:
|
1785 |
+
raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.")
|
1786 |
+
|
1787 |
+
# If both `src_lang` and `tgt_lang` are given, pass them to the request body
|
1788 |
+
payload: Dict = {"inputs": text}
|
1789 |
+
if src_lang and tgt_lang:
|
1790 |
+
payload["parameters"] = {"src_lang": src_lang, "tgt_lang": tgt_lang}
|
1791 |
+
response = await self.post(json=payload, model=model, task="translation")
|
1792 |
+
return _bytes_to_dict(response)[0]["translation_text"]
|
1793 |
+
|
1794 |
+
async def zero_shot_classification(
|
1795 |
+
self, text: str, labels: List[str], *, multi_label: bool = False, model: Optional[str] = None
|
1796 |
+
) -> List[ClassificationOutput]:
|
1797 |
+
"""
|
1798 |
+
Provide as input a text and a set of candidate labels to classify the input text.
|
1799 |
+
|
1800 |
+
Args:
|
1801 |
+
text (`str`):
|
1802 |
+
The input text to classify.
|
1803 |
+
labels (`List[str]`):
|
1804 |
+
List of string possible labels. There must be at least 2 labels.
|
1805 |
+
multi_label (`bool`):
|
1806 |
+
Boolean that is set to True if classes can overlap.
|
1807 |
+
model (`str`, *optional*):
|
1808 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
1809 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
1810 |
+
|
1811 |
+
Returns:
|
1812 |
+
`List[Dict]`: List of classification outputs containing the predicted labels and their confidence.
|
1813 |
+
|
1814 |
+
Raises:
|
1815 |
+
[`InferenceTimeoutError`]:
|
1816 |
+
If the model is unavailable or the request times out.
|
1817 |
+
`aiohttp.ClientResponseError`:
|
1818 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1819 |
+
|
1820 |
+
Example:
|
1821 |
+
```py
|
1822 |
+
# Must be run in an async context
|
1823 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
1824 |
+
>>> client = AsyncInferenceClient()
|
1825 |
+
>>> text = (
|
1826 |
+
... "A new model offers an explanation async for how the Galilean satellites formed around the solar system's"
|
1827 |
+
... "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling"
|
1828 |
+
... " mysteries when he went async for a run up a hill in Nice, France."
|
1829 |
+
... )
|
1830 |
+
>>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"]
|
1831 |
+
>>> await client.zero_shot_classification(text, labels)
|
1832 |
+
[
|
1833 |
+
{"label": "scientific discovery", "score": 0.7961668968200684},
|
1834 |
+
{"label": "space & cosmos", "score": 0.18570658564567566},
|
1835 |
+
{"label": "microbiology", "score": 0.00730885099619627},
|
1836 |
+
{"label": "archeology", "score": 0.006258360575884581},
|
1837 |
+
{"label": "robots", "score": 0.004559356719255447},
|
1838 |
+
]
|
1839 |
+
>>> await client.zero_shot_classification(text, labels, multi_label=True)
|
1840 |
+
[
|
1841 |
+
{"label": "scientific discovery", "score": 0.9829297661781311},
|
1842 |
+
{"label": "space & cosmos", "score": 0.755190908908844},
|
1843 |
+
{"label": "microbiology", "score": 0.0005462635890580714},
|
1844 |
+
{"label": "archeology", "score": 0.00047131875180639327},
|
1845 |
+
{"label": "robots", "score": 0.00030448526376858354},
|
1846 |
+
]
|
1847 |
+
```
|
1848 |
+
"""
|
1849 |
+
# Raise ValueError if input is less than 2 labels
|
1850 |
+
if len(labels) < 2:
|
1851 |
+
raise ValueError("You must specify at least 2 classes to compare.")
|
1852 |
+
|
1853 |
+
response = await self.post(
|
1854 |
+
json={
|
1855 |
+
"inputs": text,
|
1856 |
+
"parameters": {
|
1857 |
+
"candidate_labels": ",".join(labels),
|
1858 |
+
"multi_label": multi_label,
|
1859 |
+
},
|
1860 |
+
},
|
1861 |
+
model=model,
|
1862 |
+
task="zero-shot-classification",
|
1863 |
+
)
|
1864 |
+
output = _bytes_to_dict(response)
|
1865 |
+
return [{"label": label, "score": score} for label, score in zip(output["labels"], output["scores"])]
|
1866 |
+
|
1867 |
+
async def zero_shot_image_classification(
|
1868 |
+
self, image: ContentT, labels: List[str], *, model: Optional[str] = None
|
1869 |
+
) -> List[ClassificationOutput]:
|
1870 |
+
"""
|
1871 |
+
Provide input image and text labels to predict text labels for the image.
|
1872 |
+
|
1873 |
+
Args:
|
1874 |
+
image (`Union[str, Path, bytes, BinaryIO]`):
|
1875 |
+
The input image to caption. It can be raw bytes, an image file, or a URL to an online image.
|
1876 |
+
labels (`List[str]`):
|
1877 |
+
List of string possible labels. There must be at least 2 labels.
|
1878 |
+
model (`str`, *optional*):
|
1879 |
+
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
|
1880 |
+
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
|
1881 |
+
|
1882 |
+
Returns:
|
1883 |
+
`List[Dict]`: List of classification outputs containing the predicted labels and their confidence.
|
1884 |
+
|
1885 |
+
Raises:
|
1886 |
+
[`InferenceTimeoutError`]:
|
1887 |
+
If the model is unavailable or the request times out.
|
1888 |
+
`aiohttp.ClientResponseError`:
|
1889 |
+
If the request fails with an HTTP error status code other than HTTP 503.
|
1890 |
+
|
1891 |
+
Example:
|
1892 |
+
```py
|
1893 |
+
# Must be run in an async context
|
1894 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
1895 |
+
>>> client = AsyncInferenceClient()
|
1896 |
+
|
1897 |
+
>>> await client.zero_shot_image_classification(
|
1898 |
+
... "https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg",
|
1899 |
+
... labels=["dog", "cat", "horse"],
|
1900 |
+
... )
|
1901 |
+
[{"label": "dog", "score": 0.956}, ...]
|
1902 |
+
```
|
1903 |
+
"""
|
1904 |
+
# Raise ValueError if input is less than 2 labels
|
1905 |
+
if len(labels) < 2:
|
1906 |
+
raise ValueError("You must specify at least 2 classes to compare.")
|
1907 |
+
|
1908 |
+
response = await self.post(
|
1909 |
+
json={"image": _b64_encode(image), "parameters": {"candidate_labels": ",".join(labels)}},
|
1910 |
+
model=model,
|
1911 |
+
task="zero-shot-image-classification",
|
1912 |
+
)
|
1913 |
+
return _bytes_to_list(response)
|
1914 |
+
|
1915 |
+
def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
|
1916 |
+
model = model or self.model
|
1917 |
+
|
1918 |
+
# If model is already a URL, ignore `task` and return directly
|
1919 |
+
if model is not None and (model.startswith("http://") or model.startswith("https://")):
|
1920 |
+
return model
|
1921 |
+
|
1922 |
+
# # If no model but task is set => fetch the recommended one for this task
|
1923 |
+
if model is None:
|
1924 |
+
if task is None:
|
1925 |
+
raise ValueError(
|
1926 |
+
"You must specify at least a model (repo_id or URL) or a task, either when instantiating"
|
1927 |
+
" `InferenceClient` or when making a request."
|
1928 |
+
)
|
1929 |
+
model = self.get_recommended_model(task)
|
1930 |
+
logger.info(
|
1931 |
+
f"Using recommended model {model} for task {task}. Note that it is"
|
1932 |
+
f" encouraged to explicitly set `model='{model}'` as the recommended"
|
1933 |
+
" models list might get updated without prior notice."
|
1934 |
+
)
|
1935 |
+
|
1936 |
+
# Compute InferenceAPI url
|
1937 |
+
return (
|
1938 |
+
# Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks.
|
1939 |
+
f"{INFERENCE_ENDPOINT}/pipeline/{task}/{model}"
|
1940 |
+
if task in ("feature-extraction", "sentence-similarity")
|
1941 |
+
# Otherwise, we use the default endpoint
|
1942 |
+
else f"{INFERENCE_ENDPOINT}/models/{model}"
|
1943 |
+
)
|
1944 |
+
|
1945 |
+
@staticmethod
|
1946 |
+
def get_recommended_model(task: str) -> str:
|
1947 |
+
"""
|
1948 |
+
Get the model Hugging Face recommends for the input task.
|
1949 |
+
|
1950 |
+
Args:
|
1951 |
+
task (`str`):
|
1952 |
+
The Hugging Face task to get which model Hugging Face recommends.
|
1953 |
+
All available tasks can be found [here](https://huggingface.co/tasks).
|
1954 |
+
|
1955 |
+
Returns:
|
1956 |
+
`str`: Name of the model recommended for the input task.
|
1957 |
+
|
1958 |
+
Raises:
|
1959 |
+
`ValueError`: If Hugging Face has no recommendation for the input task.
|
1960 |
+
"""
|
1961 |
+
model = _fetch_recommended_models().get(task)
|
1962 |
+
if model is None:
|
1963 |
+
raise ValueError(
|
1964 |
+
f"Task {task} has no recommended model. Please specify a model"
|
1965 |
+
" explicitly. Visit https://huggingface.co/tasks for more info."
|
1966 |
+
)
|
1967 |
+
return model
|
1968 |
+
|
1969 |
+
async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
|
1970 |
+
"""
|
1971 |
+
Get the status of a model hosted on the Inference API.
|
1972 |
+
|
1973 |
+
<Tip>
|
1974 |
+
|
1975 |
+
This endpoint is mostly useful when you already know which model you want to use and want to check its
|
1976 |
+
availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`].
|
1977 |
+
|
1978 |
+
</Tip>
|
1979 |
+
|
1980 |
+
Args:
|
1981 |
+
model (`str`, *optional*):
|
1982 |
+
Identifier of the model for witch the status gonna be checked. If model is not provided,
|
1983 |
+
the model associated with this instance of [`InferenceClient`] will be used. Only InferenceAPI service can be checked so the
|
1984 |
+
identifier cannot be a URL.
|
1985 |
+
|
1986 |
+
|
1987 |
+
Returns:
|
1988 |
+
[`ModelStatus`]: An instance of ModelStatus dataclass, containing information,
|
1989 |
+
about the state of the model: load, state, compute type and framework.
|
1990 |
+
|
1991 |
+
Example:
|
1992 |
+
```py
|
1993 |
+
# Must be run in an async context
|
1994 |
+
>>> from huggingface_hub import AsyncInferenceClient
|
1995 |
+
>>> client = AsyncInferenceClient()
|
1996 |
+
>>> await client.get_model_status("bigcode/starcoder")
|
1997 |
+
ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
|
1998 |
+
```
|
1999 |
+
"""
|
2000 |
+
model = model or self.model
|
2001 |
+
if model is None:
|
2002 |
+
raise ValueError("Model id not provided.")
|
2003 |
+
if model.startswith("https://"):
|
2004 |
+
raise NotImplementedError("Model status is only available for Inference API endpoints.")
|
2005 |
+
url = f"{INFERENCE_ENDPOINT}/status/{model}"
|
2006 |
+
|
2007 |
+
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
|
2008 |
+
response = await client.get(url)
|
2009 |
+
response.raise_for_status()
|
2010 |
+
response_data = await response.json()
|
2011 |
+
|
2012 |
+
if "error" in response_data:
|
2013 |
+
raise ValueError(response_data["error"])
|
2014 |
+
|
2015 |
+
return ModelStatus(
|
2016 |
+
loaded=response_data["loaded"],
|
2017 |
+
state=response_data["state"],
|
2018 |
+
compute_type=response_data["compute_type"],
|
2019 |
+
framework=response_data["framework"],
|
2020 |
+
)
|
lib/python3.11/site-packages/huggingface_hub/inference/_text_generation.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023-present, the HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
# Original implementation taken from the `text-generation` Python client (see https://pypi.org/project/text-generation/
|
17 |
+
# and https://github.com/huggingface/text-generation-inference/tree/main/clients/python)
|
18 |
+
#
|
19 |
+
# Changes compared to original implementation:
|
20 |
+
# - use pydantic.dataclasses instead of BaseModel
|
21 |
+
# - default to Python's dataclasses if Pydantic is not installed (same implementation but no validation)
|
22 |
+
# - added default values for all parameters (not needed in BaseModel but dataclasses yes)
|
23 |
+
# - integrated in `huggingface_hub.InferenceClient``
|
24 |
+
# - added `stream: bool` and `details: bool` in the `text_generation` method instead of having different methods for each use case
|
25 |
+
import warnings
|
26 |
+
from dataclasses import field
|
27 |
+
from enum import Enum
|
28 |
+
from typing import List, NoReturn, Optional
|
29 |
+
|
30 |
+
from requests import HTTPError
|
31 |
+
|
32 |
+
from ..utils import is_pydantic_available
|
33 |
+
|
34 |
+
|
35 |
+
if is_pydantic_available():
|
36 |
+
from pydantic import validator as pydantic_validator
|
37 |
+
from pydantic.dataclasses import dataclass
|
38 |
+
|
39 |
+
def validator(*args, **kwargs):
|
40 |
+
# Pydantic v1's `@validator` is deprecated in favor of `@field_validator`. In order to support both pydantic v1
|
41 |
+
# and v2 without changing the logic, we catch the warning message in pydantic v2 and ignore it. If we want to
|
42 |
+
# support pydantic v3 in the future, we will drop support for pydantic v1 and use `pydantic.field_validator`
|
43 |
+
# correctly.
|
44 |
+
#
|
45 |
+
# Related:
|
46 |
+
# - https://docs.pydantic.dev/latest/migration/#changes-to-validators
|
47 |
+
# - https://github.com/huggingface/huggingface_hub/pull/1837
|
48 |
+
with warnings.catch_warnings():
|
49 |
+
warnings.filterwarnings("ignore", message="Pydantic V1 style `@validator` validators are deprecated.")
|
50 |
+
return pydantic_validator(*args, **kwargs)
|
51 |
+
else:
|
52 |
+
# No validation if Pydantic is not installed
|
53 |
+
from dataclasses import dataclass # type: ignore
|
54 |
+
|
55 |
+
def validator(x): # type: ignore
|
56 |
+
return lambda y: y
|
57 |
+
|
58 |
+
|
59 |
+
@dataclass
|
60 |
+
class TextGenerationParameters:
|
61 |
+
"""
|
62 |
+
Parameters for text generation.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
do_sample (`bool`, *optional*):
|
66 |
+
Activate logits sampling. Defaults to False.
|
67 |
+
max_new_tokens (`int`, *optional*):
|
68 |
+
Maximum number of generated tokens. Defaults to 20.
|
69 |
+
repetition_penalty (`Optional[float]`, *optional*):
|
70 |
+
The parameter for repetition penalty. A value of 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf)
|
71 |
+
for more details. Defaults to None.
|
72 |
+
return_full_text (`bool`, *optional*):
|
73 |
+
Whether to prepend the prompt to the generated text. Defaults to False.
|
74 |
+
stop (`List[str]`, *optional*):
|
75 |
+
Stop generating tokens if a member of `stop_sequences` is generated. Defaults to an empty list.
|
76 |
+
seed (`Optional[int]`, *optional*):
|
77 |
+
Random sampling seed. Defaults to None.
|
78 |
+
temperature (`Optional[float]`, *optional*):
|
79 |
+
The value used to modulate the logits distribution. Defaults to None.
|
80 |
+
top_k (`Optional[int]`, *optional*):
|
81 |
+
The number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
|
82 |
+
top_p (`Optional[float]`, *optional*):
|
83 |
+
If set to a value less than 1, only the smallest set of most probable tokens with probabilities that add up
|
84 |
+
to `top_p` or higher are kept for generation. Defaults to None.
|
85 |
+
truncate (`Optional[int]`, *optional*):
|
86 |
+
Truncate input tokens to the given size. Defaults to None.
|
87 |
+
typical_p (`Optional[float]`, *optional*):
|
88 |
+
Typical Decoding mass. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666)
|
89 |
+
for more information. Defaults to None.
|
90 |
+
best_of (`Optional[int]`, *optional*):
|
91 |
+
Generate `best_of` sequences and return the one with the highest token logprobs. Defaults to None.
|
92 |
+
watermark (`bool`, *optional*):
|
93 |
+
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226). Defaults to False.
|
94 |
+
details (`bool`, *optional*):
|
95 |
+
Get generation details. Defaults to False.
|
96 |
+
decoder_input_details (`bool`, *optional*):
|
97 |
+
Get decoder input token logprobs and ids. Defaults to False.
|
98 |
+
"""
|
99 |
+
|
100 |
+
# Activate logits sampling
|
101 |
+
do_sample: bool = False
|
102 |
+
# Maximum number of generated tokens
|
103 |
+
max_new_tokens: int = 20
|
104 |
+
# The parameter for repetition penalty. 1.0 means no penalty.
|
105 |
+
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
106 |
+
repetition_penalty: Optional[float] = None
|
107 |
+
# Whether to prepend the prompt to the generated text
|
108 |
+
return_full_text: bool = False
|
109 |
+
# Stop generating tokens if a member of `stop_sequences` is generated
|
110 |
+
stop: List[str] = field(default_factory=lambda: [])
|
111 |
+
# Random sampling seed
|
112 |
+
seed: Optional[int] = None
|
113 |
+
# The value used to module the logits distribution.
|
114 |
+
temperature: Optional[float] = None
|
115 |
+
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
116 |
+
top_k: Optional[int] = None
|
117 |
+
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
118 |
+
# higher are kept for generation.
|
119 |
+
top_p: Optional[float] = None
|
120 |
+
# truncate inputs tokens to the given size
|
121 |
+
truncate: Optional[int] = None
|
122 |
+
# Typical Decoding mass
|
123 |
+
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
124 |
+
typical_p: Optional[float] = None
|
125 |
+
# Generate best_of sequences and return the one if the highest token logprobs
|
126 |
+
best_of: Optional[int] = None
|
127 |
+
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
128 |
+
watermark: bool = False
|
129 |
+
# Get generation details
|
130 |
+
details: bool = False
|
131 |
+
# Get decoder input token logprobs and ids
|
132 |
+
decoder_input_details: bool = False
|
133 |
+
|
134 |
+
@validator("best_of")
|
135 |
+
def valid_best_of(cls, field_value, values):
|
136 |
+
if field_value is not None:
|
137 |
+
if field_value <= 0:
|
138 |
+
raise ValueError("`best_of` must be strictly positive")
|
139 |
+
if field_value > 1 and values["seed"] is not None:
|
140 |
+
raise ValueError("`seed` must not be set when `best_of` is > 1")
|
141 |
+
sampling = (
|
142 |
+
values["do_sample"]
|
143 |
+
| (values["temperature"] is not None)
|
144 |
+
| (values["top_k"] is not None)
|
145 |
+
| (values["top_p"] is not None)
|
146 |
+
| (values["typical_p"] is not None)
|
147 |
+
)
|
148 |
+
if field_value > 1 and not sampling:
|
149 |
+
raise ValueError("you must use sampling when `best_of` is > 1")
|
150 |
+
|
151 |
+
return field_value
|
152 |
+
|
153 |
+
@validator("repetition_penalty")
|
154 |
+
def valid_repetition_penalty(cls, v):
|
155 |
+
if v is not None and v <= 0:
|
156 |
+
raise ValueError("`repetition_penalty` must be strictly positive")
|
157 |
+
return v
|
158 |
+
|
159 |
+
@validator("seed")
|
160 |
+
def valid_seed(cls, v):
|
161 |
+
if v is not None and v < 0:
|
162 |
+
raise ValueError("`seed` must be positive")
|
163 |
+
return v
|
164 |
+
|
165 |
+
@validator("temperature")
|
166 |
+
def valid_temp(cls, v):
|
167 |
+
if v is not None and v <= 0:
|
168 |
+
raise ValueError("`temperature` must be strictly positive")
|
169 |
+
return v
|
170 |
+
|
171 |
+
@validator("top_k")
|
172 |
+
def valid_top_k(cls, v):
|
173 |
+
if v is not None and v <= 0:
|
174 |
+
raise ValueError("`top_k` must be strictly positive")
|
175 |
+
return v
|
176 |
+
|
177 |
+
@validator("top_p")
|
178 |
+
def valid_top_p(cls, v):
|
179 |
+
if v is not None and (v <= 0 or v >= 1.0):
|
180 |
+
raise ValueError("`top_p` must be > 0.0 and < 1.0")
|
181 |
+
return v
|
182 |
+
|
183 |
+
@validator("truncate")
|
184 |
+
def valid_truncate(cls, v):
|
185 |
+
if v is not None and v <= 0:
|
186 |
+
raise ValueError("`truncate` must be strictly positive")
|
187 |
+
return v
|
188 |
+
|
189 |
+
@validator("typical_p")
|
190 |
+
def valid_typical_p(cls, v):
|
191 |
+
if v is not None and (v <= 0 or v >= 1.0):
|
192 |
+
raise ValueError("`typical_p` must be > 0.0 and < 1.0")
|
193 |
+
return v
|
194 |
+
|
195 |
+
|
196 |
+
@dataclass
|
197 |
+
class TextGenerationRequest:
|
198 |
+
"""
|
199 |
+
Request object for text generation (only for internal use).
|
200 |
+
|
201 |
+
Args:
|
202 |
+
inputs (`str`):
|
203 |
+
The prompt for text generation.
|
204 |
+
parameters (`Optional[TextGenerationParameters]`, *optional*):
|
205 |
+
Generation parameters.
|
206 |
+
stream (`bool`, *optional*):
|
207 |
+
Whether to stream output tokens. Defaults to False.
|
208 |
+
"""
|
209 |
+
|
210 |
+
# Prompt
|
211 |
+
inputs: str
|
212 |
+
# Generation parameters
|
213 |
+
parameters: Optional[TextGenerationParameters] = None
|
214 |
+
# Whether to stream output tokens
|
215 |
+
stream: bool = False
|
216 |
+
|
217 |
+
@validator("inputs")
|
218 |
+
def valid_input(cls, v):
|
219 |
+
if not v:
|
220 |
+
raise ValueError("`inputs` cannot be empty")
|
221 |
+
return v
|
222 |
+
|
223 |
+
@validator("stream")
|
224 |
+
def valid_best_of_stream(cls, field_value, values):
|
225 |
+
parameters = values["parameters"]
|
226 |
+
if parameters is not None and parameters.best_of is not None and parameters.best_of > 1 and field_value:
|
227 |
+
raise ValueError("`best_of` != 1 is not supported when `stream` == True")
|
228 |
+
return field_value
|
229 |
+
|
230 |
+
def __post_init__(self):
|
231 |
+
if not is_pydantic_available():
|
232 |
+
# If pydantic is not installed, we need to instantiate the nested dataclasses manually
|
233 |
+
if self.parameters is not None and isinstance(self.parameters, dict):
|
234 |
+
self.parameters = TextGenerationParameters(**self.parameters)
|
235 |
+
|
236 |
+
|
237 |
+
# Decoder input tokens
|
238 |
+
@dataclass
|
239 |
+
class InputToken:
|
240 |
+
"""
|
241 |
+
Represents an input token.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
id (`int`):
|
245 |
+
Token ID from the model tokenizer.
|
246 |
+
text (`str`):
|
247 |
+
Token text.
|
248 |
+
logprob (`float` or `None`):
|
249 |
+
Log probability of the token. Optional since the logprob of the first token cannot be computed.
|
250 |
+
"""
|
251 |
+
|
252 |
+
# Token ID from the model tokenizer
|
253 |
+
id: int
|
254 |
+
# Token text
|
255 |
+
text: str
|
256 |
+
# Logprob
|
257 |
+
# Optional since the logprob of the first token cannot be computed
|
258 |
+
logprob: Optional[float] = None
|
259 |
+
|
260 |
+
|
261 |
+
# Generated tokens
|
262 |
+
@dataclass
|
263 |
+
class Token:
|
264 |
+
"""
|
265 |
+
Represents a token.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
id (`int`):
|
269 |
+
Token ID from the model tokenizer.
|
270 |
+
text (`str`):
|
271 |
+
Token text.
|
272 |
+
logprob (`float`):
|
273 |
+
Log probability of the token.
|
274 |
+
special (`bool`):
|
275 |
+
Indicates whether the token is a special token. It can be used to ignore
|
276 |
+
tokens when concatenating.
|
277 |
+
"""
|
278 |
+
|
279 |
+
# Token ID from the model tokenizer
|
280 |
+
id: int
|
281 |
+
# Token text
|
282 |
+
text: str
|
283 |
+
# Logprob
|
284 |
+
logprob: float
|
285 |
+
# Is the token a special token
|
286 |
+
# Can be used to ignore tokens when concatenating
|
287 |
+
special: bool
|
288 |
+
|
289 |
+
|
290 |
+
# Generation finish reason
|
291 |
+
class FinishReason(str, Enum):
|
292 |
+
# number of generated tokens == `max_new_tokens`
|
293 |
+
Length = "length"
|
294 |
+
# the model generated its end of sequence token
|
295 |
+
EndOfSequenceToken = "eos_token"
|
296 |
+
# the model generated a text included in `stop_sequences`
|
297 |
+
StopSequence = "stop_sequence"
|
298 |
+
|
299 |
+
|
300 |
+
# Additional sequences when using the `best_of` parameter
|
301 |
+
@dataclass
|
302 |
+
class BestOfSequence:
|
303 |
+
"""
|
304 |
+
Represents a best-of sequence generated during text generation.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
generated_text (`str`):
|
308 |
+
The generated text.
|
309 |
+
finish_reason (`FinishReason`):
|
310 |
+
The reason for the generation to finish, represented by a `FinishReason` value.
|
311 |
+
generated_tokens (`int`):
|
312 |
+
The number of generated tokens in the sequence.
|
313 |
+
seed (`Optional[int]`):
|
314 |
+
The sampling seed if sampling was activated.
|
315 |
+
prefill (`List[InputToken]`):
|
316 |
+
The decoder input tokens. Empty if `decoder_input_details` is False. Defaults to an empty list.
|
317 |
+
tokens (`List[Token]`):
|
318 |
+
The generated tokens. Defaults to an empty list.
|
319 |
+
"""
|
320 |
+
|
321 |
+
# Generated text
|
322 |
+
generated_text: str
|
323 |
+
# Generation finish reason
|
324 |
+
finish_reason: FinishReason
|
325 |
+
# Number of generated tokens
|
326 |
+
generated_tokens: int
|
327 |
+
# Sampling seed if sampling was activated
|
328 |
+
seed: Optional[int] = None
|
329 |
+
# Decoder input tokens, empty if decoder_input_details is False
|
330 |
+
prefill: List[InputToken] = field(default_factory=lambda: [])
|
331 |
+
# Generated tokens
|
332 |
+
tokens: List[Token] = field(default_factory=lambda: [])
|
333 |
+
|
334 |
+
def __post_init__(self):
|
335 |
+
if not is_pydantic_available():
|
336 |
+
# If pydantic is not installed, we need to instantiate the nested dataclasses manually
|
337 |
+
self.prefill = [
|
338 |
+
InputToken(**input_token) if isinstance(input_token, dict) else input_token
|
339 |
+
for input_token in self.prefill
|
340 |
+
]
|
341 |
+
self.tokens = [Token(**token) if isinstance(token, dict) else token for token in self.tokens]
|
342 |
+
|
343 |
+
|
344 |
+
# `generate` details
|
345 |
+
@dataclass
|
346 |
+
class Details:
|
347 |
+
"""
|
348 |
+
Represents details of a text generation.
|
349 |
+
|
350 |
+
Args:
|
351 |
+
finish_reason (`FinishReason`):
|
352 |
+
The reason for the generation to finish, represented by a `FinishReason` value.
|
353 |
+
generated_tokens (`int`):
|
354 |
+
The number of generated tokens.
|
355 |
+
seed (`Optional[int]`):
|
356 |
+
The sampling seed if sampling was activated.
|
357 |
+
prefill (`List[InputToken]`, *optional*):
|
358 |
+
The decoder input tokens. Empty if `decoder_input_details` is False. Defaults to an empty list.
|
359 |
+
tokens (`List[Token]`):
|
360 |
+
The generated tokens. Defaults to an empty list.
|
361 |
+
best_of_sequences (`Optional[List[BestOfSequence]]`):
|
362 |
+
Additional sequences when using the `best_of` parameter.
|
363 |
+
"""
|
364 |
+
|
365 |
+
# Generation finish reason
|
366 |
+
finish_reason: FinishReason
|
367 |
+
# Number of generated tokens
|
368 |
+
generated_tokens: int
|
369 |
+
# Sampling seed if sampling was activated
|
370 |
+
seed: Optional[int] = None
|
371 |
+
# Decoder input tokens, empty if decoder_input_details is False
|
372 |
+
prefill: List[InputToken] = field(default_factory=lambda: [])
|
373 |
+
# Generated tokens
|
374 |
+
tokens: List[Token] = field(default_factory=lambda: [])
|
375 |
+
# Additional sequences when using the `best_of` parameter
|
376 |
+
best_of_sequences: Optional[List[BestOfSequence]] = None
|
377 |
+
|
378 |
+
def __post_init__(self):
|
379 |
+
if not is_pydantic_available():
|
380 |
+
# If pydantic is not installed, we need to instantiate the nested dataclasses manually
|
381 |
+
self.prefill = [
|
382 |
+
InputToken(**input_token) if isinstance(input_token, dict) else input_token
|
383 |
+
for input_token in self.prefill
|
384 |
+
]
|
385 |
+
self.tokens = [Token(**token) if isinstance(token, dict) else token for token in self.tokens]
|
386 |
+
if self.best_of_sequences is not None:
|
387 |
+
self.best_of_sequences = [
|
388 |
+
BestOfSequence(**best_of_sequence) if isinstance(best_of_sequence, dict) else best_of_sequence
|
389 |
+
for best_of_sequence in self.best_of_sequences
|
390 |
+
]
|
391 |
+
|
392 |
+
|
393 |
+
# `generate` return value
|
394 |
+
@dataclass
|
395 |
+
class TextGenerationResponse:
|
396 |
+
"""
|
397 |
+
Represents a response for text generation.
|
398 |
+
|
399 |
+
Only returned when `details=True`, otherwise a string is returned.
|
400 |
+
|
401 |
+
Args:
|
402 |
+
generated_text (`str`):
|
403 |
+
The generated text.
|
404 |
+
details (`Optional[Details]`):
|
405 |
+
Generation details. Returned only if `details=True` is sent to the server.
|
406 |
+
"""
|
407 |
+
|
408 |
+
# Generated text
|
409 |
+
generated_text: str
|
410 |
+
# Generation details
|
411 |
+
details: Optional[Details] = None
|
412 |
+
|
413 |
+
def __post_init__(self):
|
414 |
+
if not is_pydantic_available():
|
415 |
+
# If pydantic is not installed, we need to instantiate the nested dataclasses manually
|
416 |
+
if self.details is not None and isinstance(self.details, dict):
|
417 |
+
self.details = Details(**self.details)
|
418 |
+
|
419 |
+
|
420 |
+
# `generate_stream` details
|
421 |
+
@dataclass
|
422 |
+
class StreamDetails:
|
423 |
+
"""
|
424 |
+
Represents details of a text generation stream.
|
425 |
+
|
426 |
+
Args:
|
427 |
+
finish_reason (`FinishReason`):
|
428 |
+
The reason for the generation to finish, represented by a `FinishReason` value.
|
429 |
+
generated_tokens (`int`):
|
430 |
+
The number of generated tokens.
|
431 |
+
seed (`Optional[int]`):
|
432 |
+
The sampling seed if sampling was activated.
|
433 |
+
"""
|
434 |
+
|
435 |
+
# Generation finish reason
|
436 |
+
finish_reason: FinishReason
|
437 |
+
# Number of generated tokens
|
438 |
+
generated_tokens: int
|
439 |
+
# Sampling seed if sampling was activated
|
440 |
+
seed: Optional[int] = None
|
441 |
+
|
442 |
+
|
443 |
+
# `generate_stream` return value
|
444 |
+
@dataclass
|
445 |
+
class TextGenerationStreamResponse:
|
446 |
+
"""
|
447 |
+
Represents a response for streaming text generation.
|
448 |
+
|
449 |
+
Only returned when `details=True` and `stream=True`.
|
450 |
+
|
451 |
+
Args:
|
452 |
+
token (`Token`):
|
453 |
+
The generated token.
|
454 |
+
generated_text (`Optional[str]`, *optional*):
|
455 |
+
The complete generated text. Only available when the generation is finished.
|
456 |
+
details (`Optional[StreamDetails]`, *optional*):
|
457 |
+
Generation details. Only available when the generation is finished.
|
458 |
+
"""
|
459 |
+
|
460 |
+
# Generated token
|
461 |
+
token: Token
|
462 |
+
# Complete generated text
|
463 |
+
# Only available when the generation is finished
|
464 |
+
generated_text: Optional[str] = None
|
465 |
+
# Generation details
|
466 |
+
# Only available when the generation is finished
|
467 |
+
details: Optional[StreamDetails] = None
|
468 |
+
|
469 |
+
def __post_init__(self):
|
470 |
+
if not is_pydantic_available():
|
471 |
+
# If pydantic is not installed, we need to instantiate the nested dataclasses manually
|
472 |
+
if isinstance(self.token, dict):
|
473 |
+
self.token = Token(**self.token)
|
474 |
+
if self.details is not None and isinstance(self.details, dict):
|
475 |
+
self.details = StreamDetails(**self.details)
|
476 |
+
|
477 |
+
|
478 |
+
# TEXT GENERATION ERRORS
|
479 |
+
# ----------------------
|
480 |
+
# Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation
|
481 |
+
# inference project (https://github.com/huggingface/text-generation-inference).
|
482 |
+
# ----------------------
|
483 |
+
|
484 |
+
|
485 |
+
class TextGenerationError(HTTPError):
|
486 |
+
"""Generic error raised if text-generation went wrong."""
|
487 |
+
|
488 |
+
|
489 |
+
# Text Generation Inference Errors
|
490 |
+
class ValidationError(TextGenerationError):
|
491 |
+
"""Server-side validation error."""
|
492 |
+
|
493 |
+
|
494 |
+
class GenerationError(TextGenerationError):
|
495 |
+
pass
|
496 |
+
|
497 |
+
|
498 |
+
class OverloadedError(TextGenerationError):
|
499 |
+
pass
|
500 |
+
|
501 |
+
|
502 |
+
class IncompleteGenerationError(TextGenerationError):
|
503 |
+
pass
|
504 |
+
|
505 |
+
|
506 |
+
class UnknownError(TextGenerationError):
|
507 |
+
pass
|
508 |
+
|
509 |
+
|
510 |
+
def raise_text_generation_error(http_error: HTTPError) -> NoReturn:
|
511 |
+
"""
|
512 |
+
Try to parse text-generation-inference error message and raise HTTPError in any case.
|
513 |
+
|
514 |
+
Args:
|
515 |
+
error (`HTTPError`):
|
516 |
+
The HTTPError that have been raised.
|
517 |
+
"""
|
518 |
+
# Try to parse a Text Generation Inference error
|
519 |
+
|
520 |
+
try:
|
521 |
+
# Hacky way to retrieve payload in case of aiohttp error
|
522 |
+
payload = getattr(http_error, "response_error_payload", None) or http_error.response.json()
|
523 |
+
error = payload.get("error")
|
524 |
+
error_type = payload.get("error_type")
|
525 |
+
except Exception: # no payload
|
526 |
+
raise http_error
|
527 |
+
|
528 |
+
# If error_type => more information than `hf_raise_for_status`
|
529 |
+
if error_type is not None:
|
530 |
+
exception = _parse_text_generation_error(error, error_type)
|
531 |
+
raise exception from http_error
|
532 |
+
|
533 |
+
# Otherwise, fallback to default error
|
534 |
+
raise http_error
|
535 |
+
|
536 |
+
|
537 |
+
def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError:
|
538 |
+
if error_type == "generation":
|
539 |
+
return GenerationError(error) # type: ignore
|
540 |
+
if error_type == "incomplete_generation":
|
541 |
+
return IncompleteGenerationError(error) # type: ignore
|
542 |
+
if error_type == "overloaded":
|
543 |
+
return OverloadedError(error) # type: ignore
|
544 |
+
if error_type == "validation":
|
545 |
+
return ValidationError(error) # type: ignore
|
546 |
+
return UnknownError(error) # type: ignore
|
lib/python3.11/site-packages/huggingface_hub/inference/_types.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023-present, the HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
from typing import TYPE_CHECKING, List, TypedDict
|
16 |
+
|
17 |
+
|
18 |
+
if TYPE_CHECKING:
|
19 |
+
from PIL import Image
|
20 |
+
|
21 |
+
|
22 |
+
class ClassificationOutput(TypedDict):
|
23 |
+
"""Dictionary containing the output of a [`~InferenceClient.audio_classification`] and [`~InferenceClient.image_classification`] task.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
label (`str`):
|
27 |
+
The label predicted by the model.
|
28 |
+
score (`float`):
|
29 |
+
The score of the label predicted by the model.
|
30 |
+
"""
|
31 |
+
|
32 |
+
label: str
|
33 |
+
score: float
|
34 |
+
|
35 |
+
|
36 |
+
class ConversationalOutputConversation(TypedDict):
|
37 |
+
"""Dictionary containing the "conversation" part of a [`~InferenceClient.conversational`] task.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
generated_responses (`List[str]`):
|
41 |
+
A list of the responses from the model.
|
42 |
+
past_user_inputs (`List[str]`):
|
43 |
+
A list of the inputs from the user. Must be the same length as `generated_responses`.
|
44 |
+
"""
|
45 |
+
|
46 |
+
generated_responses: List[str]
|
47 |
+
past_user_inputs: List[str]
|
48 |
+
|
49 |
+
|
50 |
+
class ConversationalOutput(TypedDict):
|
51 |
+
"""Dictionary containing the output of a [`~InferenceClient.conversational`] task.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
generated_text (`str`):
|
55 |
+
The last response from the model.
|
56 |
+
conversation (`ConversationalOutputConversation`):
|
57 |
+
The past conversation.
|
58 |
+
warnings (`List[str]`):
|
59 |
+
A list of warnings associated with the process.
|
60 |
+
"""
|
61 |
+
|
62 |
+
conversation: ConversationalOutputConversation
|
63 |
+
generated_text: str
|
64 |
+
warnings: List[str]
|
65 |
+
|
66 |
+
|
67 |
+
class FillMaskOutput(TypedDict):
|
68 |
+
"""Dictionary containing information about a [`~InferenceClient.fill_mask`] task.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
score (`float`):
|
72 |
+
The probability of the token.
|
73 |
+
token (`int`):
|
74 |
+
The id of the token.
|
75 |
+
token_str (`str`):
|
76 |
+
The string representation of the token.
|
77 |
+
sequence (`str`):
|
78 |
+
The actual sequence of tokens that ran against the model (may contain special tokens).
|
79 |
+
"""
|
80 |
+
|
81 |
+
score: float
|
82 |
+
token: int
|
83 |
+
token_str: str
|
84 |
+
sequence: str
|
85 |
+
|
86 |
+
|
87 |
+
class ImageSegmentationOutput(TypedDict):
|
88 |
+
"""Dictionary containing information about a [`~InferenceClient.image_segmentation`] task. In practice, image segmentation returns a
|
89 |
+
list of `ImageSegmentationOutput` with 1 item per mask.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
label (`str`):
|
93 |
+
The label corresponding to the mask.
|
94 |
+
mask (`Image`):
|
95 |
+
An Image object representing the mask predicted by the model.
|
96 |
+
score (`float`):
|
97 |
+
The score associated with the label for this mask.
|
98 |
+
"""
|
99 |
+
|
100 |
+
label: str
|
101 |
+
mask: "Image"
|
102 |
+
score: float
|
103 |
+
|
104 |
+
|
105 |
+
class ObjectDetectionOutput(TypedDict):
|
106 |
+
"""Dictionary containing information about a [`~InferenceClient.object_detection`] task.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
label (`str`):
|
110 |
+
The label corresponding to the detected object.
|
111 |
+
box (`dict`):
|
112 |
+
A dict response of bounding box coordinates of
|
113 |
+
the detected object: xmin, ymin, xmax, ymax
|
114 |
+
score (`float`):
|
115 |
+
The score corresponding to the detected object.
|
116 |
+
"""
|
117 |
+
|
118 |
+
label: str
|
119 |
+
box: dict
|
120 |
+
score: float
|
121 |
+
|
122 |
+
|
123 |
+
class QuestionAnsweringOutput(TypedDict):
|
124 |
+
"""Dictionary containing information about a [`~InferenceClient.question_answering`] task.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
score (`float`):
|
128 |
+
A float that represents how likely that the answer is correct.
|
129 |
+
start (`int`):
|
130 |
+
The index (string wise) of the start of the answer within context.
|
131 |
+
end (`int`):
|
132 |
+
The index (string wise) of the end of the answer within context.
|
133 |
+
answer (`str`):
|
134 |
+
A string that is the answer within the text.
|
135 |
+
"""
|
136 |
+
|
137 |
+
score: float
|
138 |
+
start: int
|
139 |
+
end: int
|
140 |
+
answer: str
|
141 |
+
|
142 |
+
|
143 |
+
class TableQuestionAnsweringOutput(TypedDict):
|
144 |
+
"""Dictionary containing information about a [`~InferenceClient.table_question_answering`] task.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
answer (`str`):
|
148 |
+
The plaintext answer.
|
149 |
+
coordinates (`List[List[int]]`):
|
150 |
+
A list of coordinates of the cells referenced in the answer.
|
151 |
+
cells (`List[int]`):
|
152 |
+
A list of coordinates of the cells contents.
|
153 |
+
aggregator (`str`):
|
154 |
+
The aggregator used to get the answer.
|
155 |
+
"""
|
156 |
+
|
157 |
+
answer: str
|
158 |
+
coordinates: List[List[int]]
|
159 |
+
cells: List[List[int]]
|
160 |
+
aggregator: str
|
161 |
+
|
162 |
+
|
163 |
+
class TokenClassificationOutput(TypedDict):
|
164 |
+
"""Dictionary containing the output of a [`~InferenceClient.token_classification`] task.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
entity_group (`str`):
|
168 |
+
The type for the entity being recognized (model specific).
|
169 |
+
score (`float`):
|
170 |
+
The score of the label predicted by the model.
|
171 |
+
word (`str`):
|
172 |
+
The string that was captured.
|
173 |
+
start (`int`):
|
174 |
+
The offset stringwise where the answer is located. Useful to disambiguate if word occurs multiple times.
|
175 |
+
end (`int`):
|
176 |
+
The offset stringwise where the answer is located. Useful to disambiguate if word occurs multiple times.
|
177 |
+
"""
|
178 |
+
|
179 |
+
entity_group: str
|
180 |
+
score: float
|
181 |
+
word: str
|
182 |
+
start: int
|
183 |
+
end: int
|
lib/python3.11/site-packages/huggingface_hub/inference_api.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
from typing import Any, Dict, List, Optional, Union
|
3 |
+
|
4 |
+
from .constants import INFERENCE_ENDPOINT
|
5 |
+
from .hf_api import HfApi
|
6 |
+
from .utils import build_hf_headers, get_session, is_pillow_available, logging, validate_hf_hub_args
|
7 |
+
from .utils._deprecation import _deprecate_method
|
8 |
+
|
9 |
+
|
10 |
+
logger = logging.get_logger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
ALL_TASKS = [
|
14 |
+
# NLP
|
15 |
+
"text-classification",
|
16 |
+
"token-classification",
|
17 |
+
"table-question-answering",
|
18 |
+
"question-answering",
|
19 |
+
"zero-shot-classification",
|
20 |
+
"translation",
|
21 |
+
"summarization",
|
22 |
+
"conversational",
|
23 |
+
"feature-extraction",
|
24 |
+
"text-generation",
|
25 |
+
"text2text-generation",
|
26 |
+
"fill-mask",
|
27 |
+
"sentence-similarity",
|
28 |
+
# Audio
|
29 |
+
"text-to-speech",
|
30 |
+
"automatic-speech-recognition",
|
31 |
+
"audio-to-audio",
|
32 |
+
"audio-classification",
|
33 |
+
"voice-activity-detection",
|
34 |
+
# Computer vision
|
35 |
+
"image-classification",
|
36 |
+
"object-detection",
|
37 |
+
"image-segmentation",
|
38 |
+
"text-to-image",
|
39 |
+
"image-to-image",
|
40 |
+
# Others
|
41 |
+
"tabular-classification",
|
42 |
+
"tabular-regression",
|
43 |
+
]
|
44 |
+
|
45 |
+
|
46 |
+
class InferenceApi:
|
47 |
+
"""Client to configure requests and make calls to the HuggingFace Inference API.
|
48 |
+
|
49 |
+
Example:
|
50 |
+
|
51 |
+
```python
|
52 |
+
>>> from huggingface_hub.inference_api import InferenceApi
|
53 |
+
|
54 |
+
>>> # Mask-fill example
|
55 |
+
>>> inference = InferenceApi("bert-base-uncased")
|
56 |
+
>>> inference(inputs="The goal of life is [MASK].")
|
57 |
+
[{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}]
|
58 |
+
|
59 |
+
>>> # Question Answering example
|
60 |
+
>>> inference = InferenceApi("deepset/roberta-base-squad2")
|
61 |
+
>>> inputs = {
|
62 |
+
... "question": "What's my name?",
|
63 |
+
... "context": "My name is Clara and I live in Berkeley.",
|
64 |
+
... }
|
65 |
+
>>> inference(inputs)
|
66 |
+
{'score': 0.9326569437980652, 'start': 11, 'end': 16, 'answer': 'Clara'}
|
67 |
+
|
68 |
+
>>> # Zero-shot example
|
69 |
+
>>> inference = InferenceApi("typeform/distilbert-base-uncased-mnli")
|
70 |
+
>>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!"
|
71 |
+
>>> params = {"candidate_labels": ["refund", "legal", "faq"]}
|
72 |
+
>>> inference(inputs, params)
|
73 |
+
{'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]}
|
74 |
+
|
75 |
+
>>> # Overriding configured task
|
76 |
+
>>> inference = InferenceApi("bert-base-uncased", task="feature-extraction")
|
77 |
+
|
78 |
+
>>> # Text-to-image
|
79 |
+
>>> inference = InferenceApi("stabilityai/stable-diffusion-2-1")
|
80 |
+
>>> inference("cat")
|
81 |
+
<PIL.PngImagePlugin.PngImageFile image (...)>
|
82 |
+
|
83 |
+
>>> # Return as raw response to parse the output yourself
|
84 |
+
>>> inference = InferenceApi("mio/amadeus")
|
85 |
+
>>> response = inference("hello world", raw_response=True)
|
86 |
+
>>> response.headers
|
87 |
+
{"Content-Type": "audio/flac", ...}
|
88 |
+
>>> response.content # raw bytes from server
|
89 |
+
b'(...)'
|
90 |
+
```
|
91 |
+
"""
|
92 |
+
|
93 |
+
@validate_hf_hub_args
|
94 |
+
@_deprecate_method(
|
95 |
+
version="1.0",
|
96 |
+
message=(
|
97 |
+
"`InferenceApi` client is deprecated in favor of the more feature-complete `InferenceClient`. Check out"
|
98 |
+
" this guide to learn how to convert your script to use it:"
|
99 |
+
" https://huggingface.co/docs/huggingface_hub/guides/inference#legacy-inferenceapi-client."
|
100 |
+
),
|
101 |
+
)
|
102 |
+
def __init__(
|
103 |
+
self,
|
104 |
+
repo_id: str,
|
105 |
+
task: Optional[str] = None,
|
106 |
+
token: Optional[str] = None,
|
107 |
+
gpu: bool = False,
|
108 |
+
):
|
109 |
+
"""Inits headers and API call information.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
repo_id (``str``):
|
113 |
+
Id of repository (e.g. `user/bert-base-uncased`).
|
114 |
+
task (``str``, `optional`, defaults ``None``):
|
115 |
+
Whether to force a task instead of using task specified in the
|
116 |
+
repository.
|
117 |
+
token (`str`, `optional`):
|
118 |
+
The API token to use as HTTP bearer authorization. This is not
|
119 |
+
the authentication token. You can find the token in
|
120 |
+
https://huggingface.co/settings/token. Alternatively, you can
|
121 |
+
find both your organizations and personal API tokens using
|
122 |
+
`HfApi().whoami(token)`.
|
123 |
+
gpu (`bool`, `optional`, defaults `False`):
|
124 |
+
Whether to use GPU instead of CPU for inference(requires Startup
|
125 |
+
plan at least).
|
126 |
+
"""
|
127 |
+
self.options = {"wait_for_model": True, "use_gpu": gpu}
|
128 |
+
self.headers = build_hf_headers(token=token)
|
129 |
+
|
130 |
+
# Configure task
|
131 |
+
model_info = HfApi(token=token).model_info(repo_id=repo_id)
|
132 |
+
if not model_info.pipeline_tag and not task:
|
133 |
+
raise ValueError(
|
134 |
+
"Task not specified in the repository. Please add it to the model card"
|
135 |
+
" using pipeline_tag"
|
136 |
+
" (https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined)"
|
137 |
+
)
|
138 |
+
|
139 |
+
if task and task != model_info.pipeline_tag:
|
140 |
+
if task not in ALL_TASKS:
|
141 |
+
raise ValueError(f"Invalid task {task}. Make sure it's valid.")
|
142 |
+
|
143 |
+
logger.warning(
|
144 |
+
"You're using a different task than the one specified in the"
|
145 |
+
" repository. Be sure to know what you're doing :)"
|
146 |
+
)
|
147 |
+
self.task = task
|
148 |
+
else:
|
149 |
+
assert model_info.pipeline_tag is not None, "Pipeline tag cannot be None"
|
150 |
+
self.task = model_info.pipeline_tag
|
151 |
+
|
152 |
+
self.api_url = f"{INFERENCE_ENDPOINT}/pipeline/{self.task}/{repo_id}"
|
153 |
+
|
154 |
+
def __repr__(self):
|
155 |
+
# Do not add headers to repr to avoid leaking token.
|
156 |
+
return f"InferenceAPI(api_url='{self.api_url}', task='{self.task}', options={self.options})"
|
157 |
+
|
158 |
+
def __call__(
|
159 |
+
self,
|
160 |
+
inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
|
161 |
+
params: Optional[Dict] = None,
|
162 |
+
data: Optional[bytes] = None,
|
163 |
+
raw_response: bool = False,
|
164 |
+
) -> Any:
|
165 |
+
"""Make a call to the Inference API.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
inputs (`str` or `Dict` or `List[str]` or `List[List[str]]`, *optional*):
|
169 |
+
Inputs for the prediction.
|
170 |
+
params (`Dict`, *optional*):
|
171 |
+
Additional parameters for the models. Will be sent as `parameters` in the
|
172 |
+
payload.
|
173 |
+
data (`bytes`, *optional*):
|
174 |
+
Bytes content of the request. In this case, leave `inputs` and `params` empty.
|
175 |
+
raw_response (`bool`, defaults to `False`):
|
176 |
+
If `True`, the raw `Response` object is returned. You can parse its content
|
177 |
+
as preferred. By default, the content is parsed into a more practical format
|
178 |
+
(json dictionary or PIL Image for example).
|
179 |
+
"""
|
180 |
+
# Build payload
|
181 |
+
payload: Dict[str, Any] = {
|
182 |
+
"options": self.options,
|
183 |
+
}
|
184 |
+
if inputs:
|
185 |
+
payload["inputs"] = inputs
|
186 |
+
if params:
|
187 |
+
payload["parameters"] = params
|
188 |
+
|
189 |
+
# Make API call
|
190 |
+
response = get_session().post(self.api_url, headers=self.headers, json=payload, data=data)
|
191 |
+
|
192 |
+
# Let the user handle the response
|
193 |
+
if raw_response:
|
194 |
+
return response
|
195 |
+
|
196 |
+
# By default, parse the response for the user.
|
197 |
+
content_type = response.headers.get("Content-Type") or ""
|
198 |
+
if content_type.startswith("image"):
|
199 |
+
if not is_pillow_available():
|
200 |
+
raise ImportError(
|
201 |
+
f"Task '{self.task}' returned as image but Pillow is not installed."
|
202 |
+
" Please install it (`pip install Pillow`) or pass"
|
203 |
+
" `raw_response=True` to get the raw `Response` object and parse"
|
204 |
+
" the image by yourself."
|
205 |
+
)
|
206 |
+
|
207 |
+
from PIL import Image
|
208 |
+
|
209 |
+
return Image.open(io.BytesIO(response.content))
|
210 |
+
elif content_type == "application/json":
|
211 |
+
return response.json()
|
212 |
+
else:
|
213 |
+
raise NotImplementedError(
|
214 |
+
f"{content_type} output type is not implemented yet. You can pass"
|
215 |
+
" `raw_response=True` to get the raw `Response` object and parse the"
|
216 |
+
" output by yourself."
|
217 |
+
)
|
lib/python3.11/site-packages/huggingface_hub/keras_mixin.py
ADDED
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections.abc as collections
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import warnings
|
5 |
+
from pathlib import Path
|
6 |
+
from shutil import copytree
|
7 |
+
from typing import Any, Dict, List, Optional, Union
|
8 |
+
|
9 |
+
from huggingface_hub import ModelHubMixin, snapshot_download
|
10 |
+
from huggingface_hub.utils import (
|
11 |
+
get_tf_version,
|
12 |
+
is_graphviz_available,
|
13 |
+
is_pydot_available,
|
14 |
+
is_tf_available,
|
15 |
+
yaml_dump,
|
16 |
+
)
|
17 |
+
|
18 |
+
from .constants import CONFIG_NAME
|
19 |
+
from .hf_api import HfApi
|
20 |
+
from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__)
|
24 |
+
|
25 |
+
if is_tf_available():
|
26 |
+
import tensorflow as tf # type: ignore
|
27 |
+
|
28 |
+
|
29 |
+
def _flatten_dict(dictionary, parent_key=""):
|
30 |
+
"""Flatten a nested dictionary.
|
31 |
+
Reference: https://stackoverflow.com/a/6027615/10319735
|
32 |
+
|
33 |
+
Args:
|
34 |
+
dictionary (`dict`):
|
35 |
+
The nested dictionary to be flattened.
|
36 |
+
parent_key (`str`):
|
37 |
+
The parent key to be prefixed to the children keys.
|
38 |
+
Necessary for recursing over the nested dictionary.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
The flattened dictionary.
|
42 |
+
"""
|
43 |
+
items = []
|
44 |
+
for key, value in dictionary.items():
|
45 |
+
new_key = f"{parent_key}.{key}" if parent_key else key
|
46 |
+
if isinstance(value, collections.MutableMapping):
|
47 |
+
items.extend(
|
48 |
+
_flatten_dict(
|
49 |
+
value,
|
50 |
+
new_key,
|
51 |
+
).items()
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
items.append((new_key, value))
|
55 |
+
return dict(items)
|
56 |
+
|
57 |
+
|
58 |
+
def _create_hyperparameter_table(model):
|
59 |
+
"""Parse hyperparameter dictionary into a markdown table."""
|
60 |
+
if model.optimizer is not None:
|
61 |
+
optimizer_params = model.optimizer.get_config()
|
62 |
+
# flatten the configuration
|
63 |
+
optimizer_params = _flatten_dict(optimizer_params)
|
64 |
+
optimizer_params["training_precision"] = tf.keras.mixed_precision.global_policy().name
|
65 |
+
table = "| Hyperparameters | Value |\n| :-- | :-- |\n"
|
66 |
+
for key, value in optimizer_params.items():
|
67 |
+
table += f"| {key} | {value} |\n"
|
68 |
+
else:
|
69 |
+
table = None
|
70 |
+
return table
|
71 |
+
|
72 |
+
|
73 |
+
def _plot_network(model, save_directory):
|
74 |
+
tf.keras.utils.plot_model(
|
75 |
+
model,
|
76 |
+
to_file=f"{save_directory}/model.png",
|
77 |
+
show_shapes=False,
|
78 |
+
show_dtype=False,
|
79 |
+
show_layer_names=True,
|
80 |
+
rankdir="TB",
|
81 |
+
expand_nested=False,
|
82 |
+
dpi=96,
|
83 |
+
layer_range=None,
|
84 |
+
)
|
85 |
+
|
86 |
+
|
87 |
+
def _create_model_card(
|
88 |
+
model,
|
89 |
+
repo_dir: Path,
|
90 |
+
plot_model: bool = True,
|
91 |
+
metadata: Optional[dict] = None,
|
92 |
+
):
|
93 |
+
"""
|
94 |
+
Creates a model card for the repository.
|
95 |
+
|
96 |
+
Do not overwrite an existing README.md file.
|
97 |
+
"""
|
98 |
+
readme_path = repo_dir / "README.md"
|
99 |
+
if readme_path.exists():
|
100 |
+
return
|
101 |
+
|
102 |
+
hyperparameters = _create_hyperparameter_table(model)
|
103 |
+
if plot_model and is_graphviz_available() and is_pydot_available():
|
104 |
+
_plot_network(model, repo_dir)
|
105 |
+
if metadata is None:
|
106 |
+
metadata = {}
|
107 |
+
metadata["library_name"] = "keras"
|
108 |
+
model_card: str = "---\n"
|
109 |
+
model_card += yaml_dump(metadata, default_flow_style=False)
|
110 |
+
model_card += "---\n"
|
111 |
+
model_card += "\n## Model description\n\nMore information needed\n"
|
112 |
+
model_card += "\n## Intended uses & limitations\n\nMore information needed\n"
|
113 |
+
model_card += "\n## Training and evaluation data\n\nMore information needed\n"
|
114 |
+
if hyperparameters is not None:
|
115 |
+
model_card += "\n## Training procedure\n"
|
116 |
+
model_card += "\n### Training hyperparameters\n"
|
117 |
+
model_card += "\nThe following hyperparameters were used during training:\n\n"
|
118 |
+
model_card += hyperparameters
|
119 |
+
model_card += "\n"
|
120 |
+
if plot_model and os.path.exists(f"{repo_dir}/model.png"):
|
121 |
+
model_card += "\n ## Model Plot\n"
|
122 |
+
model_card += "\n<details>"
|
123 |
+
model_card += "\n<summary>View Model Plot</summary>\n"
|
124 |
+
path_to_plot = "./model.png"
|
125 |
+
model_card += f"\n![Model Image]({path_to_plot})\n"
|
126 |
+
model_card += "\n</details>"
|
127 |
+
|
128 |
+
readme_path.write_text(model_card)
|
129 |
+
|
130 |
+
|
131 |
+
def save_pretrained_keras(
|
132 |
+
model,
|
133 |
+
save_directory: Union[str, Path],
|
134 |
+
config: Optional[Dict[str, Any]] = None,
|
135 |
+
include_optimizer: bool = False,
|
136 |
+
plot_model: bool = True,
|
137 |
+
tags: Optional[Union[list, str]] = None,
|
138 |
+
**model_save_kwargs,
|
139 |
+
):
|
140 |
+
"""
|
141 |
+
Saves a Keras model to save_directory in SavedModel format. Use this if
|
142 |
+
you're using the Functional or Sequential APIs.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
model (`Keras.Model`):
|
146 |
+
The [Keras
|
147 |
+
model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)
|
148 |
+
you'd like to save. The model must be compiled and built.
|
149 |
+
save_directory (`str` or `Path`):
|
150 |
+
Specify directory in which you want to save the Keras model.
|
151 |
+
config (`dict`, *optional*):
|
152 |
+
Configuration object to be saved alongside the model weights.
|
153 |
+
include_optimizer(`bool`, *optional*, defaults to `False`):
|
154 |
+
Whether or not to include optimizer in serialization.
|
155 |
+
plot_model (`bool`, *optional*, defaults to `True`):
|
156 |
+
Setting this to `True` will plot the model and put it in the model
|
157 |
+
card. Requires graphviz and pydot to be installed.
|
158 |
+
tags (Union[`str`,`list`], *optional*):
|
159 |
+
List of tags that are related to model or string of a single tag. See example tags
|
160 |
+
[here](https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1).
|
161 |
+
model_save_kwargs(`dict`, *optional*):
|
162 |
+
model_save_kwargs will be passed to
|
163 |
+
[`tf.keras.models.save_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model).
|
164 |
+
"""
|
165 |
+
if is_tf_available():
|
166 |
+
import tensorflow as tf
|
167 |
+
else:
|
168 |
+
raise ImportError("Called a Tensorflow-specific function but could not import it.")
|
169 |
+
|
170 |
+
if not model.built:
|
171 |
+
raise ValueError("Model should be built before trying to save")
|
172 |
+
|
173 |
+
save_directory = Path(save_directory)
|
174 |
+
save_directory.mkdir(parents=True, exist_ok=True)
|
175 |
+
|
176 |
+
# saving config
|
177 |
+
if config:
|
178 |
+
if not isinstance(config, dict):
|
179 |
+
raise RuntimeError(f"Provided config to save_pretrained_keras should be a dict. Got: '{type(config)}'")
|
180 |
+
|
181 |
+
with (save_directory / CONFIG_NAME).open("w") as f:
|
182 |
+
json.dump(config, f)
|
183 |
+
|
184 |
+
metadata = {}
|
185 |
+
if isinstance(tags, list):
|
186 |
+
metadata["tags"] = tags
|
187 |
+
elif isinstance(tags, str):
|
188 |
+
metadata["tags"] = [tags]
|
189 |
+
|
190 |
+
task_name = model_save_kwargs.pop("task_name", None)
|
191 |
+
if task_name is not None:
|
192 |
+
warnings.warn(
|
193 |
+
"`task_name` input argument is deprecated. Pass `tags` instead.",
|
194 |
+
FutureWarning,
|
195 |
+
)
|
196 |
+
if "tags" in metadata:
|
197 |
+
metadata["tags"].append(task_name)
|
198 |
+
else:
|
199 |
+
metadata["tags"] = [task_name]
|
200 |
+
|
201 |
+
if model.history is not None:
|
202 |
+
if model.history.history != {}:
|
203 |
+
path = save_directory / "history.json"
|
204 |
+
if path.exists():
|
205 |
+
warnings.warn(
|
206 |
+
"`history.json` file already exists, it will be overwritten by the history of this version.",
|
207 |
+
UserWarning,
|
208 |
+
)
|
209 |
+
with path.open("w", encoding="utf-8") as f:
|
210 |
+
json.dump(model.history.history, f, indent=2, sort_keys=True)
|
211 |
+
|
212 |
+
_create_model_card(model, save_directory, plot_model, metadata)
|
213 |
+
tf.keras.models.save_model(model, save_directory, include_optimizer=include_optimizer, **model_save_kwargs)
|
214 |
+
|
215 |
+
|
216 |
+
def from_pretrained_keras(*args, **kwargs) -> "KerasModelHubMixin":
|
217 |
+
r"""
|
218 |
+
Instantiate a pretrained Keras model from a pre-trained model from the Hub.
|
219 |
+
The model is expected to be in `SavedModel` format.
|
220 |
+
|
221 |
+
Args:
|
222 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
223 |
+
Can be either:
|
224 |
+
- A string, the `model id` of a pretrained model hosted inside a
|
225 |
+
model repo on huggingface.co. Valid model ids can be located
|
226 |
+
at the root-level, like `bert-base-uncased`, or namespaced
|
227 |
+
under a user or organization name, like
|
228 |
+
`dbmdz/bert-base-german-cased`.
|
229 |
+
- You can add `revision` by appending `@` at the end of model_id
|
230 |
+
simply like this: `dbmdz/bert-base-german-cased@main` Revision
|
231 |
+
is the specific model version to use. It can be a branch name,
|
232 |
+
a tag name, or a commit id, since we use a git-based system
|
233 |
+
for storing models and other artifacts on huggingface.co, so
|
234 |
+
`revision` can be any identifier allowed by git.
|
235 |
+
- A path to a `directory` containing model weights saved using
|
236 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g.,
|
237 |
+
`./my_model_directory/`.
|
238 |
+
- `None` if you are both providing the configuration and state
|
239 |
+
dictionary (resp. with keyword arguments `config` and
|
240 |
+
`state_dict`).
|
241 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
242 |
+
Whether to force the (re-)download of the model weights and
|
243 |
+
configuration files, overriding the cached versions if they exist.
|
244 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
245 |
+
Whether to delete incompletely received files. Will attempt to
|
246 |
+
resume the download if such a file exists.
|
247 |
+
proxies (`Dict[str, str]`, *optional*):
|
248 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g.,
|
249 |
+
`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The
|
250 |
+
proxies are used on each request.
|
251 |
+
token (`str` or `bool`, *optional*):
|
252 |
+
The token to use as HTTP bearer authorization for remote files. If
|
253 |
+
`True`, will use the token generated when running `transformers-cli
|
254 |
+
login` (stored in `~/.huggingface`).
|
255 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
256 |
+
Path to a directory in which a downloaded pretrained model
|
257 |
+
configuration should be cached if the standard cache should not be
|
258 |
+
used.
|
259 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
260 |
+
Whether to only look at local files (i.e., do not try to download
|
261 |
+
the model).
|
262 |
+
model_kwargs (`Dict`, *optional*):
|
263 |
+
model_kwargs will be passed to the model during initialization
|
264 |
+
|
265 |
+
<Tip>
|
266 |
+
|
267 |
+
Passing `token=True` is required when you want to use a private
|
268 |
+
model.
|
269 |
+
|
270 |
+
</Tip>
|
271 |
+
"""
|
272 |
+
return KerasModelHubMixin.from_pretrained(*args, **kwargs)
|
273 |
+
|
274 |
+
|
275 |
+
@validate_hf_hub_args
|
276 |
+
def push_to_hub_keras(
|
277 |
+
model,
|
278 |
+
repo_id: str,
|
279 |
+
*,
|
280 |
+
config: Optional[dict] = None,
|
281 |
+
commit_message: str = "Push Keras model using huggingface_hub.",
|
282 |
+
private: bool = False,
|
283 |
+
api_endpoint: Optional[str] = None,
|
284 |
+
token: Optional[str] = None,
|
285 |
+
branch: Optional[str] = None,
|
286 |
+
create_pr: Optional[bool] = None,
|
287 |
+
allow_patterns: Optional[Union[List[str], str]] = None,
|
288 |
+
ignore_patterns: Optional[Union[List[str], str]] = None,
|
289 |
+
delete_patterns: Optional[Union[List[str], str]] = None,
|
290 |
+
log_dir: Optional[str] = None,
|
291 |
+
include_optimizer: bool = False,
|
292 |
+
tags: Optional[Union[list, str]] = None,
|
293 |
+
plot_model: bool = True,
|
294 |
+
**model_save_kwargs,
|
295 |
+
):
|
296 |
+
"""
|
297 |
+
Upload model checkpoint to the Hub.
|
298 |
+
|
299 |
+
Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
|
300 |
+
`delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
|
301 |
+
details.
|
302 |
+
|
303 |
+
Args:
|
304 |
+
model (`Keras.Model`):
|
305 |
+
The [Keras model](`https://www.tensorflow.org/api_docs/python/tf/keras/Model`) you'd like to push to the
|
306 |
+
Hub. The model must be compiled and built.
|
307 |
+
repo_id (`str`):
|
308 |
+
ID of the repository to push to (example: `"username/my-model"`).
|
309 |
+
commit_message (`str`, *optional*, defaults to "Add Keras model"):
|
310 |
+
Message to commit while pushing.
|
311 |
+
private (`bool`, *optional*, defaults to `False`):
|
312 |
+
Whether the repository created should be private.
|
313 |
+
api_endpoint (`str`, *optional*):
|
314 |
+
The API endpoint to use when pushing the model to the hub.
|
315 |
+
token (`str`, *optional*):
|
316 |
+
The token to use as HTTP bearer authorization for remote files. If
|
317 |
+
not set, will use the token set when logging in with
|
318 |
+
`huggingface-cli login` (stored in `~/.huggingface`).
|
319 |
+
branch (`str`, *optional*):
|
320 |
+
The git branch on which to push the model. This defaults to
|
321 |
+
the default branch as specified in your repository, which
|
322 |
+
defaults to `"main"`.
|
323 |
+
create_pr (`boolean`, *optional*):
|
324 |
+
Whether or not to create a Pull Request from `branch` with that commit.
|
325 |
+
Defaults to `False`.
|
326 |
+
config (`dict`, *optional*):
|
327 |
+
Configuration object to be saved alongside the model weights.
|
328 |
+
allow_patterns (`List[str]` or `str`, *optional*):
|
329 |
+
If provided, only files matching at least one pattern are pushed.
|
330 |
+
ignore_patterns (`List[str]` or `str`, *optional*):
|
331 |
+
If provided, files matching any of the patterns are not pushed.
|
332 |
+
delete_patterns (`List[str]` or `str`, *optional*):
|
333 |
+
If provided, remote files matching any of the patterns will be deleted from the repo.
|
334 |
+
log_dir (`str`, *optional*):
|
335 |
+
TensorBoard logging directory to be pushed. The Hub automatically
|
336 |
+
hosts and displays a TensorBoard instance if log files are included
|
337 |
+
in the repository.
|
338 |
+
include_optimizer (`bool`, *optional*, defaults to `False`):
|
339 |
+
Whether or not to include optimizer during serialization.
|
340 |
+
tags (Union[`list`, `str`], *optional*):
|
341 |
+
List of tags that are related to model or string of a single tag. See example tags
|
342 |
+
[here](https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1).
|
343 |
+
plot_model (`bool`, *optional*, defaults to `True`):
|
344 |
+
Setting this to `True` will plot the model and put it in the model
|
345 |
+
card. Requires graphviz and pydot to be installed.
|
346 |
+
model_save_kwargs(`dict`, *optional*):
|
347 |
+
model_save_kwargs will be passed to
|
348 |
+
[`tf.keras.models.save_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model).
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
The url of the commit of your model in the given repository.
|
352 |
+
"""
|
353 |
+
api = HfApi(endpoint=api_endpoint)
|
354 |
+
repo_id = api.create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True).repo_id
|
355 |
+
|
356 |
+
# Push the files to the repo in a single commit
|
357 |
+
with SoftTemporaryDirectory() as tmp:
|
358 |
+
saved_path = Path(tmp) / repo_id
|
359 |
+
save_pretrained_keras(
|
360 |
+
model,
|
361 |
+
saved_path,
|
362 |
+
config=config,
|
363 |
+
include_optimizer=include_optimizer,
|
364 |
+
tags=tags,
|
365 |
+
plot_model=plot_model,
|
366 |
+
**model_save_kwargs,
|
367 |
+
)
|
368 |
+
|
369 |
+
# If `log_dir` provided, delete remote logs and upload new ones
|
370 |
+
if log_dir is not None:
|
371 |
+
delete_patterns = (
|
372 |
+
[]
|
373 |
+
if delete_patterns is None
|
374 |
+
else (
|
375 |
+
[delete_patterns] # convert `delete_patterns` to a list
|
376 |
+
if isinstance(delete_patterns, str)
|
377 |
+
else delete_patterns
|
378 |
+
)
|
379 |
+
)
|
380 |
+
delete_patterns.append("logs/*")
|
381 |
+
copytree(log_dir, saved_path / "logs")
|
382 |
+
|
383 |
+
return api.upload_folder(
|
384 |
+
repo_type="model",
|
385 |
+
repo_id=repo_id,
|
386 |
+
folder_path=saved_path,
|
387 |
+
commit_message=commit_message,
|
388 |
+
token=token,
|
389 |
+
revision=branch,
|
390 |
+
create_pr=create_pr,
|
391 |
+
allow_patterns=allow_patterns,
|
392 |
+
ignore_patterns=ignore_patterns,
|
393 |
+
delete_patterns=delete_patterns,
|
394 |
+
)
|
395 |
+
|
396 |
+
|
397 |
+
class KerasModelHubMixin(ModelHubMixin):
|
398 |
+
"""
|
399 |
+
Implementation of [`ModelHubMixin`] to provide model Hub upload/download
|
400 |
+
capabilities to Keras models.
|
401 |
+
|
402 |
+
|
403 |
+
```python
|
404 |
+
>>> import tensorflow as tf
|
405 |
+
>>> from huggingface_hub import KerasModelHubMixin
|
406 |
+
|
407 |
+
|
408 |
+
>>> class MyModel(tf.keras.Model, KerasModelHubMixin):
|
409 |
+
... def __init__(self, **kwargs):
|
410 |
+
... super().__init__()
|
411 |
+
... self.config = kwargs.pop("config", None)
|
412 |
+
... self.dummy_inputs = ...
|
413 |
+
... self.layer = ...
|
414 |
+
|
415 |
+
... def call(self, *args):
|
416 |
+
... return ...
|
417 |
+
|
418 |
+
|
419 |
+
>>> # Initialize and compile the model as you normally would
|
420 |
+
>>> model = MyModel()
|
421 |
+
>>> model.compile(...)
|
422 |
+
>>> # Build the graph by training it or passing dummy inputs
|
423 |
+
>>> _ = model(model.dummy_inputs)
|
424 |
+
>>> # Save model weights to local directory
|
425 |
+
>>> model.save_pretrained("my-awesome-model")
|
426 |
+
>>> # Push model weights to the Hub
|
427 |
+
>>> model.push_to_hub("my-awesome-model")
|
428 |
+
>>> # Download and initialize weights from the Hub
|
429 |
+
>>> model = MyModel.from_pretrained("username/super-cool-model")
|
430 |
+
```
|
431 |
+
"""
|
432 |
+
|
433 |
+
def _save_pretrained(self, save_directory):
|
434 |
+
save_pretrained_keras(self, save_directory)
|
435 |
+
|
436 |
+
@classmethod
|
437 |
+
def _from_pretrained(
|
438 |
+
cls,
|
439 |
+
model_id,
|
440 |
+
revision,
|
441 |
+
cache_dir,
|
442 |
+
force_download,
|
443 |
+
proxies,
|
444 |
+
resume_download,
|
445 |
+
local_files_only,
|
446 |
+
token,
|
447 |
+
**model_kwargs,
|
448 |
+
):
|
449 |
+
"""Here we just call [`from_pretrained_keras`] function so both the mixin and
|
450 |
+
functional APIs stay in sync.
|
451 |
+
|
452 |
+
TODO - Some args above aren't used since we are calling
|
453 |
+
snapshot_download instead of hf_hub_download.
|
454 |
+
"""
|
455 |
+
if is_tf_available():
|
456 |
+
import tensorflow as tf
|
457 |
+
else:
|
458 |
+
raise ImportError("Called a TensorFlow-specific function but could not import it.")
|
459 |
+
|
460 |
+
# TODO - Figure out what to do about these config values. Config is not going to be needed to load model
|
461 |
+
cfg = model_kwargs.pop("config", None)
|
462 |
+
|
463 |
+
# Root is either a local filepath matching model_id or a cached snapshot
|
464 |
+
if not os.path.isdir(model_id):
|
465 |
+
storage_folder = snapshot_download(
|
466 |
+
repo_id=model_id,
|
467 |
+
revision=revision,
|
468 |
+
cache_dir=cache_dir,
|
469 |
+
library_name="keras",
|
470 |
+
library_version=get_tf_version(),
|
471 |
+
)
|
472 |
+
else:
|
473 |
+
storage_folder = model_id
|
474 |
+
|
475 |
+
model = tf.keras.models.load_model(storage_folder, **model_kwargs)
|
476 |
+
|
477 |
+
# For now, we add a new attribute, config, to store the config loaded from the hub/a local dir.
|
478 |
+
model.config = cfg
|
479 |
+
|
480 |
+
return model
|
lib/python3.11/site-packages/huggingface_hub/lfs.py
ADDED
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2019-present, the HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Git LFS related type definitions and utilities"""
|
16 |
+
import inspect
|
17 |
+
import io
|
18 |
+
import os
|
19 |
+
import re
|
20 |
+
import warnings
|
21 |
+
from contextlib import AbstractContextManager
|
22 |
+
from dataclasses import dataclass
|
23 |
+
from math import ceil
|
24 |
+
from os.path import getsize
|
25 |
+
from pathlib import Path
|
26 |
+
from typing import TYPE_CHECKING, BinaryIO, Dict, Iterable, List, Optional, Tuple, TypedDict
|
27 |
+
from urllib.parse import unquote
|
28 |
+
|
29 |
+
from huggingface_hub.constants import ENDPOINT, HF_HUB_ENABLE_HF_TRANSFER, REPO_TYPES_URL_PREFIXES
|
30 |
+
from huggingface_hub.utils import get_session
|
31 |
+
|
32 |
+
from .utils import (
|
33 |
+
build_hf_headers,
|
34 |
+
hf_raise_for_status,
|
35 |
+
http_backoff,
|
36 |
+
logging,
|
37 |
+
tqdm,
|
38 |
+
validate_hf_hub_args,
|
39 |
+
)
|
40 |
+
from .utils.sha import sha256, sha_fileobj
|
41 |
+
|
42 |
+
|
43 |
+
if TYPE_CHECKING:
|
44 |
+
from ._commit_api import CommitOperationAdd
|
45 |
+
|
46 |
+
logger = logging.get_logger(__name__)
|
47 |
+
|
48 |
+
OID_REGEX = re.compile(r"^[0-9a-f]{40}$")
|
49 |
+
|
50 |
+
LFS_MULTIPART_UPLOAD_COMMAND = "lfs-multipart-upload"
|
51 |
+
|
52 |
+
LFS_HEADERS = {
|
53 |
+
"Accept": "application/vnd.git-lfs+json",
|
54 |
+
"Content-Type": "application/vnd.git-lfs+json",
|
55 |
+
}
|
56 |
+
|
57 |
+
|
58 |
+
@dataclass
|
59 |
+
class UploadInfo:
|
60 |
+
"""
|
61 |
+
Dataclass holding required information to determine whether a blob
|
62 |
+
should be uploaded to the hub using the LFS protocol or the regular protocol
|
63 |
+
|
64 |
+
Args:
|
65 |
+
sha256 (`bytes`):
|
66 |
+
SHA256 hash of the blob
|
67 |
+
size (`int`):
|
68 |
+
Size in bytes of the blob
|
69 |
+
sample (`bytes`):
|
70 |
+
First 512 bytes of the blob
|
71 |
+
"""
|
72 |
+
|
73 |
+
sha256: bytes
|
74 |
+
size: int
|
75 |
+
sample: bytes
|
76 |
+
|
77 |
+
@classmethod
|
78 |
+
def from_path(cls, path: str):
|
79 |
+
size = getsize(path)
|
80 |
+
with io.open(path, "rb") as file:
|
81 |
+
sample = file.peek(512)[:512]
|
82 |
+
sha = sha_fileobj(file)
|
83 |
+
return cls(size=size, sha256=sha, sample=sample)
|
84 |
+
|
85 |
+
@classmethod
|
86 |
+
def from_bytes(cls, data: bytes):
|
87 |
+
sha = sha256(data).digest()
|
88 |
+
return cls(size=len(data), sample=data[:512], sha256=sha)
|
89 |
+
|
90 |
+
@classmethod
|
91 |
+
def from_fileobj(cls, fileobj: BinaryIO):
|
92 |
+
sample = fileobj.read(512)
|
93 |
+
fileobj.seek(0, io.SEEK_SET)
|
94 |
+
sha = sha_fileobj(fileobj)
|
95 |
+
size = fileobj.tell()
|
96 |
+
fileobj.seek(0, io.SEEK_SET)
|
97 |
+
return cls(size=size, sha256=sha, sample=sample)
|
98 |
+
|
99 |
+
|
100 |
+
@validate_hf_hub_args
|
101 |
+
def post_lfs_batch_info(
|
102 |
+
upload_infos: Iterable[UploadInfo],
|
103 |
+
token: Optional[str],
|
104 |
+
repo_type: str,
|
105 |
+
repo_id: str,
|
106 |
+
revision: Optional[str] = None,
|
107 |
+
endpoint: Optional[str] = None,
|
108 |
+
) -> Tuple[List[dict], List[dict]]:
|
109 |
+
"""
|
110 |
+
Requests the LFS batch endpoint to retrieve upload instructions
|
111 |
+
|
112 |
+
Learn more: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md
|
113 |
+
|
114 |
+
Args:
|
115 |
+
upload_infos (`Iterable` of `UploadInfo`):
|
116 |
+
`UploadInfo` for the files that are being uploaded, typically obtained
|
117 |
+
from `CommitOperationAdd.upload_info`
|
118 |
+
repo_type (`str`):
|
119 |
+
Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
|
120 |
+
repo_id (`str`):
|
121 |
+
A namespace (user or an organization) and a repo name separated
|
122 |
+
by a `/`.
|
123 |
+
token (`str`, *optional*):
|
124 |
+
An authentication token ( See https://huggingface.co/settings/tokens )
|
125 |
+
revision (`str`, *optional*):
|
126 |
+
The git revision to upload to.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
`LfsBatchInfo`: 2-tuple:
|
130 |
+
- First element is the list of upload instructions from the server
|
131 |
+
- Second element is an list of errors, if any
|
132 |
+
|
133 |
+
Raises:
|
134 |
+
`ValueError`: If an argument is invalid or the server response is malformed
|
135 |
+
|
136 |
+
`HTTPError`: If the server returned an error
|
137 |
+
"""
|
138 |
+
endpoint = endpoint if endpoint is not None else ENDPOINT
|
139 |
+
url_prefix = ""
|
140 |
+
if repo_type in REPO_TYPES_URL_PREFIXES:
|
141 |
+
url_prefix = REPO_TYPES_URL_PREFIXES[repo_type]
|
142 |
+
batch_url = f"{endpoint}/{url_prefix}{repo_id}.git/info/lfs/objects/batch"
|
143 |
+
payload: Dict = {
|
144 |
+
"operation": "upload",
|
145 |
+
"transfers": ["basic", "multipart"],
|
146 |
+
"objects": [
|
147 |
+
{
|
148 |
+
"oid": upload.sha256.hex(),
|
149 |
+
"size": upload.size,
|
150 |
+
}
|
151 |
+
for upload in upload_infos
|
152 |
+
],
|
153 |
+
"hash_algo": "sha256",
|
154 |
+
}
|
155 |
+
if revision is not None:
|
156 |
+
payload["ref"] = {"name": unquote(revision)} # revision has been previously 'quoted'
|
157 |
+
headers = {**LFS_HEADERS, **build_hf_headers(token=token or True)} # Token must be provided or retrieved
|
158 |
+
resp = get_session().post(batch_url, headers=headers, json=payload)
|
159 |
+
hf_raise_for_status(resp)
|
160 |
+
batch_info = resp.json()
|
161 |
+
|
162 |
+
objects = batch_info.get("objects", None)
|
163 |
+
if not isinstance(objects, list):
|
164 |
+
raise ValueError("Malformed response from server")
|
165 |
+
|
166 |
+
return (
|
167 |
+
[_validate_batch_actions(obj) for obj in objects if "error" not in obj],
|
168 |
+
[_validate_batch_error(obj) for obj in objects if "error" in obj],
|
169 |
+
)
|
170 |
+
|
171 |
+
|
172 |
+
class PayloadPartT(TypedDict):
|
173 |
+
partNumber: int
|
174 |
+
etag: str
|
175 |
+
|
176 |
+
|
177 |
+
class CompletionPayloadT(TypedDict):
|
178 |
+
"""Payload that will be sent to the Hub when uploading multi-part."""
|
179 |
+
|
180 |
+
oid: str
|
181 |
+
parts: List[PayloadPartT]
|
182 |
+
|
183 |
+
|
184 |
+
def lfs_upload(operation: "CommitOperationAdd", lfs_batch_action: Dict, token: Optional[str]) -> None:
|
185 |
+
"""
|
186 |
+
Handles uploading a given object to the Hub with the LFS protocol.
|
187 |
+
|
188 |
+
Can be a No-op if the content of the file is already present on the hub large file storage.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
operation (`CommitOperationAdd`):
|
192 |
+
The add operation triggering this upload.
|
193 |
+
lfs_batch_action (`dict`):
|
194 |
+
Upload instructions from the LFS batch endpoint for this object. See [`~utils.lfs.post_lfs_batch_info`] for
|
195 |
+
more details.
|
196 |
+
token (`str`, *optional*):
|
197 |
+
A [user access token](https://hf.co/settings/tokens) to authenticate requests against the Hub
|
198 |
+
|
199 |
+
Raises:
|
200 |
+
- `ValueError` if `lfs_batch_action` is improperly formatted
|
201 |
+
- `HTTPError` if the upload resulted in an error
|
202 |
+
"""
|
203 |
+
# 0. If LFS file is already present, skip upload
|
204 |
+
_validate_batch_actions(lfs_batch_action)
|
205 |
+
actions = lfs_batch_action.get("actions")
|
206 |
+
if actions is None:
|
207 |
+
# The file was already uploaded
|
208 |
+
logger.debug(f"Content of file {operation.path_in_repo} is already present upstream - skipping upload")
|
209 |
+
return
|
210 |
+
|
211 |
+
# 1. Validate server response (check required keys in dict)
|
212 |
+
upload_action = lfs_batch_action["actions"]["upload"]
|
213 |
+
_validate_lfs_action(upload_action)
|
214 |
+
verify_action = lfs_batch_action["actions"].get("verify")
|
215 |
+
if verify_action is not None:
|
216 |
+
_validate_lfs_action(verify_action)
|
217 |
+
|
218 |
+
# 2. Upload file (either single part or multi-part)
|
219 |
+
header = upload_action.get("header", {})
|
220 |
+
chunk_size = header.get("chunk_size")
|
221 |
+
if chunk_size is not None:
|
222 |
+
try:
|
223 |
+
chunk_size = int(chunk_size)
|
224 |
+
except (ValueError, TypeError):
|
225 |
+
raise ValueError(
|
226 |
+
f"Malformed response from LFS batch endpoint: `chunk_size` should be an integer. Got '{chunk_size}'."
|
227 |
+
)
|
228 |
+
_upload_multi_part(operation=operation, header=header, chunk_size=chunk_size, upload_url=upload_action["href"])
|
229 |
+
else:
|
230 |
+
_upload_single_part(operation=operation, upload_url=upload_action["href"])
|
231 |
+
|
232 |
+
# 3. Verify upload went well
|
233 |
+
if verify_action is not None:
|
234 |
+
_validate_lfs_action(verify_action)
|
235 |
+
verify_resp = get_session().post(
|
236 |
+
verify_action["href"],
|
237 |
+
headers=build_hf_headers(token=token or True),
|
238 |
+
json={"oid": operation.upload_info.sha256.hex(), "size": operation.upload_info.size},
|
239 |
+
)
|
240 |
+
hf_raise_for_status(verify_resp)
|
241 |
+
logger.debug(f"{operation.path_in_repo}: Upload successful")
|
242 |
+
|
243 |
+
|
244 |
+
def _validate_lfs_action(lfs_action: dict):
|
245 |
+
"""validates response from the LFS batch endpoint"""
|
246 |
+
if not (
|
247 |
+
isinstance(lfs_action.get("href"), str)
|
248 |
+
and (lfs_action.get("header") is None or isinstance(lfs_action.get("header"), dict))
|
249 |
+
):
|
250 |
+
raise ValueError("lfs_action is improperly formatted")
|
251 |
+
return lfs_action
|
252 |
+
|
253 |
+
|
254 |
+
def _validate_batch_actions(lfs_batch_actions: dict):
|
255 |
+
"""validates response from the LFS batch endpoint"""
|
256 |
+
if not (isinstance(lfs_batch_actions.get("oid"), str) and isinstance(lfs_batch_actions.get("size"), int)):
|
257 |
+
raise ValueError("lfs_batch_actions is improperly formatted")
|
258 |
+
|
259 |
+
upload_action = lfs_batch_actions.get("actions", {}).get("upload")
|
260 |
+
verify_action = lfs_batch_actions.get("actions", {}).get("verify")
|
261 |
+
if upload_action is not None:
|
262 |
+
_validate_lfs_action(upload_action)
|
263 |
+
if verify_action is not None:
|
264 |
+
_validate_lfs_action(verify_action)
|
265 |
+
return lfs_batch_actions
|
266 |
+
|
267 |
+
|
268 |
+
def _validate_batch_error(lfs_batch_error: dict):
|
269 |
+
"""validates response from the LFS batch endpoint"""
|
270 |
+
if not (isinstance(lfs_batch_error.get("oid"), str) and isinstance(lfs_batch_error.get("size"), int)):
|
271 |
+
raise ValueError("lfs_batch_error is improperly formatted")
|
272 |
+
error_info = lfs_batch_error.get("error")
|
273 |
+
if not (
|
274 |
+
isinstance(error_info, dict)
|
275 |
+
and isinstance(error_info.get("message"), str)
|
276 |
+
and isinstance(error_info.get("code"), int)
|
277 |
+
):
|
278 |
+
raise ValueError("lfs_batch_error is improperly formatted")
|
279 |
+
return lfs_batch_error
|
280 |
+
|
281 |
+
|
282 |
+
def _upload_single_part(operation: "CommitOperationAdd", upload_url: str) -> None:
|
283 |
+
"""
|
284 |
+
Uploads `fileobj` as a single PUT HTTP request (basic LFS transfer protocol)
|
285 |
+
|
286 |
+
Args:
|
287 |
+
upload_url (`str`):
|
288 |
+
The URL to PUT the file to.
|
289 |
+
fileobj:
|
290 |
+
The file-like object holding the data to upload.
|
291 |
+
|
292 |
+
Returns: `requests.Response`
|
293 |
+
|
294 |
+
Raises: `requests.HTTPError` if the upload resulted in an error
|
295 |
+
"""
|
296 |
+
with operation.as_file(with_tqdm=True) as fileobj:
|
297 |
+
# S3 might raise a transient 500 error -> let's retry if that happens
|
298 |
+
response = http_backoff("PUT", upload_url, data=fileobj, retry_on_status_codes=(500, 503))
|
299 |
+
hf_raise_for_status(response)
|
300 |
+
|
301 |
+
|
302 |
+
def _upload_multi_part(operation: "CommitOperationAdd", header: Dict, chunk_size: int, upload_url: str) -> None:
|
303 |
+
"""
|
304 |
+
Uploads file using HF multipart LFS transfer protocol.
|
305 |
+
"""
|
306 |
+
# 1. Get upload URLs for each part
|
307 |
+
sorted_parts_urls = _get_sorted_parts_urls(header=header, upload_info=operation.upload_info, chunk_size=chunk_size)
|
308 |
+
|
309 |
+
# 2. Upload parts (either with hf_transfer or in pure Python)
|
310 |
+
use_hf_transfer = HF_HUB_ENABLE_HF_TRANSFER
|
311 |
+
if (
|
312 |
+
HF_HUB_ENABLE_HF_TRANSFER
|
313 |
+
and not isinstance(operation.path_or_fileobj, str)
|
314 |
+
and not isinstance(operation.path_or_fileobj, Path)
|
315 |
+
):
|
316 |
+
warnings.warn(
|
317 |
+
"hf_transfer is enabled but does not support uploading from bytes or BinaryIO, falling back to regular"
|
318 |
+
" upload"
|
319 |
+
)
|
320 |
+
use_hf_transfer = False
|
321 |
+
|
322 |
+
response_headers = (
|
323 |
+
_upload_parts_hf_transfer(operation=operation, sorted_parts_urls=sorted_parts_urls, chunk_size=chunk_size)
|
324 |
+
if use_hf_transfer
|
325 |
+
else _upload_parts_iteratively(operation=operation, sorted_parts_urls=sorted_parts_urls, chunk_size=chunk_size)
|
326 |
+
)
|
327 |
+
|
328 |
+
# 3. Send completion request
|
329 |
+
completion_res = get_session().post(
|
330 |
+
upload_url,
|
331 |
+
json=_get_completion_payload(response_headers, operation.upload_info.sha256.hex()),
|
332 |
+
headers=LFS_HEADERS,
|
333 |
+
)
|
334 |
+
hf_raise_for_status(completion_res)
|
335 |
+
|
336 |
+
|
337 |
+
def _get_sorted_parts_urls(header: Dict, upload_info: UploadInfo, chunk_size: int) -> List[str]:
|
338 |
+
sorted_part_upload_urls = [
|
339 |
+
upload_url
|
340 |
+
for _, upload_url in sorted(
|
341 |
+
[
|
342 |
+
(int(part_num, 10), upload_url)
|
343 |
+
for part_num, upload_url in header.items()
|
344 |
+
if part_num.isdigit() and len(part_num) > 0
|
345 |
+
],
|
346 |
+
key=lambda t: t[0],
|
347 |
+
)
|
348 |
+
]
|
349 |
+
num_parts = len(sorted_part_upload_urls)
|
350 |
+
if num_parts != ceil(upload_info.size / chunk_size):
|
351 |
+
raise ValueError("Invalid server response to upload large LFS file")
|
352 |
+
return sorted_part_upload_urls
|
353 |
+
|
354 |
+
|
355 |
+
def _get_completion_payload(response_headers: List[Dict], oid: str) -> CompletionPayloadT:
|
356 |
+
parts: List[PayloadPartT] = []
|
357 |
+
for part_number, header in enumerate(response_headers):
|
358 |
+
etag = header.get("etag")
|
359 |
+
if etag is None or etag == "":
|
360 |
+
raise ValueError(f"Invalid etag (`{etag}`) returned for part {part_number + 1}")
|
361 |
+
parts.append(
|
362 |
+
{
|
363 |
+
"partNumber": part_number + 1,
|
364 |
+
"etag": etag,
|
365 |
+
}
|
366 |
+
)
|
367 |
+
return {"oid": oid, "parts": parts}
|
368 |
+
|
369 |
+
|
370 |
+
def _upload_parts_iteratively(
|
371 |
+
operation: "CommitOperationAdd", sorted_parts_urls: List[str], chunk_size: int
|
372 |
+
) -> List[Dict]:
|
373 |
+
headers = []
|
374 |
+
with operation.as_file(with_tqdm=True) as fileobj:
|
375 |
+
for part_idx, part_upload_url in enumerate(sorted_parts_urls):
|
376 |
+
with SliceFileObj(
|
377 |
+
fileobj,
|
378 |
+
seek_from=chunk_size * part_idx,
|
379 |
+
read_limit=chunk_size,
|
380 |
+
) as fileobj_slice:
|
381 |
+
# S3 might raise a transient 500 error -> let's retry if that happens
|
382 |
+
part_upload_res = http_backoff(
|
383 |
+
"PUT", part_upload_url, data=fileobj_slice, retry_on_status_codes=(500, 503)
|
384 |
+
)
|
385 |
+
hf_raise_for_status(part_upload_res)
|
386 |
+
headers.append(part_upload_res.headers)
|
387 |
+
return headers # type: ignore
|
388 |
+
|
389 |
+
|
390 |
+
def _upload_parts_hf_transfer(
|
391 |
+
operation: "CommitOperationAdd", sorted_parts_urls: List[str], chunk_size: int
|
392 |
+
) -> List[Dict]:
|
393 |
+
# Upload file using an external Rust-based package. Upload is faster but support less features (no progress bars).
|
394 |
+
try:
|
395 |
+
from hf_transfer import multipart_upload
|
396 |
+
except ImportError:
|
397 |
+
raise ValueError(
|
398 |
+
"Fast uploading using 'hf_transfer' is enabled (HF_HUB_ENABLE_HF_TRANSFER=1) but 'hf_transfer' package is"
|
399 |
+
" not available in your environment. Try `pip install hf_transfer`."
|
400 |
+
)
|
401 |
+
|
402 |
+
supports_callback = "callback" in inspect.signature(multipart_upload).parameters
|
403 |
+
if not supports_callback:
|
404 |
+
warnings.warn(
|
405 |
+
"You are using an outdated version of `hf_transfer`. Consider upgrading to latest version to enable progress bars using `pip install -U hf_transfer`."
|
406 |
+
)
|
407 |
+
|
408 |
+
total = operation.upload_info.size
|
409 |
+
desc = operation.path_in_repo
|
410 |
+
if len(desc) > 40:
|
411 |
+
desc = f"(…){desc[-40:]}"
|
412 |
+
disable = bool(logger.getEffectiveLevel() == logging.NOTSET)
|
413 |
+
|
414 |
+
with tqdm(unit="B", unit_scale=True, total=total, initial=0, desc=desc, disable=disable) as progress:
|
415 |
+
try:
|
416 |
+
output = multipart_upload(
|
417 |
+
file_path=operation.path_or_fileobj,
|
418 |
+
parts_urls=sorted_parts_urls,
|
419 |
+
chunk_size=chunk_size,
|
420 |
+
max_files=128,
|
421 |
+
parallel_failures=127, # could be removed
|
422 |
+
max_retries=5,
|
423 |
+
**({"callback": progress.update} if supports_callback else {}),
|
424 |
+
)
|
425 |
+
except Exception as e:
|
426 |
+
raise RuntimeError(
|
427 |
+
"An error occurred while uploading using `hf_transfer`. Consider disabling HF_HUB_ENABLE_HF_TRANSFER for"
|
428 |
+
" better error handling."
|
429 |
+
) from e
|
430 |
+
if not supports_callback:
|
431 |
+
progress.update(total)
|
432 |
+
return output
|
433 |
+
|
434 |
+
|
435 |
+
class SliceFileObj(AbstractContextManager):
|
436 |
+
"""
|
437 |
+
Utility context manager to read a *slice* of a seekable file-like object as a seekable, file-like object.
|
438 |
+
|
439 |
+
This is NOT thread safe
|
440 |
+
|
441 |
+
Inspired by stackoverflow.com/a/29838711/593036
|
442 |
+
|
443 |
+
Credits to @julien-c
|
444 |
+
|
445 |
+
Args:
|
446 |
+
fileobj (`BinaryIO`):
|
447 |
+
A file-like object to slice. MUST implement `tell()` and `seek()` (and `read()` of course).
|
448 |
+
`fileobj` will be reset to its original position when exiting the context manager.
|
449 |
+
seek_from (`int`):
|
450 |
+
The start of the slice (offset from position 0 in bytes).
|
451 |
+
read_limit (`int`):
|
452 |
+
The maximum number of bytes to read from the slice.
|
453 |
+
|
454 |
+
Attributes:
|
455 |
+
previous_position (`int`):
|
456 |
+
The previous position
|
457 |
+
|
458 |
+
Examples:
|
459 |
+
|
460 |
+
Reading 200 bytes with an offset of 128 bytes from a file (ie bytes 128 to 327):
|
461 |
+
```python
|
462 |
+
>>> with open("path/to/file", "rb") as file:
|
463 |
+
... with SliceFileObj(file, seek_from=128, read_limit=200) as fslice:
|
464 |
+
... fslice.read(...)
|
465 |
+
```
|
466 |
+
|
467 |
+
Reading a file in chunks of 512 bytes
|
468 |
+
```python
|
469 |
+
>>> import os
|
470 |
+
>>> chunk_size = 512
|
471 |
+
>>> file_size = os.getsize("path/to/file")
|
472 |
+
>>> with open("path/to/file", "rb") as file:
|
473 |
+
... for chunk_idx in range(ceil(file_size / chunk_size)):
|
474 |
+
... with SliceFileObj(file, seek_from=chunk_idx * chunk_size, read_limit=chunk_size) as fslice:
|
475 |
+
... chunk = fslice.read(...)
|
476 |
+
|
477 |
+
```
|
478 |
+
"""
|
479 |
+
|
480 |
+
def __init__(self, fileobj: BinaryIO, seek_from: int, read_limit: int):
|
481 |
+
self.fileobj = fileobj
|
482 |
+
self.seek_from = seek_from
|
483 |
+
self.read_limit = read_limit
|
484 |
+
|
485 |
+
def __enter__(self):
|
486 |
+
self._previous_position = self.fileobj.tell()
|
487 |
+
end_of_stream = self.fileobj.seek(0, os.SEEK_END)
|
488 |
+
self._len = min(self.read_limit, end_of_stream - self.seek_from)
|
489 |
+
# ^^ The actual number of bytes that can be read from the slice
|
490 |
+
self.fileobj.seek(self.seek_from, io.SEEK_SET)
|
491 |
+
return self
|
492 |
+
|
493 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
494 |
+
self.fileobj.seek(self._previous_position, io.SEEK_SET)
|
495 |
+
|
496 |
+
def read(self, n: int = -1):
|
497 |
+
pos = self.tell()
|
498 |
+
if pos >= self._len:
|
499 |
+
return b""
|
500 |
+
remaining_amount = self._len - pos
|
501 |
+
data = self.fileobj.read(remaining_amount if n < 0 else min(n, remaining_amount))
|
502 |
+
return data
|
503 |
+
|
504 |
+
def tell(self) -> int:
|
505 |
+
return self.fileobj.tell() - self.seek_from
|
506 |
+
|
507 |
+
def seek(self, offset: int, whence: int = os.SEEK_SET) -> int:
|
508 |
+
start = self.seek_from
|
509 |
+
end = start + self._len
|
510 |
+
if whence in (os.SEEK_SET, os.SEEK_END):
|
511 |
+
offset = start + offset if whence == os.SEEK_SET else end + offset
|
512 |
+
offset = max(start, min(offset, end))
|
513 |
+
whence = os.SEEK_SET
|
514 |
+
elif whence == os.SEEK_CUR:
|
515 |
+
cur_pos = self.fileobj.tell()
|
516 |
+
offset = max(start - cur_pos, min(offset, end - cur_pos))
|
517 |
+
else:
|
518 |
+
raise ValueError(f"whence value {whence} is not supported")
|
519 |
+
return self.fileobj.seek(offset, whence) - self.seek_from
|
520 |
+
|
521 |
+
def __iter__(self):
|
522 |
+
yield self.read(n=4 * 1024 * 1024)
|
lib/python3.11/site-packages/huggingface_hub/repocard.py
ADDED
@@ -0,0 +1,818 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import warnings
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any, Dict, Literal, Optional, Type, Union
|
6 |
+
|
7 |
+
import requests
|
8 |
+
import yaml
|
9 |
+
|
10 |
+
from huggingface_hub.file_download import hf_hub_download
|
11 |
+
from huggingface_hub.hf_api import upload_file
|
12 |
+
from huggingface_hub.repocard_data import (
|
13 |
+
CardData,
|
14 |
+
DatasetCardData,
|
15 |
+
EvalResult,
|
16 |
+
ModelCardData,
|
17 |
+
SpaceCardData,
|
18 |
+
eval_results_to_model_index,
|
19 |
+
model_index_to_eval_results,
|
20 |
+
)
|
21 |
+
from huggingface_hub.utils import get_session, is_jinja_available, yaml_dump
|
22 |
+
|
23 |
+
from .constants import REPOCARD_NAME
|
24 |
+
from .utils import EntryNotFoundError, SoftTemporaryDirectory, validate_hf_hub_args
|
25 |
+
|
26 |
+
|
27 |
+
TEMPLATE_MODELCARD_PATH = Path(__file__).parent / "templates" / "modelcard_template.md"
|
28 |
+
TEMPLATE_DATASETCARD_PATH = Path(__file__).parent / "templates" / "datasetcard_template.md"
|
29 |
+
|
30 |
+
# exact same regex as in the Hub server. Please keep in sync.
|
31 |
+
# See https://github.com/huggingface/moon-landing/blob/main/server/lib/ViewMarkdown.ts#L18
|
32 |
+
REGEX_YAML_BLOCK = re.compile(r"^(\s*---[\r\n]+)([\S\s]*?)([\r\n]+---(\r\n|\n|$))")
|
33 |
+
|
34 |
+
|
35 |
+
class RepoCard:
|
36 |
+
card_data_class = CardData
|
37 |
+
default_template_path = TEMPLATE_MODELCARD_PATH
|
38 |
+
repo_type = "model"
|
39 |
+
|
40 |
+
def __init__(self, content: str, ignore_metadata_errors: bool = False):
|
41 |
+
"""Initialize a RepoCard from string content. The content should be a
|
42 |
+
Markdown file with a YAML block at the beginning and a Markdown body.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
content (`str`): The content of the Markdown file.
|
46 |
+
|
47 |
+
Example:
|
48 |
+
```python
|
49 |
+
>>> from huggingface_hub.repocard import RepoCard
|
50 |
+
>>> text = '''
|
51 |
+
... ---
|
52 |
+
... language: en
|
53 |
+
... license: mit
|
54 |
+
... ---
|
55 |
+
...
|
56 |
+
... # My repo
|
57 |
+
... '''
|
58 |
+
>>> card = RepoCard(text)
|
59 |
+
>>> card.data.to_dict()
|
60 |
+
{'language': 'en', 'license': 'mit'}
|
61 |
+
>>> card.text
|
62 |
+
'\\n# My repo\\n'
|
63 |
+
|
64 |
+
```
|
65 |
+
<Tip>
|
66 |
+
Raises the following error:
|
67 |
+
|
68 |
+
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
69 |
+
when the content of the repo card metadata is not a dictionary.
|
70 |
+
|
71 |
+
</Tip>
|
72 |
+
"""
|
73 |
+
|
74 |
+
# Set the content of the RepoCard, as well as underlying .data and .text attributes.
|
75 |
+
# See the `content` property setter for more details.
|
76 |
+
self.ignore_metadata_errors = ignore_metadata_errors
|
77 |
+
self.content = content
|
78 |
+
|
79 |
+
@property
|
80 |
+
def content(self):
|
81 |
+
"""The content of the RepoCard, including the YAML block and the Markdown body."""
|
82 |
+
line_break = _detect_line_ending(self._content) or "\n"
|
83 |
+
return f"---{line_break}{self.data.to_yaml(line_break=line_break)}{line_break}---{line_break}{self.text}"
|
84 |
+
|
85 |
+
@content.setter
|
86 |
+
def content(self, content: str):
|
87 |
+
"""Set the content of the RepoCard."""
|
88 |
+
self._content = content
|
89 |
+
|
90 |
+
match = REGEX_YAML_BLOCK.search(content)
|
91 |
+
if match:
|
92 |
+
# Metadata found in the YAML block
|
93 |
+
yaml_block = match.group(2)
|
94 |
+
self.text = content[match.end() :]
|
95 |
+
data_dict = yaml.safe_load(yaml_block)
|
96 |
+
|
97 |
+
if data_dict is None:
|
98 |
+
data_dict = {}
|
99 |
+
|
100 |
+
# The YAML block's data should be a dictionary
|
101 |
+
if not isinstance(data_dict, dict):
|
102 |
+
raise ValueError("repo card metadata block should be a dict")
|
103 |
+
else:
|
104 |
+
# Model card without metadata... create empty metadata
|
105 |
+
warnings.warn("Repo card metadata block was not found. Setting CardData to empty.")
|
106 |
+
data_dict = {}
|
107 |
+
self.text = content
|
108 |
+
|
109 |
+
self.data = self.card_data_class(**data_dict, ignore_metadata_errors=self.ignore_metadata_errors)
|
110 |
+
|
111 |
+
def __str__(self):
|
112 |
+
return self.content
|
113 |
+
|
114 |
+
def save(self, filepath: Union[Path, str]):
|
115 |
+
r"""Save a RepoCard to a file.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
filepath (`Union[Path, str]`): Filepath to the markdown file to save.
|
119 |
+
|
120 |
+
Example:
|
121 |
+
```python
|
122 |
+
>>> from huggingface_hub.repocard import RepoCard
|
123 |
+
>>> card = RepoCard("---\nlanguage: en\n---\n# This is a test repo card")
|
124 |
+
>>> card.save("/tmp/test.md")
|
125 |
+
|
126 |
+
```
|
127 |
+
"""
|
128 |
+
filepath = Path(filepath)
|
129 |
+
filepath.parent.mkdir(parents=True, exist_ok=True)
|
130 |
+
# Preserve newlines as in the existing file.
|
131 |
+
with open(filepath, mode="w", newline="", encoding="utf-8") as f:
|
132 |
+
f.write(str(self))
|
133 |
+
|
134 |
+
@classmethod
|
135 |
+
def load(
|
136 |
+
cls,
|
137 |
+
repo_id_or_path: Union[str, Path],
|
138 |
+
repo_type: Optional[str] = None,
|
139 |
+
token: Optional[str] = None,
|
140 |
+
ignore_metadata_errors: bool = False,
|
141 |
+
):
|
142 |
+
"""Initialize a RepoCard from a Hugging Face Hub repo's README.md or a local filepath.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
repo_id_or_path (`Union[str, Path]`):
|
146 |
+
The repo ID associated with a Hugging Face Hub repo or a local filepath.
|
147 |
+
repo_type (`str`, *optional*):
|
148 |
+
The type of Hugging Face repo to push to. Defaults to None, which will use use "model". Other options
|
149 |
+
are "dataset" and "space". Not used when loading from a local filepath. If this is called from a child
|
150 |
+
class, the default value will be the child class's `repo_type`.
|
151 |
+
token (`str`, *optional*):
|
152 |
+
Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token.
|
153 |
+
ignore_metadata_errors (`str`):
|
154 |
+
If True, errors while parsing the metadata section will be ignored. Some information might be lost during
|
155 |
+
the process. Use it at your own risk.
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
[`huggingface_hub.repocard.RepoCard`]: The RepoCard (or subclass) initialized from the repo's
|
159 |
+
README.md file or filepath.
|
160 |
+
|
161 |
+
Example:
|
162 |
+
```python
|
163 |
+
>>> from huggingface_hub.repocard import RepoCard
|
164 |
+
>>> card = RepoCard.load("nateraw/food")
|
165 |
+
>>> assert card.data.tags == ["generated_from_trainer", "image-classification", "pytorch"]
|
166 |
+
|
167 |
+
```
|
168 |
+
"""
|
169 |
+
|
170 |
+
if Path(repo_id_or_path).exists():
|
171 |
+
card_path = Path(repo_id_or_path)
|
172 |
+
elif isinstance(repo_id_or_path, str):
|
173 |
+
card_path = Path(
|
174 |
+
hf_hub_download(
|
175 |
+
repo_id_or_path,
|
176 |
+
REPOCARD_NAME,
|
177 |
+
repo_type=repo_type or cls.repo_type,
|
178 |
+
token=token,
|
179 |
+
)
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
raise ValueError(f"Cannot load RepoCard: path not found on disk ({repo_id_or_path}).")
|
183 |
+
|
184 |
+
# Preserve newlines in the existing file.
|
185 |
+
with card_path.open(mode="r", newline="", encoding="utf-8") as f:
|
186 |
+
return cls(f.read(), ignore_metadata_errors=ignore_metadata_errors)
|
187 |
+
|
188 |
+
def validate(self, repo_type: Optional[str] = None):
|
189 |
+
"""Validates card against Hugging Face Hub's card validation logic.
|
190 |
+
Using this function requires access to the internet, so it is only called
|
191 |
+
internally by [`huggingface_hub.repocard.RepoCard.push_to_hub`].
|
192 |
+
|
193 |
+
Args:
|
194 |
+
repo_type (`str`, *optional*, defaults to "model"):
|
195 |
+
The type of Hugging Face repo to push to. Options are "model", "dataset", and "space".
|
196 |
+
If this function is called from a child class, the default will be the child class's `repo_type`.
|
197 |
+
|
198 |
+
<Tip>
|
199 |
+
Raises the following errors:
|
200 |
+
|
201 |
+
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
202 |
+
if the card fails validation checks.
|
203 |
+
- [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError)
|
204 |
+
if the request to the Hub API fails for any other reason.
|
205 |
+
|
206 |
+
</Tip>
|
207 |
+
"""
|
208 |
+
|
209 |
+
# If repo type is provided, otherwise, use the repo type of the card.
|
210 |
+
repo_type = repo_type or self.repo_type
|
211 |
+
|
212 |
+
body = {
|
213 |
+
"repoType": repo_type,
|
214 |
+
"content": str(self),
|
215 |
+
}
|
216 |
+
headers = {"Accept": "text/plain"}
|
217 |
+
|
218 |
+
try:
|
219 |
+
r = get_session().post("https://huggingface.co/api/validate-yaml", body, headers=headers)
|
220 |
+
r.raise_for_status()
|
221 |
+
except requests.exceptions.HTTPError as exc:
|
222 |
+
if r.status_code == 400:
|
223 |
+
raise ValueError(r.text)
|
224 |
+
else:
|
225 |
+
raise exc
|
226 |
+
|
227 |
+
def push_to_hub(
|
228 |
+
self,
|
229 |
+
repo_id: str,
|
230 |
+
token: Optional[str] = None,
|
231 |
+
repo_type: Optional[str] = None,
|
232 |
+
commit_message: Optional[str] = None,
|
233 |
+
commit_description: Optional[str] = None,
|
234 |
+
revision: Optional[str] = None,
|
235 |
+
create_pr: Optional[bool] = None,
|
236 |
+
parent_commit: Optional[str] = None,
|
237 |
+
):
|
238 |
+
"""Push a RepoCard to a Hugging Face Hub repo.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
repo_id (`str`):
|
242 |
+
The repo ID of the Hugging Face Hub repo to push to. Example: "nateraw/food".
|
243 |
+
token (`str`, *optional*):
|
244 |
+
Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to
|
245 |
+
the stored token.
|
246 |
+
repo_type (`str`, *optional*, defaults to "model"):
|
247 |
+
The type of Hugging Face repo to push to. Options are "model", "dataset", and "space". If this
|
248 |
+
function is called by a child class, it will default to the child class's `repo_type`.
|
249 |
+
commit_message (`str`, *optional*):
|
250 |
+
The summary / title / first line of the generated commit.
|
251 |
+
commit_description (`str`, *optional*)
|
252 |
+
The description of the generated commit.
|
253 |
+
revision (`str`, *optional*):
|
254 |
+
The git revision to commit from. Defaults to the head of the `"main"` branch.
|
255 |
+
create_pr (`bool`, *optional*):
|
256 |
+
Whether or not to create a Pull Request with this commit. Defaults to `False`.
|
257 |
+
parent_commit (`str`, *optional*):
|
258 |
+
The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported.
|
259 |
+
If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`.
|
260 |
+
If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`.
|
261 |
+
Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be
|
262 |
+
especially useful if the repo is updated / committed to concurrently.
|
263 |
+
Returns:
|
264 |
+
`str`: URL of the commit which updated the card metadata.
|
265 |
+
"""
|
266 |
+
|
267 |
+
# If repo type is provided, otherwise, use the repo type of the card.
|
268 |
+
repo_type = repo_type or self.repo_type
|
269 |
+
|
270 |
+
# Validate card before pushing to hub
|
271 |
+
self.validate(repo_type=repo_type)
|
272 |
+
|
273 |
+
with SoftTemporaryDirectory() as tmpdir:
|
274 |
+
tmp_path = Path(tmpdir) / REPOCARD_NAME
|
275 |
+
tmp_path.write_text(str(self))
|
276 |
+
url = upload_file(
|
277 |
+
path_or_fileobj=str(tmp_path),
|
278 |
+
path_in_repo=REPOCARD_NAME,
|
279 |
+
repo_id=repo_id,
|
280 |
+
token=token,
|
281 |
+
repo_type=repo_type,
|
282 |
+
commit_message=commit_message,
|
283 |
+
commit_description=commit_description,
|
284 |
+
create_pr=create_pr,
|
285 |
+
revision=revision,
|
286 |
+
parent_commit=parent_commit,
|
287 |
+
)
|
288 |
+
return url
|
289 |
+
|
290 |
+
@classmethod
|
291 |
+
def from_template(
|
292 |
+
cls,
|
293 |
+
card_data: CardData,
|
294 |
+
template_path: Optional[str] = None,
|
295 |
+
**template_kwargs,
|
296 |
+
):
|
297 |
+
"""Initialize a RepoCard from a template. By default, it uses the default template.
|
298 |
+
|
299 |
+
Templates are Jinja2 templates that can be customized by passing keyword arguments.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
card_data (`huggingface_hub.CardData`):
|
303 |
+
A huggingface_hub.CardData instance containing the metadata you want to include in the YAML
|
304 |
+
header of the repo card on the Hugging Face Hub.
|
305 |
+
template_path (`str`, *optional*):
|
306 |
+
A path to a markdown file with optional Jinja template variables that can be filled
|
307 |
+
in with `template_kwargs`. Defaults to the default template.
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
[`huggingface_hub.repocard.RepoCard`]: A RepoCard instance with the specified card data and content from the
|
311 |
+
template.
|
312 |
+
"""
|
313 |
+
if is_jinja_available():
|
314 |
+
import jinja2
|
315 |
+
else:
|
316 |
+
raise ImportError(
|
317 |
+
"Using RepoCard.from_template requires Jinja2 to be installed. Please"
|
318 |
+
" install it with `pip install Jinja2`."
|
319 |
+
)
|
320 |
+
|
321 |
+
kwargs = card_data.to_dict().copy()
|
322 |
+
kwargs.update(template_kwargs) # Template_kwargs have priority
|
323 |
+
template = jinja2.Template(Path(template_path or cls.default_template_path).read_text())
|
324 |
+
content = template.render(card_data=card_data.to_yaml(), **kwargs)
|
325 |
+
return cls(content)
|
326 |
+
|
327 |
+
|
328 |
+
class ModelCard(RepoCard):
|
329 |
+
card_data_class = ModelCardData
|
330 |
+
default_template_path = TEMPLATE_MODELCARD_PATH
|
331 |
+
repo_type = "model"
|
332 |
+
|
333 |
+
@classmethod
|
334 |
+
def from_template( # type: ignore # violates Liskov property but easier to use
|
335 |
+
cls,
|
336 |
+
card_data: ModelCardData,
|
337 |
+
template_path: Optional[str] = None,
|
338 |
+
**template_kwargs,
|
339 |
+
):
|
340 |
+
"""Initialize a ModelCard from a template. By default, it uses the default template, which can be found here:
|
341 |
+
https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md
|
342 |
+
|
343 |
+
Templates are Jinja2 templates that can be customized by passing keyword arguments.
|
344 |
+
|
345 |
+
Args:
|
346 |
+
card_data (`huggingface_hub.ModelCardData`):
|
347 |
+
A huggingface_hub.ModelCardData instance containing the metadata you want to include in the YAML
|
348 |
+
header of the model card on the Hugging Face Hub.
|
349 |
+
template_path (`str`, *optional*):
|
350 |
+
A path to a markdown file with optional Jinja template variables that can be filled
|
351 |
+
in with `template_kwargs`. Defaults to the default template.
|
352 |
+
|
353 |
+
Returns:
|
354 |
+
[`huggingface_hub.ModelCard`]: A ModelCard instance with the specified card data and content from the
|
355 |
+
template.
|
356 |
+
|
357 |
+
Example:
|
358 |
+
```python
|
359 |
+
>>> from huggingface_hub import ModelCard, ModelCardData, EvalResult
|
360 |
+
|
361 |
+
>>> # Using the Default Template
|
362 |
+
>>> card_data = ModelCardData(
|
363 |
+
... language='en',
|
364 |
+
... license='mit',
|
365 |
+
... library_name='timm',
|
366 |
+
... tags=['image-classification', 'resnet'],
|
367 |
+
... datasets=['beans'],
|
368 |
+
... metrics=['accuracy'],
|
369 |
+
... )
|
370 |
+
>>> card = ModelCard.from_template(
|
371 |
+
... card_data,
|
372 |
+
... model_description='This model does x + y...'
|
373 |
+
... )
|
374 |
+
|
375 |
+
>>> # Including Evaluation Results
|
376 |
+
>>> card_data = ModelCardData(
|
377 |
+
... language='en',
|
378 |
+
... tags=['image-classification', 'resnet'],
|
379 |
+
... eval_results=[
|
380 |
+
... EvalResult(
|
381 |
+
... task_type='image-classification',
|
382 |
+
... dataset_type='beans',
|
383 |
+
... dataset_name='Beans',
|
384 |
+
... metric_type='accuracy',
|
385 |
+
... metric_value=0.9,
|
386 |
+
... ),
|
387 |
+
... ],
|
388 |
+
... model_name='my-cool-model',
|
389 |
+
... )
|
390 |
+
>>> card = ModelCard.from_template(card_data)
|
391 |
+
|
392 |
+
>>> # Using a Custom Template
|
393 |
+
>>> card_data = ModelCardData(
|
394 |
+
... language='en',
|
395 |
+
... tags=['image-classification', 'resnet']
|
396 |
+
... )
|
397 |
+
>>> card = ModelCard.from_template(
|
398 |
+
... card_data=card_data,
|
399 |
+
... template_path='./src/huggingface_hub/templates/modelcard_template.md',
|
400 |
+
... custom_template_var='custom value', # will be replaced in template if it exists
|
401 |
+
... )
|
402 |
+
|
403 |
+
```
|
404 |
+
"""
|
405 |
+
return super().from_template(card_data, template_path, **template_kwargs)
|
406 |
+
|
407 |
+
|
408 |
+
class DatasetCard(RepoCard):
|
409 |
+
card_data_class = DatasetCardData
|
410 |
+
default_template_path = TEMPLATE_DATASETCARD_PATH
|
411 |
+
repo_type = "dataset"
|
412 |
+
|
413 |
+
@classmethod
|
414 |
+
def from_template( # type: ignore # violates Liskov property but easier to use
|
415 |
+
cls,
|
416 |
+
card_data: DatasetCardData,
|
417 |
+
template_path: Optional[str] = None,
|
418 |
+
**template_kwargs,
|
419 |
+
):
|
420 |
+
"""Initialize a DatasetCard from a template. By default, it uses the default template, which can be found here:
|
421 |
+
https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/datasetcard_template.md
|
422 |
+
|
423 |
+
Templates are Jinja2 templates that can be customized by passing keyword arguments.
|
424 |
+
|
425 |
+
Args:
|
426 |
+
card_data (`huggingface_hub.DatasetCardData`):
|
427 |
+
A huggingface_hub.DatasetCardData instance containing the metadata you want to include in the YAML
|
428 |
+
header of the dataset card on the Hugging Face Hub.
|
429 |
+
template_path (`str`, *optional*):
|
430 |
+
A path to a markdown file with optional Jinja template variables that can be filled
|
431 |
+
in with `template_kwargs`. Defaults to the default template.
|
432 |
+
|
433 |
+
Returns:
|
434 |
+
[`huggingface_hub.DatasetCard`]: A DatasetCard instance with the specified card data and content from the
|
435 |
+
template.
|
436 |
+
|
437 |
+
Example:
|
438 |
+
```python
|
439 |
+
>>> from huggingface_hub import DatasetCard, DatasetCardData
|
440 |
+
|
441 |
+
>>> # Using the Default Template
|
442 |
+
>>> card_data = DatasetCardData(
|
443 |
+
... language='en',
|
444 |
+
... license='mit',
|
445 |
+
... annotations_creators='crowdsourced',
|
446 |
+
... task_categories=['text-classification'],
|
447 |
+
... task_ids=['sentiment-classification', 'text-scoring'],
|
448 |
+
... multilinguality='monolingual',
|
449 |
+
... pretty_name='My Text Classification Dataset',
|
450 |
+
... )
|
451 |
+
>>> card = DatasetCard.from_template(
|
452 |
+
... card_data,
|
453 |
+
... pretty_name=card_data.pretty_name,
|
454 |
+
... )
|
455 |
+
|
456 |
+
>>> # Using a Custom Template
|
457 |
+
>>> card_data = DatasetCardData(
|
458 |
+
... language='en',
|
459 |
+
... license='mit',
|
460 |
+
... )
|
461 |
+
>>> card = DatasetCard.from_template(
|
462 |
+
... card_data=card_data,
|
463 |
+
... template_path='./src/huggingface_hub/templates/datasetcard_template.md',
|
464 |
+
... custom_template_var='custom value', # will be replaced in template if it exists
|
465 |
+
... )
|
466 |
+
|
467 |
+
```
|
468 |
+
"""
|
469 |
+
return super().from_template(card_data, template_path, **template_kwargs)
|
470 |
+
|
471 |
+
|
472 |
+
class SpaceCard(RepoCard):
|
473 |
+
card_data_class = SpaceCardData
|
474 |
+
default_template_path = TEMPLATE_MODELCARD_PATH
|
475 |
+
repo_type = "space"
|
476 |
+
|
477 |
+
|
478 |
+
def _detect_line_ending(content: str) -> Literal["\r", "\n", "\r\n", None]: # noqa: F722
|
479 |
+
"""Detect the line ending of a string. Used by RepoCard to avoid making huge diff on newlines.
|
480 |
+
|
481 |
+
Uses same implementation as in Hub server, keep it in sync.
|
482 |
+
|
483 |
+
Returns:
|
484 |
+
str: The detected line ending of the string.
|
485 |
+
"""
|
486 |
+
cr = content.count("\r")
|
487 |
+
lf = content.count("\n")
|
488 |
+
crlf = content.count("\r\n")
|
489 |
+
if cr + lf == 0:
|
490 |
+
return None
|
491 |
+
if crlf == cr and crlf == lf:
|
492 |
+
return "\r\n"
|
493 |
+
if cr > lf:
|
494 |
+
return "\r"
|
495 |
+
else:
|
496 |
+
return "\n"
|
497 |
+
|
498 |
+
|
499 |
+
def metadata_load(local_path: Union[str, Path]) -> Optional[Dict]:
|
500 |
+
content = Path(local_path).read_text()
|
501 |
+
match = REGEX_YAML_BLOCK.search(content)
|
502 |
+
if match:
|
503 |
+
yaml_block = match.group(2)
|
504 |
+
data = yaml.safe_load(yaml_block)
|
505 |
+
if data is None or isinstance(data, dict):
|
506 |
+
return data
|
507 |
+
raise ValueError("repo card metadata block should be a dict")
|
508 |
+
else:
|
509 |
+
return None
|
510 |
+
|
511 |
+
|
512 |
+
def metadata_save(local_path: Union[str, Path], data: Dict) -> None:
|
513 |
+
"""
|
514 |
+
Save the metadata dict in the upper YAML part Trying to preserve newlines as
|
515 |
+
in the existing file. Docs about open() with newline="" parameter:
|
516 |
+
https://docs.python.org/3/library/functions.html?highlight=open#open Does
|
517 |
+
not work with "^M" linebreaks, which are replaced by \n
|
518 |
+
"""
|
519 |
+
line_break = "\n"
|
520 |
+
content = ""
|
521 |
+
# try to detect existing newline character
|
522 |
+
if os.path.exists(local_path):
|
523 |
+
with open(local_path, "r", newline="", encoding="utf8") as readme:
|
524 |
+
content = readme.read()
|
525 |
+
if isinstance(readme.newlines, tuple):
|
526 |
+
line_break = readme.newlines[0]
|
527 |
+
elif isinstance(readme.newlines, str):
|
528 |
+
line_break = readme.newlines
|
529 |
+
|
530 |
+
# creates a new file if it not
|
531 |
+
with open(local_path, "w", newline="", encoding="utf8") as readme:
|
532 |
+
data_yaml = yaml_dump(data, sort_keys=False, line_break=line_break)
|
533 |
+
# sort_keys: keep dict order
|
534 |
+
match = REGEX_YAML_BLOCK.search(content)
|
535 |
+
if match:
|
536 |
+
output = content[: match.start()] + f"---{line_break}{data_yaml}---{line_break}" + content[match.end() :]
|
537 |
+
else:
|
538 |
+
output = f"---{line_break}{data_yaml}---{line_break}{content}"
|
539 |
+
|
540 |
+
readme.write(output)
|
541 |
+
readme.close()
|
542 |
+
|
543 |
+
|
544 |
+
def metadata_eval_result(
|
545 |
+
*,
|
546 |
+
model_pretty_name: str,
|
547 |
+
task_pretty_name: str,
|
548 |
+
task_id: str,
|
549 |
+
metrics_pretty_name: str,
|
550 |
+
metrics_id: str,
|
551 |
+
metrics_value: Any,
|
552 |
+
dataset_pretty_name: str,
|
553 |
+
dataset_id: str,
|
554 |
+
metrics_config: Optional[str] = None,
|
555 |
+
metrics_verified: bool = False,
|
556 |
+
dataset_config: Optional[str] = None,
|
557 |
+
dataset_split: Optional[str] = None,
|
558 |
+
dataset_revision: Optional[str] = None,
|
559 |
+
metrics_verification_token: Optional[str] = None,
|
560 |
+
) -> Dict:
|
561 |
+
"""
|
562 |
+
Creates a metadata dict with the result from a model evaluated on a dataset.
|
563 |
+
|
564 |
+
Args:
|
565 |
+
model_pretty_name (`str`):
|
566 |
+
The name of the model in natural language.
|
567 |
+
task_pretty_name (`str`):
|
568 |
+
The name of a task in natural language.
|
569 |
+
task_id (`str`):
|
570 |
+
Example: automatic-speech-recognition. A task id.
|
571 |
+
metrics_pretty_name (`str`):
|
572 |
+
A name for the metric in natural language. Example: Test WER.
|
573 |
+
metrics_id (`str`):
|
574 |
+
Example: wer. A metric id from https://hf.co/metrics.
|
575 |
+
metrics_value (`Any`):
|
576 |
+
The value from the metric. Example: 20.0 or "20.0 ± 1.2".
|
577 |
+
dataset_pretty_name (`str`):
|
578 |
+
The name of the dataset in natural language.
|
579 |
+
dataset_id (`str`):
|
580 |
+
Example: common_voice. A dataset id from https://hf.co/datasets.
|
581 |
+
metrics_config (`str`, *optional*):
|
582 |
+
The name of the metric configuration used in `load_metric()`.
|
583 |
+
Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`.
|
584 |
+
metrics_verified (`bool`, *optional*, defaults to `False`):
|
585 |
+
Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set.
|
586 |
+
dataset_config (`str`, *optional*):
|
587 |
+
Example: fr. The name of the dataset configuration used in `load_dataset()`.
|
588 |
+
dataset_split (`str`, *optional*):
|
589 |
+
Example: test. The name of the dataset split used in `load_dataset()`.
|
590 |
+
dataset_revision (`str`, *optional*):
|
591 |
+
Example: 5503434ddd753f426f4b38109466949a1217c2bb. The name of the dataset dataset revision
|
592 |
+
used in `load_dataset()`.
|
593 |
+
metrics_verification_token (`bool`, *optional*):
|
594 |
+
A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not.
|
595 |
+
|
596 |
+
Returns:
|
597 |
+
`dict`: a metadata dict with the result from a model evaluated on a dataset.
|
598 |
+
|
599 |
+
Example:
|
600 |
+
```python
|
601 |
+
>>> from huggingface_hub import metadata_eval_result
|
602 |
+
>>> results = metadata_eval_result(
|
603 |
+
... model_pretty_name="RoBERTa fine-tuned on ReactionGIF",
|
604 |
+
... task_pretty_name="Text Classification",
|
605 |
+
... task_id="text-classification",
|
606 |
+
... metrics_pretty_name="Accuracy",
|
607 |
+
... metrics_id="accuracy",
|
608 |
+
... metrics_value=0.2662102282047272,
|
609 |
+
... dataset_pretty_name="ReactionJPEG",
|
610 |
+
... dataset_id="julien-c/reactionjpeg",
|
611 |
+
... dataset_config="default",
|
612 |
+
... dataset_split="test",
|
613 |
+
... )
|
614 |
+
>>> results == {
|
615 |
+
... 'model-index': [
|
616 |
+
... {
|
617 |
+
... 'name': 'RoBERTa fine-tuned on ReactionGIF',
|
618 |
+
... 'results': [
|
619 |
+
... {
|
620 |
+
... 'task': {
|
621 |
+
... 'type': 'text-classification',
|
622 |
+
... 'name': 'Text Classification'
|
623 |
+
... },
|
624 |
+
... 'dataset': {
|
625 |
+
... 'name': 'ReactionJPEG',
|
626 |
+
... 'type': 'julien-c/reactionjpeg',
|
627 |
+
... 'config': 'default',
|
628 |
+
... 'split': 'test'
|
629 |
+
... },
|
630 |
+
... 'metrics': [
|
631 |
+
... {
|
632 |
+
... 'type': 'accuracy',
|
633 |
+
... 'value': 0.2662102282047272,
|
634 |
+
... 'name': 'Accuracy',
|
635 |
+
... 'verified': False
|
636 |
+
... }
|
637 |
+
... ]
|
638 |
+
... }
|
639 |
+
... ]
|
640 |
+
... }
|
641 |
+
... ]
|
642 |
+
... }
|
643 |
+
True
|
644 |
+
|
645 |
+
```
|
646 |
+
"""
|
647 |
+
|
648 |
+
return {
|
649 |
+
"model-index": eval_results_to_model_index(
|
650 |
+
model_name=model_pretty_name,
|
651 |
+
eval_results=[
|
652 |
+
EvalResult(
|
653 |
+
task_name=task_pretty_name,
|
654 |
+
task_type=task_id,
|
655 |
+
metric_name=metrics_pretty_name,
|
656 |
+
metric_type=metrics_id,
|
657 |
+
metric_value=metrics_value,
|
658 |
+
dataset_name=dataset_pretty_name,
|
659 |
+
dataset_type=dataset_id,
|
660 |
+
metric_config=metrics_config,
|
661 |
+
verified=metrics_verified,
|
662 |
+
verify_token=metrics_verification_token,
|
663 |
+
dataset_config=dataset_config,
|
664 |
+
dataset_split=dataset_split,
|
665 |
+
dataset_revision=dataset_revision,
|
666 |
+
)
|
667 |
+
],
|
668 |
+
)
|
669 |
+
}
|
670 |
+
|
671 |
+
|
672 |
+
@validate_hf_hub_args
|
673 |
+
def metadata_update(
|
674 |
+
repo_id: str,
|
675 |
+
metadata: Dict,
|
676 |
+
*,
|
677 |
+
repo_type: Optional[str] = None,
|
678 |
+
overwrite: bool = False,
|
679 |
+
token: Optional[str] = None,
|
680 |
+
commit_message: Optional[str] = None,
|
681 |
+
commit_description: Optional[str] = None,
|
682 |
+
revision: Optional[str] = None,
|
683 |
+
create_pr: bool = False,
|
684 |
+
parent_commit: Optional[str] = None,
|
685 |
+
) -> str:
|
686 |
+
"""
|
687 |
+
Updates the metadata in the README.md of a repository on the Hugging Face Hub.
|
688 |
+
If the README.md file doesn't exist yet, a new one is created with metadata and an
|
689 |
+
the default ModelCard or DatasetCard template. For `space` repo, an error is thrown
|
690 |
+
as a Space cannot exist without a `README.md` file.
|
691 |
+
|
692 |
+
Args:
|
693 |
+
repo_id (`str`):
|
694 |
+
The name of the repository.
|
695 |
+
metadata (`dict`):
|
696 |
+
A dictionary containing the metadata to be updated.
|
697 |
+
repo_type (`str`, *optional*):
|
698 |
+
Set to `"dataset"` or `"space"` if updating to a dataset or space,
|
699 |
+
`None` or `"model"` if updating to a model. Default is `None`.
|
700 |
+
overwrite (`bool`, *optional*, defaults to `False`):
|
701 |
+
If set to `True` an existing field can be overwritten, otherwise
|
702 |
+
attempting to overwrite an existing field will cause an error.
|
703 |
+
token (`str`, *optional*):
|
704 |
+
The Hugging Face authentication token.
|
705 |
+
commit_message (`str`, *optional*):
|
706 |
+
The summary / title / first line of the generated commit. Defaults to
|
707 |
+
`f"Update metadata with huggingface_hub"`
|
708 |
+
commit_description (`str` *optional*)
|
709 |
+
The description of the generated commit
|
710 |
+
revision (`str`, *optional*):
|
711 |
+
The git revision to commit from. Defaults to the head of the
|
712 |
+
`"main"` branch.
|
713 |
+
create_pr (`boolean`, *optional*):
|
714 |
+
Whether or not to create a Pull Request from `revision` with that commit.
|
715 |
+
Defaults to `False`.
|
716 |
+
parent_commit (`str`, *optional*):
|
717 |
+
The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported.
|
718 |
+
If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`.
|
719 |
+
If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`.
|
720 |
+
Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be
|
721 |
+
especially useful if the repo is updated / committed to concurrently.
|
722 |
+
Returns:
|
723 |
+
`str`: URL of the commit which updated the card metadata.
|
724 |
+
|
725 |
+
Example:
|
726 |
+
```python
|
727 |
+
>>> from huggingface_hub import metadata_update
|
728 |
+
>>> metadata = {'model-index': [{'name': 'RoBERTa fine-tuned on ReactionGIF',
|
729 |
+
... 'results': [{'dataset': {'name': 'ReactionGIF',
|
730 |
+
... 'type': 'julien-c/reactiongif'},
|
731 |
+
... 'metrics': [{'name': 'Recall',
|
732 |
+
... 'type': 'recall',
|
733 |
+
... 'value': 0.7762102282047272}],
|
734 |
+
... 'task': {'name': 'Text Classification',
|
735 |
+
... 'type': 'text-classification'}}]}]}
|
736 |
+
>>> url = metadata_update("hf-internal-testing/reactiongif-roberta-card", metadata)
|
737 |
+
|
738 |
+
```
|
739 |
+
"""
|
740 |
+
commit_message = commit_message if commit_message is not None else "Update metadata with huggingface_hub"
|
741 |
+
|
742 |
+
# Card class given repo_type
|
743 |
+
card_class: Type[RepoCard]
|
744 |
+
if repo_type is None or repo_type == "model":
|
745 |
+
card_class = ModelCard
|
746 |
+
elif repo_type == "dataset":
|
747 |
+
card_class = DatasetCard
|
748 |
+
elif repo_type == "space":
|
749 |
+
card_class = RepoCard
|
750 |
+
else:
|
751 |
+
raise ValueError(f"Unknown repo_type: {repo_type}")
|
752 |
+
|
753 |
+
# Either load repo_card from the Hub or create an empty one.
|
754 |
+
# NOTE: Will not create the repo if it doesn't exist.
|
755 |
+
try:
|
756 |
+
card = card_class.load(repo_id, token=token, repo_type=repo_type)
|
757 |
+
except EntryNotFoundError:
|
758 |
+
if repo_type == "space":
|
759 |
+
raise ValueError("Cannot update metadata on a Space that doesn't contain a `README.md` file.")
|
760 |
+
|
761 |
+
# Initialize a ModelCard or DatasetCard from default template and no data.
|
762 |
+
card = card_class.from_template(CardData())
|
763 |
+
|
764 |
+
for key, value in metadata.items():
|
765 |
+
if key == "model-index":
|
766 |
+
# if the new metadata doesn't include a name, either use existing one or repo name
|
767 |
+
if "name" not in value[0]:
|
768 |
+
value[0]["name"] = getattr(card, "model_name", repo_id)
|
769 |
+
model_name, new_results = model_index_to_eval_results(value)
|
770 |
+
if card.data.eval_results is None:
|
771 |
+
card.data.eval_results = new_results
|
772 |
+
card.data.model_name = model_name
|
773 |
+
else:
|
774 |
+
existing_results = card.data.eval_results
|
775 |
+
|
776 |
+
# Iterate over new results
|
777 |
+
# Iterate over existing results
|
778 |
+
# If both results describe the same metric but value is different:
|
779 |
+
# If overwrite=True: overwrite the metric value
|
780 |
+
# Else: raise ValueError
|
781 |
+
# Else: append new result to existing ones.
|
782 |
+
for new_result in new_results:
|
783 |
+
result_found = False
|
784 |
+
for existing_result in existing_results:
|
785 |
+
if new_result.is_equal_except_value(existing_result):
|
786 |
+
if new_result != existing_result and not overwrite:
|
787 |
+
raise ValueError(
|
788 |
+
"You passed a new value for the existing metric"
|
789 |
+
f" 'name: {new_result.metric_name}, type: "
|
790 |
+
f"{new_result.metric_type}'. Set `overwrite=True`"
|
791 |
+
" to overwrite existing metrics."
|
792 |
+
)
|
793 |
+
result_found = True
|
794 |
+
existing_result.metric_value = new_result.metric_value
|
795 |
+
if existing_result.verified is True:
|
796 |
+
existing_result.verify_token = new_result.verify_token
|
797 |
+
if not result_found:
|
798 |
+
card.data.eval_results.append(new_result)
|
799 |
+
else:
|
800 |
+
# Any metadata that is not a result metric
|
801 |
+
if card.data.get(key) is not None and not overwrite and card.data.get(key) != value:
|
802 |
+
raise ValueError(
|
803 |
+
f"You passed a new value for the existing meta data field '{key}'."
|
804 |
+
" Set `overwrite=True` to overwrite existing metadata."
|
805 |
+
)
|
806 |
+
else:
|
807 |
+
card.data[key] = value
|
808 |
+
|
809 |
+
return card.push_to_hub(
|
810 |
+
repo_id,
|
811 |
+
token=token,
|
812 |
+
repo_type=repo_type,
|
813 |
+
commit_message=commit_message,
|
814 |
+
commit_description=commit_description,
|
815 |
+
create_pr=create_pr,
|
816 |
+
revision=revision,
|
817 |
+
parent_commit=parent_commit,
|
818 |
+
)
|
lib/python3.11/site-packages/huggingface_hub/repocard_data.py
ADDED
@@ -0,0 +1,711 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import warnings
|
3 |
+
from collections import defaultdict
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
from huggingface_hub.utils import yaml_dump
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class EvalResult:
|
12 |
+
"""
|
13 |
+
Flattened representation of individual evaluation results found in model-index of Model Cards.
|
14 |
+
|
15 |
+
For more information on the model-index spec, see https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
task_type (`str`):
|
19 |
+
The task identifier. Example: "image-classification".
|
20 |
+
dataset_type (`str`):
|
21 |
+
The dataset identifier. Example: "common_voice". Use dataset id from https://hf.co/datasets.
|
22 |
+
dataset_name (`str`):
|
23 |
+
A pretty name for the dataset. Example: "Common Voice (French)".
|
24 |
+
metric_type (`str`):
|
25 |
+
The metric identifier. Example: "wer". Use metric id from https://hf.co/metrics.
|
26 |
+
metric_value (`Any`):
|
27 |
+
The metric value. Example: 0.9 or "20.0 ± 1.2".
|
28 |
+
task_name (`str`, *optional*):
|
29 |
+
A pretty name for the task. Example: "Speech Recognition".
|
30 |
+
dataset_config (`str`, *optional*):
|
31 |
+
The name of the dataset configuration used in `load_dataset()`.
|
32 |
+
Example: fr in `load_dataset("common_voice", "fr")`. See the `datasets` docs for more info:
|
33 |
+
https://hf.co/docs/datasets/package_reference/loading_methods#datasets.load_dataset.name
|
34 |
+
dataset_split (`str`, *optional*):
|
35 |
+
The split used in `load_dataset()`. Example: "test".
|
36 |
+
dataset_revision (`str`, *optional*):
|
37 |
+
The revision (AKA Git Sha) of the dataset used in `load_dataset()`.
|
38 |
+
Example: 5503434ddd753f426f4b38109466949a1217c2bb
|
39 |
+
dataset_args (`Dict[str, Any]`, *optional*):
|
40 |
+
The arguments passed during `Metric.compute()`. Example for `bleu`: `{"max_order": 4}`
|
41 |
+
metric_name (`str`, *optional*):
|
42 |
+
A pretty name for the metric. Example: "Test WER".
|
43 |
+
metric_config (`str`, *optional*):
|
44 |
+
The name of the metric configuration used in `load_metric()`.
|
45 |
+
Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`.
|
46 |
+
See the `datasets` docs for more info: https://huggingface.co/docs/datasets/v2.1.0/en/loading#load-configurations
|
47 |
+
metric_args (`Dict[str, Any]`, *optional*):
|
48 |
+
The arguments passed during `Metric.compute()`. Example for `bleu`: max_order: 4
|
49 |
+
verified (`bool`, *optional*):
|
50 |
+
Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set.
|
51 |
+
verify_token (`str`, *optional*):
|
52 |
+
A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not.
|
53 |
+
source_name (`str`, *optional*):
|
54 |
+
The name of the source of the evaluation result. Example: "Open LLM Leaderboard".
|
55 |
+
source_url (`str`, *optional*):
|
56 |
+
The URL of the source of the evaluation result. Example: "https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard".
|
57 |
+
"""
|
58 |
+
|
59 |
+
# Required
|
60 |
+
|
61 |
+
# The task identifier
|
62 |
+
# Example: automatic-speech-recognition
|
63 |
+
task_type: str
|
64 |
+
|
65 |
+
# The dataset identifier
|
66 |
+
# Example: common_voice. Use dataset id from https://hf.co/datasets
|
67 |
+
dataset_type: str
|
68 |
+
|
69 |
+
# A pretty name for the dataset.
|
70 |
+
# Example: Common Voice (French)
|
71 |
+
dataset_name: str
|
72 |
+
|
73 |
+
# The metric identifier
|
74 |
+
# Example: wer. Use metric id from https://hf.co/metrics
|
75 |
+
metric_type: str
|
76 |
+
|
77 |
+
# Value of the metric.
|
78 |
+
# Example: 20.0 or "20.0 ± 1.2"
|
79 |
+
metric_value: Any
|
80 |
+
|
81 |
+
# Optional
|
82 |
+
|
83 |
+
# A pretty name for the task.
|
84 |
+
# Example: Speech Recognition
|
85 |
+
task_name: Optional[str] = None
|
86 |
+
|
87 |
+
# The name of the dataset configuration used in `load_dataset()`.
|
88 |
+
# Example: fr in `load_dataset("common_voice", "fr")`.
|
89 |
+
# See the `datasets` docs for more info:
|
90 |
+
# https://huggingface.co/docs/datasets/package_reference/loading_methods#datasets.load_dataset.name
|
91 |
+
dataset_config: Optional[str] = None
|
92 |
+
|
93 |
+
# The split used in `load_dataset()`.
|
94 |
+
# Example: test
|
95 |
+
dataset_split: Optional[str] = None
|
96 |
+
|
97 |
+
# The revision (AKA Git Sha) of the dataset used in `load_dataset()`.
|
98 |
+
# Example: 5503434ddd753f426f4b38109466949a1217c2bb
|
99 |
+
dataset_revision: Optional[str] = None
|
100 |
+
|
101 |
+
# The arguments passed during `Metric.compute()`.
|
102 |
+
# Example for `bleu`: max_order: 4
|
103 |
+
dataset_args: Optional[Dict[str, Any]] = None
|
104 |
+
|
105 |
+
# A pretty name for the metric.
|
106 |
+
# Example: Test WER
|
107 |
+
metric_name: Optional[str] = None
|
108 |
+
|
109 |
+
# The name of the metric configuration used in `load_metric()`.
|
110 |
+
# Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`.
|
111 |
+
# See the `datasets` docs for more info: https://huggingface.co/docs/datasets/v2.1.0/en/loading#load-configurations
|
112 |
+
metric_config: Optional[str] = None
|
113 |
+
|
114 |
+
# The arguments passed during `Metric.compute()`.
|
115 |
+
# Example for `bleu`: max_order: 4
|
116 |
+
metric_args: Optional[Dict[str, Any]] = None
|
117 |
+
|
118 |
+
# Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set.
|
119 |
+
verified: Optional[bool] = None
|
120 |
+
|
121 |
+
# A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not.
|
122 |
+
verify_token: Optional[str] = None
|
123 |
+
|
124 |
+
# The name of the source of the evaluation result.
|
125 |
+
# Example: Open LLM Leaderboard
|
126 |
+
source_name: Optional[str] = None
|
127 |
+
|
128 |
+
# The URL of the source of the evaluation result.
|
129 |
+
# Example: https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard
|
130 |
+
source_url: Optional[str] = None
|
131 |
+
|
132 |
+
@property
|
133 |
+
def unique_identifier(self) -> tuple:
|
134 |
+
"""Returns a tuple that uniquely identifies this evaluation."""
|
135 |
+
return (
|
136 |
+
self.task_type,
|
137 |
+
self.dataset_type,
|
138 |
+
self.dataset_config,
|
139 |
+
self.dataset_split,
|
140 |
+
self.dataset_revision,
|
141 |
+
)
|
142 |
+
|
143 |
+
def is_equal_except_value(self, other: "EvalResult") -> bool:
|
144 |
+
"""
|
145 |
+
Return True if `self` and `other` describe exactly the same metric but with a
|
146 |
+
different value.
|
147 |
+
"""
|
148 |
+
for key, _ in self.__dict__.items():
|
149 |
+
if key == "metric_value":
|
150 |
+
continue
|
151 |
+
# For metrics computed by Hugging Face's evaluation service, `verify_token` is derived from `metric_value`,
|
152 |
+
# so we exclude it here in the comparison.
|
153 |
+
if key != "verify_token" and getattr(self, key) != getattr(other, key):
|
154 |
+
return False
|
155 |
+
return True
|
156 |
+
|
157 |
+
def __post_init__(self) -> None:
|
158 |
+
if self.source_name is not None and self.source_url is None:
|
159 |
+
raise ValueError("If `source_name` is provided, `source_url` must also be provided.")
|
160 |
+
|
161 |
+
|
162 |
+
@dataclass
|
163 |
+
class CardData:
|
164 |
+
"""Structure containing metadata from a RepoCard.
|
165 |
+
|
166 |
+
[`CardData`] is the parent class of [`ModelCardData`] and [`DatasetCardData`].
|
167 |
+
|
168 |
+
Metadata can be exported as a dictionary or YAML. Export can be customized to alter the representation of the data
|
169 |
+
(example: flatten evaluation results). `CardData` behaves as a dictionary (can get, pop, set values) but do not
|
170 |
+
inherit from `dict` to allow this export step.
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(self, ignore_metadata_errors: bool = False, **kwargs):
|
174 |
+
self.__dict__.update(kwargs)
|
175 |
+
|
176 |
+
def to_dict(self) -> Dict[str, Any]:
|
177 |
+
"""Converts CardData to a dict.
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
`dict`: CardData represented as a dictionary ready to be dumped to a YAML
|
181 |
+
block for inclusion in a README.md file.
|
182 |
+
"""
|
183 |
+
|
184 |
+
data_dict = copy.deepcopy(self.__dict__)
|
185 |
+
self._to_dict(data_dict)
|
186 |
+
return _remove_none(data_dict)
|
187 |
+
|
188 |
+
def _to_dict(self, data_dict):
|
189 |
+
"""Use this method in child classes to alter the dict representation of the data. Alter the dict in-place.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
data_dict (`dict`): The raw dict representation of the card data.
|
193 |
+
"""
|
194 |
+
pass
|
195 |
+
|
196 |
+
def to_yaml(self, line_break=None) -> str:
|
197 |
+
"""Dumps CardData to a YAML block for inclusion in a README.md file.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
line_break (str, *optional*):
|
201 |
+
The line break to use when dumping to yaml.
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
`str`: CardData represented as a YAML block.
|
205 |
+
"""
|
206 |
+
return yaml_dump(self.to_dict(), sort_keys=False, line_break=line_break).strip()
|
207 |
+
|
208 |
+
def __repr__(self):
|
209 |
+
return repr(self.__dict__)
|
210 |
+
|
211 |
+
def __str__(self):
|
212 |
+
return self.to_yaml()
|
213 |
+
|
214 |
+
def get(self, key: str, default: Any = None) -> Any:
|
215 |
+
"""Get value for a given metadata key."""
|
216 |
+
return self.__dict__.get(key, default)
|
217 |
+
|
218 |
+
def pop(self, key: str, default: Any = None) -> Any:
|
219 |
+
"""Pop value for a given metadata key."""
|
220 |
+
return self.__dict__.pop(key, default)
|
221 |
+
|
222 |
+
def __getitem__(self, key: str) -> Any:
|
223 |
+
"""Get value for a given metadata key."""
|
224 |
+
return self.__dict__[key]
|
225 |
+
|
226 |
+
def __setitem__(self, key: str, value: Any) -> None:
|
227 |
+
"""Set value for a given metadata key."""
|
228 |
+
self.__dict__[key] = value
|
229 |
+
|
230 |
+
def __contains__(self, key: str) -> bool:
|
231 |
+
"""Check if a given metadata key is set."""
|
232 |
+
return key in self.__dict__
|
233 |
+
|
234 |
+
def __len__(self) -> int:
|
235 |
+
"""Return the number of metadata keys set."""
|
236 |
+
return len(self.__dict__)
|
237 |
+
|
238 |
+
|
239 |
+
class ModelCardData(CardData):
|
240 |
+
"""Model Card Metadata that is used by Hugging Face Hub when included at the top of your README.md
|
241 |
+
|
242 |
+
Args:
|
243 |
+
language (`Union[str, List[str]]`, *optional*):
|
244 |
+
Language of model's training data or metadata. It must be an ISO 639-1, 639-2 or
|
245 |
+
639-3 code (two/three letters), or a special value like "code", "multilingual". Defaults to `None`.
|
246 |
+
license (`str`, *optional*):
|
247 |
+
License of this model. Example: apache-2.0 or any license from
|
248 |
+
https://huggingface.co/docs/hub/repositories-licenses. Defaults to None.
|
249 |
+
library_name (`str`, *optional*):
|
250 |
+
Name of library used by this model. Example: keras or any library from
|
251 |
+
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/model-libraries.ts.
|
252 |
+
Defaults to None.
|
253 |
+
tags (`List[str]`, *optional*):
|
254 |
+
List of tags to add to your model that can be used when filtering on the Hugging
|
255 |
+
Face Hub. Defaults to None.
|
256 |
+
datasets (`List[str]`, *optional*):
|
257 |
+
List of datasets that were used to train this model. Should be a dataset ID
|
258 |
+
found on https://hf.co/datasets. Defaults to None.
|
259 |
+
metrics (`List[str]`, *optional*):
|
260 |
+
List of metrics used to evaluate this model. Should be a metric name that can be found
|
261 |
+
at https://hf.co/metrics. Example: 'accuracy'. Defaults to None.
|
262 |
+
eval_results (`Union[List[EvalResult], EvalResult]`, *optional*):
|
263 |
+
List of `huggingface_hub.EvalResult` that define evaluation results of the model. If provided,
|
264 |
+
`model_name` is used to as a name on PapersWithCode's leaderboards. Defaults to `None`.
|
265 |
+
model_name (`str`, *optional*):
|
266 |
+
A name for this model. It is used along with
|
267 |
+
`eval_results` to construct the `model-index` within the card's metadata. The name
|
268 |
+
you supply here is what will be used on PapersWithCode's leaderboards. If None is provided
|
269 |
+
then the repo name is used as a default. Defaults to None.
|
270 |
+
ignore_metadata_errors (`str`):
|
271 |
+
If True, errors while parsing the metadata section will be ignored. Some information might be lost during
|
272 |
+
the process. Use it at your own risk.
|
273 |
+
kwargs (`dict`, *optional*):
|
274 |
+
Additional metadata that will be added to the model card. Defaults to None.
|
275 |
+
|
276 |
+
Example:
|
277 |
+
```python
|
278 |
+
>>> from huggingface_hub import ModelCardData
|
279 |
+
>>> card_data = ModelCardData(
|
280 |
+
... language="en",
|
281 |
+
... license="mit",
|
282 |
+
... library_name="timm",
|
283 |
+
... tags=['image-classification', 'resnet'],
|
284 |
+
... )
|
285 |
+
>>> card_data.to_dict()
|
286 |
+
{'language': 'en', 'license': 'mit', 'library_name': 'timm', 'tags': ['image-classification', 'resnet']}
|
287 |
+
|
288 |
+
```
|
289 |
+
"""
|
290 |
+
|
291 |
+
def __init__(
|
292 |
+
self,
|
293 |
+
*,
|
294 |
+
language: Optional[Union[str, List[str]]] = None,
|
295 |
+
license: Optional[str] = None,
|
296 |
+
library_name: Optional[str] = None,
|
297 |
+
tags: Optional[List[str]] = None,
|
298 |
+
datasets: Optional[List[str]] = None,
|
299 |
+
metrics: Optional[List[str]] = None,
|
300 |
+
eval_results: Optional[List[EvalResult]] = None,
|
301 |
+
model_name: Optional[str] = None,
|
302 |
+
ignore_metadata_errors: bool = False,
|
303 |
+
**kwargs,
|
304 |
+
):
|
305 |
+
self.language = language
|
306 |
+
self.license = license
|
307 |
+
self.library_name = library_name
|
308 |
+
self.tags = tags
|
309 |
+
self.datasets = datasets
|
310 |
+
self.metrics = metrics
|
311 |
+
self.eval_results = eval_results
|
312 |
+
self.model_name = model_name
|
313 |
+
|
314 |
+
model_index = kwargs.pop("model-index", None)
|
315 |
+
if model_index:
|
316 |
+
try:
|
317 |
+
model_name, eval_results = model_index_to_eval_results(model_index)
|
318 |
+
self.model_name = model_name
|
319 |
+
self.eval_results = eval_results
|
320 |
+
except (KeyError, TypeError) as error:
|
321 |
+
if ignore_metadata_errors:
|
322 |
+
warnings.warn("Invalid model-index. Not loading eval results into CardData.")
|
323 |
+
else:
|
324 |
+
raise ValueError(
|
325 |
+
f"Invalid `model_index` in metadata cannot be parsed: {error.__class__} {error}. Pass"
|
326 |
+
" `ignore_metadata_errors=True` to ignore this error while loading a Model Card. Warning:"
|
327 |
+
" some information will be lost. Use it at your own risk."
|
328 |
+
)
|
329 |
+
|
330 |
+
super().__init__(**kwargs)
|
331 |
+
|
332 |
+
if self.eval_results:
|
333 |
+
if type(self.eval_results) == EvalResult:
|
334 |
+
self.eval_results = [self.eval_results]
|
335 |
+
if self.model_name is None:
|
336 |
+
raise ValueError("Passing `eval_results` requires `model_name` to be set.")
|
337 |
+
|
338 |
+
def _to_dict(self, data_dict):
|
339 |
+
"""Format the internal data dict. In this case, we convert eval results to a valid model index"""
|
340 |
+
if self.eval_results is not None:
|
341 |
+
data_dict["model-index"] = eval_results_to_model_index(self.model_name, self.eval_results)
|
342 |
+
del data_dict["eval_results"], data_dict["model_name"]
|
343 |
+
|
344 |
+
|
345 |
+
class DatasetCardData(CardData):
|
346 |
+
"""Dataset Card Metadata that is used by Hugging Face Hub when included at the top of your README.md
|
347 |
+
|
348 |
+
Args:
|
349 |
+
language (`List[str]`, *optional*):
|
350 |
+
Language of dataset's data or metadata. It must be an ISO 639-1, 639-2 or
|
351 |
+
639-3 code (two/three letters), or a special value like "code", "multilingual".
|
352 |
+
license (`Union[str, List[str]]`, *optional*):
|
353 |
+
License(s) of this dataset. Example: apache-2.0 or any license from
|
354 |
+
https://huggingface.co/docs/hub/repositories-licenses.
|
355 |
+
annotations_creators (`Union[str, List[str]]`, *optional*):
|
356 |
+
How the annotations for the dataset were created.
|
357 |
+
Options are: 'found', 'crowdsourced', 'expert-generated', 'machine-generated', 'no-annotation', 'other'.
|
358 |
+
language_creators (`Union[str, List[str]]`, *optional*):
|
359 |
+
How the text-based data in the dataset was created.
|
360 |
+
Options are: 'found', 'crowdsourced', 'expert-generated', 'machine-generated', 'other'
|
361 |
+
multilinguality (`Union[str, List[str]]`, *optional*):
|
362 |
+
Whether the dataset is multilingual.
|
363 |
+
Options are: 'monolingual', 'multilingual', 'translation', 'other'.
|
364 |
+
size_categories (`Union[str, List[str]]`, *optional*):
|
365 |
+
The number of examples in the dataset. Options are: 'n<1K', '1K<n<10K', '10K<n<100K',
|
366 |
+
'100K<n<1M', '1M<n<10M', '10M<n<100M', '100M<n<1B', '1B<n<10B', '10B<n<100B', '100B<n<1T', 'n>1T', and 'other'.
|
367 |
+
source_datasets (`List[str]]`, *optional*):
|
368 |
+
Indicates whether the dataset is an original dataset or extended from another existing dataset.
|
369 |
+
Options are: 'original' and 'extended'.
|
370 |
+
task_categories (`Union[str, List[str]]`, *optional*):
|
371 |
+
What categories of task does the dataset support?
|
372 |
+
task_ids (`Union[str, List[str]]`, *optional*):
|
373 |
+
What specific tasks does the dataset support?
|
374 |
+
paperswithcode_id (`str`, *optional*):
|
375 |
+
ID of the dataset on PapersWithCode.
|
376 |
+
pretty_name (`str`, *optional*):
|
377 |
+
A more human-readable name for the dataset. (ex. "Cats vs. Dogs")
|
378 |
+
train_eval_index (`Dict`, *optional*):
|
379 |
+
A dictionary that describes the necessary spec for doing evaluation on the Hub.
|
380 |
+
If not provided, it will be gathered from the 'train-eval-index' key of the kwargs.
|
381 |
+
config_names (`Union[str, List[str]]`, *optional*):
|
382 |
+
A list of the available dataset configs for the dataset.
|
383 |
+
"""
|
384 |
+
|
385 |
+
def __init__(
|
386 |
+
self,
|
387 |
+
*,
|
388 |
+
language: Optional[Union[str, List[str]]] = None,
|
389 |
+
license: Optional[Union[str, List[str]]] = None,
|
390 |
+
annotations_creators: Optional[Union[str, List[str]]] = None,
|
391 |
+
language_creators: Optional[Union[str, List[str]]] = None,
|
392 |
+
multilinguality: Optional[Union[str, List[str]]] = None,
|
393 |
+
size_categories: Optional[Union[str, List[str]]] = None,
|
394 |
+
source_datasets: Optional[List[str]] = None,
|
395 |
+
task_categories: Optional[Union[str, List[str]]] = None,
|
396 |
+
task_ids: Optional[Union[str, List[str]]] = None,
|
397 |
+
paperswithcode_id: Optional[str] = None,
|
398 |
+
pretty_name: Optional[str] = None,
|
399 |
+
train_eval_index: Optional[Dict] = None,
|
400 |
+
config_names: Optional[Union[str, List[str]]] = None,
|
401 |
+
ignore_metadata_errors: bool = False,
|
402 |
+
**kwargs,
|
403 |
+
):
|
404 |
+
self.annotations_creators = annotations_creators
|
405 |
+
self.language_creators = language_creators
|
406 |
+
self.language = language
|
407 |
+
self.license = license
|
408 |
+
self.multilinguality = multilinguality
|
409 |
+
self.size_categories = size_categories
|
410 |
+
self.source_datasets = source_datasets
|
411 |
+
self.task_categories = task_categories
|
412 |
+
self.task_ids = task_ids
|
413 |
+
self.paperswithcode_id = paperswithcode_id
|
414 |
+
self.pretty_name = pretty_name
|
415 |
+
self.config_names = config_names
|
416 |
+
|
417 |
+
# TODO - maybe handle this similarly to EvalResult?
|
418 |
+
self.train_eval_index = train_eval_index or kwargs.pop("train-eval-index", None)
|
419 |
+
super().__init__(**kwargs)
|
420 |
+
|
421 |
+
def _to_dict(self, data_dict):
|
422 |
+
data_dict["train-eval-index"] = data_dict.pop("train_eval_index")
|
423 |
+
|
424 |
+
|
425 |
+
class SpaceCardData(CardData):
|
426 |
+
"""Space Card Metadata that is used by Hugging Face Hub when included at the top of your README.md
|
427 |
+
|
428 |
+
To get an exhaustive reference of Spaces configuration, please visit https://huggingface.co/docs/hub/spaces-config-reference#spaces-configuration-reference.
|
429 |
+
|
430 |
+
Args:
|
431 |
+
title (`str`, *optional*)
|
432 |
+
Title of the Space.
|
433 |
+
sdk (`str`, *optional*)
|
434 |
+
SDK of the Space (one of `gradio`, `streamlit`, `docker`, or `static`).
|
435 |
+
sdk_version (`str`, *optional*)
|
436 |
+
Version of the used SDK (if Gradio/Streamlit sdk).
|
437 |
+
python_version (`str`, *optional*)
|
438 |
+
Python version used in the Space (if Gradio/Streamlit sdk).
|
439 |
+
app_file (`str`, *optional*)
|
440 |
+
Path to your main application file (which contains either gradio or streamlit Python code, or static html code).
|
441 |
+
Path is relative to the root of the repository.
|
442 |
+
app_port (`str`, *optional*)
|
443 |
+
Port on which your application is running. Used only if sdk is `docker`.
|
444 |
+
license (`str`, *optional*)
|
445 |
+
License of this model. Example: apache-2.0 or any license from
|
446 |
+
https://huggingface.co/docs/hub/repositories-licenses.
|
447 |
+
duplicated_from (`str`, *optional*)
|
448 |
+
ID of the original Space if this is a duplicated Space.
|
449 |
+
models (List[`str`], *optional*)
|
450 |
+
List of models related to this Space. Should be a dataset ID found on https://hf.co/models.
|
451 |
+
datasets (`List[str]`, *optional*)
|
452 |
+
List of datasets related to this Space. Should be a dataset ID found on https://hf.co/datasets.
|
453 |
+
tags (`List[str]`, *optional*)
|
454 |
+
List of tags to add to your Space that can be used when filtering on the Hub.
|
455 |
+
ignore_metadata_errors (`str`):
|
456 |
+
If True, errors while parsing the metadata section will be ignored. Some information might be lost during
|
457 |
+
the process. Use it at your own risk.
|
458 |
+
kwargs (`dict`, *optional*):
|
459 |
+
Additional metadata that will be added to the space card.
|
460 |
+
|
461 |
+
Example:
|
462 |
+
```python
|
463 |
+
>>> from huggingface_hub import SpaceCardData
|
464 |
+
>>> card_data = SpaceCardData(
|
465 |
+
... title="Dreambooth Training",
|
466 |
+
... license="mit",
|
467 |
+
... sdk="gradio",
|
468 |
+
... duplicated_from="multimodalart/dreambooth-training"
|
469 |
+
... )
|
470 |
+
>>> card_data.to_dict()
|
471 |
+
{'title': 'Dreambooth Training', 'sdk': 'gradio', 'license': 'mit', 'duplicated_from': 'multimodalart/dreambooth-training'}
|
472 |
+
```
|
473 |
+
"""
|
474 |
+
|
475 |
+
def __init__(
|
476 |
+
self,
|
477 |
+
*,
|
478 |
+
title: Optional[str] = None,
|
479 |
+
sdk: Optional[str] = None,
|
480 |
+
sdk_version: Optional[str] = None,
|
481 |
+
python_version: Optional[str] = None,
|
482 |
+
app_file: Optional[str] = None,
|
483 |
+
app_port: Optional[int] = None,
|
484 |
+
license: Optional[str] = None,
|
485 |
+
duplicated_from: Optional[str] = None,
|
486 |
+
models: Optional[List[str]] = None,
|
487 |
+
datasets: Optional[List[str]] = None,
|
488 |
+
tags: Optional[List[str]] = None,
|
489 |
+
ignore_metadata_errors: bool = False,
|
490 |
+
**kwargs,
|
491 |
+
):
|
492 |
+
self.title = title
|
493 |
+
self.sdk = sdk
|
494 |
+
self.sdk_version = sdk_version
|
495 |
+
self.python_version = python_version
|
496 |
+
self.app_file = app_file
|
497 |
+
self.app_port = app_port
|
498 |
+
self.license = license
|
499 |
+
self.duplicated_from = duplicated_from
|
500 |
+
self.models = models
|
501 |
+
self.datasets = datasets
|
502 |
+
self.tags = tags
|
503 |
+
super().__init__(**kwargs)
|
504 |
+
|
505 |
+
|
506 |
+
def model_index_to_eval_results(model_index: List[Dict[str, Any]]) -> Tuple[str, List[EvalResult]]:
|
507 |
+
"""Takes in a model index and returns the model name and a list of `huggingface_hub.EvalResult` objects.
|
508 |
+
|
509 |
+
A detailed spec of the model index can be found here:
|
510 |
+
https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
511 |
+
|
512 |
+
Args:
|
513 |
+
model_index (`List[Dict[str, Any]]`):
|
514 |
+
A model index data structure, likely coming from a README.md file on the
|
515 |
+
Hugging Face Hub.
|
516 |
+
|
517 |
+
Returns:
|
518 |
+
model_name (`str`):
|
519 |
+
The name of the model as found in the model index. This is used as the
|
520 |
+
identifier for the model on leaderboards like PapersWithCode.
|
521 |
+
eval_results (`List[EvalResult]`):
|
522 |
+
A list of `huggingface_hub.EvalResult` objects containing the metrics
|
523 |
+
reported in the provided model_index.
|
524 |
+
|
525 |
+
Example:
|
526 |
+
```python
|
527 |
+
>>> from huggingface_hub.repocard_data import model_index_to_eval_results
|
528 |
+
>>> # Define a minimal model index
|
529 |
+
>>> model_index = [
|
530 |
+
... {
|
531 |
+
... "name": "my-cool-model",
|
532 |
+
... "results": [
|
533 |
+
... {
|
534 |
+
... "task": {
|
535 |
+
... "type": "image-classification"
|
536 |
+
... },
|
537 |
+
... "dataset": {
|
538 |
+
... "type": "beans",
|
539 |
+
... "name": "Beans"
|
540 |
+
... },
|
541 |
+
... "metrics": [
|
542 |
+
... {
|
543 |
+
... "type": "accuracy",
|
544 |
+
... "value": 0.9
|
545 |
+
... }
|
546 |
+
... ]
|
547 |
+
... }
|
548 |
+
... ]
|
549 |
+
... }
|
550 |
+
... ]
|
551 |
+
>>> model_name, eval_results = model_index_to_eval_results(model_index)
|
552 |
+
>>> model_name
|
553 |
+
'my-cool-model'
|
554 |
+
>>> eval_results[0].task_type
|
555 |
+
'image-classification'
|
556 |
+
>>> eval_results[0].metric_type
|
557 |
+
'accuracy'
|
558 |
+
|
559 |
+
```
|
560 |
+
"""
|
561 |
+
|
562 |
+
eval_results = []
|
563 |
+
for elem in model_index:
|
564 |
+
name = elem["name"]
|
565 |
+
results = elem["results"]
|
566 |
+
for result in results:
|
567 |
+
task_type = result["task"]["type"]
|
568 |
+
task_name = result["task"].get("name")
|
569 |
+
dataset_type = result["dataset"]["type"]
|
570 |
+
dataset_name = result["dataset"]["name"]
|
571 |
+
dataset_config = result["dataset"].get("config")
|
572 |
+
dataset_split = result["dataset"].get("split")
|
573 |
+
dataset_revision = result["dataset"].get("revision")
|
574 |
+
dataset_args = result["dataset"].get("args")
|
575 |
+
source_name = result.get("source", {}).get("name")
|
576 |
+
source_url = result.get("source", {}).get("url")
|
577 |
+
|
578 |
+
for metric in result["metrics"]:
|
579 |
+
metric_type = metric["type"]
|
580 |
+
metric_value = metric["value"]
|
581 |
+
metric_name = metric.get("name")
|
582 |
+
metric_args = metric.get("args")
|
583 |
+
metric_config = metric.get("config")
|
584 |
+
verified = metric.get("verified")
|
585 |
+
verify_token = metric.get("verifyToken")
|
586 |
+
|
587 |
+
eval_result = EvalResult(
|
588 |
+
task_type=task_type, # Required
|
589 |
+
dataset_type=dataset_type, # Required
|
590 |
+
dataset_name=dataset_name, # Required
|
591 |
+
metric_type=metric_type, # Required
|
592 |
+
metric_value=metric_value, # Required
|
593 |
+
task_name=task_name,
|
594 |
+
dataset_config=dataset_config,
|
595 |
+
dataset_split=dataset_split,
|
596 |
+
dataset_revision=dataset_revision,
|
597 |
+
dataset_args=dataset_args,
|
598 |
+
metric_name=metric_name,
|
599 |
+
metric_args=metric_args,
|
600 |
+
metric_config=metric_config,
|
601 |
+
verified=verified,
|
602 |
+
verify_token=verify_token,
|
603 |
+
source_name=source_name,
|
604 |
+
source_url=source_url,
|
605 |
+
)
|
606 |
+
eval_results.append(eval_result)
|
607 |
+
return name, eval_results
|
608 |
+
|
609 |
+
|
610 |
+
def _remove_none(obj):
|
611 |
+
"""
|
612 |
+
Recursively remove `None` values from a dict. Borrowed from: https://stackoverflow.com/a/20558778
|
613 |
+
"""
|
614 |
+
if isinstance(obj, (list, tuple, set)):
|
615 |
+
return type(obj)(_remove_none(x) for x in obj if x is not None)
|
616 |
+
elif isinstance(obj, dict):
|
617 |
+
return type(obj)((_remove_none(k), _remove_none(v)) for k, v in obj.items() if k is not None and v is not None)
|
618 |
+
else:
|
619 |
+
return obj
|
620 |
+
|
621 |
+
|
622 |
+
def eval_results_to_model_index(model_name: str, eval_results: List[EvalResult]) -> List[Dict[str, Any]]:
|
623 |
+
"""Takes in given model name and list of `huggingface_hub.EvalResult` and returns a
|
624 |
+
valid model-index that will be compatible with the format expected by the
|
625 |
+
Hugging Face Hub.
|
626 |
+
|
627 |
+
Args:
|
628 |
+
model_name (`str`):
|
629 |
+
Name of the model (ex. "my-cool-model"). This is used as the identifier
|
630 |
+
for the model on leaderboards like PapersWithCode.
|
631 |
+
eval_results (`List[EvalResult]`):
|
632 |
+
List of `huggingface_hub.EvalResult` objects containing the metrics to be
|
633 |
+
reported in the model-index.
|
634 |
+
|
635 |
+
Returns:
|
636 |
+
model_index (`List[Dict[str, Any]]`): The eval_results converted to a model-index.
|
637 |
+
|
638 |
+
Example:
|
639 |
+
```python
|
640 |
+
>>> from huggingface_hub.repocard_data import eval_results_to_model_index, EvalResult
|
641 |
+
>>> # Define minimal eval_results
|
642 |
+
>>> eval_results = [
|
643 |
+
... EvalResult(
|
644 |
+
... task_type="image-classification", # Required
|
645 |
+
... dataset_type="beans", # Required
|
646 |
+
... dataset_name="Beans", # Required
|
647 |
+
... metric_type="accuracy", # Required
|
648 |
+
... metric_value=0.9, # Required
|
649 |
+
... )
|
650 |
+
... ]
|
651 |
+
>>> eval_results_to_model_index("my-cool-model", eval_results)
|
652 |
+
[{'name': 'my-cool-model', 'results': [{'task': {'type': 'image-classification'}, 'dataset': {'name': 'Beans', 'type': 'beans'}, 'metrics': [{'type': 'accuracy', 'value': 0.9}]}]}]
|
653 |
+
|
654 |
+
```
|
655 |
+
"""
|
656 |
+
|
657 |
+
# Metrics are reported on a unique task-and-dataset basis.
|
658 |
+
# Here, we make a map of those pairs and the associated EvalResults.
|
659 |
+
task_and_ds_types_map: Dict[Any, List[EvalResult]] = defaultdict(list)
|
660 |
+
for eval_result in eval_results:
|
661 |
+
task_and_ds_types_map[eval_result.unique_identifier].append(eval_result)
|
662 |
+
|
663 |
+
# Use the map from above to generate the model index data.
|
664 |
+
model_index_data = []
|
665 |
+
for results in task_and_ds_types_map.values():
|
666 |
+
# All items from `results` share same metadata
|
667 |
+
sample_result = results[0]
|
668 |
+
data = {
|
669 |
+
"task": {
|
670 |
+
"type": sample_result.task_type,
|
671 |
+
"name": sample_result.task_name,
|
672 |
+
},
|
673 |
+
"dataset": {
|
674 |
+
"name": sample_result.dataset_name,
|
675 |
+
"type": sample_result.dataset_type,
|
676 |
+
"config": sample_result.dataset_config,
|
677 |
+
"split": sample_result.dataset_split,
|
678 |
+
"revision": sample_result.dataset_revision,
|
679 |
+
"args": sample_result.dataset_args,
|
680 |
+
},
|
681 |
+
"metrics": [
|
682 |
+
{
|
683 |
+
"type": result.metric_type,
|
684 |
+
"value": result.metric_value,
|
685 |
+
"name": result.metric_name,
|
686 |
+
"config": result.metric_config,
|
687 |
+
"args": result.metric_args,
|
688 |
+
"verified": result.verified,
|
689 |
+
"verifyToken": result.verify_token,
|
690 |
+
}
|
691 |
+
for result in results
|
692 |
+
],
|
693 |
+
}
|
694 |
+
if sample_result.source_url is not None:
|
695 |
+
source = {
|
696 |
+
"url": sample_result.source_url,
|
697 |
+
}
|
698 |
+
if sample_result.source_name is not None:
|
699 |
+
source["name"] = sample_result.source_name
|
700 |
+
data["source"] = source
|
701 |
+
model_index_data.append(data)
|
702 |
+
|
703 |
+
# TODO - Check if there cases where this list is longer than one?
|
704 |
+
# Finally, the model index itself is list of dicts.
|
705 |
+
model_index = [
|
706 |
+
{
|
707 |
+
"name": model_name,
|
708 |
+
"results": model_index_data,
|
709 |
+
}
|
710 |
+
]
|
711 |
+
return _remove_none(model_index)
|
lib/python3.11/site-packages/huggingface_hub/repository.py
ADDED
@@ -0,0 +1,1476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import atexit
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import subprocess
|
5 |
+
import threading
|
6 |
+
import time
|
7 |
+
from contextlib import contextmanager
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypedDict, Union
|
10 |
+
from urllib.parse import urlparse
|
11 |
+
|
12 |
+
from huggingface_hub.constants import REPO_TYPES_URL_PREFIXES, REPOCARD_NAME
|
13 |
+
from huggingface_hub.repocard import metadata_load, metadata_save
|
14 |
+
|
15 |
+
from .hf_api import HfApi, repo_type_and_id_from_hf_id
|
16 |
+
from .lfs import LFS_MULTIPART_UPLOAD_COMMAND
|
17 |
+
from .utils import (
|
18 |
+
SoftTemporaryDirectory,
|
19 |
+
get_token,
|
20 |
+
logging,
|
21 |
+
run_subprocess,
|
22 |
+
tqdm,
|
23 |
+
validate_hf_hub_args,
|
24 |
+
)
|
25 |
+
from .utils._deprecation import _deprecate_method
|
26 |
+
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
class CommandInProgress:
|
32 |
+
"""
|
33 |
+
Utility to follow commands launched asynchronously.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
title: str,
|
39 |
+
is_done_method: Callable,
|
40 |
+
status_method: Callable,
|
41 |
+
process: subprocess.Popen,
|
42 |
+
post_method: Optional[Callable] = None,
|
43 |
+
):
|
44 |
+
self.title = title
|
45 |
+
self._is_done = is_done_method
|
46 |
+
self._status = status_method
|
47 |
+
self._process = process
|
48 |
+
self._stderr = ""
|
49 |
+
self._stdout = ""
|
50 |
+
self._post_method = post_method
|
51 |
+
|
52 |
+
@property
|
53 |
+
def is_done(self) -> bool:
|
54 |
+
"""
|
55 |
+
Whether the process is done.
|
56 |
+
"""
|
57 |
+
result = self._is_done()
|
58 |
+
|
59 |
+
if result and self._post_method is not None:
|
60 |
+
self._post_method()
|
61 |
+
self._post_method = None
|
62 |
+
|
63 |
+
return result
|
64 |
+
|
65 |
+
@property
|
66 |
+
def status(self) -> int:
|
67 |
+
"""
|
68 |
+
The exit code/status of the current action. Will return `0` if the
|
69 |
+
command has completed successfully, and a number between 1 and 255 if
|
70 |
+
the process errored-out.
|
71 |
+
|
72 |
+
Will return -1 if the command is still ongoing.
|
73 |
+
"""
|
74 |
+
return self._status()
|
75 |
+
|
76 |
+
@property
|
77 |
+
def failed(self) -> bool:
|
78 |
+
"""
|
79 |
+
Whether the process errored-out.
|
80 |
+
"""
|
81 |
+
return self.status > 0
|
82 |
+
|
83 |
+
@property
|
84 |
+
def stderr(self) -> str:
|
85 |
+
"""
|
86 |
+
The current output message on the standard error.
|
87 |
+
"""
|
88 |
+
if self._process.stderr is not None:
|
89 |
+
self._stderr += self._process.stderr.read()
|
90 |
+
return self._stderr
|
91 |
+
|
92 |
+
@property
|
93 |
+
def stdout(self) -> str:
|
94 |
+
"""
|
95 |
+
The current output message on the standard output.
|
96 |
+
"""
|
97 |
+
if self._process.stdout is not None:
|
98 |
+
self._stdout += self._process.stdout.read()
|
99 |
+
return self._stdout
|
100 |
+
|
101 |
+
def __repr__(self):
|
102 |
+
status = self.status
|
103 |
+
|
104 |
+
if status == -1:
|
105 |
+
status = "running"
|
106 |
+
|
107 |
+
return (
|
108 |
+
f"[{self.title} command, status code: {status},"
|
109 |
+
f" {'in progress.' if not self.is_done else 'finished.'} PID:"
|
110 |
+
f" {self._process.pid}]"
|
111 |
+
)
|
112 |
+
|
113 |
+
|
114 |
+
def is_git_repo(folder: Union[str, Path]) -> bool:
|
115 |
+
"""
|
116 |
+
Check if the folder is the root or part of a git repository
|
117 |
+
|
118 |
+
Args:
|
119 |
+
folder (`str`):
|
120 |
+
The folder in which to run the command.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
`bool`: `True` if the repository is part of a repository, `False`
|
124 |
+
otherwise.
|
125 |
+
"""
|
126 |
+
folder_exists = os.path.exists(os.path.join(folder, ".git"))
|
127 |
+
git_branch = subprocess.run("git branch".split(), cwd=folder, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
128 |
+
return folder_exists and git_branch.returncode == 0
|
129 |
+
|
130 |
+
|
131 |
+
def is_local_clone(folder: Union[str, Path], remote_url: str) -> bool:
|
132 |
+
"""
|
133 |
+
Check if the folder is a local clone of the remote_url
|
134 |
+
|
135 |
+
Args:
|
136 |
+
folder (`str` or `Path`):
|
137 |
+
The folder in which to run the command.
|
138 |
+
remote_url (`str`):
|
139 |
+
The url of a git repository.
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
`bool`: `True` if the repository is a local clone of the remote
|
143 |
+
repository specified, `False` otherwise.
|
144 |
+
"""
|
145 |
+
if not is_git_repo(folder):
|
146 |
+
return False
|
147 |
+
|
148 |
+
remotes = run_subprocess("git remote -v", folder).stdout
|
149 |
+
|
150 |
+
# Remove token for the test with remotes.
|
151 |
+
remote_url = re.sub(r"https://.*@", "https://", remote_url)
|
152 |
+
remotes = [re.sub(r"https://.*@", "https://", remote) for remote in remotes.split()]
|
153 |
+
return remote_url in remotes
|
154 |
+
|
155 |
+
|
156 |
+
def is_tracked_with_lfs(filename: Union[str, Path]) -> bool:
|
157 |
+
"""
|
158 |
+
Check if the file passed is tracked with git-lfs.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
filename (`str` or `Path`):
|
162 |
+
The filename to check.
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
`bool`: `True` if the file passed is tracked with git-lfs, `False`
|
166 |
+
otherwise.
|
167 |
+
"""
|
168 |
+
folder = Path(filename).parent
|
169 |
+
filename = Path(filename).name
|
170 |
+
|
171 |
+
try:
|
172 |
+
p = run_subprocess("git check-attr -a".split() + [filename], folder)
|
173 |
+
attributes = p.stdout.strip()
|
174 |
+
except subprocess.CalledProcessError as exc:
|
175 |
+
if not is_git_repo(folder):
|
176 |
+
return False
|
177 |
+
else:
|
178 |
+
raise OSError(exc.stderr)
|
179 |
+
|
180 |
+
if len(attributes) == 0:
|
181 |
+
return False
|
182 |
+
|
183 |
+
found_lfs_tag = {"diff": False, "merge": False, "filter": False}
|
184 |
+
|
185 |
+
for attribute in attributes.split("\n"):
|
186 |
+
for tag in found_lfs_tag.keys():
|
187 |
+
if tag in attribute and "lfs" in attribute:
|
188 |
+
found_lfs_tag[tag] = True
|
189 |
+
|
190 |
+
return all(found_lfs_tag.values())
|
191 |
+
|
192 |
+
|
193 |
+
def is_git_ignored(filename: Union[str, Path]) -> bool:
|
194 |
+
"""
|
195 |
+
Check if file is git-ignored. Supports nested .gitignore files.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
filename (`str` or `Path`):
|
199 |
+
The filename to check.
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
`bool`: `True` if the file passed is ignored by `git`, `False`
|
203 |
+
otherwise.
|
204 |
+
"""
|
205 |
+
folder = Path(filename).parent
|
206 |
+
filename = Path(filename).name
|
207 |
+
|
208 |
+
try:
|
209 |
+
p = run_subprocess("git check-ignore".split() + [filename], folder, check=False)
|
210 |
+
# Will return exit code 1 if not gitignored
|
211 |
+
is_ignored = not bool(p.returncode)
|
212 |
+
except subprocess.CalledProcessError as exc:
|
213 |
+
raise OSError(exc.stderr)
|
214 |
+
|
215 |
+
return is_ignored
|
216 |
+
|
217 |
+
|
218 |
+
def is_binary_file(filename: Union[str, Path]) -> bool:
|
219 |
+
"""
|
220 |
+
Check if file is a binary file.
|
221 |
+
|
222 |
+
Args:
|
223 |
+
filename (`str` or `Path`):
|
224 |
+
The filename to check.
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
`bool`: `True` if the file passed is a binary file, `False` otherwise.
|
228 |
+
"""
|
229 |
+
try:
|
230 |
+
with open(filename, "rb") as f:
|
231 |
+
content = f.read(10 * (1024**2)) # Read a maximum of 10MB
|
232 |
+
|
233 |
+
# Code sample taken from the following stack overflow thread
|
234 |
+
# https://stackoverflow.com/questions/898669/how-can-i-detect-if-a-file-is-binary-non-text-in-python/7392391#7392391
|
235 |
+
text_chars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F})
|
236 |
+
return bool(content.translate(None, text_chars))
|
237 |
+
except UnicodeDecodeError:
|
238 |
+
return True
|
239 |
+
|
240 |
+
|
241 |
+
def files_to_be_staged(pattern: str = ".", folder: Union[str, Path, None] = None) -> List[str]:
|
242 |
+
"""
|
243 |
+
Returns a list of filenames that are to be staged.
|
244 |
+
|
245 |
+
Args:
|
246 |
+
pattern (`str` or `Path`):
|
247 |
+
The pattern of filenames to check. Put `.` to get all files.
|
248 |
+
folder (`str` or `Path`):
|
249 |
+
The folder in which to run the command.
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
`List[str]`: List of files that are to be staged.
|
253 |
+
"""
|
254 |
+
try:
|
255 |
+
p = run_subprocess("git ls-files --exclude-standard -mo".split() + [pattern], folder)
|
256 |
+
if len(p.stdout.strip()):
|
257 |
+
files = p.stdout.strip().split("\n")
|
258 |
+
else:
|
259 |
+
files = []
|
260 |
+
except subprocess.CalledProcessError as exc:
|
261 |
+
raise EnvironmentError(exc.stderr)
|
262 |
+
|
263 |
+
return files
|
264 |
+
|
265 |
+
|
266 |
+
def is_tracked_upstream(folder: Union[str, Path]) -> bool:
|
267 |
+
"""
|
268 |
+
Check if the current checked-out branch is tracked upstream.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
folder (`str` or `Path`):
|
272 |
+
The folder in which to run the command.
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
`bool`: `True` if the current checked-out branch is tracked upstream,
|
276 |
+
`False` otherwise.
|
277 |
+
"""
|
278 |
+
try:
|
279 |
+
run_subprocess("git rev-parse --symbolic-full-name --abbrev-ref @{u}", folder)
|
280 |
+
return True
|
281 |
+
except subprocess.CalledProcessError as exc:
|
282 |
+
if "HEAD" in exc.stderr:
|
283 |
+
raise OSError("No branch checked out")
|
284 |
+
|
285 |
+
return False
|
286 |
+
|
287 |
+
|
288 |
+
def commits_to_push(folder: Union[str, Path], upstream: Optional[str] = None) -> int:
|
289 |
+
"""
|
290 |
+
Check the number of commits that would be pushed upstream
|
291 |
+
|
292 |
+
Args:
|
293 |
+
folder (`str` or `Path`):
|
294 |
+
The folder in which to run the command.
|
295 |
+
upstream (`str`, *optional*):
|
296 |
+
The name of the upstream repository with which the comparison should be
|
297 |
+
made.
|
298 |
+
|
299 |
+
Returns:
|
300 |
+
`int`: Number of commits that would be pushed upstream were a `git
|
301 |
+
push` to proceed.
|
302 |
+
"""
|
303 |
+
try:
|
304 |
+
result = run_subprocess(f"git cherry -v {upstream or ''}", folder)
|
305 |
+
return len(result.stdout.split("\n")) - 1
|
306 |
+
except subprocess.CalledProcessError as exc:
|
307 |
+
raise EnvironmentError(exc.stderr)
|
308 |
+
|
309 |
+
|
310 |
+
class PbarT(TypedDict):
|
311 |
+
# Used to store an opened progress bar in `_lfs_log_progress`
|
312 |
+
bar: tqdm
|
313 |
+
past_bytes: int
|
314 |
+
|
315 |
+
|
316 |
+
@contextmanager
|
317 |
+
def _lfs_log_progress():
|
318 |
+
"""
|
319 |
+
This is a context manager that will log the Git LFS progress of cleaning,
|
320 |
+
smudging, pulling and pushing.
|
321 |
+
"""
|
322 |
+
|
323 |
+
if logger.getEffectiveLevel() >= logging.ERROR:
|
324 |
+
try:
|
325 |
+
yield
|
326 |
+
except Exception:
|
327 |
+
pass
|
328 |
+
return
|
329 |
+
|
330 |
+
def output_progress(stopping_event: threading.Event):
|
331 |
+
"""
|
332 |
+
To be launched as a separate thread with an event meaning it should stop
|
333 |
+
the tail.
|
334 |
+
"""
|
335 |
+
# Key is tuple(state, filename), value is a dict(tqdm bar and a previous value)
|
336 |
+
pbars: Dict[Tuple[str, str], PbarT] = {}
|
337 |
+
|
338 |
+
def close_pbars():
|
339 |
+
for pbar in pbars.values():
|
340 |
+
pbar["bar"].update(pbar["bar"].total - pbar["past_bytes"])
|
341 |
+
pbar["bar"].refresh()
|
342 |
+
pbar["bar"].close()
|
343 |
+
|
344 |
+
def tail_file(filename) -> Iterator[str]:
|
345 |
+
"""
|
346 |
+
Creates a generator to be iterated through, which will return each
|
347 |
+
line one by one. Will stop tailing the file if the stopping_event is
|
348 |
+
set.
|
349 |
+
"""
|
350 |
+
with open(filename, "r") as file:
|
351 |
+
current_line = ""
|
352 |
+
while True:
|
353 |
+
if stopping_event.is_set():
|
354 |
+
close_pbars()
|
355 |
+
break
|
356 |
+
|
357 |
+
line_bit = file.readline()
|
358 |
+
if line_bit is not None and not len(line_bit.strip()) == 0:
|
359 |
+
current_line += line_bit
|
360 |
+
if current_line.endswith("\n"):
|
361 |
+
yield current_line
|
362 |
+
current_line = ""
|
363 |
+
else:
|
364 |
+
time.sleep(1)
|
365 |
+
|
366 |
+
# If the file isn't created yet, wait for a few seconds before trying again.
|
367 |
+
# Can be interrupted with the stopping_event.
|
368 |
+
while not os.path.exists(os.environ["GIT_LFS_PROGRESS"]):
|
369 |
+
if stopping_event.is_set():
|
370 |
+
close_pbars()
|
371 |
+
return
|
372 |
+
|
373 |
+
time.sleep(2)
|
374 |
+
|
375 |
+
for line in tail_file(os.environ["GIT_LFS_PROGRESS"]):
|
376 |
+
try:
|
377 |
+
state, file_progress, byte_progress, filename = line.split()
|
378 |
+
except ValueError as error:
|
379 |
+
# Try/except to ease debugging. See https://github.com/huggingface/huggingface_hub/issues/1373.
|
380 |
+
raise ValueError(f"Cannot unpack LFS progress line:\n{line}") from error
|
381 |
+
description = f"{state.capitalize()} file {filename}"
|
382 |
+
|
383 |
+
current_bytes, total_bytes = byte_progress.split("/")
|
384 |
+
current_bytes_int = int(current_bytes)
|
385 |
+
total_bytes_int = int(total_bytes)
|
386 |
+
|
387 |
+
pbar = pbars.get((state, filename))
|
388 |
+
if pbar is None:
|
389 |
+
# Initialize progress bar
|
390 |
+
pbars[(state, filename)] = {
|
391 |
+
"bar": tqdm(
|
392 |
+
desc=description,
|
393 |
+
initial=current_bytes_int,
|
394 |
+
total=total_bytes_int,
|
395 |
+
unit="B",
|
396 |
+
unit_scale=True,
|
397 |
+
unit_divisor=1024,
|
398 |
+
),
|
399 |
+
"past_bytes": int(current_bytes),
|
400 |
+
}
|
401 |
+
else:
|
402 |
+
# Update progress bar
|
403 |
+
pbar["bar"].update(current_bytes_int - pbar["past_bytes"])
|
404 |
+
pbar["past_bytes"] = current_bytes_int
|
405 |
+
|
406 |
+
current_lfs_progress_value = os.environ.get("GIT_LFS_PROGRESS", "")
|
407 |
+
|
408 |
+
with SoftTemporaryDirectory() as tmpdir:
|
409 |
+
os.environ["GIT_LFS_PROGRESS"] = os.path.join(tmpdir, "lfs_progress")
|
410 |
+
logger.debug(f"Following progress in {os.environ['GIT_LFS_PROGRESS']}")
|
411 |
+
|
412 |
+
exit_event = threading.Event()
|
413 |
+
x = threading.Thread(target=output_progress, args=(exit_event,), daemon=True)
|
414 |
+
x.start()
|
415 |
+
|
416 |
+
try:
|
417 |
+
yield
|
418 |
+
finally:
|
419 |
+
exit_event.set()
|
420 |
+
x.join()
|
421 |
+
|
422 |
+
os.environ["GIT_LFS_PROGRESS"] = current_lfs_progress_value
|
423 |
+
|
424 |
+
|
425 |
+
class Repository:
|
426 |
+
"""
|
427 |
+
Helper class to wrap the git and git-lfs commands.
|
428 |
+
|
429 |
+
The aim is to facilitate interacting with huggingface.co hosted model or
|
430 |
+
dataset repos, though not a lot here (if any) is actually specific to
|
431 |
+
huggingface.co.
|
432 |
+
|
433 |
+
<Tip warning={true}>
|
434 |
+
|
435 |
+
[`Repository`] is deprecated in favor of the http-based alternatives implemented in
|
436 |
+
[`HfApi`]. Given its large adoption in legacy code, the complete removal of
|
437 |
+
[`Repository`] will only happen in release `v1.0`. For more details, please read
|
438 |
+
https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http.
|
439 |
+
|
440 |
+
</Tip>
|
441 |
+
"""
|
442 |
+
|
443 |
+
command_queue: List[CommandInProgress]
|
444 |
+
|
445 |
+
@validate_hf_hub_args
|
446 |
+
@_deprecate_method(
|
447 |
+
version="1.0",
|
448 |
+
message=(
|
449 |
+
"Please prefer the http-based alternatives instead. Given its large adoption in legacy code, the complete"
|
450 |
+
" removal is only planned on next major release.\nFor more details, please read"
|
451 |
+
" https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http."
|
452 |
+
),
|
453 |
+
)
|
454 |
+
def __init__(
|
455 |
+
self,
|
456 |
+
local_dir: Union[str, Path],
|
457 |
+
clone_from: Optional[str] = None,
|
458 |
+
repo_type: Optional[str] = None,
|
459 |
+
token: Union[bool, str] = True,
|
460 |
+
git_user: Optional[str] = None,
|
461 |
+
git_email: Optional[str] = None,
|
462 |
+
revision: Optional[str] = None,
|
463 |
+
skip_lfs_files: bool = False,
|
464 |
+
client: Optional[HfApi] = None,
|
465 |
+
):
|
466 |
+
"""
|
467 |
+
Instantiate a local clone of a git repo.
|
468 |
+
|
469 |
+
If `clone_from` is set, the repo will be cloned from an existing remote repository.
|
470 |
+
If the remote repo does not exist, a `EnvironmentError` exception will be thrown.
|
471 |
+
Please create the remote repo first using [`create_repo`].
|
472 |
+
|
473 |
+
`Repository` uses the local git credentials by default. If explicitly set, the `token`
|
474 |
+
or the `git_user`/`git_email` pair will be used instead.
|
475 |
+
|
476 |
+
Args:
|
477 |
+
local_dir (`str` or `Path`):
|
478 |
+
path (e.g. `'my_trained_model/'`) to the local directory, where
|
479 |
+
the `Repository` will be initialized.
|
480 |
+
clone_from (`str`, *optional*):
|
481 |
+
Either a repository url or `repo_id`.
|
482 |
+
Example:
|
483 |
+
- `"https://huggingface.co/philschmid/playground-tests"`
|
484 |
+
- `"philschmid/playground-tests"`
|
485 |
+
repo_type (`str`, *optional*):
|
486 |
+
To set when cloning a repo from a repo_id. Default is model.
|
487 |
+
token (`bool` or `str`, *optional*):
|
488 |
+
A valid authentication token (see https://huggingface.co/settings/token).
|
489 |
+
If `None` or `True` and machine is logged in (through `huggingface-cli login`
|
490 |
+
or [`~huggingface_hub.login`]), token will be retrieved from the cache.
|
491 |
+
If `False`, token is not sent in the request header.
|
492 |
+
git_user (`str`, *optional*):
|
493 |
+
will override the `git config user.name` for committing and
|
494 |
+
pushing files to the hub.
|
495 |
+
git_email (`str`, *optional*):
|
496 |
+
will override the `git config user.email` for committing and
|
497 |
+
pushing files to the hub.
|
498 |
+
revision (`str`, *optional*):
|
499 |
+
Revision to checkout after initializing the repository. If the
|
500 |
+
revision doesn't exist, a branch will be created with that
|
501 |
+
revision name from the default branch's current HEAD.
|
502 |
+
skip_lfs_files (`bool`, *optional*, defaults to `False`):
|
503 |
+
whether to skip git-LFS files or not.
|
504 |
+
client (`HfApi`, *optional*):
|
505 |
+
Instance of [`HfApi`] to use when calling the HF Hub API. A new
|
506 |
+
instance will be created if this is left to `None`.
|
507 |
+
|
508 |
+
Raises:
|
509 |
+
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
510 |
+
if the remote repository set in `clone_from` does not exist.
|
511 |
+
"""
|
512 |
+
if isinstance(local_dir, Path):
|
513 |
+
local_dir = str(local_dir)
|
514 |
+
os.makedirs(local_dir, exist_ok=True)
|
515 |
+
self.local_dir = os.path.join(os.getcwd(), local_dir)
|
516 |
+
self._repo_type = repo_type
|
517 |
+
self.command_queue = []
|
518 |
+
self.skip_lfs_files = skip_lfs_files
|
519 |
+
self.client = client if client is not None else HfApi()
|
520 |
+
|
521 |
+
self.check_git_versions()
|
522 |
+
|
523 |
+
if isinstance(token, str):
|
524 |
+
self.huggingface_token: Optional[str] = token
|
525 |
+
elif token is False:
|
526 |
+
self.huggingface_token = None
|
527 |
+
else:
|
528 |
+
# if `True` -> explicit use of the cached token
|
529 |
+
# if `None` -> implicit use of the cached token
|
530 |
+
self.huggingface_token = get_token()
|
531 |
+
|
532 |
+
if clone_from is not None:
|
533 |
+
self.clone_from(repo_url=clone_from)
|
534 |
+
else:
|
535 |
+
if is_git_repo(self.local_dir):
|
536 |
+
logger.debug("[Repository] is a valid git repo")
|
537 |
+
else:
|
538 |
+
raise ValueError("If not specifying `clone_from`, you need to pass Repository a valid git clone.")
|
539 |
+
|
540 |
+
if self.huggingface_token is not None and (git_email is None or git_user is None):
|
541 |
+
user = self.client.whoami(self.huggingface_token)
|
542 |
+
|
543 |
+
if git_email is None:
|
544 |
+
git_email = user["email"]
|
545 |
+
|
546 |
+
if git_user is None:
|
547 |
+
git_user = user["fullname"]
|
548 |
+
|
549 |
+
if git_user is not None or git_email is not None:
|
550 |
+
self.git_config_username_and_email(git_user, git_email)
|
551 |
+
|
552 |
+
self.lfs_enable_largefiles()
|
553 |
+
self.git_credential_helper_store()
|
554 |
+
|
555 |
+
if revision is not None:
|
556 |
+
self.git_checkout(revision, create_branch_ok=True)
|
557 |
+
|
558 |
+
# This ensures that all commands exit before exiting the Python runtime.
|
559 |
+
# This will ensure all pushes register on the hub, even if other errors happen in subsequent operations.
|
560 |
+
atexit.register(self.wait_for_commands)
|
561 |
+
|
562 |
+
@property
|
563 |
+
def current_branch(self) -> str:
|
564 |
+
"""
|
565 |
+
Returns the current checked out branch.
|
566 |
+
|
567 |
+
Returns:
|
568 |
+
`str`: Current checked out branch.
|
569 |
+
"""
|
570 |
+
try:
|
571 |
+
result = run_subprocess("git rev-parse --abbrev-ref HEAD", self.local_dir).stdout.strip()
|
572 |
+
except subprocess.CalledProcessError as exc:
|
573 |
+
raise EnvironmentError(exc.stderr)
|
574 |
+
|
575 |
+
return result
|
576 |
+
|
577 |
+
def check_git_versions(self):
|
578 |
+
"""
|
579 |
+
Checks that `git` and `git-lfs` can be run.
|
580 |
+
|
581 |
+
Raises:
|
582 |
+
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
583 |
+
if `git` or `git-lfs` are not installed.
|
584 |
+
"""
|
585 |
+
try:
|
586 |
+
git_version = run_subprocess("git --version", self.local_dir).stdout.strip()
|
587 |
+
except FileNotFoundError:
|
588 |
+
raise EnvironmentError("Looks like you do not have git installed, please install.")
|
589 |
+
|
590 |
+
try:
|
591 |
+
lfs_version = run_subprocess("git-lfs --version", self.local_dir).stdout.strip()
|
592 |
+
except FileNotFoundError:
|
593 |
+
raise EnvironmentError(
|
594 |
+
"Looks like you do not have git-lfs installed, please install."
|
595 |
+
" You can install from https://git-lfs.github.com/."
|
596 |
+
" Then run `git lfs install` (you only have to do this once)."
|
597 |
+
)
|
598 |
+
logger.info(git_version + "\n" + lfs_version)
|
599 |
+
|
600 |
+
@validate_hf_hub_args
|
601 |
+
def clone_from(self, repo_url: str, token: Union[bool, str, None] = None):
|
602 |
+
"""
|
603 |
+
Clone from a remote. If the folder already exists, will try to clone the
|
604 |
+
repository within it.
|
605 |
+
|
606 |
+
If this folder is a git repository with linked history, will try to
|
607 |
+
update the repository.
|
608 |
+
|
609 |
+
Args:
|
610 |
+
repo_url (`str`):
|
611 |
+
The URL from which to clone the repository
|
612 |
+
token (`Union[str, bool]`, *optional*):
|
613 |
+
Whether to use the authentication token. It can be:
|
614 |
+
- a string which is the token itself
|
615 |
+
- `False`, which would not use the authentication token
|
616 |
+
- `True`, which would fetch the authentication token from the
|
617 |
+
local folder and use it (you should be logged in for this to
|
618 |
+
work).
|
619 |
+
- `None`, which would retrieve the value of
|
620 |
+
`self.huggingface_token`.
|
621 |
+
|
622 |
+
<Tip>
|
623 |
+
|
624 |
+
Raises the following error:
|
625 |
+
|
626 |
+
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
627 |
+
if an organization token (starts with "api_org") is passed. Use must use
|
628 |
+
your own personal access token (see https://hf.co/settings/tokens).
|
629 |
+
|
630 |
+
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
631 |
+
if you are trying to clone the repository in a non-empty folder, or if the
|
632 |
+
`git` operations raise errors.
|
633 |
+
|
634 |
+
</Tip>
|
635 |
+
"""
|
636 |
+
token = (
|
637 |
+
token # str -> use it
|
638 |
+
if isinstance(token, str)
|
639 |
+
else (
|
640 |
+
None # `False` -> explicit no token
|
641 |
+
if token is False
|
642 |
+
else self.huggingface_token # `None` or `True` -> use default
|
643 |
+
)
|
644 |
+
)
|
645 |
+
if token is not None and token.startswith("api_org"):
|
646 |
+
raise ValueError(
|
647 |
+
"You must use your personal access token, not an Organization token"
|
648 |
+
" (see https://hf.co/settings/tokens)."
|
649 |
+
)
|
650 |
+
|
651 |
+
hub_url = self.client.endpoint
|
652 |
+
if hub_url in repo_url or ("http" not in repo_url and len(repo_url.split("/")) <= 2):
|
653 |
+
repo_type, namespace, repo_name = repo_type_and_id_from_hf_id(repo_url, hub_url=hub_url)
|
654 |
+
repo_id = f"{namespace}/{repo_name}" if namespace is not None else repo_name
|
655 |
+
|
656 |
+
if repo_type is not None:
|
657 |
+
self._repo_type = repo_type
|
658 |
+
|
659 |
+
repo_url = hub_url + "/"
|
660 |
+
|
661 |
+
if self._repo_type in REPO_TYPES_URL_PREFIXES:
|
662 |
+
repo_url += REPO_TYPES_URL_PREFIXES[self._repo_type]
|
663 |
+
|
664 |
+
if token is not None:
|
665 |
+
# Add token in git url when provided
|
666 |
+
scheme = urlparse(repo_url).scheme
|
667 |
+
repo_url = repo_url.replace(f"{scheme}://", f"{scheme}://user:{token}@")
|
668 |
+
|
669 |
+
repo_url += repo_id
|
670 |
+
|
671 |
+
# For error messages, it's cleaner to show the repo url without the token.
|
672 |
+
clean_repo_url = re.sub(r"(https?)://.*@", r"\1://", repo_url)
|
673 |
+
try:
|
674 |
+
run_subprocess("git lfs install", self.local_dir)
|
675 |
+
|
676 |
+
# checks if repository is initialized in a empty repository or in one with files
|
677 |
+
if len(os.listdir(self.local_dir)) == 0:
|
678 |
+
logger.warning(f"Cloning {clean_repo_url} into local empty directory.")
|
679 |
+
|
680 |
+
with _lfs_log_progress():
|
681 |
+
env = os.environ.copy()
|
682 |
+
|
683 |
+
if self.skip_lfs_files:
|
684 |
+
env.update({"GIT_LFS_SKIP_SMUDGE": "1"})
|
685 |
+
|
686 |
+
run_subprocess(
|
687 |
+
# 'git lfs clone' is deprecated (will display a warning in the terminal)
|
688 |
+
# but we still use it as it provides a nicer UX when downloading large
|
689 |
+
# files (shows progress).
|
690 |
+
f"{'git clone' if self.skip_lfs_files else 'git lfs clone'} {repo_url} .",
|
691 |
+
self.local_dir,
|
692 |
+
env=env,
|
693 |
+
)
|
694 |
+
else:
|
695 |
+
# Check if the folder is the root of a git repository
|
696 |
+
if not is_git_repo(self.local_dir):
|
697 |
+
raise EnvironmentError(
|
698 |
+
"Tried to clone a repository in a non-empty folder that isn't"
|
699 |
+
f" a git repository ('{self.local_dir}'). If you really want to"
|
700 |
+
f" do this, do it manually:\n cd {self.local_dir} && git init"
|
701 |
+
" && git remote add origin && git pull origin main\n or clone"
|
702 |
+
" repo to a new folder and move your existing files there"
|
703 |
+
" afterwards."
|
704 |
+
)
|
705 |
+
|
706 |
+
if is_local_clone(self.local_dir, repo_url):
|
707 |
+
logger.warning(
|
708 |
+
f"{self.local_dir} is already a clone of {clean_repo_url}."
|
709 |
+
" Make sure you pull the latest changes with"
|
710 |
+
" `repo.git_pull()`."
|
711 |
+
)
|
712 |
+
else:
|
713 |
+
output = run_subprocess("git remote get-url origin", self.local_dir, check=False)
|
714 |
+
|
715 |
+
error_msg = (
|
716 |
+
f"Tried to clone {clean_repo_url} in an unrelated git"
|
717 |
+
" repository.\nIf you believe this is an error, please add"
|
718 |
+
f" a remote with the following URL: {clean_repo_url}."
|
719 |
+
)
|
720 |
+
if output.returncode == 0:
|
721 |
+
clean_local_remote_url = re.sub(r"https://.*@", "https://", output.stdout)
|
722 |
+
error_msg += f"\nLocal path has its origin defined as: {clean_local_remote_url}"
|
723 |
+
raise EnvironmentError(error_msg)
|
724 |
+
|
725 |
+
except subprocess.CalledProcessError as exc:
|
726 |
+
raise EnvironmentError(exc.stderr)
|
727 |
+
|
728 |
+
def git_config_username_and_email(self, git_user: Optional[str] = None, git_email: Optional[str] = None):
|
729 |
+
"""
|
730 |
+
Sets git username and email (only in the current repo).
|
731 |
+
|
732 |
+
Args:
|
733 |
+
git_user (`str`, *optional*):
|
734 |
+
The username to register through `git`.
|
735 |
+
git_email (`str`, *optional*):
|
736 |
+
The email to register through `git`.
|
737 |
+
"""
|
738 |
+
try:
|
739 |
+
if git_user is not None:
|
740 |
+
run_subprocess("git config user.name".split() + [git_user], self.local_dir)
|
741 |
+
|
742 |
+
if git_email is not None:
|
743 |
+
run_subprocess(f"git config user.email {git_email}".split(), self.local_dir)
|
744 |
+
except subprocess.CalledProcessError as exc:
|
745 |
+
raise EnvironmentError(exc.stderr)
|
746 |
+
|
747 |
+
def git_credential_helper_store(self):
|
748 |
+
"""
|
749 |
+
Sets the git credential helper to `store`
|
750 |
+
"""
|
751 |
+
try:
|
752 |
+
run_subprocess("git config credential.helper store", self.local_dir)
|
753 |
+
except subprocess.CalledProcessError as exc:
|
754 |
+
raise EnvironmentError(exc.stderr)
|
755 |
+
|
756 |
+
def git_head_hash(self) -> str:
|
757 |
+
"""
|
758 |
+
Get commit sha on top of HEAD.
|
759 |
+
|
760 |
+
Returns:
|
761 |
+
`str`: The current checked out commit SHA.
|
762 |
+
"""
|
763 |
+
try:
|
764 |
+
p = run_subprocess("git rev-parse HEAD", self.local_dir)
|
765 |
+
return p.stdout.strip()
|
766 |
+
except subprocess.CalledProcessError as exc:
|
767 |
+
raise EnvironmentError(exc.stderr)
|
768 |
+
|
769 |
+
def git_remote_url(self) -> str:
|
770 |
+
"""
|
771 |
+
Get URL to origin remote.
|
772 |
+
|
773 |
+
Returns:
|
774 |
+
`str`: The URL of the `origin` remote.
|
775 |
+
"""
|
776 |
+
try:
|
777 |
+
p = run_subprocess("git config --get remote.origin.url", self.local_dir)
|
778 |
+
url = p.stdout.strip()
|
779 |
+
# Strip basic auth info.
|
780 |
+
return re.sub(r"https://.*@", "https://", url)
|
781 |
+
except subprocess.CalledProcessError as exc:
|
782 |
+
raise EnvironmentError(exc.stderr)
|
783 |
+
|
784 |
+
def git_head_commit_url(self) -> str:
|
785 |
+
"""
|
786 |
+
Get URL to last commit on HEAD. We assume it's been pushed, and the url
|
787 |
+
scheme is the same one as for GitHub or HuggingFace.
|
788 |
+
|
789 |
+
Returns:
|
790 |
+
`str`: The URL to the current checked-out commit.
|
791 |
+
"""
|
792 |
+
sha = self.git_head_hash()
|
793 |
+
url = self.git_remote_url()
|
794 |
+
if url.endswith("/"):
|
795 |
+
url = url[:-1]
|
796 |
+
return f"{url}/commit/{sha}"
|
797 |
+
|
798 |
+
def list_deleted_files(self) -> List[str]:
|
799 |
+
"""
|
800 |
+
Returns a list of the files that are deleted in the working directory or
|
801 |
+
index.
|
802 |
+
|
803 |
+
Returns:
|
804 |
+
`List[str]`: A list of files that have been deleted in the working
|
805 |
+
directory or index.
|
806 |
+
"""
|
807 |
+
try:
|
808 |
+
git_status = run_subprocess("git status -s", self.local_dir).stdout.strip()
|
809 |
+
except subprocess.CalledProcessError as exc:
|
810 |
+
raise EnvironmentError(exc.stderr)
|
811 |
+
|
812 |
+
if len(git_status) == 0:
|
813 |
+
return []
|
814 |
+
|
815 |
+
# Receives a status like the following
|
816 |
+
# D .gitignore
|
817 |
+
# D new_file.json
|
818 |
+
# AD new_file1.json
|
819 |
+
# ?? new_file2.json
|
820 |
+
# ?? new_file4.json
|
821 |
+
|
822 |
+
# Strip each line of whitespaces
|
823 |
+
modified_files_statuses = [status.strip() for status in git_status.split("\n")]
|
824 |
+
|
825 |
+
# Only keep files that are deleted using the D prefix
|
826 |
+
deleted_files_statuses = [status for status in modified_files_statuses if "D" in status.split()[0]]
|
827 |
+
|
828 |
+
# Remove the D prefix and strip to keep only the relevant filename
|
829 |
+
deleted_files = [status.split()[-1].strip() for status in deleted_files_statuses]
|
830 |
+
|
831 |
+
return deleted_files
|
832 |
+
|
833 |
+
def lfs_track(self, patterns: Union[str, List[str]], filename: bool = False):
|
834 |
+
"""
|
835 |
+
Tell git-lfs to track files according to a pattern.
|
836 |
+
|
837 |
+
Setting the `filename` argument to `True` will treat the arguments as
|
838 |
+
literal filenames, not as patterns. Any special glob characters in the
|
839 |
+
filename will be escaped when writing to the `.gitattributes` file.
|
840 |
+
|
841 |
+
Args:
|
842 |
+
patterns (`Union[str, List[str]]`):
|
843 |
+
The pattern, or list of patterns, to track with git-lfs.
|
844 |
+
filename (`bool`, *optional*, defaults to `False`):
|
845 |
+
Whether to use the patterns as literal filenames.
|
846 |
+
"""
|
847 |
+
if isinstance(patterns, str):
|
848 |
+
patterns = [patterns]
|
849 |
+
try:
|
850 |
+
for pattern in patterns:
|
851 |
+
run_subprocess(
|
852 |
+
f"git lfs track {'--filename' if filename else ''} {pattern}",
|
853 |
+
self.local_dir,
|
854 |
+
)
|
855 |
+
except subprocess.CalledProcessError as exc:
|
856 |
+
raise EnvironmentError(exc.stderr)
|
857 |
+
|
858 |
+
def lfs_untrack(self, patterns: Union[str, List[str]]):
|
859 |
+
"""
|
860 |
+
Tell git-lfs to untrack those files.
|
861 |
+
|
862 |
+
Args:
|
863 |
+
patterns (`Union[str, List[str]]`):
|
864 |
+
The pattern, or list of patterns, to untrack with git-lfs.
|
865 |
+
"""
|
866 |
+
if isinstance(patterns, str):
|
867 |
+
patterns = [patterns]
|
868 |
+
try:
|
869 |
+
for pattern in patterns:
|
870 |
+
run_subprocess("git lfs untrack".split() + [pattern], self.local_dir)
|
871 |
+
except subprocess.CalledProcessError as exc:
|
872 |
+
raise EnvironmentError(exc.stderr)
|
873 |
+
|
874 |
+
def lfs_enable_largefiles(self):
|
875 |
+
"""
|
876 |
+
HF-specific. This enables upload support of files >5GB.
|
877 |
+
"""
|
878 |
+
try:
|
879 |
+
lfs_config = "git config lfs.customtransfer.multipart"
|
880 |
+
run_subprocess(f"{lfs_config}.path huggingface-cli", self.local_dir)
|
881 |
+
run_subprocess(
|
882 |
+
f"{lfs_config}.args {LFS_MULTIPART_UPLOAD_COMMAND}",
|
883 |
+
self.local_dir,
|
884 |
+
)
|
885 |
+
except subprocess.CalledProcessError as exc:
|
886 |
+
raise EnvironmentError(exc.stderr)
|
887 |
+
|
888 |
+
def auto_track_binary_files(self, pattern: str = ".") -> List[str]:
|
889 |
+
"""
|
890 |
+
Automatically track binary files with git-lfs.
|
891 |
+
|
892 |
+
Args:
|
893 |
+
pattern (`str`, *optional*, defaults to "."):
|
894 |
+
The pattern with which to track files that are binary.
|
895 |
+
|
896 |
+
Returns:
|
897 |
+
`List[str]`: List of filenames that are now tracked due to being
|
898 |
+
binary files
|
899 |
+
"""
|
900 |
+
files_to_be_tracked_with_lfs = []
|
901 |
+
|
902 |
+
deleted_files = self.list_deleted_files()
|
903 |
+
|
904 |
+
for filename in files_to_be_staged(pattern, folder=self.local_dir):
|
905 |
+
if filename in deleted_files:
|
906 |
+
continue
|
907 |
+
|
908 |
+
path_to_file = os.path.join(os.getcwd(), self.local_dir, filename)
|
909 |
+
|
910 |
+
if not (is_tracked_with_lfs(path_to_file) or is_git_ignored(path_to_file)):
|
911 |
+
size_in_mb = os.path.getsize(path_to_file) / (1024 * 1024)
|
912 |
+
|
913 |
+
if size_in_mb >= 10:
|
914 |
+
logger.warning(
|
915 |
+
"Parsing a large file to check if binary or not. Tracking large"
|
916 |
+
" files using `repository.auto_track_large_files` is"
|
917 |
+
" recommended so as to not load the full file in memory."
|
918 |
+
)
|
919 |
+
|
920 |
+
is_binary = is_binary_file(path_to_file)
|
921 |
+
|
922 |
+
if is_binary:
|
923 |
+
self.lfs_track(filename)
|
924 |
+
files_to_be_tracked_with_lfs.append(filename)
|
925 |
+
|
926 |
+
# Cleanup the .gitattributes if files were deleted
|
927 |
+
self.lfs_untrack(deleted_files)
|
928 |
+
|
929 |
+
return files_to_be_tracked_with_lfs
|
930 |
+
|
931 |
+
def auto_track_large_files(self, pattern: str = ".") -> List[str]:
|
932 |
+
"""
|
933 |
+
Automatically track large files (files that weigh more than 10MBs) with
|
934 |
+
git-lfs.
|
935 |
+
|
936 |
+
Args:
|
937 |
+
pattern (`str`, *optional*, defaults to "."):
|
938 |
+
The pattern with which to track files that are above 10MBs.
|
939 |
+
|
940 |
+
Returns:
|
941 |
+
`List[str]`: List of filenames that are now tracked due to their
|
942 |
+
size.
|
943 |
+
"""
|
944 |
+
files_to_be_tracked_with_lfs = []
|
945 |
+
|
946 |
+
deleted_files = self.list_deleted_files()
|
947 |
+
|
948 |
+
for filename in files_to_be_staged(pattern, folder=self.local_dir):
|
949 |
+
if filename in deleted_files:
|
950 |
+
continue
|
951 |
+
|
952 |
+
path_to_file = os.path.join(os.getcwd(), self.local_dir, filename)
|
953 |
+
size_in_mb = os.path.getsize(path_to_file) / (1024 * 1024)
|
954 |
+
|
955 |
+
if size_in_mb >= 10 and not is_tracked_with_lfs(path_to_file) and not is_git_ignored(path_to_file):
|
956 |
+
self.lfs_track(filename)
|
957 |
+
files_to_be_tracked_with_lfs.append(filename)
|
958 |
+
|
959 |
+
# Cleanup the .gitattributes if files were deleted
|
960 |
+
self.lfs_untrack(deleted_files)
|
961 |
+
|
962 |
+
return files_to_be_tracked_with_lfs
|
963 |
+
|
964 |
+
def lfs_prune(self, recent=False):
|
965 |
+
"""
|
966 |
+
git lfs prune
|
967 |
+
|
968 |
+
Args:
|
969 |
+
recent (`bool`, *optional*, defaults to `False`):
|
970 |
+
Whether to prune files even if they were referenced by recent
|
971 |
+
commits. See the following
|
972 |
+
[link](https://github.com/git-lfs/git-lfs/blob/f3d43f0428a84fc4f1e5405b76b5a73ec2437e65/docs/man/git-lfs-prune.1.ronn#recent-files)
|
973 |
+
for more information.
|
974 |
+
"""
|
975 |
+
try:
|
976 |
+
with _lfs_log_progress():
|
977 |
+
result = run_subprocess(f"git lfs prune {'--recent' if recent else ''}", self.local_dir)
|
978 |
+
logger.info(result.stdout)
|
979 |
+
except subprocess.CalledProcessError as exc:
|
980 |
+
raise EnvironmentError(exc.stderr)
|
981 |
+
|
982 |
+
def git_pull(self, rebase: bool = False, lfs: bool = False):
|
983 |
+
"""
|
984 |
+
git pull
|
985 |
+
|
986 |
+
Args:
|
987 |
+
rebase (`bool`, *optional*, defaults to `False`):
|
988 |
+
Whether to rebase the current branch on top of the upstream
|
989 |
+
branch after fetching.
|
990 |
+
lfs (`bool`, *optional*, defaults to `False`):
|
991 |
+
Whether to fetch the LFS files too. This option only changes the
|
992 |
+
behavior when a repository was cloned without fetching the LFS
|
993 |
+
files; calling `repo.git_pull(lfs=True)` will then fetch the LFS
|
994 |
+
file from the remote repository.
|
995 |
+
"""
|
996 |
+
command = "git pull" if not lfs else "git lfs pull"
|
997 |
+
if rebase:
|
998 |
+
command += " --rebase"
|
999 |
+
try:
|
1000 |
+
with _lfs_log_progress():
|
1001 |
+
result = run_subprocess(command, self.local_dir)
|
1002 |
+
logger.info(result.stdout)
|
1003 |
+
except subprocess.CalledProcessError as exc:
|
1004 |
+
raise EnvironmentError(exc.stderr)
|
1005 |
+
|
1006 |
+
def git_add(self, pattern: str = ".", auto_lfs_track: bool = False):
|
1007 |
+
"""
|
1008 |
+
git add
|
1009 |
+
|
1010 |
+
Setting the `auto_lfs_track` parameter to `True` will automatically
|
1011 |
+
track files that are larger than 10MB with `git-lfs`.
|
1012 |
+
|
1013 |
+
Args:
|
1014 |
+
pattern (`str`, *optional*, defaults to "."):
|
1015 |
+
The pattern with which to add files to staging.
|
1016 |
+
auto_lfs_track (`bool`, *optional*, defaults to `False`):
|
1017 |
+
Whether to automatically track large and binary files with
|
1018 |
+
git-lfs. Any file over 10MB in size, or in binary format, will
|
1019 |
+
be automatically tracked.
|
1020 |
+
"""
|
1021 |
+
if auto_lfs_track:
|
1022 |
+
# Track files according to their size (>=10MB)
|
1023 |
+
tracked_files = self.auto_track_large_files(pattern)
|
1024 |
+
|
1025 |
+
# Read the remaining files and track them if they're binary
|
1026 |
+
tracked_files.extend(self.auto_track_binary_files(pattern))
|
1027 |
+
|
1028 |
+
if tracked_files:
|
1029 |
+
logger.warning(
|
1030 |
+
f"Adding files tracked by Git LFS: {tracked_files}. This may take a"
|
1031 |
+
" bit of time if the files are large."
|
1032 |
+
)
|
1033 |
+
|
1034 |
+
try:
|
1035 |
+
result = run_subprocess("git add -v".split() + [pattern], self.local_dir)
|
1036 |
+
logger.info(f"Adding to index:\n{result.stdout}\n")
|
1037 |
+
except subprocess.CalledProcessError as exc:
|
1038 |
+
raise EnvironmentError(exc.stderr)
|
1039 |
+
|
1040 |
+
def git_commit(self, commit_message: str = "commit files to HF hub"):
|
1041 |
+
"""
|
1042 |
+
git commit
|
1043 |
+
|
1044 |
+
Args:
|
1045 |
+
commit_message (`str`, *optional*, defaults to "commit files to HF hub"):
|
1046 |
+
The message attributed to the commit.
|
1047 |
+
"""
|
1048 |
+
try:
|
1049 |
+
result = run_subprocess("git commit -v -m".split() + [commit_message], self.local_dir)
|
1050 |
+
logger.info(f"Committed:\n{result.stdout}\n")
|
1051 |
+
except subprocess.CalledProcessError as exc:
|
1052 |
+
if len(exc.stderr) > 0:
|
1053 |
+
raise EnvironmentError(exc.stderr)
|
1054 |
+
else:
|
1055 |
+
raise EnvironmentError(exc.stdout)
|
1056 |
+
|
1057 |
+
def git_push(
|
1058 |
+
self,
|
1059 |
+
upstream: Optional[str] = None,
|
1060 |
+
blocking: bool = True,
|
1061 |
+
auto_lfs_prune: bool = False,
|
1062 |
+
) -> Union[str, Tuple[str, CommandInProgress]]:
|
1063 |
+
"""
|
1064 |
+
git push
|
1065 |
+
|
1066 |
+
If used without setting `blocking`, will return url to commit on remote
|
1067 |
+
repo. If used with `blocking=True`, will return a tuple containing the
|
1068 |
+
url to commit and the command object to follow for information about the
|
1069 |
+
process.
|
1070 |
+
|
1071 |
+
Args:
|
1072 |
+
upstream (`str`, *optional*):
|
1073 |
+
Upstream to which this should push. If not specified, will push
|
1074 |
+
to the lastly defined upstream or to the default one (`origin
|
1075 |
+
main`).
|
1076 |
+
blocking (`bool`, *optional*, defaults to `True`):
|
1077 |
+
Whether the function should return only when the push has
|
1078 |
+
finished. Setting this to `False` will return an
|
1079 |
+
`CommandInProgress` object which has an `is_done` property. This
|
1080 |
+
property will be set to `True` when the push is finished.
|
1081 |
+
auto_lfs_prune (`bool`, *optional*, defaults to `False`):
|
1082 |
+
Whether to automatically prune files once they have been pushed
|
1083 |
+
to the remote.
|
1084 |
+
"""
|
1085 |
+
command = "git push"
|
1086 |
+
|
1087 |
+
if upstream:
|
1088 |
+
command += f" --set-upstream {upstream}"
|
1089 |
+
|
1090 |
+
number_of_commits = commits_to_push(self.local_dir, upstream)
|
1091 |
+
|
1092 |
+
if number_of_commits > 1:
|
1093 |
+
logger.warning(f"Several commits ({number_of_commits}) will be pushed upstream.")
|
1094 |
+
if blocking:
|
1095 |
+
logger.warning("The progress bars may be unreliable.")
|
1096 |
+
|
1097 |
+
try:
|
1098 |
+
with _lfs_log_progress():
|
1099 |
+
process = subprocess.Popen(
|
1100 |
+
command.split(),
|
1101 |
+
stderr=subprocess.PIPE,
|
1102 |
+
stdout=subprocess.PIPE,
|
1103 |
+
encoding="utf-8",
|
1104 |
+
cwd=self.local_dir,
|
1105 |
+
)
|
1106 |
+
|
1107 |
+
if blocking:
|
1108 |
+
stdout, stderr = process.communicate()
|
1109 |
+
return_code = process.poll()
|
1110 |
+
process.kill()
|
1111 |
+
|
1112 |
+
if len(stderr):
|
1113 |
+
logger.warning(stderr)
|
1114 |
+
|
1115 |
+
if return_code:
|
1116 |
+
raise subprocess.CalledProcessError(return_code, process.args, output=stdout, stderr=stderr)
|
1117 |
+
|
1118 |
+
except subprocess.CalledProcessError as exc:
|
1119 |
+
raise EnvironmentError(exc.stderr)
|
1120 |
+
|
1121 |
+
if not blocking:
|
1122 |
+
|
1123 |
+
def status_method():
|
1124 |
+
status = process.poll()
|
1125 |
+
if status is None:
|
1126 |
+
return -1
|
1127 |
+
else:
|
1128 |
+
return status
|
1129 |
+
|
1130 |
+
command_in_progress = CommandInProgress(
|
1131 |
+
"push",
|
1132 |
+
is_done_method=lambda: process.poll() is not None,
|
1133 |
+
status_method=status_method,
|
1134 |
+
process=process,
|
1135 |
+
post_method=self.lfs_prune if auto_lfs_prune else None,
|
1136 |
+
)
|
1137 |
+
|
1138 |
+
self.command_queue.append(command_in_progress)
|
1139 |
+
|
1140 |
+
return self.git_head_commit_url(), command_in_progress
|
1141 |
+
|
1142 |
+
if auto_lfs_prune:
|
1143 |
+
self.lfs_prune()
|
1144 |
+
|
1145 |
+
return self.git_head_commit_url()
|
1146 |
+
|
1147 |
+
def git_checkout(self, revision: str, create_branch_ok: bool = False):
|
1148 |
+
"""
|
1149 |
+
git checkout a given revision
|
1150 |
+
|
1151 |
+
Specifying `create_branch_ok` to `True` will create the branch to the
|
1152 |
+
given revision if that revision doesn't exist.
|
1153 |
+
|
1154 |
+
Args:
|
1155 |
+
revision (`str`):
|
1156 |
+
The revision to checkout.
|
1157 |
+
create_branch_ok (`str`, *optional*, defaults to `False`):
|
1158 |
+
Whether creating a branch named with the `revision` passed at
|
1159 |
+
the current checked-out reference if `revision` isn't an
|
1160 |
+
existing revision is allowed.
|
1161 |
+
"""
|
1162 |
+
try:
|
1163 |
+
result = run_subprocess(f"git checkout {revision}", self.local_dir)
|
1164 |
+
logger.warning(f"Checked out {revision} from {self.current_branch}.")
|
1165 |
+
logger.warning(result.stdout)
|
1166 |
+
except subprocess.CalledProcessError as exc:
|
1167 |
+
if not create_branch_ok:
|
1168 |
+
raise EnvironmentError(exc.stderr)
|
1169 |
+
else:
|
1170 |
+
try:
|
1171 |
+
result = run_subprocess(f"git checkout -b {revision}", self.local_dir)
|
1172 |
+
logger.warning(
|
1173 |
+
f"Revision `{revision}` does not exist. Created and checked out branch `{revision}`."
|
1174 |
+
)
|
1175 |
+
logger.warning(result.stdout)
|
1176 |
+
except subprocess.CalledProcessError as exc:
|
1177 |
+
raise EnvironmentError(exc.stderr)
|
1178 |
+
|
1179 |
+
def tag_exists(self, tag_name: str, remote: Optional[str] = None) -> bool:
|
1180 |
+
"""
|
1181 |
+
Check if a tag exists or not.
|
1182 |
+
|
1183 |
+
Args:
|
1184 |
+
tag_name (`str`):
|
1185 |
+
The name of the tag to check.
|
1186 |
+
remote (`str`, *optional*):
|
1187 |
+
Whether to check if the tag exists on a remote. This parameter
|
1188 |
+
should be the identifier of the remote.
|
1189 |
+
|
1190 |
+
Returns:
|
1191 |
+
`bool`: Whether the tag exists.
|
1192 |
+
"""
|
1193 |
+
if remote:
|
1194 |
+
try:
|
1195 |
+
result = run_subprocess(f"git ls-remote origin refs/tags/{tag_name}", self.local_dir).stdout.strip()
|
1196 |
+
except subprocess.CalledProcessError as exc:
|
1197 |
+
raise EnvironmentError(exc.stderr)
|
1198 |
+
|
1199 |
+
return len(result) != 0
|
1200 |
+
else:
|
1201 |
+
try:
|
1202 |
+
git_tags = run_subprocess("git tag", self.local_dir).stdout.strip()
|
1203 |
+
except subprocess.CalledProcessError as exc:
|
1204 |
+
raise EnvironmentError(exc.stderr)
|
1205 |
+
|
1206 |
+
git_tags = git_tags.split("\n")
|
1207 |
+
return tag_name in git_tags
|
1208 |
+
|
1209 |
+
def delete_tag(self, tag_name: str, remote: Optional[str] = None) -> bool:
|
1210 |
+
"""
|
1211 |
+
Delete a tag, both local and remote, if it exists
|
1212 |
+
|
1213 |
+
Args:
|
1214 |
+
tag_name (`str`):
|
1215 |
+
The tag name to delete.
|
1216 |
+
remote (`str`, *optional*):
|
1217 |
+
The remote on which to delete the tag.
|
1218 |
+
|
1219 |
+
Returns:
|
1220 |
+
`bool`: `True` if deleted, `False` if the tag didn't exist.
|
1221 |
+
If remote is not passed, will just be updated locally
|
1222 |
+
"""
|
1223 |
+
delete_locally = True
|
1224 |
+
delete_remotely = True
|
1225 |
+
|
1226 |
+
if not self.tag_exists(tag_name):
|
1227 |
+
delete_locally = False
|
1228 |
+
|
1229 |
+
if not self.tag_exists(tag_name, remote=remote):
|
1230 |
+
delete_remotely = False
|
1231 |
+
|
1232 |
+
if delete_locally:
|
1233 |
+
try:
|
1234 |
+
run_subprocess(["git", "tag", "-d", tag_name], self.local_dir).stdout.strip()
|
1235 |
+
except subprocess.CalledProcessError as exc:
|
1236 |
+
raise EnvironmentError(exc.stderr)
|
1237 |
+
|
1238 |
+
if remote and delete_remotely:
|
1239 |
+
try:
|
1240 |
+
run_subprocess(f"git push {remote} --delete {tag_name}", self.local_dir).stdout.strip()
|
1241 |
+
except subprocess.CalledProcessError as exc:
|
1242 |
+
raise EnvironmentError(exc.stderr)
|
1243 |
+
|
1244 |
+
return True
|
1245 |
+
|
1246 |
+
def add_tag(self, tag_name: str, message: Optional[str] = None, remote: Optional[str] = None):
|
1247 |
+
"""
|
1248 |
+
Add a tag at the current head and push it
|
1249 |
+
|
1250 |
+
If remote is None, will just be updated locally
|
1251 |
+
|
1252 |
+
If no message is provided, the tag will be lightweight. if a message is
|
1253 |
+
provided, the tag will be annotated.
|
1254 |
+
|
1255 |
+
Args:
|
1256 |
+
tag_name (`str`):
|
1257 |
+
The name of the tag to be added.
|
1258 |
+
message (`str`, *optional*):
|
1259 |
+
The message that accompanies the tag. The tag will turn into an
|
1260 |
+
annotated tag if a message is passed.
|
1261 |
+
remote (`str`, *optional*):
|
1262 |
+
The remote on which to add the tag.
|
1263 |
+
"""
|
1264 |
+
if message:
|
1265 |
+
tag_args = ["git", "tag", "-a", tag_name, "-m", message]
|
1266 |
+
else:
|
1267 |
+
tag_args = ["git", "tag", tag_name]
|
1268 |
+
|
1269 |
+
try:
|
1270 |
+
run_subprocess(tag_args, self.local_dir).stdout.strip()
|
1271 |
+
except subprocess.CalledProcessError as exc:
|
1272 |
+
raise EnvironmentError(exc.stderr)
|
1273 |
+
|
1274 |
+
if remote:
|
1275 |
+
try:
|
1276 |
+
run_subprocess(f"git push {remote} {tag_name}", self.local_dir).stdout.strip()
|
1277 |
+
except subprocess.CalledProcessError as exc:
|
1278 |
+
raise EnvironmentError(exc.stderr)
|
1279 |
+
|
1280 |
+
def is_repo_clean(self) -> bool:
|
1281 |
+
"""
|
1282 |
+
Return whether or not the git status is clean or not
|
1283 |
+
|
1284 |
+
Returns:
|
1285 |
+
`bool`: `True` if the git status is clean, `False` otherwise.
|
1286 |
+
"""
|
1287 |
+
try:
|
1288 |
+
git_status = run_subprocess("git status --porcelain", self.local_dir).stdout.strip()
|
1289 |
+
except subprocess.CalledProcessError as exc:
|
1290 |
+
raise EnvironmentError(exc.stderr)
|
1291 |
+
|
1292 |
+
return len(git_status) == 0
|
1293 |
+
|
1294 |
+
def push_to_hub(
|
1295 |
+
self,
|
1296 |
+
commit_message: str = "commit files to HF hub",
|
1297 |
+
blocking: bool = True,
|
1298 |
+
clean_ok: bool = True,
|
1299 |
+
auto_lfs_prune: bool = False,
|
1300 |
+
) -> Union[None, str, Tuple[str, CommandInProgress]]:
|
1301 |
+
"""
|
1302 |
+
Helper to add, commit, and push files to remote repository on the
|
1303 |
+
HuggingFace Hub. Will automatically track large files (>10MB).
|
1304 |
+
|
1305 |
+
Args:
|
1306 |
+
commit_message (`str`):
|
1307 |
+
Message to use for the commit.
|
1308 |
+
blocking (`bool`, *optional*, defaults to `True`):
|
1309 |
+
Whether the function should return only when the `git push` has
|
1310 |
+
finished.
|
1311 |
+
clean_ok (`bool`, *optional*, defaults to `True`):
|
1312 |
+
If True, this function will return None if the repo is
|
1313 |
+
untouched. Default behavior is to fail because the git command
|
1314 |
+
fails.
|
1315 |
+
auto_lfs_prune (`bool`, *optional*, defaults to `False`):
|
1316 |
+
Whether to automatically prune files once they have been pushed
|
1317 |
+
to the remote.
|
1318 |
+
"""
|
1319 |
+
if clean_ok and self.is_repo_clean():
|
1320 |
+
logger.info("Repo currently clean. Ignoring push_to_hub")
|
1321 |
+
return None
|
1322 |
+
self.git_add(auto_lfs_track=True)
|
1323 |
+
self.git_commit(commit_message)
|
1324 |
+
return self.git_push(
|
1325 |
+
upstream=f"origin {self.current_branch}",
|
1326 |
+
blocking=blocking,
|
1327 |
+
auto_lfs_prune=auto_lfs_prune,
|
1328 |
+
)
|
1329 |
+
|
1330 |
+
@contextmanager
|
1331 |
+
def commit(
|
1332 |
+
self,
|
1333 |
+
commit_message: str,
|
1334 |
+
branch: Optional[str] = None,
|
1335 |
+
track_large_files: bool = True,
|
1336 |
+
blocking: bool = True,
|
1337 |
+
auto_lfs_prune: bool = False,
|
1338 |
+
):
|
1339 |
+
"""
|
1340 |
+
Context manager utility to handle committing to a repository. This
|
1341 |
+
automatically tracks large files (>10Mb) with git-lfs. Set the
|
1342 |
+
`track_large_files` argument to `False` if you wish to ignore that
|
1343 |
+
behavior.
|
1344 |
+
|
1345 |
+
Args:
|
1346 |
+
commit_message (`str`):
|
1347 |
+
Message to use for the commit.
|
1348 |
+
branch (`str`, *optional*):
|
1349 |
+
The branch on which the commit will appear. This branch will be
|
1350 |
+
checked-out before any operation.
|
1351 |
+
track_large_files (`bool`, *optional*, defaults to `True`):
|
1352 |
+
Whether to automatically track large files or not. Will do so by
|
1353 |
+
default.
|
1354 |
+
blocking (`bool`, *optional*, defaults to `True`):
|
1355 |
+
Whether the function should return only when the `git push` has
|
1356 |
+
finished.
|
1357 |
+
auto_lfs_prune (`bool`, defaults to `True`):
|
1358 |
+
Whether to automatically prune files once they have been pushed
|
1359 |
+
to the remote.
|
1360 |
+
|
1361 |
+
Examples:
|
1362 |
+
|
1363 |
+
```python
|
1364 |
+
>>> with Repository(
|
1365 |
+
... "text-files",
|
1366 |
+
... clone_from="<user>/text-files",
|
1367 |
+
... token=True,
|
1368 |
+
>>> ).commit("My first file :)"):
|
1369 |
+
... with open("file.txt", "w+") as f:
|
1370 |
+
... f.write(json.dumps({"hey": 8}))
|
1371 |
+
|
1372 |
+
>>> import torch
|
1373 |
+
|
1374 |
+
>>> model = torch.nn.Transformer()
|
1375 |
+
>>> with Repository(
|
1376 |
+
... "torch-model",
|
1377 |
+
... clone_from="<user>/torch-model",
|
1378 |
+
... token=True,
|
1379 |
+
>>> ).commit("My cool model :)"):
|
1380 |
+
... torch.save(model.state_dict(), "model.pt")
|
1381 |
+
```
|
1382 |
+
|
1383 |
+
"""
|
1384 |
+
|
1385 |
+
files_to_stage = files_to_be_staged(".", folder=self.local_dir)
|
1386 |
+
|
1387 |
+
if len(files_to_stage):
|
1388 |
+
files_in_msg = str(files_to_stage[:5])[:-1] + ", ...]" if len(files_to_stage) > 5 else str(files_to_stage)
|
1389 |
+
logger.error(
|
1390 |
+
"There exists some updated files in the local repository that are not"
|
1391 |
+
f" committed: {files_in_msg}. This may lead to errors if checking out"
|
1392 |
+
" a branch. These files and their modifications will be added to the"
|
1393 |
+
" current commit."
|
1394 |
+
)
|
1395 |
+
|
1396 |
+
if branch is not None:
|
1397 |
+
self.git_checkout(branch, create_branch_ok=True)
|
1398 |
+
|
1399 |
+
if is_tracked_upstream(self.local_dir):
|
1400 |
+
logger.warning("Pulling changes ...")
|
1401 |
+
self.git_pull(rebase=True)
|
1402 |
+
else:
|
1403 |
+
logger.warning(f"The current branch has no upstream branch. Will push to 'origin {self.current_branch}'")
|
1404 |
+
|
1405 |
+
current_working_directory = os.getcwd()
|
1406 |
+
os.chdir(os.path.join(current_working_directory, self.local_dir))
|
1407 |
+
|
1408 |
+
try:
|
1409 |
+
yield self
|
1410 |
+
finally:
|
1411 |
+
self.git_add(auto_lfs_track=track_large_files)
|
1412 |
+
|
1413 |
+
try:
|
1414 |
+
self.git_commit(commit_message)
|
1415 |
+
except OSError as e:
|
1416 |
+
# If no changes are detected, there is nothing to commit.
|
1417 |
+
if "nothing to commit" not in str(e):
|
1418 |
+
raise e
|
1419 |
+
|
1420 |
+
try:
|
1421 |
+
self.git_push(
|
1422 |
+
upstream=f"origin {self.current_branch}",
|
1423 |
+
blocking=blocking,
|
1424 |
+
auto_lfs_prune=auto_lfs_prune,
|
1425 |
+
)
|
1426 |
+
except OSError as e:
|
1427 |
+
# If no changes are detected, there is nothing to commit.
|
1428 |
+
if "could not read Username" in str(e):
|
1429 |
+
raise OSError("Couldn't authenticate user for push. Did you set `token` to `True`?") from e
|
1430 |
+
else:
|
1431 |
+
raise e
|
1432 |
+
|
1433 |
+
os.chdir(current_working_directory)
|
1434 |
+
|
1435 |
+
def repocard_metadata_load(self) -> Optional[Dict]:
|
1436 |
+
filepath = os.path.join(self.local_dir, REPOCARD_NAME)
|
1437 |
+
if os.path.isfile(filepath):
|
1438 |
+
return metadata_load(filepath)
|
1439 |
+
return None
|
1440 |
+
|
1441 |
+
def repocard_metadata_save(self, data: Dict) -> None:
|
1442 |
+
return metadata_save(os.path.join(self.local_dir, REPOCARD_NAME), data)
|
1443 |
+
|
1444 |
+
@property
|
1445 |
+
def commands_failed(self):
|
1446 |
+
"""
|
1447 |
+
Returns the asynchronous commands that failed.
|
1448 |
+
"""
|
1449 |
+
return [c for c in self.command_queue if c.status > 0]
|
1450 |
+
|
1451 |
+
@property
|
1452 |
+
def commands_in_progress(self):
|
1453 |
+
"""
|
1454 |
+
Returns the asynchronous commands that are currently in progress.
|
1455 |
+
"""
|
1456 |
+
return [c for c in self.command_queue if not c.is_done]
|
1457 |
+
|
1458 |
+
def wait_for_commands(self):
|
1459 |
+
"""
|
1460 |
+
Blocking method: blocks all subsequent execution until all commands have
|
1461 |
+
been processed.
|
1462 |
+
"""
|
1463 |
+
index = 0
|
1464 |
+
for command_failed in self.commands_failed:
|
1465 |
+
logger.error(f"The {command_failed.title} command with PID {command_failed._process.pid} failed.")
|
1466 |
+
logger.error(command_failed.stderr)
|
1467 |
+
|
1468 |
+
while self.commands_in_progress:
|
1469 |
+
if index % 10 == 0:
|
1470 |
+
logger.warning(
|
1471 |
+
f"Waiting for the following commands to finish before shutting down: {self.commands_in_progress}."
|
1472 |
+
)
|
1473 |
+
|
1474 |
+
index += 1
|
1475 |
+
|
1476 |
+
time.sleep(1)
|