reach-vb HF staff commited on
Commit
254a3c6
1 Parent(s): 50f8b94

72cce7d5913aad64af86c29560a31d664b4ae723604106669ab386e118a0be60

Browse files
Files changed (50) hide show
  1. lib/python3.11/site-packages/huggingface_hub/_tensorboard_logger.py +168 -0
  2. lib/python3.11/site-packages/huggingface_hub/_webhooks_payload.py +115 -0
  3. lib/python3.11/site-packages/huggingface_hub/_webhooks_server.py +379 -0
  4. lib/python3.11/site-packages/huggingface_hub/commands/__init__.py +27 -0
  5. lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/__init__.cpython-311.pyc +0 -0
  6. lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/_cli_utils.cpython-311.pyc +0 -0
  7. lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/delete_cache.cpython-311.pyc +0 -0
  8. lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/download.cpython-311.pyc +0 -0
  9. lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/env.cpython-311.pyc +0 -0
  10. lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/huggingface_cli.cpython-311.pyc +0 -0
  11. lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/lfs.cpython-311.pyc +0 -0
  12. lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/scan_cache.cpython-311.pyc +0 -0
  13. lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/upload.cpython-311.pyc +0 -0
  14. lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/user.cpython-311.pyc +0 -0
  15. lib/python3.11/site-packages/huggingface_hub/commands/_cli_utils.py +63 -0
  16. lib/python3.11/site-packages/huggingface_hub/commands/delete_cache.py +427 -0
  17. lib/python3.11/site-packages/huggingface_hub/commands/download.py +214 -0
  18. lib/python3.11/site-packages/huggingface_hub/commands/env.py +35 -0
  19. lib/python3.11/site-packages/huggingface_hub/commands/huggingface_cli.py +53 -0
  20. lib/python3.11/site-packages/huggingface_hub/commands/lfs.py +199 -0
  21. lib/python3.11/site-packages/huggingface_hub/commands/scan_cache.py +138 -0
  22. lib/python3.11/site-packages/huggingface_hub/commands/upload.py +297 -0
  23. lib/python3.11/site-packages/huggingface_hub/commands/user.py +188 -0
  24. lib/python3.11/site-packages/huggingface_hub/community.py +354 -0
  25. lib/python3.11/site-packages/huggingface_hub/constants.py +213 -0
  26. lib/python3.11/site-packages/huggingface_hub/fastai_utils.py +425 -0
  27. lib/python3.11/site-packages/huggingface_hub/file_download.py +1727 -0
  28. lib/python3.11/site-packages/huggingface_hub/hf_api.py +0 -0
  29. lib/python3.11/site-packages/huggingface_hub/hf_file_system.py +670 -0
  30. lib/python3.11/site-packages/huggingface_hub/hub_mixin.py +368 -0
  31. lib/python3.11/site-packages/huggingface_hub/inference/__init__.py +0 -0
  32. lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/__init__.cpython-311.pyc +0 -0
  33. lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_client.cpython-311.pyc +0 -0
  34. lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_common.cpython-311.pyc +0 -0
  35. lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_text_generation.cpython-311.pyc +0 -0
  36. lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_types.cpython-311.pyc +0 -0
  37. lib/python3.11/site-packages/huggingface_hub/inference/_client.py +1990 -0
  38. lib/python3.11/site-packages/huggingface_hub/inference/_common.py +327 -0
  39. lib/python3.11/site-packages/huggingface_hub/inference/_generated/__init__.py +0 -0
  40. lib/python3.11/site-packages/huggingface_hub/inference/_generated/__pycache__/__init__.cpython-311.pyc +0 -0
  41. lib/python3.11/site-packages/huggingface_hub/inference/_generated/__pycache__/_async_client.cpython-311.pyc +0 -0
  42. lib/python3.11/site-packages/huggingface_hub/inference/_generated/_async_client.py +2020 -0
  43. lib/python3.11/site-packages/huggingface_hub/inference/_text_generation.py +546 -0
  44. lib/python3.11/site-packages/huggingface_hub/inference/_types.py +183 -0
  45. lib/python3.11/site-packages/huggingface_hub/inference_api.py +217 -0
  46. lib/python3.11/site-packages/huggingface_hub/keras_mixin.py +480 -0
  47. lib/python3.11/site-packages/huggingface_hub/lfs.py +522 -0
  48. lib/python3.11/site-packages/huggingface_hub/repocard.py +818 -0
  49. lib/python3.11/site-packages/huggingface_hub/repocard_data.py +711 -0
  50. lib/python3.11/site-packages/huggingface_hub/repository.py +1476 -0
lib/python3.11/site-packages/huggingface_hub/_tensorboard_logger.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Contains a logger to push training logs to the Hub, using Tensorboard."""
15
+ from pathlib import Path
16
+ from typing import TYPE_CHECKING, List, Optional, Union
17
+
18
+ from huggingface_hub._commit_scheduler import CommitScheduler
19
+
20
+ from .utils import experimental, is_tensorboard_available
21
+
22
+
23
+ if is_tensorboard_available():
24
+ from tensorboardX import SummaryWriter
25
+
26
+ # TODO: clarify: should we import from torch.utils.tensorboard ?
27
+
28
+ else:
29
+ SummaryWriter = object # Dummy class to avoid failing at import. Will raise on instance creation.
30
+
31
+ if TYPE_CHECKING:
32
+ from tensorboardX import SummaryWriter
33
+
34
+
35
+ class HFSummaryWriter(SummaryWriter):
36
+ """
37
+ Wrapper around the tensorboard's `SummaryWriter` to push training logs to the Hub.
38
+
39
+ Data is logged locally and then pushed to the Hub asynchronously. Pushing data to the Hub is done in a separate
40
+ thread to avoid blocking the training script. In particular, if the upload fails for any reason (e.g. a connection
41
+ issue), the main script will not be interrupted. Data is automatically pushed to the Hub every `commit_every`
42
+ minutes (default to every 5 minutes).
43
+
44
+ <Tip warning={true}>
45
+
46
+ `HFSummaryWriter` is experimental. Its API is subject to change in the future without prior notice.
47
+
48
+ </Tip>
49
+
50
+ Args:
51
+ repo_id (`str`):
52
+ The id of the repo to which the logs will be pushed.
53
+ logdir (`str`, *optional*):
54
+ The directory where the logs will be written. If not specified, a local directory will be created by the
55
+ underlying `SummaryWriter` object.
56
+ commit_every (`int` or `float`, *optional*):
57
+ The frequency (in minutes) at which the logs will be pushed to the Hub. Defaults to 5 minutes.
58
+ squash_history (`bool`, *optional*):
59
+ Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
60
+ useful to avoid degraded performances on the repo when it grows too large.
61
+ repo_type (`str`, *optional*):
62
+ The type of the repo to which the logs will be pushed. Defaults to "model".
63
+ repo_revision (`str`, *optional*):
64
+ The revision of the repo to which the logs will be pushed. Defaults to "main".
65
+ repo_private (`bool`, *optional*):
66
+ Whether to create a private repo or not. Defaults to False. This argument is ignored if the repo already
67
+ exists.
68
+ path_in_repo (`str`, *optional*):
69
+ The path to the folder in the repo where the logs will be pushed. Defaults to "tensorboard/".
70
+ repo_allow_patterns (`List[str]` or `str`, *optional*):
71
+ A list of patterns to include in the upload. Defaults to `"*.tfevents.*"`. Check out the
72
+ [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details.
73
+ repo_ignore_patterns (`List[str]` or `str`, *optional*):
74
+ A list of patterns to exclude in the upload. Check out the
75
+ [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details.
76
+ token (`str`, *optional*):
77
+ Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more
78
+ details
79
+ kwargs:
80
+ Additional keyword arguments passed to `SummaryWriter`.
81
+
82
+ Examples:
83
+ ```py
84
+ >>> from huggingface_hub import HFSummaryWriter
85
+
86
+ # Logs are automatically pushed every 15 minutes
87
+ >>> logger = HFSummaryWriter(repo_id="test_hf_logger", commit_every=15)
88
+ >>> logger.add_scalar("a", 1)
89
+ >>> logger.add_scalar("b", 2)
90
+ ...
91
+
92
+ # You can also trigger a push manually
93
+ >>> logger.scheduler.trigger()
94
+ ```
95
+
96
+ ```py
97
+ >>> from huggingface_hub import HFSummaryWriter
98
+
99
+ # Logs are automatically pushed every 5 minutes (default) + when exiting the context manager
100
+ >>> with HFSummaryWriter(repo_id="test_hf_logger") as logger:
101
+ ... logger.add_scalar("a", 1)
102
+ ... logger.add_scalar("b", 2)
103
+ ```
104
+ """
105
+
106
+ @experimental
107
+ def __new__(cls, *args, **kwargs) -> "HFSummaryWriter":
108
+ if not is_tensorboard_available():
109
+ raise ImportError(
110
+ "You must have `tensorboard` installed to use `HFSummaryWriter`. Please run `pip install --upgrade"
111
+ " tensorboardX` first."
112
+ )
113
+ return super().__new__(cls)
114
+
115
+ def __init__(
116
+ self,
117
+ repo_id: str,
118
+ *,
119
+ logdir: Optional[str] = None,
120
+ commit_every: Union[int, float] = 5,
121
+ squash_history: bool = False,
122
+ repo_type: Optional[str] = None,
123
+ repo_revision: Optional[str] = None,
124
+ repo_private: bool = False,
125
+ path_in_repo: Optional[str] = "tensorboard",
126
+ repo_allow_patterns: Optional[Union[List[str], str]] = "*.tfevents.*",
127
+ repo_ignore_patterns: Optional[Union[List[str], str]] = None,
128
+ token: Optional[str] = None,
129
+ **kwargs,
130
+ ):
131
+ # Initialize SummaryWriter
132
+ super().__init__(logdir=logdir, **kwargs)
133
+
134
+ # Check logdir has been correctly initialized and fail early otherwise. In practice, SummaryWriter takes care of it.
135
+ if not isinstance(self.logdir, str):
136
+ raise ValueError(f"`self.logdir` must be a string. Got '{self.logdir}' of type {type(self.logdir)}.")
137
+
138
+ # Append logdir name to `path_in_repo`
139
+ if path_in_repo is None or path_in_repo == "":
140
+ path_in_repo = Path(self.logdir).name
141
+ else:
142
+ path_in_repo = path_in_repo.strip("/") + "/" + Path(self.logdir).name
143
+
144
+ # Initialize scheduler
145
+ self.scheduler = CommitScheduler(
146
+ folder_path=self.logdir,
147
+ path_in_repo=path_in_repo,
148
+ repo_id=repo_id,
149
+ repo_type=repo_type,
150
+ revision=repo_revision,
151
+ private=repo_private,
152
+ token=token,
153
+ allow_patterns=repo_allow_patterns,
154
+ ignore_patterns=repo_ignore_patterns,
155
+ every=commit_every,
156
+ squash_history=squash_history,
157
+ )
158
+
159
+ # Exposing some high-level info at root level
160
+ self.repo_id = self.scheduler.repo_id
161
+ self.repo_type = self.scheduler.repo_type
162
+ self.repo_revision = self.scheduler.revision
163
+
164
+ def __exit__(self, exc_type, exc_val, exc_tb):
165
+ """Push to hub in a non-blocking way when exiting the logger's context manager."""
166
+ super().__exit__(exc_type, exc_val, exc_tb)
167
+ future = self.scheduler.trigger()
168
+ future.result()
lib/python3.11/site-packages/huggingface_hub/_webhooks_payload.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Contains data structures to parse the webhooks payload."""
16
+ from typing import List, Literal, Optional
17
+
18
+ from pydantic import BaseModel
19
+
20
+
21
+ # This is an adaptation of the ReportV3 interface implemented in moon-landing. V0, V1 and V2 have been ignored as they
22
+ # are not in used anymore. To keep in sync when format is updated in
23
+ # https://github.com/huggingface/moon-landing/blob/main/server/lib/HFWebhooks.ts (internal link).
24
+
25
+
26
+ WebhookEvent_T = Literal[
27
+ "create",
28
+ "delete",
29
+ "move",
30
+ "update",
31
+ ]
32
+ RepoChangeEvent_T = Literal[
33
+ "add",
34
+ "move",
35
+ "remove",
36
+ "update",
37
+ ]
38
+ RepoType_T = Literal[
39
+ "dataset",
40
+ "model",
41
+ "space",
42
+ ]
43
+ DiscussionStatus_T = Literal[
44
+ "closed",
45
+ "draft",
46
+ "open",
47
+ "merged",
48
+ ]
49
+ SupportedWebhookVersion = Literal[3]
50
+
51
+
52
+ class ObjectId(BaseModel):
53
+ id: str
54
+
55
+
56
+ class WebhookPayloadUrl(BaseModel):
57
+ web: str
58
+ api: Optional[str] = None
59
+
60
+
61
+ class WebhookPayloadMovedTo(BaseModel):
62
+ name: str
63
+ owner: ObjectId
64
+
65
+
66
+ class WebhookPayloadWebhook(ObjectId):
67
+ version: SupportedWebhookVersion
68
+
69
+
70
+ class WebhookPayloadEvent(BaseModel):
71
+ action: WebhookEvent_T
72
+ scope: str
73
+
74
+
75
+ class WebhookPayloadDiscussionChanges(BaseModel):
76
+ base: str
77
+ mergeCommitId: Optional[str] = None
78
+
79
+
80
+ class WebhookPayloadComment(ObjectId):
81
+ author: ObjectId
82
+ hidden: bool
83
+ content: Optional[str] = None
84
+ url: WebhookPayloadUrl
85
+
86
+
87
+ class WebhookPayloadDiscussion(ObjectId):
88
+ num: int
89
+ author: ObjectId
90
+ url: WebhookPayloadUrl
91
+ title: str
92
+ isPullRequest: bool
93
+ status: DiscussionStatus_T
94
+ changes: Optional[WebhookPayloadDiscussionChanges] = None
95
+ pinned: Optional[bool] = None
96
+
97
+
98
+ class WebhookPayloadRepo(ObjectId):
99
+ owner: ObjectId
100
+ head_sha: Optional[str] = None
101
+ name: str
102
+ private: bool
103
+ subdomain: Optional[str] = None
104
+ tags: Optional[List[str]] = None
105
+ type: Literal["dataset", "model", "space"]
106
+ url: WebhookPayloadUrl
107
+
108
+
109
+ class WebhookPayload(BaseModel):
110
+ event: WebhookPayloadEvent
111
+ repo: WebhookPayloadRepo
112
+ discussion: Optional[WebhookPayloadDiscussion] = None
113
+ comment: Optional[WebhookPayloadComment] = None
114
+ webhook: WebhookPayloadWebhook
115
+ movedTo: Optional[WebhookPayloadMovedTo] = None
lib/python3.11/site-packages/huggingface_hub/_webhooks_server.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Contains `WebhooksServer` and `webhook_endpoint` to create a webhook server easily."""
16
+ import atexit
17
+ import inspect
18
+ import os
19
+ from functools import wraps
20
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
21
+
22
+ from .utils import experimental, is_gradio_available
23
+ from .utils._deprecation import _deprecate_method
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ import gradio as gr
28
+
29
+
30
+ from fastapi import FastAPI, Request
31
+ from fastapi.responses import JSONResponse
32
+
33
+
34
+ _global_app: Optional["WebhooksServer"] = None
35
+ _is_local = os.getenv("SYSTEM") != "spaces"
36
+
37
+
38
+ @experimental
39
+ class WebhooksServer:
40
+ """
41
+ The [`WebhooksServer`] class lets you create an instance of a Gradio app that can receive Huggingface webhooks.
42
+ These webhooks can be registered using the [`~WebhooksServer.add_webhook`] decorator. Webhook endpoints are added to
43
+ the app as a POST endpoint to the FastAPI router. Once all the webhooks are registered, the `run` method has to be
44
+ called to start the app.
45
+
46
+ It is recommended to accept [`WebhookPayload`] as the first argument of the webhook function. It is a Pydantic
47
+ model that contains all the information about the webhook event. The data will be parsed automatically for you.
48
+
49
+ Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your
50
+ WebhooksServer and deploy it on a Space.
51
+
52
+ <Tip warning={true}>
53
+
54
+ `WebhooksServer` is experimental. Its API is subject to change in the future.
55
+
56
+ </Tip>
57
+
58
+ <Tip warning={true}>
59
+
60
+ You must have `gradio` installed to use `WebhooksServer` (`pip install --upgrade gradio`).
61
+
62
+ </Tip>
63
+
64
+ Args:
65
+ ui (`gradio.Blocks`, optional):
66
+ A Gradio UI instance to be used as the Space landing page. If `None`, a UI displaying instructions
67
+ about the configured webhooks is created.
68
+ webhook_secret (`str`, optional):
69
+ A secret key to verify incoming webhook requests. You can set this value to any secret you want as long as
70
+ you also configure it in your [webhooks settings panel](https://huggingface.co/settings/webhooks). You
71
+ can also set this value as the `WEBHOOK_SECRET` environment variable. If no secret is provided, the
72
+ webhook endpoints are opened without any security.
73
+
74
+ Example:
75
+
76
+ ```python
77
+ import gradio as gr
78
+ from huggingface_hub import WebhooksServer, WebhookPayload
79
+
80
+ with gr.Blocks() as ui:
81
+ ...
82
+
83
+ app = WebhooksServer(ui=ui, webhook_secret="my_secret_key")
84
+
85
+ @app.add_webhook("/say_hello")
86
+ async def hello(payload: WebhookPayload):
87
+ return {"message": "hello"}
88
+
89
+ app.run()
90
+ ```
91
+ """
92
+
93
+ def __new__(cls, *args, **kwargs) -> "WebhooksServer":
94
+ if not is_gradio_available():
95
+ raise ImportError(
96
+ "You must have `gradio` installed to use `WebhooksServer`. Please run `pip install --upgrade gradio`"
97
+ " first."
98
+ )
99
+ return super().__new__(cls)
100
+
101
+ def __init__(
102
+ self,
103
+ ui: Optional["gr.Blocks"] = None,
104
+ webhook_secret: Optional[str] = None,
105
+ ) -> None:
106
+ self._ui = ui
107
+
108
+ self.webhook_secret = webhook_secret or os.getenv("WEBHOOK_SECRET")
109
+ self.registered_webhooks: Dict[str, Callable] = {}
110
+ _warn_on_empty_secret(self.webhook_secret)
111
+
112
+ def add_webhook(self, path: Optional[str] = None) -> Callable:
113
+ """
114
+ Decorator to add a webhook to the [`WebhooksServer`] server.
115
+
116
+ Args:
117
+ path (`str`, optional):
118
+ The URL path to register the webhook function. If not provided, the function name will be used as the
119
+ path. In any case, all webhooks are registered under `/webhooks`.
120
+
121
+ Raises:
122
+ ValueError: If the provided path is already registered as a webhook.
123
+
124
+ Example:
125
+ ```python
126
+ from huggingface_hub import WebhooksServer, WebhookPayload
127
+
128
+ app = WebhooksServer()
129
+
130
+ @app.add_webhook
131
+ async def trigger_training(payload: WebhookPayload):
132
+ if payload.repo.type == "dataset" and payload.event.action == "update":
133
+ # Trigger a training job if a dataset is updated
134
+ ...
135
+
136
+ app.run()
137
+ ```
138
+ """
139
+ # Usage: directly as decorator. Example: `@app.add_webhook`
140
+ if callable(path):
141
+ # If path is a function, it means it was used as a decorator without arguments
142
+ return self.add_webhook()(path)
143
+
144
+ # Usage: provide a path. Example: `@app.add_webhook(...)`
145
+ @wraps(FastAPI.post)
146
+ def _inner_post(*args, **kwargs):
147
+ func = args[0]
148
+ abs_path = f"/webhooks/{(path or func.__name__).strip('/')}"
149
+ if abs_path in self.registered_webhooks:
150
+ raise ValueError(f"Webhook {abs_path} already exists.")
151
+ self.registered_webhooks[abs_path] = func
152
+
153
+ return _inner_post
154
+
155
+ def launch(self, prevent_thread_lock: bool = False, **launch_kwargs: Any) -> None:
156
+ """Launch the Gradio app and register webhooks to the underlying FastAPI server.
157
+
158
+ Input parameters are forwarded to Gradio when launching the app.
159
+ """
160
+ ui = self._ui or self._get_default_ui()
161
+
162
+ # Start Gradio App
163
+ # - as non-blocking so that webhooks can be added afterwards
164
+ # - as shared if launch locally (to debug webhooks)
165
+ launch_kwargs.setdefault("share", _is_local)
166
+ self.fastapi_app, _, _ = ui.launch(prevent_thread_lock=True, **launch_kwargs)
167
+
168
+ # Register webhooks to FastAPI app
169
+ for path, func in self.registered_webhooks.items():
170
+ # Add secret check if required
171
+ if self.webhook_secret is not None:
172
+ func = _wrap_webhook_to_check_secret(func, webhook_secret=self.webhook_secret)
173
+
174
+ # Add route to FastAPI app
175
+ self.fastapi_app.post(path)(func)
176
+
177
+ # Print instructions and block main thread
178
+ url = (ui.share_url or ui.local_url).strip("/")
179
+ message = "\nWebhooks are correctly setup and ready to use:"
180
+ message += "\n" + "\n".join(f" - POST {url}{webhook}" for webhook in self.registered_webhooks)
181
+ message += "\nGo to https://huggingface.co/settings/webhooks to setup your webhooks."
182
+ print(message)
183
+
184
+ if not prevent_thread_lock:
185
+ ui.block_thread()
186
+
187
+ @_deprecate_method(version="0.23", message="Use `WebhooksServer.launch` instead.")
188
+ def run(self) -> None:
189
+ return self.launch()
190
+
191
+ def _get_default_ui(self) -> "gr.Blocks":
192
+ """Default UI if not provided (lists webhooks and provides basic instructions)."""
193
+ import gradio as gr
194
+
195
+ with gr.Blocks() as ui:
196
+ gr.Markdown("# This is an app to process 🤗 Webhooks")
197
+ gr.Markdown(
198
+ "Webhooks are a foundation for MLOps-related features. They allow you to listen for new changes on"
199
+ " specific repos or to all repos belonging to particular set of users/organizations (not just your"
200
+ " repos, but any repo). Check out this [guide](https://huggingface.co/docs/hub/webhooks) to get to"
201
+ " know more about webhooks on the Huggingface Hub."
202
+ )
203
+ gr.Markdown(
204
+ f"{len(self.registered_webhooks)} webhook(s) are registered:"
205
+ + "\n\n"
206
+ + "\n ".join(
207
+ f"- [{webhook_path}]({_get_webhook_doc_url(webhook.__name__, webhook_path)})"
208
+ for webhook_path, webhook in self.registered_webhooks.items()
209
+ )
210
+ )
211
+ gr.Markdown(
212
+ "Go to https://huggingface.co/settings/webhooks to setup your webhooks."
213
+ + "\nYou app is running locally. Please look at the logs to check the full URL you need to set."
214
+ if _is_local
215
+ else (
216
+ "\nThis app is running on a Space. You can find the corresponding URL in the options menu"
217
+ " (top-right) > 'Embed the Space'. The URL looks like 'https://{username}-{repo_name}.hf.space'."
218
+ )
219
+ )
220
+ return ui
221
+
222
+
223
+ @experimental
224
+ def webhook_endpoint(path: Optional[str] = None) -> Callable:
225
+ """Decorator to start a [`WebhooksServer`] and register the decorated function as a webhook endpoint.
226
+
227
+ This is a helper to get started quickly. If you need more flexibility (custom landing page or webhook secret),
228
+ you can use [`WebhooksServer`] directly. You can register multiple webhook endpoints (to the same server) by using
229
+ this decorator multiple times.
230
+
231
+ Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your
232
+ server and deploy it on a Space.
233
+
234
+ <Tip warning={true}>
235
+
236
+ `webhook_endpoint` is experimental. Its API is subject to change in the future.
237
+
238
+ </Tip>
239
+
240
+ <Tip warning={true}>
241
+
242
+ You must have `gradio` installed to use `webhook_endpoint` (`pip install --upgrade gradio`).
243
+
244
+ </Tip>
245
+
246
+ Args:
247
+ path (`str`, optional):
248
+ The URL path to register the webhook function. If not provided, the function name will be used as the path.
249
+ In any case, all webhooks are registered under `/webhooks`.
250
+
251
+ Examples:
252
+ The default usage is to register a function as a webhook endpoint. The function name will be used as the path.
253
+ The server will be started automatically at exit (i.e. at the end of the script).
254
+
255
+ ```python
256
+ from huggingface_hub import webhook_endpoint, WebhookPayload
257
+
258
+ @webhook_endpoint
259
+ async def trigger_training(payload: WebhookPayload):
260
+ if payload.repo.type == "dataset" and payload.event.action == "update":
261
+ # Trigger a training job if a dataset is updated
262
+ ...
263
+
264
+ # Server is automatically started at the end of the script.
265
+ ```
266
+
267
+ Advanced usage: register a function as a webhook endpoint and start the server manually. This is useful if you
268
+ are running it in a notebook.
269
+
270
+ ```python
271
+ from huggingface_hub import webhook_endpoint, WebhookPayload
272
+
273
+ @webhook_endpoint
274
+ async def trigger_training(payload: WebhookPayload):
275
+ if payload.repo.type == "dataset" and payload.event.action == "update":
276
+ # Trigger a training job if a dataset is updated
277
+ ...
278
+
279
+ # Start the server manually
280
+ trigger_training.run()
281
+ ```
282
+ """
283
+ if callable(path):
284
+ # If path is a function, it means it was used as a decorator without arguments
285
+ return webhook_endpoint()(path)
286
+
287
+ @wraps(WebhooksServer.add_webhook)
288
+ def _inner(func: Callable) -> Callable:
289
+ app = _get_global_app()
290
+ app.add_webhook(path)(func)
291
+ if len(app.registered_webhooks) == 1:
292
+ # Register `app.run` to run at exit (only once)
293
+ atexit.register(app.run)
294
+
295
+ @wraps(app.run)
296
+ def _run_now():
297
+ # Run the app directly (without waiting atexit)
298
+ atexit.unregister(app.run)
299
+ app.run()
300
+
301
+ func.run = _run_now # type: ignore
302
+ return func
303
+
304
+ return _inner
305
+
306
+
307
+ def _get_global_app() -> WebhooksServer:
308
+ global _global_app
309
+ if _global_app is None:
310
+ _global_app = WebhooksServer()
311
+ return _global_app
312
+
313
+
314
+ def _warn_on_empty_secret(webhook_secret: Optional[str]) -> None:
315
+ if webhook_secret is None:
316
+ print("Webhook secret is not defined. This means your webhook endpoints will be open to everyone.")
317
+ print(
318
+ "To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: "
319
+ "\n\t`app = WebhooksServer(webhook_secret='my_secret', ...)`"
320
+ )
321
+ print(
322
+ "For more details about webhook secrets, please refer to"
323
+ " https://huggingface.co/docs/hub/webhooks#webhook-secret."
324
+ )
325
+ else:
326
+ print("Webhook secret is correctly defined.")
327
+
328
+
329
+ def _get_webhook_doc_url(webhook_name: str, webhook_path: str) -> str:
330
+ """Returns the anchor to a given webhook in the docs (experimental)"""
331
+ return "/docs#/default/" + webhook_name + webhook_path.replace("/", "_") + "_post"
332
+
333
+
334
+ def _wrap_webhook_to_check_secret(func: Callable, webhook_secret: str) -> Callable:
335
+ """Wraps a webhook function to check the webhook secret before calling the function.
336
+
337
+ This is a hacky way to add the `request` parameter to the function signature. Since FastAPI based itself on route
338
+ parameters to inject the values to the function, we need to hack the function signature to retrieve the `Request`
339
+ object (and hence the headers). A far cleaner solution would be to use a middleware. However, since
340
+ `fastapi==0.90.1`, a middleware cannot be added once the app has started. And since the FastAPI app is started by
341
+ Gradio internals (and not by us), we cannot add a middleware.
342
+
343
+ This method is called only when a secret has been defined by the user. If a request is sent without the
344
+ "x-webhook-secret", the function will return a 401 error (unauthorized). If the header is sent but is incorrect,
345
+ the function will return a 403 error (forbidden).
346
+
347
+ Inspired by https://stackoverflow.com/a/33112180.
348
+ """
349
+ initial_sig = inspect.signature(func)
350
+
351
+ @wraps(func)
352
+ async def _protected_func(request: Request, **kwargs):
353
+ request_secret = request.headers.get("x-webhook-secret")
354
+ if request_secret is None:
355
+ return JSONResponse({"error": "x-webhook-secret header not set."}, status_code=401)
356
+ if request_secret != webhook_secret:
357
+ return JSONResponse({"error": "Invalid webhook secret."}, status_code=403)
358
+
359
+ # Inject `request` in kwargs if required
360
+ if "request" in initial_sig.parameters:
361
+ kwargs["request"] = request
362
+
363
+ # Handle both sync and async routes
364
+ if inspect.iscoroutinefunction(func):
365
+ return await func(**kwargs)
366
+ else:
367
+ return func(**kwargs)
368
+
369
+ # Update signature to include request
370
+ if "request" not in initial_sig.parameters:
371
+ _protected_func.__signature__ = initial_sig.replace( # type: ignore
372
+ parameters=(
373
+ inspect.Parameter(name="request", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request),
374
+ )
375
+ + tuple(initial_sig.parameters.values())
376
+ )
377
+
378
+ # Return protected route
379
+ return _protected_func
lib/python3.11/site-packages/huggingface_hub/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import _SubParsersAction
17
+
18
+
19
+ class BaseHuggingfaceCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: _SubParsersAction):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.16 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/_cli_utils.cpython-311.pyc ADDED
Binary file (3.65 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/delete_cache.cpython-311.pyc ADDED
Binary file (20 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/download.cpython-311.pyc ADDED
Binary file (9.56 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/env.cpython-311.pyc ADDED
Binary file (1.67 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/huggingface_cli.cpython-311.pyc ADDED
Binary file (2.24 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/lfs.cpython-311.pyc ADDED
Binary file (9.61 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/scan_cache.cpython-311.pyc ADDED
Binary file (6.85 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/upload.cpython-311.pyc ADDED
Binary file (13.8 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/commands/__pycache__/user.cpython-311.pyc ADDED
Binary file (11.5 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/commands/_cli_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Contains a utility for good-looking prints."""
15
+ import os
16
+ from typing import List, Union
17
+
18
+
19
+ class ANSI:
20
+ """
21
+ Helper for en.wikipedia.org/wiki/ANSI_escape_code
22
+ """
23
+
24
+ _bold = "\u001b[1m"
25
+ _gray = "\u001b[90m"
26
+ _red = "\u001b[31m"
27
+ _reset = "\u001b[0m"
28
+
29
+ @classmethod
30
+ def bold(cls, s: str) -> str:
31
+ return cls._format(s, cls._bold)
32
+
33
+ @classmethod
34
+ def gray(cls, s: str) -> str:
35
+ return cls._format(s, cls._gray)
36
+
37
+ @classmethod
38
+ def red(cls, s: str) -> str:
39
+ return cls._format(s, cls._bold + cls._red)
40
+
41
+ @classmethod
42
+ def _format(cls, s: str, code: str) -> str:
43
+ if os.environ.get("NO_COLOR"):
44
+ # See https://no-color.org/
45
+ return s
46
+ return f"{code}{s}{cls._reset}"
47
+
48
+
49
+ def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
50
+ """
51
+ Inspired by:
52
+
53
+ - stackoverflow.com/a/8356620/593036
54
+ - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
55
+ """
56
+ col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
57
+ row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
58
+ lines = []
59
+ lines.append(row_format.format(*headers))
60
+ lines.append(row_format.format(*["-" * w for w in col_widths]))
61
+ for row in rows:
62
+ lines.append(row_format.format(*row))
63
+ return "\n".join(lines)
lib/python3.11/site-packages/huggingface_hub/commands/delete_cache.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Contains command to delete some revisions from the HF cache directory.
16
+
17
+ Usage:
18
+ huggingface-cli delete-cache
19
+ huggingface-cli delete-cache --disable-tui
20
+ huggingface-cli delete-cache --dir ~/.cache/huggingface/hub
21
+
22
+ NOTE:
23
+ This command is based on `InquirerPy` to build the multiselect menu in the terminal.
24
+ This dependency has to be installed with `pip install huggingface_hub[cli]`. Since
25
+ we want to avoid as much as possible cross-platform issues, I chose a library that
26
+ is built on top of `python-prompt-toolkit` which seems to be a reference in terminal
27
+ GUI (actively maintained on both Unix and Windows, 7.9k stars).
28
+
29
+ For the moment, the TUI feature is in beta.
30
+
31
+ See:
32
+ - https://github.com/kazhala/InquirerPy
33
+ - https://inquirerpy.readthedocs.io/en/latest/
34
+ - https://github.com/prompt-toolkit/python-prompt-toolkit
35
+
36
+ Other solutions could have been:
37
+ - `simple_term_menu`: would be good as well for our use case but some issues suggest
38
+ that Windows is less supported.
39
+ See: https://github.com/IngoMeyer441/simple-term-menu
40
+ - `PyInquirer`: very similar to `InquirerPy` but older and not maintained anymore.
41
+ In particular, no support of Python3.10.
42
+ See: https://github.com/CITGuru/PyInquirer
43
+ - `pick` (or `pickpack`): easy to use and flexible but built on top of Python's
44
+ standard library `curses` that is specific to Unix (not implemented on Windows).
45
+ See https://github.com/wong2/pick and https://github.com/anafvana/pickpack.
46
+ - `inquirer`: lot of traction (700 stars) but explicitly states "experimental
47
+ support of Windows". Not built on top of `python-prompt-toolkit`.
48
+ See https://github.com/magmax/python-inquirer
49
+
50
+ TODO: add support for `huggingface-cli delete-cache aaaaaa bbbbbb cccccc (...)` ?
51
+ TODO: add "--keep-last" arg to delete revisions that are not on `main` ref
52
+ TODO: add "--filter" arg to filter repositories by name ?
53
+ TODO: add "--sort" arg to sort by size ?
54
+ TODO: add "--limit" arg to limit to X repos ?
55
+ TODO: add "-y" arg for immediate deletion ?
56
+ See discussions in https://github.com/huggingface/huggingface_hub/issues/1025.
57
+ """
58
+ import os
59
+ from argparse import Namespace, _SubParsersAction
60
+ from functools import wraps
61
+ from tempfile import mkstemp
62
+ from typing import Any, Callable, Iterable, List, Optional, Union
63
+
64
+ from ..utils import CachedRepoInfo, CachedRevisionInfo, HFCacheInfo, scan_cache_dir
65
+ from . import BaseHuggingfaceCLICommand
66
+ from ._cli_utils import ANSI
67
+
68
+
69
+ try:
70
+ from InquirerPy import inquirer
71
+ from InquirerPy.base.control import Choice
72
+ from InquirerPy.separator import Separator
73
+
74
+ _inquirer_py_available = True
75
+ except ImportError:
76
+ _inquirer_py_available = False
77
+
78
+
79
+ def require_inquirer_py(fn: Callable) -> Callable:
80
+ """Decorator to flag methods that require `InquirerPy`."""
81
+
82
+ # TODO: refactor this + imports in a unified pattern across codebase
83
+ @wraps(fn)
84
+ def _inner(*args, **kwargs):
85
+ if not _inquirer_py_available:
86
+ raise ImportError(
87
+ "The `delete-cache` command requires extra dependencies to work with"
88
+ " the TUI.\nPlease run `pip install huggingface_hub[cli]` to install"
89
+ " them.\nOtherwise, disable TUI using the `--disable-tui` flag."
90
+ )
91
+
92
+ return fn(*args, **kwargs)
93
+
94
+ return _inner
95
+
96
+
97
+ # Possibility for the user to cancel deletion
98
+ _CANCEL_DELETION_STR = "CANCEL_DELETION"
99
+
100
+
101
+ class DeleteCacheCommand(BaseHuggingfaceCLICommand):
102
+ @staticmethod
103
+ def register_subcommand(parser: _SubParsersAction):
104
+ delete_cache_parser = parser.add_parser("delete-cache", help="Delete revisions from the cache directory.")
105
+
106
+ delete_cache_parser.add_argument(
107
+ "--dir",
108
+ type=str,
109
+ default=None,
110
+ help="cache directory (optional). Default to the default HuggingFace cache.",
111
+ )
112
+
113
+ delete_cache_parser.add_argument(
114
+ "--disable-tui",
115
+ action="store_true",
116
+ help=(
117
+ "Disable Terminal User Interface (TUI) mode. Useful if your"
118
+ " platform/terminal doesn't support the multiselect menu."
119
+ ),
120
+ )
121
+
122
+ delete_cache_parser.set_defaults(func=DeleteCacheCommand)
123
+
124
+ def __init__(self, args: Namespace) -> None:
125
+ self.cache_dir: Optional[str] = args.dir
126
+ self.disable_tui: bool = args.disable_tui
127
+
128
+ def run(self):
129
+ """Run `delete-cache` command with or without TUI."""
130
+ # Scan cache directory
131
+ hf_cache_info = scan_cache_dir(self.cache_dir)
132
+
133
+ # Manual review from the user
134
+ if self.disable_tui:
135
+ selected_hashes = _manual_review_no_tui(hf_cache_info, preselected=[])
136
+ else:
137
+ selected_hashes = _manual_review_tui(hf_cache_info, preselected=[])
138
+
139
+ # If deletion is not cancelled
140
+ if len(selected_hashes) > 0 and _CANCEL_DELETION_STR not in selected_hashes:
141
+ confirm_message = _get_expectations_str(hf_cache_info, selected_hashes) + " Confirm deletion ?"
142
+
143
+ # Confirm deletion
144
+ if self.disable_tui:
145
+ confirmed = _ask_for_confirmation_no_tui(confirm_message)
146
+ else:
147
+ confirmed = _ask_for_confirmation_tui(confirm_message)
148
+
149
+ # Deletion is confirmed
150
+ if confirmed:
151
+ strategy = hf_cache_info.delete_revisions(*selected_hashes)
152
+ print("Start deletion.")
153
+ strategy.execute()
154
+ print(
155
+ f"Done. Deleted {len(strategy.repos)} repo(s) and"
156
+ f" {len(strategy.snapshots)} revision(s) for a total of"
157
+ f" {strategy.expected_freed_size_str}."
158
+ )
159
+ return
160
+
161
+ # Deletion is cancelled
162
+ print("Deletion is cancelled. Do nothing.")
163
+
164
+
165
+ @require_inquirer_py
166
+ def _manual_review_tui(hf_cache_info: HFCacheInfo, preselected: List[str]) -> List[str]:
167
+ """Ask the user for a manual review of the revisions to delete.
168
+
169
+ Displays a multi-select menu in the terminal (TUI).
170
+ """
171
+ # Define multiselect list
172
+ choices = _get_tui_choices_from_scan(repos=hf_cache_info.repos, preselected=preselected)
173
+ checkbox = inquirer.checkbox(
174
+ message="Select revisions to delete:",
175
+ choices=choices, # List of revisions with some pre-selection
176
+ cycle=False, # No loop between top and bottom
177
+ height=100, # Large list if possible
178
+ # We use the instruction to display to the user the expected effect of the
179
+ # deletion.
180
+ instruction=_get_expectations_str(
181
+ hf_cache_info,
182
+ selected_hashes=[c.value for c in choices if isinstance(c, Choice) and c.enabled],
183
+ ),
184
+ # We use the long instruction to should keybindings instructions to the user
185
+ long_instruction="Press <space> to select, <enter> to validate and <ctrl+c> to quit without modification.",
186
+ # Message that is displayed once the user validates its selection.
187
+ transformer=lambda result: f"{len(result)} revision(s) selected.",
188
+ )
189
+
190
+ # Add a callback to update the information line when a revision is
191
+ # selected/unselected
192
+ def _update_expectations(_) -> None:
193
+ # Hacky way to dynamically set an instruction message to the checkbox when
194
+ # a revision hash is selected/unselected.
195
+ checkbox._instruction = _get_expectations_str(
196
+ hf_cache_info,
197
+ selected_hashes=[choice["value"] for choice in checkbox.content_control.choices if choice["enabled"]],
198
+ )
199
+
200
+ checkbox.kb_func_lookup["toggle"].append({"func": _update_expectations})
201
+
202
+ # Finally display the form to the user.
203
+ try:
204
+ return checkbox.execute()
205
+ except KeyboardInterrupt:
206
+ return [] # Quit without deletion
207
+
208
+
209
+ @require_inquirer_py
210
+ def _ask_for_confirmation_tui(message: str, default: bool = True) -> bool:
211
+ """Ask for confirmation using Inquirer."""
212
+ return inquirer.confirm(message, default=default).execute()
213
+
214
+
215
+ def _get_tui_choices_from_scan(repos: Iterable[CachedRepoInfo], preselected: List[str]) -> List:
216
+ """Build a list of choices from the scanned repos.
217
+
218
+ Args:
219
+ repos (*Iterable[`CachedRepoInfo`]*):
220
+ List of scanned repos on which we want to delete revisions.
221
+ preselected (*List[`str`]*):
222
+ List of revision hashes that will be preselected.
223
+
224
+ Return:
225
+ The list of choices to pass to `inquirer.checkbox`.
226
+ """
227
+ choices: List[Union[Choice, Separator]] = []
228
+
229
+ # First choice is to cancel the deletion. If selected, nothing will be deleted,
230
+ # no matter the other selected items.
231
+ choices.append(
232
+ Choice(
233
+ _CANCEL_DELETION_STR,
234
+ name="None of the following (if selected, nothing will be deleted).",
235
+ enabled=False,
236
+ )
237
+ )
238
+
239
+ # Display a separator per repo and a Choice for each revisions of the repo
240
+ for repo in sorted(repos, key=_repo_sorting_order):
241
+ # Repo as separator
242
+ choices.append(
243
+ Separator(
244
+ f"\n{repo.repo_type.capitalize()} {repo.repo_id} ({repo.size_on_disk_str},"
245
+ f" used {repo.last_accessed_str})"
246
+ )
247
+ )
248
+ for revision in sorted(repo.revisions, key=_revision_sorting_order):
249
+ # Revision as choice
250
+ choices.append(
251
+ Choice(
252
+ revision.commit_hash,
253
+ name=(
254
+ f"{revision.commit_hash[:8]}:"
255
+ f" {', '.join(sorted(revision.refs)) or '(detached)'} #"
256
+ f" modified {revision.last_modified_str}"
257
+ ),
258
+ enabled=revision.commit_hash in preselected,
259
+ )
260
+ )
261
+
262
+ # Return choices
263
+ return choices
264
+
265
+
266
+ def _manual_review_no_tui(hf_cache_info: HFCacheInfo, preselected: List[str]) -> List[str]:
267
+ """Ask the user for a manual review of the revisions to delete.
268
+
269
+ Used when TUI is disabled. Manual review happens in a separate tmp file that the
270
+ user can manually edit.
271
+ """
272
+ # 1. Generate temporary file with delete commands.
273
+ fd, tmp_path = mkstemp(suffix=".txt") # suffix to make it easier to find by editors
274
+ os.close(fd)
275
+
276
+ lines = []
277
+ for repo in sorted(hf_cache_info.repos, key=_repo_sorting_order):
278
+ lines.append(
279
+ f"\n# {repo.repo_type.capitalize()} {repo.repo_id} ({repo.size_on_disk_str},"
280
+ f" used {repo.last_accessed_str})"
281
+ )
282
+ for revision in sorted(repo.revisions, key=_revision_sorting_order):
283
+ lines.append(
284
+ # Deselect by prepending a '#'
285
+ f"{'' if revision.commit_hash in preselected else '#'} "
286
+ f" {revision.commit_hash} # Refs:"
287
+ # Print `refs` as comment on same line
288
+ f" {', '.join(sorted(revision.refs)) or '(detached)'} # modified"
289
+ # Print `last_modified` as comment on same line
290
+ f" {revision.last_modified_str}"
291
+ )
292
+
293
+ with open(tmp_path, "w") as f:
294
+ f.write(_MANUAL_REVIEW_NO_TUI_INSTRUCTIONS)
295
+ f.write("\n".join(lines))
296
+
297
+ # 2. Prompt instructions to user.
298
+ instructions = f"""
299
+ TUI is disabled. In order to select which revisions you want to delete, please edit
300
+ the following file using the text editor of your choice. Instructions for manual
301
+ editing are located at the beginning of the file. Edit the file, save it and confirm
302
+ to continue.
303
+ File to edit: {ANSI.bold(tmp_path)}
304
+ """
305
+ print("\n".join(line.strip() for line in instructions.strip().split("\n")))
306
+
307
+ # 3. Wait for user confirmation.
308
+ while True:
309
+ selected_hashes = _read_manual_review_tmp_file(tmp_path)
310
+ if _ask_for_confirmation_no_tui(
311
+ _get_expectations_str(hf_cache_info, selected_hashes) + " Continue ?",
312
+ default=False,
313
+ ):
314
+ break
315
+
316
+ # 4. Return selected_hashes
317
+ os.remove(tmp_path)
318
+ return selected_hashes
319
+
320
+
321
+ def _ask_for_confirmation_no_tui(message: str, default: bool = True) -> bool:
322
+ """Ask for confirmation using pure-python."""
323
+ YES = ("y", "yes", "1")
324
+ NO = ("n", "no", "0")
325
+ DEFAULT = ""
326
+ ALL = YES + NO + (DEFAULT,)
327
+ full_message = message + (" (Y/n) " if default else " (y/N) ")
328
+ while True:
329
+ answer = input(full_message).lower()
330
+ if answer == DEFAULT:
331
+ return default
332
+ if answer in YES:
333
+ return True
334
+ if answer in NO:
335
+ return False
336
+ print(f"Invalid input. Must be one of {ALL}")
337
+
338
+
339
+ def _get_expectations_str(hf_cache_info: HFCacheInfo, selected_hashes: List[str]) -> str:
340
+ """Format a string to display to the user how much space would be saved.
341
+
342
+ Example:
343
+ ```
344
+ >>> _get_expectations_str(hf_cache_info, selected_hashes)
345
+ '7 revisions selected counting for 4.3G.'
346
+ ```
347
+ """
348
+ if _CANCEL_DELETION_STR in selected_hashes:
349
+ return "Nothing will be deleted."
350
+ strategy = hf_cache_info.delete_revisions(*selected_hashes)
351
+ return f"{len(selected_hashes)} revisions selected counting for {strategy.expected_freed_size_str}."
352
+
353
+
354
+ def _read_manual_review_tmp_file(tmp_path: str) -> List[str]:
355
+ """Read the manually reviewed instruction file and return a list of revision hash.
356
+
357
+ Example:
358
+ ```txt
359
+ # This is the tmp file content
360
+ ###
361
+
362
+ # Commented out line
363
+ 123456789 # revision hash
364
+
365
+ # Something else
366
+ # a_newer_hash # 2 days ago
367
+ an_older_hash # 3 days ago
368
+ ```
369
+
370
+ ```py
371
+ >>> _read_manual_review_tmp_file(tmp_path)
372
+ ['123456789', 'an_older_hash']
373
+ ```
374
+ """
375
+ with open(tmp_path) as f:
376
+ content = f.read()
377
+
378
+ # Split lines
379
+ lines = [line.strip() for line in content.split("\n")]
380
+
381
+ # Filter commented lines
382
+ selected_lines = [line for line in lines if not line.startswith("#")]
383
+
384
+ # Select only before comment
385
+ selected_hashes = [line.split("#")[0].strip() for line in selected_lines]
386
+
387
+ # Return revision hashes
388
+ return [hash for hash in selected_hashes if len(hash) > 0]
389
+
390
+
391
+ _MANUAL_REVIEW_NO_TUI_INSTRUCTIONS = f"""
392
+ # INSTRUCTIONS
393
+ # ------------
394
+ # This is a temporary file created by running `huggingface-cli delete-cache` with the
395
+ # `--disable-tui` option. It contains a set of revisions that can be deleted from your
396
+ # local cache directory.
397
+ #
398
+ # Please manually review the revisions you want to delete:
399
+ # - Revision hashes can be commented out with '#'.
400
+ # - Only non-commented revisions in this file will be deleted.
401
+ # - Revision hashes that are removed from this file are ignored as well.
402
+ # - If `{_CANCEL_DELETION_STR}` line is uncommented, the all cache deletion is cancelled and
403
+ # no changes will be applied.
404
+ #
405
+ # Once you've manually reviewed this file, please confirm deletion in the terminal. This
406
+ # file will be automatically removed once done.
407
+ # ------------
408
+
409
+ # KILL SWITCH
410
+ # ------------
411
+ # Un-comment following line to completely cancel the deletion process
412
+ # {_CANCEL_DELETION_STR}
413
+ # ------------
414
+
415
+ # REVISIONS
416
+ # ------------
417
+ """.strip()
418
+
419
+
420
+ def _repo_sorting_order(repo: CachedRepoInfo) -> Any:
421
+ # First split by Dataset/Model, then sort by last accessed (oldest first)
422
+ return (repo.repo_type, repo.last_accessed)
423
+
424
+
425
+ def _revision_sorting_order(revision: CachedRevisionInfo) -> Any:
426
+ # Sort by last modified (oldest first)
427
+ return revision.last_modified
lib/python3.11/site-packages/huggingface_hub/commands/download.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Contains command to download files from the Hub with the CLI.
16
+
17
+ Usage:
18
+ huggingface-cli download --help
19
+
20
+ # Download file
21
+ huggingface-cli download gpt2 config.json
22
+
23
+ # Download entire repo
24
+ huggingface-cli download fffiloni/zeroscope --repo-type=space --revision=refs/pr/78
25
+
26
+ # Download repo with filters
27
+ huggingface-cli download gpt2 --include="*.safetensors"
28
+
29
+ # Download with token
30
+ huggingface-cli download Wauplin/private-model --token=hf_***
31
+
32
+ # Download quietly (no progress bar, no warnings, only the returned path)
33
+ huggingface-cli download gpt2 config.json --quiet
34
+
35
+ # Download to local dir
36
+ huggingface-cli download gpt2 --local-dir=./models/gpt2
37
+ """
38
+ import warnings
39
+ from argparse import Namespace, _SubParsersAction
40
+ from typing import List, Literal, Optional, Union
41
+
42
+ from huggingface_hub import logging
43
+ from huggingface_hub._snapshot_download import snapshot_download
44
+ from huggingface_hub.commands import BaseHuggingfaceCLICommand
45
+ from huggingface_hub.constants import HF_HUB_ENABLE_HF_TRANSFER
46
+ from huggingface_hub.file_download import hf_hub_download
47
+ from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+
53
+ class DownloadCommand(BaseHuggingfaceCLICommand):
54
+ @staticmethod
55
+ def register_subcommand(parser: _SubParsersAction):
56
+ download_parser = parser.add_parser("download", help="Download files from the Hub")
57
+ download_parser.add_argument(
58
+ "repo_id", type=str, help="ID of the repo to download from (e.g. `username/repo-name`)."
59
+ )
60
+ download_parser.add_argument(
61
+ "filenames", type=str, nargs="*", help="Files to download (e.g. `config.json`, `data/metadata.jsonl`)."
62
+ )
63
+ download_parser.add_argument(
64
+ "--repo-type",
65
+ choices=["model", "dataset", "space"],
66
+ default="model",
67
+ help="Type of repo to download from (e.g. `dataset`).",
68
+ )
69
+ download_parser.add_argument(
70
+ "--revision",
71
+ type=str,
72
+ help="An optional Git revision id which can be a branch name, a tag, or a commit hash.",
73
+ )
74
+ download_parser.add_argument(
75
+ "--include", nargs="*", type=str, help="Glob patterns to match files to download."
76
+ )
77
+ download_parser.add_argument(
78
+ "--exclude", nargs="*", type=str, help="Glob patterns to exclude from files to download."
79
+ )
80
+ download_parser.add_argument(
81
+ "--cache-dir", type=str, help="Path to the directory where to save the downloaded files."
82
+ )
83
+ download_parser.add_argument(
84
+ "--local-dir",
85
+ type=str,
86
+ help=(
87
+ "If set, the downloaded file will be placed under this directory either as a symlink (default) or a"
88
+ " regular file. Check out"
89
+ " https://huggingface.co/docs/huggingface_hub/guides/download#download-files-to-local-folder for more"
90
+ " details."
91
+ ),
92
+ )
93
+ download_parser.add_argument(
94
+ "--local-dir-use-symlinks",
95
+ choices=["auto", "True", "False"],
96
+ default="auto",
97
+ help=(
98
+ "To be used with `local_dir`. If set to 'auto', the cache directory will be used and the file will be"
99
+ " either duplicated or symlinked to the local directory depending on its size. It set to `True`, a"
100
+ " symlink will be created, no matter the file size. If set to `False`, the file will either be"
101
+ " duplicated from cache (if already exists) or downloaded from the Hub and not cached."
102
+ ),
103
+ )
104
+ download_parser.add_argument(
105
+ "--force-download",
106
+ action="store_true",
107
+ help="If True, the files will be downloaded even if they are already cached.",
108
+ )
109
+ download_parser.add_argument(
110
+ "--resume-download", action="store_true", help="If True, resume a previously interrupted download."
111
+ )
112
+ download_parser.add_argument(
113
+ "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens"
114
+ )
115
+ download_parser.add_argument(
116
+ "--quiet",
117
+ action="store_true",
118
+ help="If True, progress bars are disabled and only the path to the download files is printed.",
119
+ )
120
+ download_parser.set_defaults(func=DownloadCommand)
121
+
122
+ def __init__(self, args: Namespace) -> None:
123
+ self.token = args.token
124
+ self.repo_id: str = args.repo_id
125
+ self.filenames: List[str] = args.filenames
126
+ self.repo_type: str = args.repo_type
127
+ self.revision: Optional[str] = args.revision
128
+ self.include: Optional[List[str]] = args.include
129
+ self.exclude: Optional[List[str]] = args.exclude
130
+ self.cache_dir: Optional[str] = args.cache_dir
131
+ self.local_dir: Optional[str] = args.local_dir
132
+ self.force_download: bool = args.force_download
133
+ self.resume_download: bool = args.resume_download
134
+ self.quiet: bool = args.quiet
135
+
136
+ # Raise if local_dir_use_symlinks is invalid
137
+ self.local_dir_use_symlinks: Union[Literal["auto"], bool]
138
+ use_symlinks_lowercase = args.local_dir_use_symlinks.lower()
139
+ if use_symlinks_lowercase == "true":
140
+ self.local_dir_use_symlinks = True
141
+ elif use_symlinks_lowercase == "false":
142
+ self.local_dir_use_symlinks = False
143
+ elif use_symlinks_lowercase == "auto":
144
+ self.local_dir_use_symlinks = "auto"
145
+ else:
146
+ raise ValueError(
147
+ f"'{args.local_dir_use_symlinks}' is not a valid value for `local_dir_use_symlinks`. It must be either"
148
+ " 'auto', 'True' or 'False'."
149
+ )
150
+
151
+ def run(self) -> None:
152
+ if self.quiet:
153
+ disable_progress_bars()
154
+ with warnings.catch_warnings():
155
+ warnings.simplefilter("ignore")
156
+ print(self._download()) # Print path to downloaded files
157
+ enable_progress_bars()
158
+ else:
159
+ logging.set_verbosity_info()
160
+ print(self._download()) # Print path to downloaded files
161
+ logging.set_verbosity_warning()
162
+
163
+ def _download(self) -> str:
164
+ # Warn user if patterns are ignored
165
+ if len(self.filenames) > 0:
166
+ if self.include is not None and len(self.include) > 0:
167
+ warnings.warn("Ignoring `--include` since filenames have being explicitly set.")
168
+ if self.exclude is not None and len(self.exclude) > 0:
169
+ warnings.warn("Ignoring `--exclude` since filenames have being explicitly set.")
170
+
171
+ if not HF_HUB_ENABLE_HF_TRANSFER:
172
+ logger.info(
173
+ "Consider using `hf_transfer` for faster downloads. This solution comes with some limitations. See"
174
+ " https://huggingface.co/docs/huggingface_hub/hf_transfer for more details."
175
+ )
176
+
177
+ # Single file to download: use `hf_hub_download`
178
+ if len(self.filenames) == 1:
179
+ return hf_hub_download(
180
+ repo_id=self.repo_id,
181
+ repo_type=self.repo_type,
182
+ revision=self.revision,
183
+ filename=self.filenames[0],
184
+ cache_dir=self.cache_dir,
185
+ resume_download=self.resume_download,
186
+ force_download=self.force_download,
187
+ token=self.token,
188
+ local_dir=self.local_dir,
189
+ local_dir_use_symlinks=self.local_dir_use_symlinks,
190
+ library_name="huggingface-cli",
191
+ )
192
+
193
+ # Otherwise: use `snapshot_download` to ensure all files comes from same revision
194
+ elif len(self.filenames) == 0:
195
+ allow_patterns = self.include
196
+ ignore_patterns = self.exclude
197
+ else:
198
+ allow_patterns = self.filenames
199
+ ignore_patterns = None
200
+
201
+ return snapshot_download(
202
+ repo_id=self.repo_id,
203
+ repo_type=self.repo_type,
204
+ revision=self.revision,
205
+ allow_patterns=allow_patterns,
206
+ ignore_patterns=ignore_patterns,
207
+ resume_download=self.resume_download,
208
+ force_download=self.force_download,
209
+ cache_dir=self.cache_dir,
210
+ token=self.token,
211
+ local_dir=self.local_dir,
212
+ local_dir_use_symlinks=self.local_dir_use_symlinks,
213
+ library_name="huggingface-cli",
214
+ )
lib/python3.11/site-packages/huggingface_hub/commands/env.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Contains command to print information about the environment.
15
+
16
+ Usage:
17
+ huggingface-cli env
18
+ """
19
+ from argparse import _SubParsersAction
20
+
21
+ from ..utils import dump_environment_info
22
+ from . import BaseHuggingfaceCLICommand
23
+
24
+
25
+ class EnvironmentCommand(BaseHuggingfaceCLICommand):
26
+ def __init__(self, args):
27
+ self.args = args
28
+
29
+ @staticmethod
30
+ def register_subcommand(parser: _SubParsersAction):
31
+ env_parser = parser.add_parser("env", help="Print information about the environment.")
32
+ env_parser.set_defaults(func=EnvironmentCommand)
33
+
34
+ def run(self) -> None:
35
+ dump_environment_info()
lib/python3.11/site-packages/huggingface_hub/commands/huggingface_cli.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from huggingface_hub.commands.delete_cache import DeleteCacheCommand
19
+ from huggingface_hub.commands.download import DownloadCommand
20
+ from huggingface_hub.commands.env import EnvironmentCommand
21
+ from huggingface_hub.commands.lfs import LfsCommands
22
+ from huggingface_hub.commands.scan_cache import ScanCacheCommand
23
+ from huggingface_hub.commands.upload import UploadCommand
24
+ from huggingface_hub.commands.user import UserCommands
25
+
26
+
27
+ def main():
28
+ parser = ArgumentParser("huggingface-cli", usage="huggingface-cli <command> [<args>]")
29
+ commands_parser = parser.add_subparsers(help="huggingface-cli command helpers")
30
+
31
+ # Register commands
32
+ EnvironmentCommand.register_subcommand(commands_parser)
33
+ UserCommands.register_subcommand(commands_parser)
34
+ UploadCommand.register_subcommand(commands_parser)
35
+ DownloadCommand.register_subcommand(commands_parser)
36
+ LfsCommands.register_subcommand(commands_parser)
37
+ ScanCacheCommand.register_subcommand(commands_parser)
38
+ DeleteCacheCommand.register_subcommand(commands_parser)
39
+
40
+ # Let's go
41
+ args = parser.parse_args()
42
+
43
+ if not hasattr(args, "func"):
44
+ parser.print_help()
45
+ exit(1)
46
+
47
+ # Run
48
+ service = args.func(args)
49
+ service.run()
50
+
51
+
52
+ if __name__ == "__main__":
53
+ main()
lib/python3.11/site-packages/huggingface_hub/commands/lfs.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of a custom transfer agent for the transfer type "multipart" for
3
+ git-lfs.
4
+
5
+ Inspired by:
6
+ github.com/cbartz/git-lfs-swift-transfer-agent/blob/master/git_lfs_swift_transfer.py
7
+
8
+ Spec is: github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md
9
+
10
+
11
+ To launch debugger while developing:
12
+
13
+ ``` [lfs "customtransfer.multipart"]
14
+ path = /path/to/huggingface_hub/.env/bin/python args = -m debugpy --listen 5678
15
+ --wait-for-client
16
+ /path/to/huggingface_hub/src/huggingface_hub/commands/huggingface_cli.py
17
+ lfs-multipart-upload ```"""
18
+
19
+ import json
20
+ import os
21
+ import subprocess
22
+ import sys
23
+ from argparse import _SubParsersAction
24
+ from typing import Dict, List, Optional
25
+
26
+ from huggingface_hub.commands import BaseHuggingfaceCLICommand
27
+ from huggingface_hub.lfs import LFS_MULTIPART_UPLOAD_COMMAND, SliceFileObj
28
+
29
+ from ..utils import get_session, hf_raise_for_status, logging
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class LfsCommands(BaseHuggingfaceCLICommand):
36
+ """
37
+ Implementation of a custom transfer agent for the transfer type "multipart"
38
+ for git-lfs. This lets users upload large files >5GB 🔥. Spec for LFS custom
39
+ transfer agent is:
40
+ https://github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md
41
+
42
+ This introduces two commands to the CLI:
43
+
44
+ 1. $ huggingface-cli lfs-enable-largefiles
45
+
46
+ This should be executed once for each model repo that contains a model file
47
+ >5GB. It's documented in the error message you get if you just try to git
48
+ push a 5GB file without having enabled it before.
49
+
50
+ 2. $ huggingface-cli lfs-multipart-upload
51
+
52
+ This command is called by lfs directly and is not meant to be called by the
53
+ user.
54
+ """
55
+
56
+ @staticmethod
57
+ def register_subcommand(parser: _SubParsersAction):
58
+ enable_parser = parser.add_parser(
59
+ "lfs-enable-largefiles", help="Configure your repository to enable upload of files > 5GB."
60
+ )
61
+ enable_parser.add_argument("path", type=str, help="Local path to repository you want to configure.")
62
+ enable_parser.set_defaults(func=lambda args: LfsEnableCommand(args))
63
+
64
+ # Command will get called by git-lfs, do not call it directly.
65
+ upload_parser = parser.add_parser(LFS_MULTIPART_UPLOAD_COMMAND, add_help=False)
66
+ upload_parser.set_defaults(func=lambda args: LfsUploadCommand(args))
67
+
68
+
69
+ class LfsEnableCommand:
70
+ def __init__(self, args):
71
+ self.args = args
72
+
73
+ def run(self):
74
+ local_path = os.path.abspath(self.args.path)
75
+ if not os.path.isdir(local_path):
76
+ print("This does not look like a valid git repo.")
77
+ exit(1)
78
+ subprocess.run(
79
+ "git config lfs.customtransfer.multipart.path huggingface-cli".split(),
80
+ check=True,
81
+ cwd=local_path,
82
+ )
83
+ subprocess.run(
84
+ f"git config lfs.customtransfer.multipart.args {LFS_MULTIPART_UPLOAD_COMMAND}".split(),
85
+ check=True,
86
+ cwd=local_path,
87
+ )
88
+ print("Local repo set up for largefiles")
89
+
90
+
91
+ def write_msg(msg: Dict):
92
+ """Write out the message in Line delimited JSON."""
93
+ msg_str = json.dumps(msg) + "\n"
94
+ sys.stdout.write(msg_str)
95
+ sys.stdout.flush()
96
+
97
+
98
+ def read_msg() -> Optional[Dict]:
99
+ """Read Line delimited JSON from stdin."""
100
+ msg = json.loads(sys.stdin.readline().strip())
101
+
102
+ if "terminate" in (msg.get("type"), msg.get("event")):
103
+ # terminate message received
104
+ return None
105
+
106
+ if msg.get("event") not in ("download", "upload"):
107
+ logger.critical("Received unexpected message")
108
+ sys.exit(1)
109
+
110
+ return msg
111
+
112
+
113
+ class LfsUploadCommand:
114
+ def __init__(self, args) -> None:
115
+ self.args = args
116
+
117
+ def run(self) -> None:
118
+ # Immediately after invoking a custom transfer process, git-lfs
119
+ # sends initiation data to the process over stdin.
120
+ # This tells the process useful information about the configuration.
121
+ init_msg = json.loads(sys.stdin.readline().strip())
122
+ if not (init_msg.get("event") == "init" and init_msg.get("operation") == "upload"):
123
+ write_msg({"error": {"code": 32, "message": "Wrong lfs init operation"}})
124
+ sys.exit(1)
125
+
126
+ # The transfer process should use the information it needs from the
127
+ # initiation structure, and also perform any one-off setup tasks it
128
+ # needs to do. It should then respond on stdout with a simple empty
129
+ # confirmation structure, as follows:
130
+ write_msg({})
131
+
132
+ # After the initiation exchange, git-lfs will send any number of
133
+ # transfer requests to the stdin of the transfer process, in a serial sequence.
134
+ while True:
135
+ msg = read_msg()
136
+ if msg is None:
137
+ # When all transfers have been processed, git-lfs will send
138
+ # a terminate event to the stdin of the transfer process.
139
+ # On receiving this message the transfer process should
140
+ # clean up and terminate. No response is expected.
141
+ sys.exit(0)
142
+
143
+ oid = msg["oid"]
144
+ filepath = msg["path"]
145
+ completion_url = msg["action"]["href"]
146
+ header = msg["action"]["header"]
147
+ chunk_size = int(header.pop("chunk_size"))
148
+ presigned_urls: List[str] = list(header.values())
149
+
150
+ # Send a "started" progress event to allow other workers to start.
151
+ # Otherwise they're delayed until first "progress" event is reported,
152
+ # i.e. after the first 5GB by default (!)
153
+ write_msg(
154
+ {
155
+ "event": "progress",
156
+ "oid": oid,
157
+ "bytesSoFar": 1,
158
+ "bytesSinceLast": 0,
159
+ }
160
+ )
161
+
162
+ parts = []
163
+ with open(filepath, "rb") as file:
164
+ for i, presigned_url in enumerate(presigned_urls):
165
+ with SliceFileObj(
166
+ file,
167
+ seek_from=i * chunk_size,
168
+ read_limit=chunk_size,
169
+ ) as data:
170
+ r = get_session().put(presigned_url, data=data)
171
+ hf_raise_for_status(r)
172
+ parts.append(
173
+ {
174
+ "etag": r.headers.get("etag"),
175
+ "partNumber": i + 1,
176
+ }
177
+ )
178
+ # In order to support progress reporting while data is uploading / downloading,
179
+ # the transfer process should post messages to stdout
180
+ write_msg(
181
+ {
182
+ "event": "progress",
183
+ "oid": oid,
184
+ "bytesSoFar": (i + 1) * chunk_size,
185
+ "bytesSinceLast": chunk_size,
186
+ }
187
+ )
188
+ # Not precise but that's ok.
189
+
190
+ r = get_session().post(
191
+ completion_url,
192
+ json={
193
+ "oid": oid,
194
+ "parts": parts,
195
+ },
196
+ )
197
+ hf_raise_for_status(r)
198
+
199
+ write_msg({"event": "complete", "oid": oid})
lib/python3.11/site-packages/huggingface_hub/commands/scan_cache.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Contains command to scan the HF cache directory.
16
+
17
+ Usage:
18
+ huggingface-cli scan-cache
19
+ huggingface-cli scan-cache -v
20
+ huggingface-cli scan-cache -vvv
21
+ huggingface-cli scan-cache --dir ~/.cache/huggingface/hub
22
+ """
23
+ import time
24
+ from argparse import Namespace, _SubParsersAction
25
+ from typing import Optional
26
+
27
+ from ..utils import CacheNotFound, HFCacheInfo, scan_cache_dir
28
+ from . import BaseHuggingfaceCLICommand
29
+ from ._cli_utils import ANSI, tabulate
30
+
31
+
32
+ class ScanCacheCommand(BaseHuggingfaceCLICommand):
33
+ @staticmethod
34
+ def register_subcommand(parser: _SubParsersAction):
35
+ scan_cache_parser = parser.add_parser("scan-cache", help="Scan cache directory.")
36
+
37
+ scan_cache_parser.add_argument(
38
+ "--dir",
39
+ type=str,
40
+ default=None,
41
+ help="cache directory to scan (optional). Default to the default HuggingFace cache.",
42
+ )
43
+ scan_cache_parser.add_argument(
44
+ "-v",
45
+ "--verbose",
46
+ action="count",
47
+ default=0,
48
+ help="show a more verbose output",
49
+ )
50
+ scan_cache_parser.set_defaults(func=ScanCacheCommand)
51
+
52
+ def __init__(self, args: Namespace) -> None:
53
+ self.verbosity: int = args.verbose
54
+ self.cache_dir: Optional[str] = args.dir
55
+
56
+ def run(self):
57
+ try:
58
+ t0 = time.time()
59
+ hf_cache_info = scan_cache_dir(self.cache_dir)
60
+ t1 = time.time()
61
+ except CacheNotFound as exc:
62
+ cache_dir = exc.cache_dir
63
+ print(f"Cache directory not found: {cache_dir}")
64
+ return
65
+
66
+ self._print_hf_cache_info_as_table(hf_cache_info)
67
+
68
+ print(
69
+ f"\nDone in {round(t1-t0,1)}s. Scanned {len(hf_cache_info.repos)} repo(s)"
70
+ f" for a total of {ANSI.red(hf_cache_info.size_on_disk_str)}."
71
+ )
72
+ if len(hf_cache_info.warnings) > 0:
73
+ message = f"Got {len(hf_cache_info.warnings)} warning(s) while scanning."
74
+ if self.verbosity >= 3:
75
+ print(ANSI.gray(message))
76
+ for warning in hf_cache_info.warnings:
77
+ print(ANSI.gray(warning))
78
+ else:
79
+ print(ANSI.gray(message + " Use -vvv to print details."))
80
+
81
+ def _print_hf_cache_info_as_table(self, hf_cache_info: HFCacheInfo) -> None:
82
+ if self.verbosity == 0:
83
+ print(
84
+ tabulate(
85
+ rows=[
86
+ [
87
+ repo.repo_id,
88
+ repo.repo_type,
89
+ "{:>12}".format(repo.size_on_disk_str),
90
+ repo.nb_files,
91
+ repo.last_accessed_str,
92
+ repo.last_modified_str,
93
+ ", ".join(sorted(repo.refs)),
94
+ str(repo.repo_path),
95
+ ]
96
+ for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path)
97
+ ],
98
+ headers=[
99
+ "REPO ID",
100
+ "REPO TYPE",
101
+ "SIZE ON DISK",
102
+ "NB FILES",
103
+ "LAST_ACCESSED",
104
+ "LAST_MODIFIED",
105
+ "REFS",
106
+ "LOCAL PATH",
107
+ ],
108
+ )
109
+ )
110
+ else:
111
+ print(
112
+ tabulate(
113
+ rows=[
114
+ [
115
+ repo.repo_id,
116
+ repo.repo_type,
117
+ revision.commit_hash,
118
+ "{:>12}".format(revision.size_on_disk_str),
119
+ revision.nb_files,
120
+ revision.last_modified_str,
121
+ ", ".join(sorted(revision.refs)),
122
+ str(revision.snapshot_path),
123
+ ]
124
+ for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path)
125
+ for revision in sorted(repo.revisions, key=lambda revision: revision.commit_hash)
126
+ ],
127
+ headers=[
128
+ "REPO ID",
129
+ "REPO TYPE",
130
+ "REVISION",
131
+ "SIZE ON DISK",
132
+ "NB FILES",
133
+ "LAST_MODIFIED",
134
+ "REFS",
135
+ "LOCAL PATH",
136
+ ],
137
+ )
138
+ )
lib/python3.11/site-packages/huggingface_hub/commands/upload.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Contains command to upload a repo or file with the CLI.
16
+
17
+ Usage:
18
+ # Upload file (implicit)
19
+ huggingface-cli upload my-cool-model ./my-cool-model.safetensors
20
+
21
+ # Upload file (explicit)
22
+ huggingface-cli upload my-cool-model ./my-cool-model.safetensors model.safetensors
23
+
24
+ # Upload directory (implicit). If `my-cool-model/` is a directory it will be uploaded, otherwise an exception is raised.
25
+ huggingface-cli upload my-cool-model
26
+
27
+ # Upload directory (explicit)
28
+ huggingface-cli upload my-cool-model ./models/my-cool-model .
29
+
30
+ # Upload filtered directory (example: tensorboard logs except for the last run)
31
+ huggingface-cli upload my-cool-model ./model/training /logs --include "*.tfevents.*" --exclude "*20230905*"
32
+
33
+ # Upload private dataset
34
+ huggingface-cli upload Wauplin/my-cool-dataset ./data . --repo-type=dataset --private
35
+
36
+ # Upload with token
37
+ huggingface-cli upload Wauplin/my-cool-model --token=hf_****
38
+
39
+ # Sync local Space with Hub (upload new files, delete removed files)
40
+ huggingface-cli upload Wauplin/space-example --repo-type=space --exclude="/logs/*" --delete="*" --commit-message="Sync local Space with Hub"
41
+
42
+ # Schedule commits every 30 minutes
43
+ huggingface-cli upload Wauplin/my-cool-model --every=30
44
+ """
45
+ import os
46
+ import time
47
+ import warnings
48
+ from argparse import Namespace, _SubParsersAction
49
+ from typing import List, Optional
50
+
51
+ from huggingface_hub import logging
52
+ from huggingface_hub._commit_scheduler import CommitScheduler
53
+ from huggingface_hub.commands import BaseHuggingfaceCLICommand
54
+ from huggingface_hub.constants import HF_HUB_ENABLE_HF_TRANSFER
55
+ from huggingface_hub.hf_api import HfApi
56
+ from huggingface_hub.utils import RevisionNotFoundError, disable_progress_bars, enable_progress_bars
57
+
58
+
59
+ logger = logging.get_logger(__name__)
60
+
61
+
62
+ class UploadCommand(BaseHuggingfaceCLICommand):
63
+ @staticmethod
64
+ def register_subcommand(parser: _SubParsersAction):
65
+ upload_parser = parser.add_parser("upload", help="Upload a file or a folder to a repo on the Hub")
66
+ upload_parser.add_argument(
67
+ "repo_id", type=str, help="The ID of the repo to upload to (e.g. `username/repo-name`)."
68
+ )
69
+ upload_parser.add_argument(
70
+ "local_path", nargs="?", help="Local path to the file or folder to upload. Defaults to current directory."
71
+ )
72
+ upload_parser.add_argument(
73
+ "path_in_repo",
74
+ nargs="?",
75
+ help="Path of the file or folder in the repo. Defaults to the relative path of the file or folder.",
76
+ )
77
+ upload_parser.add_argument(
78
+ "--repo-type",
79
+ choices=["model", "dataset", "space"],
80
+ default="model",
81
+ help="Type of the repo to upload to (e.g. `dataset`).",
82
+ )
83
+ upload_parser.add_argument(
84
+ "--revision",
85
+ type=str,
86
+ help=(
87
+ "An optional Git revision to push to. It can be a branch name or a PR reference. If revision does not"
88
+ " exist and `--create-pr` is not set, a branch will be automatically created."
89
+ ),
90
+ )
91
+ upload_parser.add_argument(
92
+ "--private",
93
+ action="store_true",
94
+ help=(
95
+ "Whether to create a private repo if repo doesn't exist on the Hub. Ignored if the repo already"
96
+ " exists."
97
+ ),
98
+ )
99
+ upload_parser.add_argument("--include", nargs="*", type=str, help="Glob patterns to match files to upload.")
100
+ upload_parser.add_argument(
101
+ "--exclude", nargs="*", type=str, help="Glob patterns to exclude from files to upload."
102
+ )
103
+ upload_parser.add_argument(
104
+ "--delete",
105
+ nargs="*",
106
+ type=str,
107
+ help="Glob patterns for file to be deleted from the repo while committing.",
108
+ )
109
+ upload_parser.add_argument(
110
+ "--commit-message", type=str, help="The summary / title / first line of the generated commit."
111
+ )
112
+ upload_parser.add_argument("--commit-description", type=str, help="The description of the generated commit.")
113
+ upload_parser.add_argument(
114
+ "--create-pr", action="store_true", help="Whether to upload content as a new Pull Request."
115
+ )
116
+ upload_parser.add_argument(
117
+ "--every",
118
+ type=float,
119
+ help="If set, a background job is scheduled to create commits every `every` minutes.",
120
+ )
121
+ upload_parser.add_argument(
122
+ "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens"
123
+ )
124
+ upload_parser.add_argument(
125
+ "--quiet",
126
+ action="store_true",
127
+ help="If True, progress bars are disabled and only the path to the uploaded files is printed.",
128
+ )
129
+ upload_parser.set_defaults(func=UploadCommand)
130
+
131
+ def __init__(self, args: Namespace) -> None:
132
+ self.repo_id: str = args.repo_id
133
+ self.repo_type: Optional[str] = args.repo_type
134
+ self.revision: Optional[str] = args.revision
135
+ self.private: bool = args.private
136
+
137
+ self.include: Optional[List[str]] = args.include
138
+ self.exclude: Optional[List[str]] = args.exclude
139
+ self.delete: Optional[List[str]] = args.delete
140
+
141
+ self.commit_message: Optional[str] = args.commit_message
142
+ self.commit_description: Optional[str] = args.commit_description
143
+ self.create_pr: bool = args.create_pr
144
+ self.api: HfApi = HfApi(token=args.token, library_name="huggingface-cli")
145
+ self.quiet: bool = args.quiet # disable warnings and progress bars
146
+
147
+ # Check `--every` is valid
148
+ if args.every is not None and args.every <= 0:
149
+ raise ValueError(f"`every` must be a positive value (got '{args.every}')")
150
+ self.every: Optional[float] = args.every
151
+
152
+ # Resolve `local_path` and `path_in_repo`
153
+ repo_name: str = args.repo_id.split("/")[-1] # e.g. "Wauplin/my-cool-model" => "my-cool-model"
154
+ self.local_path: str
155
+ self.path_in_repo: str
156
+ if args.local_path is None and os.path.isfile(repo_name):
157
+ # Implicit case 1: user provided only a repo_id which happen to be a local file as well => upload it with same name
158
+ self.local_path = repo_name
159
+ self.path_in_repo = repo_name
160
+ elif args.local_path is None and os.path.isdir(repo_name):
161
+ # Implicit case 2: user provided only a repo_id which happen to be a local folder as well => upload it at root
162
+ self.local_path = repo_name
163
+ self.path_in_repo = "."
164
+ elif args.local_path is None:
165
+ # Implicit case 3: user provided only a repo_id that does not match a local file or folder
166
+ # => the user must explicitly provide a local_path => raise exception
167
+ raise ValueError(f"'{repo_name}' is not a local file or folder. Please set `local_path` explicitly.")
168
+ elif args.path_in_repo is None and os.path.isfile(args.local_path):
169
+ # Explicit local path to file, no path in repo => upload it at root with same name
170
+ self.local_path = args.local_path
171
+ self.path_in_repo = os.path.basename(args.local_path)
172
+ elif args.path_in_repo is None:
173
+ # Explicit local path to folder, no path in repo => upload at root
174
+ self.local_path = args.local_path
175
+ self.path_in_repo = "."
176
+ else:
177
+ # Finally, if both paths are explicit
178
+ self.local_path = args.local_path
179
+ self.path_in_repo = args.path_in_repo
180
+
181
+ def run(self) -> None:
182
+ if self.quiet:
183
+ disable_progress_bars()
184
+ with warnings.catch_warnings():
185
+ warnings.simplefilter("ignore")
186
+ print(self._upload())
187
+ enable_progress_bars()
188
+ else:
189
+ logging.set_verbosity_info()
190
+ print(self._upload())
191
+ logging.set_verbosity_warning()
192
+
193
+ def _upload(self) -> str:
194
+ if os.path.isfile(self.local_path):
195
+ if self.include is not None and len(self.include) > 0:
196
+ warnings.warn("Ignoring `--include` since a single file is uploaded.")
197
+ if self.exclude is not None and len(self.exclude) > 0:
198
+ warnings.warn("Ignoring `--exclude` since a single file is uploaded.")
199
+ if self.delete is not None and len(self.delete) > 0:
200
+ warnings.warn("Ignoring `--delete` since a single file is uploaded.")
201
+
202
+ if not HF_HUB_ENABLE_HF_TRANSFER:
203
+ logger.info(
204
+ "Consider using `hf_transfer` for faster uploads. This solution comes with some limitations. See"
205
+ " https://huggingface.co/docs/huggingface_hub/hf_transfer for more details."
206
+ )
207
+
208
+ # Schedule commits if `every` is set
209
+ if self.every is not None:
210
+ if os.path.isfile(self.local_path):
211
+ # If file => watch entire folder + use allow_patterns
212
+ folder_path = os.path.dirname(self.local_path)
213
+ path_in_repo = (
214
+ self.path_in_repo[: -len(self.local_path)] # remove filename from path_in_repo
215
+ if self.path_in_repo.endswith(self.local_path)
216
+ else self.path_in_repo
217
+ )
218
+ allow_patterns = [self.local_path]
219
+ ignore_patterns = []
220
+ else:
221
+ folder_path = self.local_path
222
+ path_in_repo = self.path_in_repo
223
+ allow_patterns = self.include or []
224
+ ignore_patterns = self.exclude or []
225
+ if self.delete is not None and len(self.delete) > 0:
226
+ warnings.warn("Ignoring `--delete` when uploading with scheduled commits.")
227
+
228
+ scheduler = CommitScheduler(
229
+ folder_path=folder_path,
230
+ repo_id=self.repo_id,
231
+ repo_type=self.repo_type,
232
+ revision=self.revision,
233
+ allow_patterns=allow_patterns,
234
+ ignore_patterns=ignore_patterns,
235
+ path_in_repo=path_in_repo,
236
+ private=self.private,
237
+ every=self.every,
238
+ hf_api=self.api,
239
+ )
240
+ print(f"Scheduling commits every {self.every} minutes to {scheduler.repo_id}.")
241
+ try: # Block main thread until KeyboardInterrupt
242
+ while True:
243
+ time.sleep(100)
244
+ except KeyboardInterrupt:
245
+ scheduler.stop()
246
+ return "Stopped scheduled commits."
247
+
248
+ # Otherwise, create repo and proceed with the upload
249
+ if not os.path.isfile(self.local_path) and not os.path.isdir(self.local_path):
250
+ raise FileNotFoundError(f"No such file or directory: '{self.local_path}'.")
251
+ repo_id = self.api.create_repo(
252
+ repo_id=self.repo_id,
253
+ repo_type=self.repo_type,
254
+ exist_ok=True,
255
+ private=self.private,
256
+ space_sdk="gradio" if self.repo_type == "space" else None,
257
+ # ^ We don't want it to fail when uploading to a Space => let's set Gradio by default.
258
+ # ^ I'd rather not add CLI args to set it explicitly as we already have `huggingface-cli repo create` for that.
259
+ ).repo_id
260
+
261
+ # Check if branch already exists and if not, create it
262
+ if self.revision is not None and not self.create_pr:
263
+ try:
264
+ self.api.repo_info(repo_id=repo_id, repo_type=self.repo_type, revision=self.revision)
265
+ except RevisionNotFoundError:
266
+ logger.info(f"Branch '{self.revision}' not found. Creating it...")
267
+ self.api.create_branch(repo_id=repo_id, repo_type=self.repo_type, branch=self.revision, exist_ok=True)
268
+ # ^ `exist_ok=True` to avoid race concurrency issues
269
+
270
+ # File-based upload
271
+ if os.path.isfile(self.local_path):
272
+ return self.api.upload_file(
273
+ path_or_fileobj=self.local_path,
274
+ path_in_repo=self.path_in_repo,
275
+ repo_id=repo_id,
276
+ repo_type=self.repo_type,
277
+ revision=self.revision,
278
+ commit_message=self.commit_message,
279
+ commit_description=self.commit_description,
280
+ create_pr=self.create_pr,
281
+ )
282
+
283
+ # Folder-based upload
284
+ else:
285
+ return self.api.upload_folder(
286
+ folder_path=self.local_path,
287
+ path_in_repo=self.path_in_repo,
288
+ repo_id=repo_id,
289
+ repo_type=self.repo_type,
290
+ revision=self.revision,
291
+ commit_message=self.commit_message,
292
+ commit_description=self.commit_description,
293
+ create_pr=self.create_pr,
294
+ allow_patterns=self.include,
295
+ ignore_patterns=self.exclude,
296
+ delete_patterns=self.delete,
297
+ )
lib/python3.11/site-packages/huggingface_hub/commands/user.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import subprocess
15
+ from argparse import _SubParsersAction
16
+
17
+ from requests.exceptions import HTTPError
18
+
19
+ from huggingface_hub.commands import BaseHuggingfaceCLICommand
20
+ from huggingface_hub.constants import (
21
+ ENDPOINT,
22
+ REPO_TYPES,
23
+ REPO_TYPES_URL_PREFIXES,
24
+ SPACES_SDK_TYPES,
25
+ )
26
+ from huggingface_hub.hf_api import HfApi
27
+
28
+ from .._login import ( # noqa: F401 # for backward compatibility # noqa: F401 # for backward compatibility
29
+ NOTEBOOK_LOGIN_PASSWORD_HTML,
30
+ NOTEBOOK_LOGIN_TOKEN_HTML_END,
31
+ NOTEBOOK_LOGIN_TOKEN_HTML_START,
32
+ login,
33
+ logout,
34
+ notebook_login,
35
+ )
36
+ from ..utils import get_token
37
+ from ._cli_utils import ANSI
38
+
39
+
40
+ class UserCommands(BaseHuggingfaceCLICommand):
41
+ @staticmethod
42
+ def register_subcommand(parser: _SubParsersAction):
43
+ login_parser = parser.add_parser("login", help="Log in using a token from huggingface.co/settings/tokens")
44
+ login_parser.add_argument(
45
+ "--token",
46
+ type=str,
47
+ help="Token generated from https://huggingface.co/settings/tokens",
48
+ )
49
+ login_parser.add_argument(
50
+ "--add-to-git-credential",
51
+ action="store_true",
52
+ help="Optional: Save token to git credential helper.",
53
+ )
54
+ login_parser.set_defaults(func=lambda args: LoginCommand(args))
55
+ whoami_parser = parser.add_parser("whoami", help="Find out which huggingface.co account you are logged in as.")
56
+ whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
57
+ logout_parser = parser.add_parser("logout", help="Log out")
58
+ logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
59
+
60
+ # new system: git-based repo system
61
+ repo_parser = parser.add_parser("repo", help="{create} Commands to interact with your huggingface.co repos.")
62
+ repo_subparsers = repo_parser.add_subparsers(help="huggingface.co repos related commands")
63
+ repo_create_parser = repo_subparsers.add_parser("create", help="Create a new repo on huggingface.co")
64
+ repo_create_parser.add_argument(
65
+ "name",
66
+ type=str,
67
+ help="Name for your repo. Will be namespaced under your username to build the repo id.",
68
+ )
69
+ repo_create_parser.add_argument(
70
+ "--type",
71
+ type=str,
72
+ help='Optional: repo_type: set to "dataset" or "space" if creating a dataset or space, default is model.',
73
+ )
74
+ repo_create_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
75
+ repo_create_parser.add_argument(
76
+ "--space_sdk",
77
+ type=str,
78
+ help='Optional: Hugging Face Spaces SDK type. Required when --type is set to "space".',
79
+ choices=SPACES_SDK_TYPES,
80
+ )
81
+ repo_create_parser.add_argument(
82
+ "-y",
83
+ "--yes",
84
+ action="store_true",
85
+ help="Optional: answer Yes to the prompt",
86
+ )
87
+ repo_create_parser.set_defaults(func=lambda args: RepoCreateCommand(args))
88
+
89
+
90
+ class BaseUserCommand:
91
+ def __init__(self, args):
92
+ self.args = args
93
+ self._api = HfApi()
94
+
95
+
96
+ class LoginCommand(BaseUserCommand):
97
+ def run(self):
98
+ login(token=self.args.token, add_to_git_credential=self.args.add_to_git_credential)
99
+
100
+
101
+ class LogoutCommand(BaseUserCommand):
102
+ def run(self):
103
+ logout()
104
+
105
+
106
+ class WhoamiCommand(BaseUserCommand):
107
+ def run(self):
108
+ token = get_token()
109
+ if token is None:
110
+ print("Not logged in")
111
+ exit()
112
+ try:
113
+ info = self._api.whoami(token)
114
+ print(info["name"])
115
+ orgs = [org["name"] for org in info["orgs"]]
116
+ if orgs:
117
+ print(ANSI.bold("orgs: "), ",".join(orgs))
118
+
119
+ if ENDPOINT != "https://huggingface.co":
120
+ print(f"Authenticated through private endpoint: {ENDPOINT}")
121
+ except HTTPError as e:
122
+ print(e)
123
+ print(ANSI.red(e.response.text))
124
+ exit(1)
125
+
126
+
127
+ class RepoCreateCommand(BaseUserCommand):
128
+ def run(self):
129
+ token = get_token()
130
+ if token is None:
131
+ print("Not logged in")
132
+ exit(1)
133
+ try:
134
+ stdout = subprocess.check_output(["git", "--version"]).decode("utf-8")
135
+ print(ANSI.gray(stdout.strip()))
136
+ except FileNotFoundError:
137
+ print("Looks like you do not have git installed, please install.")
138
+
139
+ try:
140
+ stdout = subprocess.check_output(["git-lfs", "--version"]).decode("utf-8")
141
+ print(ANSI.gray(stdout.strip()))
142
+ except FileNotFoundError:
143
+ print(
144
+ ANSI.red(
145
+ "Looks like you do not have git-lfs installed, please install."
146
+ " You can install from https://git-lfs.github.com/."
147
+ " Then run `git lfs install` (you only have to do this once)."
148
+ )
149
+ )
150
+ print("")
151
+
152
+ user = self._api.whoami(token)["name"]
153
+ namespace = self.args.organization if self.args.organization is not None else user
154
+
155
+ repo_id = f"{namespace}/{self.args.name}"
156
+
157
+ if self.args.type not in REPO_TYPES:
158
+ print("Invalid repo --type")
159
+ exit(1)
160
+
161
+ if self.args.type in REPO_TYPES_URL_PREFIXES:
162
+ prefixed_repo_id = REPO_TYPES_URL_PREFIXES[self.args.type] + repo_id
163
+ else:
164
+ prefixed_repo_id = repo_id
165
+
166
+ print(f"You are about to create {ANSI.bold(prefixed_repo_id)}")
167
+
168
+ if not self.args.yes:
169
+ choice = input("Proceed? [Y/n] ").lower()
170
+ if not (choice == "" or choice == "y" or choice == "yes"):
171
+ print("Abort")
172
+ exit()
173
+ try:
174
+ url = self._api.create_repo(
175
+ repo_id=repo_id,
176
+ token=token,
177
+ repo_type=self.args.type,
178
+ space_sdk=self.args.space_sdk,
179
+ )
180
+ except HTTPError as e:
181
+ print(e)
182
+ print(ANSI.red(e.response.text))
183
+ exit(1)
184
+ print("\nYour repo now lives at:")
185
+ print(f" {ANSI.bold(url)}")
186
+ print("\nYou can clone it locally with the command below, and commit/push as usual.")
187
+ print(f"\n git clone {url}")
188
+ print("")
lib/python3.11/site-packages/huggingface_hub/community.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data structures to interact with Discussions and Pull Requests on the Hub.
3
+
4
+ See [the Discussions and Pull Requests guide](https://huggingface.co/docs/hub/repositories-pull-requests-discussions)
5
+ for more information on Pull Requests, Discussions, and the community tab.
6
+ """
7
+ from dataclasses import dataclass
8
+ from datetime import datetime
9
+ from typing import List, Literal, Optional, Union
10
+
11
+ from .constants import REPO_TYPE_MODEL
12
+ from .utils import parse_datetime
13
+
14
+
15
+ DiscussionStatus = Literal["open", "closed", "merged", "draft"]
16
+
17
+
18
+ @dataclass
19
+ class Discussion:
20
+ """
21
+ A Discussion or Pull Request on the Hub.
22
+
23
+ This dataclass is not intended to be instantiated directly.
24
+
25
+ Attributes:
26
+ title (`str`):
27
+ The title of the Discussion / Pull Request
28
+ status (`str`):
29
+ The status of the Discussion / Pull Request.
30
+ It must be one of:
31
+ * `"open"`
32
+ * `"closed"`
33
+ * `"merged"` (only for Pull Requests )
34
+ * `"draft"` (only for Pull Requests )
35
+ num (`int`):
36
+ The number of the Discussion / Pull Request.
37
+ repo_id (`str`):
38
+ The id (`"{namespace}/{repo_name}"`) of the repo on which
39
+ the Discussion / Pull Request was open.
40
+ repo_type (`str`):
41
+ The type of the repo on which the Discussion / Pull Request was open.
42
+ Possible values are: `"model"`, `"dataset"`, `"space"`.
43
+ author (`str`):
44
+ The username of the Discussion / Pull Request author.
45
+ Can be `"deleted"` if the user has been deleted since.
46
+ is_pull_request (`bool`):
47
+ Whether or not this is a Pull Request.
48
+ created_at (`datetime`):
49
+ The `datetime` of creation of the Discussion / Pull Request.
50
+ endpoint (`str`):
51
+ Endpoint of the Hub. Default is https://huggingface.co.
52
+ git_reference (`str`, *optional*):
53
+ (property) Git reference to which changes can be pushed if this is a Pull Request, `None` otherwise.
54
+ url (`str`):
55
+ (property) URL of the discussion on the Hub.
56
+ """
57
+
58
+ title: str
59
+ status: DiscussionStatus
60
+ num: int
61
+ repo_id: str
62
+ repo_type: str
63
+ author: str
64
+ is_pull_request: bool
65
+ created_at: datetime
66
+ endpoint: str
67
+
68
+ @property
69
+ def git_reference(self) -> Optional[str]:
70
+ """
71
+ If this is a Pull Request , returns the git reference to which changes can be pushed.
72
+ Returns `None` otherwise.
73
+ """
74
+ if self.is_pull_request:
75
+ return f"refs/pr/{self.num}"
76
+ return None
77
+
78
+ @property
79
+ def url(self) -> str:
80
+ """Returns the URL of the discussion on the Hub."""
81
+ if self.repo_type is None or self.repo_type == REPO_TYPE_MODEL:
82
+ return f"{self.endpoint}/{self.repo_id}/discussions/{self.num}"
83
+ return f"{self.endpoint}/{self.repo_type}s/{self.repo_id}/discussions/{self.num}"
84
+
85
+
86
+ @dataclass
87
+ class DiscussionWithDetails(Discussion):
88
+ """
89
+ Subclass of [`Discussion`].
90
+
91
+ Attributes:
92
+ title (`str`):
93
+ The title of the Discussion / Pull Request
94
+ status (`str`):
95
+ The status of the Discussion / Pull Request.
96
+ It can be one of:
97
+ * `"open"`
98
+ * `"closed"`
99
+ * `"merged"` (only for Pull Requests )
100
+ * `"draft"` (only for Pull Requests )
101
+ num (`int`):
102
+ The number of the Discussion / Pull Request.
103
+ repo_id (`str`):
104
+ The id (`"{namespace}/{repo_name}"`) of the repo on which
105
+ the Discussion / Pull Request was open.
106
+ repo_type (`str`):
107
+ The type of the repo on which the Discussion / Pull Request was open.
108
+ Possible values are: `"model"`, `"dataset"`, `"space"`.
109
+ author (`str`):
110
+ The username of the Discussion / Pull Request author.
111
+ Can be `"deleted"` if the user has been deleted since.
112
+ is_pull_request (`bool`):
113
+ Whether or not this is a Pull Request.
114
+ created_at (`datetime`):
115
+ The `datetime` of creation of the Discussion / Pull Request.
116
+ events (`list` of [`DiscussionEvent`])
117
+ The list of [`DiscussionEvents`] in this Discussion or Pull Request.
118
+ conflicting_files (`Union[List[str], bool, None]`, *optional*):
119
+ A list of conflicting files if this is a Pull Request.
120
+ `None` if `self.is_pull_request` is `False`.
121
+ `True` if there are conflicting files but the list can't be retrieved.
122
+ target_branch (`str`, *optional*):
123
+ The branch into which changes are to be merged if this is a
124
+ Pull Request . `None` if `self.is_pull_request` is `False`.
125
+ merge_commit_oid (`str`, *optional*):
126
+ If this is a merged Pull Request , this is set to the OID / SHA of
127
+ the merge commit, `None` otherwise.
128
+ diff (`str`, *optional*):
129
+ The git diff if this is a Pull Request , `None` otherwise.
130
+ endpoint (`str`):
131
+ Endpoint of the Hub. Default is https://huggingface.co.
132
+ git_reference (`str`, *optional*):
133
+ (property) Git reference to which changes can be pushed if this is a Pull Request, `None` otherwise.
134
+ url (`str`):
135
+ (property) URL of the discussion on the Hub.
136
+ """
137
+
138
+ events: List["DiscussionEvent"]
139
+ conflicting_files: Union[List[str], bool, None]
140
+ target_branch: Optional[str]
141
+ merge_commit_oid: Optional[str]
142
+ diff: Optional[str]
143
+
144
+
145
+ @dataclass
146
+ class DiscussionEvent:
147
+ """
148
+ An event in a Discussion or Pull Request.
149
+
150
+ Use concrete classes:
151
+ * [`DiscussionComment`]
152
+ * [`DiscussionStatusChange`]
153
+ * [`DiscussionCommit`]
154
+ * [`DiscussionTitleChange`]
155
+
156
+ Attributes:
157
+ id (`str`):
158
+ The ID of the event. An hexadecimal string.
159
+ type (`str`):
160
+ The type of the event.
161
+ created_at (`datetime`):
162
+ A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
163
+ object holding the creation timestamp for the event.
164
+ author (`str`):
165
+ The username of the Discussion / Pull Request author.
166
+ Can be `"deleted"` if the user has been deleted since.
167
+ """
168
+
169
+ id: str
170
+ type: str
171
+ created_at: datetime
172
+ author: str
173
+
174
+ _event: dict
175
+ """Stores the original event data, in case we need to access it later."""
176
+
177
+
178
+ @dataclass
179
+ class DiscussionComment(DiscussionEvent):
180
+ """A comment in a Discussion / Pull Request.
181
+
182
+ Subclass of [`DiscussionEvent`].
183
+
184
+
185
+ Attributes:
186
+ id (`str`):
187
+ The ID of the event. An hexadecimal string.
188
+ type (`str`):
189
+ The type of the event.
190
+ created_at (`datetime`):
191
+ A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
192
+ object holding the creation timestamp for the event.
193
+ author (`str`):
194
+ The username of the Discussion / Pull Request author.
195
+ Can be `"deleted"` if the user has been deleted since.
196
+ content (`str`):
197
+ The raw markdown content of the comment. Mentions, links and images are not rendered.
198
+ edited (`bool`):
199
+ Whether or not this comment has been edited.
200
+ hidden (`bool`):
201
+ Whether or not this comment has been hidden.
202
+ """
203
+
204
+ content: str
205
+ edited: bool
206
+ hidden: bool
207
+
208
+ @property
209
+ def rendered(self) -> str:
210
+ """The rendered comment, as a HTML string"""
211
+ return self._event["data"]["latest"]["html"]
212
+
213
+ @property
214
+ def last_edited_at(self) -> datetime:
215
+ """The last edit time, as a `datetime` object."""
216
+ return parse_datetime(self._event["data"]["latest"]["updatedAt"])
217
+
218
+ @property
219
+ def last_edited_by(self) -> str:
220
+ """The last edit time, as a `datetime` object."""
221
+ return self._event["data"]["latest"].get("author", {}).get("name", "deleted")
222
+
223
+ @property
224
+ def edit_history(self) -> List[dict]:
225
+ """The edit history of the comment"""
226
+ return self._event["data"]["history"]
227
+
228
+ @property
229
+ def number_of_edits(self) -> int:
230
+ return len(self.edit_history)
231
+
232
+
233
+ @dataclass
234
+ class DiscussionStatusChange(DiscussionEvent):
235
+ """A change of status in a Discussion / Pull Request.
236
+
237
+ Subclass of [`DiscussionEvent`].
238
+
239
+ Attributes:
240
+ id (`str`):
241
+ The ID of the event. An hexadecimal string.
242
+ type (`str`):
243
+ The type of the event.
244
+ created_at (`datetime`):
245
+ A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
246
+ object holding the creation timestamp for the event.
247
+ author (`str`):
248
+ The username of the Discussion / Pull Request author.
249
+ Can be `"deleted"` if the user has been deleted since.
250
+ new_status (`str`):
251
+ The status of the Discussion / Pull Request after the change.
252
+ It can be one of:
253
+ * `"open"`
254
+ * `"closed"`
255
+ * `"merged"` (only for Pull Requests )
256
+ """
257
+
258
+ new_status: str
259
+
260
+
261
+ @dataclass
262
+ class DiscussionCommit(DiscussionEvent):
263
+ """A commit in a Pull Request.
264
+
265
+ Subclass of [`DiscussionEvent`].
266
+
267
+ Attributes:
268
+ id (`str`):
269
+ The ID of the event. An hexadecimal string.
270
+ type (`str`):
271
+ The type of the event.
272
+ created_at (`datetime`):
273
+ A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
274
+ object holding the creation timestamp for the event.
275
+ author (`str`):
276
+ The username of the Discussion / Pull Request author.
277
+ Can be `"deleted"` if the user has been deleted since.
278
+ summary (`str`):
279
+ The summary of the commit.
280
+ oid (`str`):
281
+ The OID / SHA of the commit, as a hexadecimal string.
282
+ """
283
+
284
+ summary: str
285
+ oid: str
286
+
287
+
288
+ @dataclass
289
+ class DiscussionTitleChange(DiscussionEvent):
290
+ """A rename event in a Discussion / Pull Request.
291
+
292
+ Subclass of [`DiscussionEvent`].
293
+
294
+ Attributes:
295
+ id (`str`):
296
+ The ID of the event. An hexadecimal string.
297
+ type (`str`):
298
+ The type of the event.
299
+ created_at (`datetime`):
300
+ A [`datetime`](https://docs.python.org/3/library/datetime.html?highlight=datetime#datetime.datetime)
301
+ object holding the creation timestamp for the event.
302
+ author (`str`):
303
+ The username of the Discussion / Pull Request author.
304
+ Can be `"deleted"` if the user has been deleted since.
305
+ old_title (`str`):
306
+ The previous title for the Discussion / Pull Request.
307
+ new_title (`str`):
308
+ The new title.
309
+ """
310
+
311
+ old_title: str
312
+ new_title: str
313
+
314
+
315
+ def deserialize_event(event: dict) -> DiscussionEvent:
316
+ """Instantiates a [`DiscussionEvent`] from a dict"""
317
+ event_id: str = event["id"]
318
+ event_type: str = event["type"]
319
+ created_at = parse_datetime(event["createdAt"])
320
+
321
+ common_args = dict(
322
+ id=event_id,
323
+ type=event_type,
324
+ created_at=created_at,
325
+ author=event.get("author", {}).get("name", "deleted"),
326
+ _event=event,
327
+ )
328
+
329
+ if event_type == "comment":
330
+ return DiscussionComment(
331
+ **common_args,
332
+ edited=event["data"]["edited"],
333
+ hidden=event["data"]["hidden"],
334
+ content=event["data"]["latest"]["raw"],
335
+ )
336
+ if event_type == "status-change":
337
+ return DiscussionStatusChange(
338
+ **common_args,
339
+ new_status=event["data"]["status"],
340
+ )
341
+ if event_type == "commit":
342
+ return DiscussionCommit(
343
+ **common_args,
344
+ summary=event["data"]["subject"],
345
+ oid=event["data"]["oid"],
346
+ )
347
+ if event_type == "title-change":
348
+ return DiscussionTitleChange(
349
+ **common_args,
350
+ old_title=event["data"]["from"],
351
+ new_title=event["data"]["to"],
352
+ )
353
+
354
+ return DiscussionEvent(**common_args)
lib/python3.11/site-packages/huggingface_hub/constants.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import typing
4
+ from typing import Literal, Optional, Tuple
5
+
6
+
7
+ # Possible values for env variables
8
+
9
+
10
+ ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
11
+ ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
12
+
13
+
14
+ def _is_true(value: Optional[str]) -> bool:
15
+ if value is None:
16
+ return False
17
+ return value.upper() in ENV_VARS_TRUE_VALUES
18
+
19
+
20
+ def _as_int(value: Optional[str]) -> Optional[int]:
21
+ if value is None:
22
+ return None
23
+ return int(value)
24
+
25
+
26
+ # Constants for file downloads
27
+
28
+ PYTORCH_WEIGHTS_NAME = "pytorch_model.bin"
29
+ TF2_WEIGHTS_NAME = "tf_model.h5"
30
+ TF_WEIGHTS_NAME = "model.ckpt"
31
+ FLAX_WEIGHTS_NAME = "flax_model.msgpack"
32
+ CONFIG_NAME = "config.json"
33
+ REPOCARD_NAME = "README.md"
34
+ DEFAULT_ETAG_TIMEOUT = 10
35
+ DEFAULT_DOWNLOAD_TIMEOUT = 10
36
+ DEFAULT_REQUEST_TIMEOUT = 10
37
+ DOWNLOAD_CHUNK_SIZE = 10 * 1024 * 1024
38
+ HF_TRANSFER_CONCURRENCY = 100
39
+
40
+ # Constants for safetensors repos
41
+
42
+ SAFETENSORS_SINGLE_FILE = "model.safetensors"
43
+ SAFETENSORS_INDEX_FILE = "model.safetensors.index.json"
44
+ SAFETENSORS_MAX_HEADER_LENGTH = 25_000_000
45
+
46
+ # Git-related constants
47
+
48
+ DEFAULT_REVISION = "main"
49
+ REGEX_COMMIT_OID = re.compile(r"[A-Fa-f0-9]{5,40}")
50
+
51
+ HUGGINGFACE_CO_URL_HOME = "https://huggingface.co/"
52
+
53
+ _staging_mode = _is_true(os.environ.get("HUGGINGFACE_CO_STAGING"))
54
+
55
+ ENDPOINT = os.getenv("HF_ENDPOINT") or ("https://hub-ci.huggingface.co" if _staging_mode else "https://huggingface.co")
56
+
57
+ HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}"
58
+ HUGGINGFACE_HEADER_X_REPO_COMMIT = "X-Repo-Commit"
59
+ HUGGINGFACE_HEADER_X_LINKED_ETAG = "X-Linked-Etag"
60
+ HUGGINGFACE_HEADER_X_LINKED_SIZE = "X-Linked-Size"
61
+
62
+ INFERENCE_ENDPOINT = os.environ.get("HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co")
63
+
64
+ # See https://huggingface.co/docs/inference-endpoints/index
65
+ INFERENCE_ENDPOINTS_ENDPOINT = "https://api.endpoints.huggingface.cloud/v2"
66
+
67
+
68
+ REPO_ID_SEPARATOR = "--"
69
+ # ^ this substring is not allowed in repo_ids on hf.co
70
+ # and is the canonical one we use for serialization of repo ids elsewhere.
71
+
72
+
73
+ REPO_TYPE_DATASET = "dataset"
74
+ REPO_TYPE_SPACE = "space"
75
+ REPO_TYPE_MODEL = "model"
76
+ REPO_TYPES = [None, REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE]
77
+ SPACES_SDK_TYPES = ["gradio", "streamlit", "docker", "static"]
78
+
79
+ REPO_TYPES_URL_PREFIXES = {
80
+ REPO_TYPE_DATASET: "datasets/",
81
+ REPO_TYPE_SPACE: "spaces/",
82
+ }
83
+ REPO_TYPES_MAPPING = {
84
+ "datasets": REPO_TYPE_DATASET,
85
+ "spaces": REPO_TYPE_SPACE,
86
+ "models": REPO_TYPE_MODEL,
87
+ }
88
+
89
+ DiscussionTypeFilter = Literal["all", "discussion", "pull_request"]
90
+ DISCUSSION_TYPES: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionTypeFilter)
91
+ DiscussionStatusFilter = Literal["all", "open", "closed"]
92
+ DISCUSSION_STATUS: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionStatusFilter)
93
+
94
+ # default cache
95
+ default_home = os.path.join(os.path.expanduser("~"), ".cache")
96
+ HF_HOME = os.path.expanduser(
97
+ os.getenv(
98
+ "HF_HOME",
99
+ os.path.join(os.getenv("XDG_CACHE_HOME", default_home), "huggingface"),
100
+ )
101
+ )
102
+ hf_cache_home = HF_HOME # for backward compatibility. TODO: remove this in 1.0.0
103
+
104
+ default_cache_path = os.path.join(HF_HOME, "hub")
105
+ default_assets_cache_path = os.path.join(HF_HOME, "assets")
106
+
107
+ # Legacy env variables
108
+ HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", default_cache_path)
109
+ HUGGINGFACE_ASSETS_CACHE = os.getenv("HUGGINGFACE_ASSETS_CACHE", default_assets_cache_path)
110
+
111
+ # New env variables
112
+ HF_HUB_CACHE = os.getenv("HF_HUB_CACHE", HUGGINGFACE_HUB_CACHE)
113
+ HF_ASSETS_CACHE = os.getenv("HF_ASSETS_CACHE", HUGGINGFACE_ASSETS_CACHE)
114
+
115
+ HF_HUB_OFFLINE = _is_true(os.environ.get("HF_HUB_OFFLINE") or os.environ.get("TRANSFORMERS_OFFLINE"))
116
+
117
+ # Opt-out from telemetry requests
118
+ HF_HUB_DISABLE_TELEMETRY = (
119
+ _is_true(os.environ.get("HF_HUB_DISABLE_TELEMETRY")) # HF-specific env variable
120
+ or _is_true(os.environ.get("DISABLE_TELEMETRY"))
121
+ or _is_true(os.environ.get("DO_NOT_TRACK")) # https://consoledonottrack.com/
122
+ )
123
+
124
+ # In the past, token was stored in a hardcoded location
125
+ # `_OLD_HF_TOKEN_PATH` is deprecated and will be removed "at some point".
126
+ # See https://github.com/huggingface/huggingface_hub/issues/1232
127
+ _OLD_HF_TOKEN_PATH = os.path.expanduser("~/.huggingface/token")
128
+ HF_TOKEN_PATH = os.path.join(HF_HOME, "token")
129
+
130
+
131
+ if _staging_mode:
132
+ # In staging mode, we use a different cache to ensure we don't mix up production and staging data or tokens
133
+ _staging_home = os.path.join(os.path.expanduser("~"), ".cache", "huggingface_staging")
134
+ HUGGINGFACE_HUB_CACHE = os.path.join(_staging_home, "hub")
135
+ _OLD_HF_TOKEN_PATH = os.path.join(_staging_home, "_old_token")
136
+ HF_TOKEN_PATH = os.path.join(_staging_home, "token")
137
+
138
+ # Here, `True` will disable progress bars globally without possibility of enabling it
139
+ # programmatically. `False` will enable them without possibility of disabling them.
140
+ # If environment variable is not set (None), then the user is free to enable/disable
141
+ # them programmatically.
142
+ # TL;DR: env variable has priority over code
143
+ __HF_HUB_DISABLE_PROGRESS_BARS = os.environ.get("HF_HUB_DISABLE_PROGRESS_BARS")
144
+ HF_HUB_DISABLE_PROGRESS_BARS: Optional[bool] = (
145
+ _is_true(__HF_HUB_DISABLE_PROGRESS_BARS) if __HF_HUB_DISABLE_PROGRESS_BARS is not None else None
146
+ )
147
+
148
+ # Disable warning on machines that do not support symlinks (e.g. Windows non-developer)
149
+ HF_HUB_DISABLE_SYMLINKS_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_SYMLINKS_WARNING"))
150
+
151
+ # Disable warning when using experimental features
152
+ HF_HUB_DISABLE_EXPERIMENTAL_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_EXPERIMENTAL_WARNING"))
153
+
154
+ # Disable sending the cached token by default is all HTTP requests to the Hub
155
+ HF_HUB_DISABLE_IMPLICIT_TOKEN: bool = _is_true(os.environ.get("HF_HUB_DISABLE_IMPLICIT_TOKEN"))
156
+
157
+ # Enable fast-download using external dependency "hf_transfer"
158
+ # See:
159
+ # - https://pypi.org/project/hf-transfer/
160
+ # - https://github.com/huggingface/hf_transfer (private)
161
+ HF_HUB_ENABLE_HF_TRANSFER: bool = _is_true(os.environ.get("HF_HUB_ENABLE_HF_TRANSFER"))
162
+
163
+
164
+ # Used if download to `local_dir` and `local_dir_use_symlinks="auto"`
165
+ # Files smaller than 5MB are copy-pasted while bigger files are symlinked. The idea is to save disk-usage by symlinking
166
+ # huge files (i.e. LFS files most of the time) while allowing small files to be manually edited in local folder.
167
+ HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD: int = (
168
+ _as_int(os.environ.get("HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD")) or 5 * 1024 * 1024
169
+ )
170
+
171
+ # Used to override the etag timeout on a system level
172
+ HF_HUB_ETAG_TIMEOUT: int = _as_int(os.environ.get("HF_HUB_ETAG_TIMEOUT")) or DEFAULT_ETAG_TIMEOUT
173
+
174
+ # Used to override the get request timeout on a system level
175
+ HF_HUB_DOWNLOAD_TIMEOUT: int = _as_int(os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT")) or DEFAULT_DOWNLOAD_TIMEOUT
176
+
177
+ # List frameworks that are handled by the InferenceAPI service. Useful to scan endpoints and check which models are
178
+ # deployed and running. Since 95% of the models are using the top 4 frameworks listed below, we scan only those by
179
+ # default. We still keep the full list of supported frameworks in case we want to scan all of them.
180
+ MAIN_INFERENCE_API_FRAMEWORKS = [
181
+ "diffusers",
182
+ "sentence-transformers",
183
+ "text-generation-inference",
184
+ "transformers",
185
+ ]
186
+
187
+ ALL_INFERENCE_API_FRAMEWORKS = MAIN_INFERENCE_API_FRAMEWORKS + [
188
+ "adapter-transformers",
189
+ "allennlp",
190
+ "asteroid",
191
+ "bertopic",
192
+ "doctr",
193
+ "espnet",
194
+ "fairseq",
195
+ "fastai",
196
+ "fasttext",
197
+ "flair",
198
+ "generic",
199
+ "k2",
200
+ "keras",
201
+ "mindspore",
202
+ "nemo",
203
+ "open_clip",
204
+ "paddlenlp",
205
+ "peft",
206
+ "pyannote-audio",
207
+ "sklearn",
208
+ "spacy",
209
+ "span-marker",
210
+ "speechbrain",
211
+ "stanza",
212
+ "timm",
213
+ ]
lib/python3.11/site-packages/huggingface_hub/fastai_utils.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ from pickle import DEFAULT_PROTOCOL, PicklingError
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ from packaging import version
8
+
9
+ from huggingface_hub import snapshot_download
10
+ from huggingface_hub.constants import CONFIG_NAME
11
+ from huggingface_hub.hf_api import HfApi
12
+ from huggingface_hub.utils import (
13
+ SoftTemporaryDirectory,
14
+ get_fastai_version,
15
+ get_fastcore_version,
16
+ get_python_version,
17
+ )
18
+
19
+ from .utils import logging, validate_hf_hub_args
20
+ from .utils._runtime import _PY_VERSION # noqa: F401 # for backward compatibility...
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ def _check_fastai_fastcore_versions(
27
+ fastai_min_version: str = "2.4",
28
+ fastcore_min_version: str = "1.3.27",
29
+ ):
30
+ """
31
+ Checks that the installed fastai and fastcore versions are compatible for pickle serialization.
32
+
33
+ Args:
34
+ fastai_min_version (`str`, *optional*):
35
+ The minimum fastai version supported.
36
+ fastcore_min_version (`str`, *optional*):
37
+ The minimum fastcore version supported.
38
+
39
+ <Tip>
40
+ Raises the following error:
41
+
42
+ - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
43
+ if the fastai or fastcore libraries are not available or are of an invalid version.
44
+
45
+ </Tip>
46
+ """
47
+
48
+ if (get_fastcore_version() or get_fastai_version()) == "N/A":
49
+ raise ImportError(
50
+ f"fastai>={fastai_min_version} and fastcore>={fastcore_min_version} are"
51
+ f" required. Currently using fastai=={get_fastai_version()} and"
52
+ f" fastcore=={get_fastcore_version()}."
53
+ )
54
+
55
+ current_fastai_version = version.Version(get_fastai_version())
56
+ current_fastcore_version = version.Version(get_fastcore_version())
57
+
58
+ if current_fastai_version < version.Version(fastai_min_version):
59
+ raise ImportError(
60
+ "`push_to_hub_fastai` and `from_pretrained_fastai` require a"
61
+ f" fastai>={fastai_min_version} version, but you are using fastai version"
62
+ f" {get_fastai_version()} which is incompatible. Upgrade with `pip install"
63
+ " fastai==2.5.6`."
64
+ )
65
+
66
+ if current_fastcore_version < version.Version(fastcore_min_version):
67
+ raise ImportError(
68
+ "`push_to_hub_fastai` and `from_pretrained_fastai` require a"
69
+ f" fastcore>={fastcore_min_version} version, but you are using fastcore"
70
+ f" version {get_fastcore_version()} which is incompatible. Upgrade with"
71
+ " `pip install fastcore==1.3.27`."
72
+ )
73
+
74
+
75
+ def _check_fastai_fastcore_pyproject_versions(
76
+ storage_folder: str,
77
+ fastai_min_version: str = "2.4",
78
+ fastcore_min_version: str = "1.3.27",
79
+ ):
80
+ """
81
+ Checks that the `pyproject.toml` file in the directory `storage_folder` has fastai and fastcore versions
82
+ that are compatible with `from_pretrained_fastai` and `push_to_hub_fastai`. If `pyproject.toml` does not exist
83
+ or does not contain versions for fastai and fastcore, then it logs a warning.
84
+
85
+ Args:
86
+ storage_folder (`str`):
87
+ Folder to look for the `pyproject.toml` file.
88
+ fastai_min_version (`str`, *optional*):
89
+ The minimum fastai version supported.
90
+ fastcore_min_version (`str`, *optional*):
91
+ The minimum fastcore version supported.
92
+
93
+ <Tip>
94
+ Raises the following errors:
95
+
96
+ - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
97
+ if the `toml` module is not installed.
98
+ - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
99
+ if the `pyproject.toml` indicates a lower than minimum supported version of fastai or fastcore.
100
+
101
+ </Tip>
102
+ """
103
+
104
+ try:
105
+ import toml
106
+ except ModuleNotFoundError:
107
+ raise ImportError(
108
+ "`push_to_hub_fastai` and `from_pretrained_fastai` require the toml module."
109
+ " Install it with `pip install toml`."
110
+ )
111
+
112
+ # Checks that a `pyproject.toml`, with `build-system` and `requires` sections, exists in the repository. If so, get a list of required packages.
113
+ if not os.path.isfile(f"{storage_folder}/pyproject.toml"):
114
+ logger.warning(
115
+ "There is no `pyproject.toml` in the repository that contains the fastai"
116
+ " `Learner`. The `pyproject.toml` would allow us to verify that your fastai"
117
+ " and fastcore versions are compatible with those of the model you want to"
118
+ " load."
119
+ )
120
+ return
121
+ pyproject_toml = toml.load(f"{storage_folder}/pyproject.toml")
122
+
123
+ if "build-system" not in pyproject_toml.keys():
124
+ logger.warning(
125
+ "There is no `build-system` section in the pyproject.toml of the repository"
126
+ " that contains the fastai `Learner`. The `build-system` would allow us to"
127
+ " verify that your fastai and fastcore versions are compatible with those"
128
+ " of the model you want to load."
129
+ )
130
+ return
131
+ build_system_toml = pyproject_toml["build-system"]
132
+
133
+ if "requires" not in build_system_toml.keys():
134
+ logger.warning(
135
+ "There is no `requires` section in the pyproject.toml of the repository"
136
+ " that contains the fastai `Learner`. The `requires` would allow us to"
137
+ " verify that your fastai and fastcore versions are compatible with those"
138
+ " of the model you want to load."
139
+ )
140
+ return
141
+ package_versions = build_system_toml["requires"]
142
+
143
+ # Extracts contains fastai and fastcore versions from `pyproject.toml` if available.
144
+ # If the package is specified but not the version (e.g. "fastai" instead of "fastai=2.4"), the default versions are the highest.
145
+ fastai_packages = [pck for pck in package_versions if pck.startswith("fastai")]
146
+ if len(fastai_packages) == 0:
147
+ logger.warning("The repository does not have a fastai version specified in the `pyproject.toml`.")
148
+ # fastai_version is an empty string if not specified
149
+ else:
150
+ fastai_version = str(fastai_packages[0]).partition("=")[2]
151
+ if fastai_version != "" and version.Version(fastai_version) < version.Version(fastai_min_version):
152
+ raise ImportError(
153
+ "`from_pretrained_fastai` requires"
154
+ f" fastai>={fastai_min_version} version but the model to load uses"
155
+ f" {fastai_version} which is incompatible."
156
+ )
157
+
158
+ fastcore_packages = [pck for pck in package_versions if pck.startswith("fastcore")]
159
+ if len(fastcore_packages) == 0:
160
+ logger.warning("The repository does not have a fastcore version specified in the `pyproject.toml`.")
161
+ # fastcore_version is an empty string if not specified
162
+ else:
163
+ fastcore_version = str(fastcore_packages[0]).partition("=")[2]
164
+ if fastcore_version != "" and version.Version(fastcore_version) < version.Version(fastcore_min_version):
165
+ raise ImportError(
166
+ "`from_pretrained_fastai` requires"
167
+ f" fastcore>={fastcore_min_version} version, but you are using fastcore"
168
+ f" version {fastcore_version} which is incompatible."
169
+ )
170
+
171
+
172
+ README_TEMPLATE = """---
173
+ tags:
174
+ - fastai
175
+ ---
176
+
177
+ # Amazing!
178
+
179
+ 🥳 Congratulations on hosting your fastai model on the Hugging Face Hub!
180
+
181
+ # Some next steps
182
+ 1. Fill out this model card with more information (see the template below and the [documentation here](https://huggingface.co/docs/hub/model-repos))!
183
+
184
+ 2. Create a demo in Gradio or Streamlit using 🤗 Spaces ([documentation here](https://huggingface.co/docs/hub/spaces)).
185
+
186
+ 3. Join the fastai community on the [Fastai Discord](https://discord.com/invite/YKrxeNn)!
187
+
188
+ Greetings fellow fastlearner 🤝! Don't forget to delete this content from your model card.
189
+
190
+
191
+ ---
192
+
193
+
194
+ # Model card
195
+
196
+ ## Model description
197
+ More information needed
198
+
199
+ ## Intended uses & limitations
200
+ More information needed
201
+
202
+ ## Training and evaluation data
203
+ More information needed
204
+ """
205
+
206
+ PYPROJECT_TEMPLATE = f"""[build-system]
207
+ requires = ["setuptools>=40.8.0", "wheel", "python={get_python_version()}", "fastai={get_fastai_version()}", "fastcore={get_fastcore_version()}"]
208
+ build-backend = "setuptools.build_meta:__legacy__"
209
+ """
210
+
211
+
212
+ def _create_model_card(repo_dir: Path):
213
+ """
214
+ Creates a model card for the repository.
215
+
216
+ Args:
217
+ repo_dir (`Path`):
218
+ Directory where model card is created.
219
+ """
220
+ readme_path = repo_dir / "README.md"
221
+
222
+ if not readme_path.exists():
223
+ with readme_path.open("w", encoding="utf-8") as f:
224
+ f.write(README_TEMPLATE)
225
+
226
+
227
+ def _create_model_pyproject(repo_dir: Path):
228
+ """
229
+ Creates a `pyproject.toml` for the repository.
230
+
231
+ Args:
232
+ repo_dir (`Path`):
233
+ Directory where `pyproject.toml` is created.
234
+ """
235
+ pyproject_path = repo_dir / "pyproject.toml"
236
+
237
+ if not pyproject_path.exists():
238
+ with pyproject_path.open("w", encoding="utf-8") as f:
239
+ f.write(PYPROJECT_TEMPLATE)
240
+
241
+
242
+ def _save_pretrained_fastai(
243
+ learner,
244
+ save_directory: Union[str, Path],
245
+ config: Optional[Dict[str, Any]] = None,
246
+ ):
247
+ """
248
+ Saves a fastai learner to `save_directory` in pickle format using the default pickle protocol for the version of python used.
249
+
250
+ Args:
251
+ learner (`Learner`):
252
+ The `fastai.Learner` you'd like to save.
253
+ save_directory (`str` or `Path`):
254
+ Specific directory in which you want to save the fastai learner.
255
+ config (`dict`, *optional*):
256
+ Configuration object. Will be uploaded as a .json file. Example: 'https://huggingface.co/espejelomar/fastai-pet-breeds-classification/blob/main/config.json'.
257
+
258
+ <Tip>
259
+
260
+ Raises the following error:
261
+
262
+ - [`RuntimeError`](https://docs.python.org/3/library/exceptions.html#RuntimeError)
263
+ if the config file provided is not a dictionary.
264
+
265
+ </Tip>
266
+ """
267
+ _check_fastai_fastcore_versions()
268
+
269
+ os.makedirs(save_directory, exist_ok=True)
270
+
271
+ # if the user provides config then we update it with the fastai and fastcore versions in CONFIG_TEMPLATE.
272
+ if config is not None:
273
+ if not isinstance(config, dict):
274
+ raise RuntimeError(f"Provided config should be a dict. Got: '{type(config)}'")
275
+ path = os.path.join(save_directory, CONFIG_NAME)
276
+ with open(path, "w") as f:
277
+ json.dump(config, f)
278
+
279
+ _create_model_card(Path(save_directory))
280
+ _create_model_pyproject(Path(save_directory))
281
+
282
+ # learner.export saves the model in `self.path`.
283
+ learner.path = Path(save_directory)
284
+ os.makedirs(save_directory, exist_ok=True)
285
+ try:
286
+ learner.export(
287
+ fname="model.pkl",
288
+ pickle_protocol=DEFAULT_PROTOCOL,
289
+ )
290
+ except PicklingError:
291
+ raise PicklingError(
292
+ "You are using a lambda function, i.e., an anonymous function. `pickle`"
293
+ " cannot pickle function objects and requires that all functions have"
294
+ " names. One possible solution is to name the function."
295
+ )
296
+
297
+
298
+ @validate_hf_hub_args
299
+ def from_pretrained_fastai(
300
+ repo_id: str,
301
+ revision: Optional[str] = None,
302
+ ):
303
+ """
304
+ Load pretrained fastai model from the Hub or from a local directory.
305
+
306
+ Args:
307
+ repo_id (`str`):
308
+ The location where the pickled fastai.Learner is. It can be either of the two:
309
+ - Hosted on the Hugging Face Hub. E.g.: 'espejelomar/fatai-pet-breeds-classification' or 'distilgpt2'.
310
+ You can add a `revision` by appending `@` at the end of `repo_id`. E.g.: `dbmdz/bert-base-german-cased@main`.
311
+ Revision is the specific model version to use. Since we use a git-based system for storing models and other
312
+ artifacts on the Hugging Face Hub, it can be a branch name, a tag name, or a commit id.
313
+ - Hosted locally. `repo_id` would be a directory containing the pickle and a pyproject.toml
314
+ indicating the fastai and fastcore versions used to build the `fastai.Learner`. E.g.: `./my_model_directory/`.
315
+ revision (`str`, *optional*):
316
+ Revision at which the repo's files are downloaded. See documentation of `snapshot_download`.
317
+
318
+ Returns:
319
+ The `fastai.Learner` model in the `repo_id` repo.
320
+ """
321
+ _check_fastai_fastcore_versions()
322
+
323
+ # Load the `repo_id` repo.
324
+ # `snapshot_download` returns the folder where the model was stored.
325
+ # `cache_dir` will be the default '/root/.cache/huggingface/hub'
326
+ if not os.path.isdir(repo_id):
327
+ storage_folder = snapshot_download(
328
+ repo_id=repo_id,
329
+ revision=revision,
330
+ library_name="fastai",
331
+ library_version=get_fastai_version(),
332
+ )
333
+ else:
334
+ storage_folder = repo_id
335
+
336
+ _check_fastai_fastcore_pyproject_versions(storage_folder)
337
+
338
+ from fastai.learner import load_learner # type: ignore
339
+
340
+ return load_learner(os.path.join(storage_folder, "model.pkl"))
341
+
342
+
343
+ @validate_hf_hub_args
344
+ def push_to_hub_fastai(
345
+ learner,
346
+ *,
347
+ repo_id: str,
348
+ commit_message: str = "Push FastAI model using huggingface_hub.",
349
+ private: bool = False,
350
+ token: Optional[str] = None,
351
+ config: Optional[dict] = None,
352
+ branch: Optional[str] = None,
353
+ create_pr: Optional[bool] = None,
354
+ allow_patterns: Optional[Union[List[str], str]] = None,
355
+ ignore_patterns: Optional[Union[List[str], str]] = None,
356
+ delete_patterns: Optional[Union[List[str], str]] = None,
357
+ api_endpoint: Optional[str] = None,
358
+ ):
359
+ """
360
+ Upload learner checkpoint files to the Hub.
361
+
362
+ Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
363
+ `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
364
+ details.
365
+
366
+ Args:
367
+ learner (`Learner`):
368
+ The `fastai.Learner' you'd like to push to the Hub.
369
+ repo_id (`str`):
370
+ The repository id for your model in Hub in the format of "namespace/repo_name". The namespace can be your individual account or an organization to which you have write access (for example, 'stanfordnlp/stanza-de').
371
+ commit_message (`str`, *optional*):
372
+ Message to commit while pushing. Will default to :obj:`"add model"`.
373
+ private (`bool`, *optional*, defaults to `False`):
374
+ Whether or not the repository created should be private.
375
+ token (`str`, *optional*):
376
+ The Hugging Face account token to use as HTTP bearer authorization for remote files. If :obj:`None`, the token will be asked by a prompt.
377
+ config (`dict`, *optional*):
378
+ Configuration object to be saved alongside the model weights.
379
+ branch (`str`, *optional*):
380
+ The git branch on which to push the model. This defaults to
381
+ the default branch as specified in your repository, which
382
+ defaults to `"main"`.
383
+ create_pr (`boolean`, *optional*):
384
+ Whether or not to create a Pull Request from `branch` with that commit.
385
+ Defaults to `False`.
386
+ api_endpoint (`str`, *optional*):
387
+ The API endpoint to use when pushing the model to the hub.
388
+ allow_patterns (`List[str]` or `str`, *optional*):
389
+ If provided, only files matching at least one pattern are pushed.
390
+ ignore_patterns (`List[str]` or `str`, *optional*):
391
+ If provided, files matching any of the patterns are not pushed.
392
+ delete_patterns (`List[str]` or `str`, *optional*):
393
+ If provided, remote files matching any of the patterns will be deleted from the repo.
394
+
395
+ Returns:
396
+ The url of the commit of your model in the given repository.
397
+
398
+ <Tip>
399
+
400
+ Raises the following error:
401
+
402
+ - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
403
+ if the user is not log on to the Hugging Face Hub.
404
+
405
+ </Tip>
406
+ """
407
+ _check_fastai_fastcore_versions()
408
+ api = HfApi(endpoint=api_endpoint)
409
+ repo_id = api.create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True).repo_id
410
+
411
+ # Push the files to the repo in a single commit
412
+ with SoftTemporaryDirectory() as tmp:
413
+ saved_path = Path(tmp) / repo_id
414
+ _save_pretrained_fastai(learner, saved_path, config=config)
415
+ return api.upload_folder(
416
+ repo_id=repo_id,
417
+ token=token,
418
+ folder_path=saved_path,
419
+ commit_message=commit_message,
420
+ revision=branch,
421
+ create_pr=create_pr,
422
+ allow_patterns=allow_patterns,
423
+ ignore_patterns=ignore_patterns,
424
+ delete_patterns=delete_patterns,
425
+ )
lib/python3.11/site-packages/huggingface_hub/file_download.py ADDED
@@ -0,0 +1,1727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import fnmatch
3
+ import inspect
4
+ import io
5
+ import json
6
+ import os
7
+ import re
8
+ import shutil
9
+ import stat
10
+ import tempfile
11
+ import time
12
+ import uuid
13
+ import warnings
14
+ from contextlib import contextmanager
15
+ from dataclasses import dataclass
16
+ from functools import partial
17
+ from pathlib import Path
18
+ from typing import Any, BinaryIO, Dict, Generator, Literal, Optional, Tuple, Union
19
+ from urllib.parse import quote, urlparse
20
+
21
+ import requests
22
+ from filelock import FileLock
23
+
24
+ from huggingface_hub import constants
25
+
26
+ from . import __version__ # noqa: F401 # for backward compatibility
27
+ from .constants import (
28
+ DEFAULT_ETAG_TIMEOUT,
29
+ DEFAULT_REQUEST_TIMEOUT,
30
+ DEFAULT_REVISION,
31
+ DOWNLOAD_CHUNK_SIZE,
32
+ ENDPOINT,
33
+ HF_HUB_CACHE,
34
+ HF_HUB_DISABLE_SYMLINKS_WARNING,
35
+ HF_HUB_DOWNLOAD_TIMEOUT,
36
+ HF_HUB_ENABLE_HF_TRANSFER,
37
+ HF_HUB_ETAG_TIMEOUT,
38
+ HF_TRANSFER_CONCURRENCY,
39
+ HUGGINGFACE_CO_URL_TEMPLATE,
40
+ HUGGINGFACE_HEADER_X_LINKED_ETAG,
41
+ HUGGINGFACE_HEADER_X_LINKED_SIZE,
42
+ HUGGINGFACE_HEADER_X_REPO_COMMIT,
43
+ HUGGINGFACE_HUB_CACHE, # noqa: F401 # for backward compatibility
44
+ REPO_ID_SEPARATOR,
45
+ REPO_TYPES,
46
+ REPO_TYPES_URL_PREFIXES,
47
+ )
48
+ from .utils import (
49
+ EntryNotFoundError,
50
+ FileMetadataError,
51
+ GatedRepoError,
52
+ LocalEntryNotFoundError,
53
+ OfflineModeIsEnabled,
54
+ RepositoryNotFoundError,
55
+ RevisionNotFoundError,
56
+ SoftTemporaryDirectory,
57
+ build_hf_headers,
58
+ get_fastai_version, # noqa: F401 # for backward compatibility
59
+ get_fastcore_version, # noqa: F401 # for backward compatibility
60
+ get_graphviz_version, # noqa: F401 # for backward compatibility
61
+ get_jinja_version, # noqa: F401 # for backward compatibility
62
+ get_pydot_version, # noqa: F401 # for backward compatibility
63
+ get_session,
64
+ get_tf_version, # noqa: F401 # for backward compatibility
65
+ get_torch_version, # noqa: F401 # for backward compatibility
66
+ hf_raise_for_status,
67
+ is_fastai_available, # noqa: F401 # for backward compatibility
68
+ is_fastcore_available, # noqa: F401 # for backward compatibility
69
+ is_graphviz_available, # noqa: F401 # for backward compatibility
70
+ is_jinja_available, # noqa: F401 # for backward compatibility
71
+ is_pydot_available, # noqa: F401 # for backward compatibility
72
+ is_tf_available, # noqa: F401 # for backward compatibility
73
+ is_torch_available, # noqa: F401 # for backward compatibility
74
+ logging,
75
+ reset_sessions,
76
+ tqdm,
77
+ validate_hf_hub_args,
78
+ )
79
+ from .utils._deprecation import _deprecate_method
80
+ from .utils._headers import _http_user_agent
81
+ from .utils._runtime import _PY_VERSION # noqa: F401 # for backward compatibility
82
+ from .utils._typing import HTTP_METHOD_T
83
+ from .utils.insecure_hashlib import sha256
84
+
85
+
86
+ logger = logging.get_logger(__name__)
87
+
88
+ # Regex to get filename from a "Content-Disposition" header for CDN-served files
89
+ HEADER_FILENAME_PATTERN = re.compile(r'filename="(?P<filename>.*?)";')
90
+
91
+
92
+ _are_symlinks_supported_in_dir: Dict[str, bool] = {}
93
+
94
+
95
+ def are_symlinks_supported(cache_dir: Union[str, Path, None] = None) -> bool:
96
+ """Return whether the symlinks are supported on the machine.
97
+
98
+ Since symlinks support can change depending on the mounted disk, we need to check
99
+ on the precise cache folder. By default, the default HF cache directory is checked.
100
+
101
+ Args:
102
+ cache_dir (`str`, `Path`, *optional*):
103
+ Path to the folder where cached files are stored.
104
+
105
+ Returns: [bool] Whether symlinks are supported in the directory.
106
+ """
107
+ # Defaults to HF cache
108
+ if cache_dir is None:
109
+ cache_dir = HF_HUB_CACHE
110
+ cache_dir = str(Path(cache_dir).expanduser().resolve()) # make it unique
111
+
112
+ # Check symlink compatibility only once (per cache directory) at first time use
113
+ if cache_dir not in _are_symlinks_supported_in_dir:
114
+ _are_symlinks_supported_in_dir[cache_dir] = True
115
+
116
+ os.makedirs(cache_dir, exist_ok=True)
117
+ with SoftTemporaryDirectory(dir=cache_dir) as tmpdir:
118
+ src_path = Path(tmpdir) / "dummy_file_src"
119
+ src_path.touch()
120
+ dst_path = Path(tmpdir) / "dummy_file_dst"
121
+
122
+ # Relative source path as in `_create_symlink``
123
+ relative_src = os.path.relpath(src_path, start=os.path.dirname(dst_path))
124
+ try:
125
+ os.symlink(relative_src, dst_path)
126
+ except OSError:
127
+ # Likely running on Windows
128
+ _are_symlinks_supported_in_dir[cache_dir] = False
129
+
130
+ if not HF_HUB_DISABLE_SYMLINKS_WARNING:
131
+ message = (
132
+ "`huggingface_hub` cache-system uses symlinks by default to"
133
+ " efficiently store duplicated files but your machine does not"
134
+ f" support them in {cache_dir}. Caching files will still work"
135
+ " but in a degraded version that might require more space on"
136
+ " your disk. This warning can be disabled by setting the"
137
+ " `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For"
138
+ " more details, see"
139
+ " https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations."
140
+ )
141
+ if os.name == "nt":
142
+ message += (
143
+ "\nTo support symlinks on Windows, you either need to"
144
+ " activate Developer Mode or to run Python as an"
145
+ " administrator. In order to see activate developer mode,"
146
+ " see this article:"
147
+ " https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development"
148
+ )
149
+ warnings.warn(message)
150
+
151
+ return _are_symlinks_supported_in_dir[cache_dir]
152
+
153
+
154
+ # Return value when trying to load a file from cache but the file does not exist in the distant repo.
155
+ _CACHED_NO_EXIST = object()
156
+ _CACHED_NO_EXIST_T = Any
157
+ REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$")
158
+
159
+
160
+ @dataclass(frozen=True)
161
+ class HfFileMetadata:
162
+ """Data structure containing information about a file versioned on the Hub.
163
+
164
+ Returned by [`get_hf_file_metadata`] based on a URL.
165
+
166
+ Args:
167
+ commit_hash (`str`, *optional*):
168
+ The commit_hash related to the file.
169
+ etag (`str`, *optional*):
170
+ Etag of the file on the server.
171
+ location (`str`):
172
+ Location where to download the file. Can be a Hub url or not (CDN).
173
+ size (`size`):
174
+ Size of the file. In case of an LFS file, contains the size of the actual
175
+ LFS file, not the pointer.
176
+ """
177
+
178
+ commit_hash: Optional[str]
179
+ etag: Optional[str]
180
+ location: str
181
+ size: Optional[int]
182
+
183
+
184
+ @validate_hf_hub_args
185
+ def hf_hub_url(
186
+ repo_id: str,
187
+ filename: str,
188
+ *,
189
+ subfolder: Optional[str] = None,
190
+ repo_type: Optional[str] = None,
191
+ revision: Optional[str] = None,
192
+ endpoint: Optional[str] = None,
193
+ ) -> str:
194
+ """Construct the URL of a file from the given information.
195
+
196
+ The resolved address can either be a huggingface.co-hosted url, or a link to
197
+ Cloudfront (a Content Delivery Network, or CDN) for large files which are
198
+ more than a few MBs.
199
+
200
+ Args:
201
+ repo_id (`str`):
202
+ A namespace (user or an organization) name and a repo name separated
203
+ by a `/`.
204
+ filename (`str`):
205
+ The name of the file in the repo.
206
+ subfolder (`str`, *optional*):
207
+ An optional value corresponding to a folder inside the repo.
208
+ repo_type (`str`, *optional*):
209
+ Set to `"dataset"` or `"space"` if downloading from a dataset or space,
210
+ `None` or `"model"` if downloading from a model. Default is `None`.
211
+ revision (`str`, *optional*):
212
+ An optional Git revision id which can be a branch name, a tag, or a
213
+ commit hash.
214
+
215
+ Example:
216
+
217
+ ```python
218
+ >>> from huggingface_hub import hf_hub_url
219
+
220
+ >>> hf_hub_url(
221
+ ... repo_id="julien-c/EsperBERTo-small", filename="pytorch_model.bin"
222
+ ... )
223
+ 'https://huggingface.co/julien-c/EsperBERTo-small/resolve/main/pytorch_model.bin'
224
+ ```
225
+
226
+ <Tip>
227
+
228
+ Notes:
229
+
230
+ Cloudfront is replicated over the globe so downloads are way faster for
231
+ the end user (and it also lowers our bandwidth costs).
232
+
233
+ Cloudfront aggressively caches files by default (default TTL is 24
234
+ hours), however this is not an issue here because we implement a
235
+ git-based versioning system on huggingface.co, which means that we store
236
+ the files on S3/Cloudfront in a content-addressable way (i.e., the file
237
+ name is its hash). Using content-addressable filenames means cache can't
238
+ ever be stale.
239
+
240
+ In terms of client-side caching from this library, we base our caching
241
+ on the objects' entity tag (`ETag`), which is an identifier of a
242
+ specific version of a resource [1]_. An object's ETag is: its git-sha1
243
+ if stored in git, or its sha256 if stored in git-lfs.
244
+
245
+ </Tip>
246
+
247
+ References:
248
+
249
+ - [1] https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag
250
+ """
251
+ if subfolder == "":
252
+ subfolder = None
253
+ if subfolder is not None:
254
+ filename = f"{subfolder}/{filename}"
255
+
256
+ if repo_type not in REPO_TYPES:
257
+ raise ValueError("Invalid repo type")
258
+
259
+ if repo_type in REPO_TYPES_URL_PREFIXES:
260
+ repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id
261
+
262
+ if revision is None:
263
+ revision = DEFAULT_REVISION
264
+ url = HUGGINGFACE_CO_URL_TEMPLATE.format(
265
+ repo_id=repo_id, revision=quote(revision, safe=""), filename=quote(filename)
266
+ )
267
+ # Update endpoint if provided
268
+ if endpoint is not None and url.startswith(ENDPOINT):
269
+ url = endpoint + url[len(ENDPOINT) :]
270
+ return url
271
+
272
+
273
+ def url_to_filename(url: str, etag: Optional[str] = None) -> str:
274
+ """Generate a local filename from a url.
275
+
276
+ Convert `url` into a hashed filename in a reproducible way. If `etag` is
277
+ specified, append its hash to the url's, delimited by a period. If the url
278
+ ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
279
+ identify it as a HDF5 file (see
280
+ https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
281
+
282
+ Args:
283
+ url (`str`):
284
+ The address to the file.
285
+ etag (`str`, *optional*):
286
+ The ETag of the file.
287
+
288
+ Returns:
289
+ The generated filename.
290
+ """
291
+ url_bytes = url.encode("utf-8")
292
+ filename = sha256(url_bytes).hexdigest()
293
+
294
+ if etag:
295
+ etag_bytes = etag.encode("utf-8")
296
+ filename += "." + sha256(etag_bytes).hexdigest()
297
+
298
+ if url.endswith(".h5"):
299
+ filename += ".h5"
300
+
301
+ return filename
302
+
303
+
304
+ def filename_to_url(
305
+ filename,
306
+ cache_dir: Optional[str] = None,
307
+ legacy_cache_layout: bool = False,
308
+ ) -> Tuple[str, str]:
309
+ """
310
+ Return the url and etag (which may be `None`) stored for `filename`. Raise
311
+ `EnvironmentError` if `filename` or its stored metadata do not exist.
312
+
313
+ Args:
314
+ filename (`str`):
315
+ The name of the file
316
+ cache_dir (`str`, *optional*):
317
+ The cache directory to use instead of the default one.
318
+ legacy_cache_layout (`bool`, *optional*, defaults to `False`):
319
+ If `True`, uses the legacy file cache layout i.e. just call `hf_hub_url`
320
+ then `cached_download`. This is deprecated as the new cache layout is
321
+ more powerful.
322
+ """
323
+ if not legacy_cache_layout:
324
+ warnings.warn(
325
+ "`filename_to_url` uses the legacy way cache file layout",
326
+ FutureWarning,
327
+ )
328
+
329
+ if cache_dir is None:
330
+ cache_dir = HF_HUB_CACHE
331
+ if isinstance(cache_dir, Path):
332
+ cache_dir = str(cache_dir)
333
+
334
+ cache_path = os.path.join(cache_dir, filename)
335
+ if not os.path.exists(cache_path):
336
+ raise EnvironmentError(f"file {cache_path} not found")
337
+
338
+ meta_path = cache_path + ".json"
339
+ if not os.path.exists(meta_path):
340
+ raise EnvironmentError(f"file {meta_path} not found")
341
+
342
+ with open(meta_path, encoding="utf-8") as meta_file:
343
+ metadata = json.load(meta_file)
344
+ url = metadata["url"]
345
+ etag = metadata["etag"]
346
+
347
+ return url, etag
348
+
349
+
350
+ @_deprecate_method(version="0.22.0", message="Use `huggingface_hub.utils.build_hf_headers` instead.")
351
+ def http_user_agent(
352
+ *,
353
+ library_name: Optional[str] = None,
354
+ library_version: Optional[str] = None,
355
+ user_agent: Union[Dict, str, None] = None,
356
+ ) -> str:
357
+ """Deprecated in favor of [`build_hf_headers`]."""
358
+ return _http_user_agent(
359
+ library_name=library_name,
360
+ library_version=library_version,
361
+ user_agent=user_agent,
362
+ )
363
+
364
+
365
+ def _request_wrapper(
366
+ method: HTTP_METHOD_T, url: str, *, follow_relative_redirects: bool = False, **params
367
+ ) -> requests.Response:
368
+ """Wrapper around requests methods to follow relative redirects if `follow_relative_redirects=True` even when
369
+ `allow_redirection=False`.
370
+
371
+ Args:
372
+ method (`str`):
373
+ HTTP method, such as 'GET' or 'HEAD'.
374
+ url (`str`):
375
+ The URL of the resource to fetch.
376
+ follow_relative_redirects (`bool`, *optional*, defaults to `False`)
377
+ If True, relative redirection (redirection to the same site) will be resolved even when `allow_redirection`
378
+ kwarg is set to False. Useful when we want to follow a redirection to a renamed repository without
379
+ following redirection to a CDN.
380
+ **params (`dict`, *optional*):
381
+ Params to pass to `requests.request`.
382
+ """
383
+ # Recursively follow relative redirects
384
+ if follow_relative_redirects:
385
+ response = _request_wrapper(
386
+ method=method,
387
+ url=url,
388
+ follow_relative_redirects=False,
389
+ **params,
390
+ )
391
+
392
+ # If redirection, we redirect only relative paths.
393
+ # This is useful in case of a renamed repository.
394
+ if 300 <= response.status_code <= 399:
395
+ parsed_target = urlparse(response.headers["Location"])
396
+ if parsed_target.netloc == "":
397
+ # This means it is a relative 'location' headers, as allowed by RFC 7231.
398
+ # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
399
+ # We want to follow this relative redirect !
400
+ #
401
+ # Highly inspired by `resolve_redirects` from requests library.
402
+ # See https://github.com/psf/requests/blob/main/requests/sessions.py#L159
403
+ next_url = urlparse(url)._replace(path=parsed_target.path).geturl()
404
+ return _request_wrapper(method=method, url=next_url, follow_relative_redirects=True, **params)
405
+ return response
406
+
407
+ # Perform request and return if status_code is not in the retry list.
408
+ response = get_session().request(method=method, url=url, **params)
409
+ hf_raise_for_status(response)
410
+ return response
411
+
412
+
413
+ def http_get(
414
+ url: str,
415
+ temp_file: BinaryIO,
416
+ *,
417
+ proxies=None,
418
+ resume_size: float = 0,
419
+ headers: Optional[Dict[str, str]] = None,
420
+ expected_size: Optional[int] = None,
421
+ _nb_retries: int = 5,
422
+ ):
423
+ """
424
+ Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
425
+
426
+ If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely a
427
+ transient error (network outage?). We log a warning message and try to resume the download a few times before
428
+ giving up. The method gives up after 5 attempts if no new data has being received from the server.
429
+ """
430
+ hf_transfer = None
431
+ if HF_HUB_ENABLE_HF_TRANSFER:
432
+ if resume_size != 0:
433
+ warnings.warn("'hf_transfer' does not support `resume_size`: falling back to regular download method")
434
+ elif proxies is not None:
435
+ warnings.warn("'hf_transfer' does not support `proxies`: falling back to regular download method")
436
+ else:
437
+ try:
438
+ import hf_transfer # type: ignore[no-redef]
439
+ except ImportError:
440
+ raise ValueError(
441
+ "Fast download using 'hf_transfer' is enabled"
442
+ " (HF_HUB_ENABLE_HF_TRANSFER=1) but 'hf_transfer' package is not"
443
+ " available in your environment. Try `pip install hf_transfer`."
444
+ )
445
+
446
+ initial_headers = headers
447
+ headers = copy.deepcopy(headers) or {}
448
+ if resume_size > 0:
449
+ headers["Range"] = "bytes=%d-" % (resume_size,)
450
+
451
+ r = _request_wrapper(
452
+ method="GET", url=url, stream=True, proxies=proxies, headers=headers, timeout=HF_HUB_DOWNLOAD_TIMEOUT
453
+ )
454
+ hf_raise_for_status(r)
455
+ content_length = r.headers.get("Content-Length")
456
+
457
+ # NOTE: 'total' is the total number of bytes to download, not the number of bytes in the file.
458
+ # If the file is compressed, the number of bytes in the saved file will be higher than 'total'.
459
+ total = resume_size + int(content_length) if content_length is not None else None
460
+
461
+ displayed_name = url
462
+ content_disposition = r.headers.get("Content-Disposition")
463
+ if content_disposition is not None:
464
+ match = HEADER_FILENAME_PATTERN.search(content_disposition)
465
+ if match is not None:
466
+ # Means file is on CDN
467
+ displayed_name = match.groupdict()["filename"]
468
+
469
+ # Truncate filename if too long to display
470
+ if len(displayed_name) > 40:
471
+ displayed_name = f"(…){displayed_name[-40:]}"
472
+
473
+ consistency_error_message = (
474
+ f"Consistency check failed: file should be of size {expected_size} but has size"
475
+ f" {{actual_size}} ({displayed_name}).\nWe are sorry for the inconvenience. Please retry download and"
476
+ " pass `force_download=True, resume_download=False` as argument.\nIf the issue persists, please let us"
477
+ " know by opening an issue on https://github.com/huggingface/huggingface_hub."
478
+ )
479
+
480
+ # Stream file to buffer
481
+ with tqdm(
482
+ unit="B",
483
+ unit_scale=True,
484
+ total=total,
485
+ initial=resume_size,
486
+ desc=displayed_name,
487
+ disable=bool(logger.getEffectiveLevel() == logging.NOTSET),
488
+ ) as progress:
489
+ if hf_transfer and total is not None and total > 5 * DOWNLOAD_CHUNK_SIZE:
490
+ supports_callback = "callback" in inspect.signature(hf_transfer.download).parameters
491
+ if not supports_callback:
492
+ warnings.warn(
493
+ "You are using an outdated version of `hf_transfer`. "
494
+ "Consider upgrading to latest version to enable progress bars "
495
+ "using `pip install -U hf_transfer`."
496
+ )
497
+ try:
498
+ hf_transfer.download(
499
+ url=url,
500
+ filename=temp_file.name,
501
+ max_files=HF_TRANSFER_CONCURRENCY,
502
+ chunk_size=DOWNLOAD_CHUNK_SIZE,
503
+ headers=headers,
504
+ parallel_failures=3,
505
+ max_retries=5,
506
+ **({"callback": progress.update} if supports_callback else {}),
507
+ )
508
+ except Exception as e:
509
+ raise RuntimeError(
510
+ "An error occurred while downloading using `hf_transfer`. Consider"
511
+ " disabling HF_HUB_ENABLE_HF_TRANSFER for better error handling."
512
+ ) from e
513
+ if not supports_callback:
514
+ progress.update(total)
515
+ if expected_size is not None and expected_size != os.path.getsize(temp_file.name):
516
+ raise EnvironmentError(
517
+ consistency_error_message.format(
518
+ actual_size=os.path.getsize(temp_file.name),
519
+ )
520
+ )
521
+ return
522
+ new_resume_size = resume_size
523
+ try:
524
+ for chunk in r.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
525
+ if chunk: # filter out keep-alive new chunks
526
+ progress.update(len(chunk))
527
+ temp_file.write(chunk)
528
+ new_resume_size += len(chunk)
529
+ # Some data has been downloaded from the server so we reset the number of retries.
530
+ _nb_retries = 5
531
+ except (requests.ConnectionError, requests.ReadTimeout) as e:
532
+ # If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely
533
+ # a transient error (network outage?). We log a warning message and try to resume the download a few times
534
+ # before giving up. Tre retry mechanism is basic but should be enough in most cases.
535
+ if _nb_retries <= 0:
536
+ logger.warning("Error while downloading from %s: %s\nMax retries exceeded.", url, str(e))
537
+ raise
538
+ logger.warning("Error while downloading from %s: %s\nTrying to resume download...", url, str(e))
539
+ time.sleep(1)
540
+ reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects
541
+ return http_get(
542
+ url=url,
543
+ temp_file=temp_file,
544
+ proxies=proxies,
545
+ resume_size=new_resume_size,
546
+ headers=initial_headers,
547
+ expected_size=expected_size,
548
+ _nb_retries=_nb_retries - 1,
549
+ )
550
+
551
+ if expected_size is not None and expected_size != temp_file.tell():
552
+ raise EnvironmentError(
553
+ consistency_error_message.format(
554
+ actual_size=temp_file.tell(),
555
+ )
556
+ )
557
+
558
+
559
+ @validate_hf_hub_args
560
+ def cached_download(
561
+ url: str,
562
+ *,
563
+ library_name: Optional[str] = None,
564
+ library_version: Optional[str] = None,
565
+ cache_dir: Union[str, Path, None] = None,
566
+ user_agent: Union[Dict, str, None] = None,
567
+ force_download: bool = False,
568
+ force_filename: Optional[str] = None,
569
+ proxies: Optional[Dict] = None,
570
+ etag_timeout: float = DEFAULT_ETAG_TIMEOUT,
571
+ resume_download: bool = False,
572
+ token: Union[bool, str, None] = None,
573
+ local_files_only: bool = False,
574
+ legacy_cache_layout: bool = False,
575
+ ) -> str:
576
+ """
577
+ Download from a given URL and cache it if it's not already present in the
578
+ local cache.
579
+
580
+ Given a URL, this function looks for the corresponding file in the local
581
+ cache. If it's not there, download it. Then return the path to the cached
582
+ file.
583
+
584
+ Will raise errors tailored to the Hugging Face Hub.
585
+
586
+ Args:
587
+ url (`str`):
588
+ The path to the file to be downloaded.
589
+ library_name (`str`, *optional*):
590
+ The name of the library to which the object corresponds.
591
+ library_version (`str`, *optional*):
592
+ The version of the library.
593
+ cache_dir (`str`, `Path`, *optional*):
594
+ Path to the folder where cached files are stored.
595
+ user_agent (`dict`, `str`, *optional*):
596
+ The user-agent info in the form of a dictionary or a string.
597
+ force_download (`bool`, *optional*, defaults to `False`):
598
+ Whether the file should be downloaded even if it already exists in
599
+ the local cache.
600
+ force_filename (`str`, *optional*):
601
+ Use this name instead of a generated file name.
602
+ proxies (`dict`, *optional*):
603
+ Dictionary mapping protocol to the URL of the proxy passed to
604
+ `requests.request`.
605
+ etag_timeout (`float`, *optional* defaults to `10`):
606
+ When fetching ETag, how many seconds to wait for the server to send
607
+ data before giving up which is passed to `requests.request`.
608
+ resume_download (`bool`, *optional*, defaults to `False`):
609
+ If `True`, resume a previously interrupted download.
610
+ token (`bool`, `str`, *optional*):
611
+ A token to be used for the download.
612
+ - If `True`, the token is read from the HuggingFace config
613
+ folder.
614
+ - If a string, it's used as the authentication token.
615
+ local_files_only (`bool`, *optional*, defaults to `False`):
616
+ If `True`, avoid downloading the file and return the path to the
617
+ local cached file if it exists.
618
+ legacy_cache_layout (`bool`, *optional*, defaults to `False`):
619
+ Set this parameter to `True` to mention that you'd like to continue
620
+ the old cache layout. Putting this to `True` manually will not raise
621
+ any warning when using `cached_download`. We recommend using
622
+ `hf_hub_download` to take advantage of the new cache.
623
+
624
+ Returns:
625
+ Local path (string) of file or if networking is off, last version of
626
+ file cached on disk.
627
+
628
+ <Tip>
629
+
630
+ Raises the following errors:
631
+
632
+ - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
633
+ if `token=True` and the token cannot be found.
634
+ - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
635
+ if ETag cannot be determined.
636
+ - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
637
+ if some parameter value is invalid
638
+ - [`~utils.RepositoryNotFoundError`]
639
+ If the repository to download from cannot be found. This may be because it doesn't exist,
640
+ or because it is set to `private` and you do not have access.
641
+ - [`~utils.RevisionNotFoundError`]
642
+ If the revision to download from cannot be found.
643
+ - [`~utils.EntryNotFoundError`]
644
+ If the file to download cannot be found.
645
+ - [`~utils.LocalEntryNotFoundError`]
646
+ If network is disabled or unavailable and file is not found in cache.
647
+
648
+ </Tip>
649
+ """
650
+ if HF_HUB_ETAG_TIMEOUT != DEFAULT_ETAG_TIMEOUT:
651
+ # Respect environment variable above user value
652
+ etag_timeout = HF_HUB_ETAG_TIMEOUT
653
+
654
+ if not legacy_cache_layout:
655
+ warnings.warn(
656
+ "'cached_download' is the legacy way to download files from the HF hub, please consider upgrading to"
657
+ " 'hf_hub_download'",
658
+ FutureWarning,
659
+ )
660
+
661
+ if cache_dir is None:
662
+ cache_dir = HF_HUB_CACHE
663
+ if isinstance(cache_dir, Path):
664
+ cache_dir = str(cache_dir)
665
+
666
+ os.makedirs(cache_dir, exist_ok=True)
667
+
668
+ headers = build_hf_headers(
669
+ token=token,
670
+ library_name=library_name,
671
+ library_version=library_version,
672
+ user_agent=user_agent,
673
+ )
674
+
675
+ url_to_download = url
676
+ etag = None
677
+ expected_size = None
678
+ if not local_files_only:
679
+ try:
680
+ # Temporary header: we want the full (decompressed) content size returned to be able to check the
681
+ # downloaded file size
682
+ headers["Accept-Encoding"] = "identity"
683
+ r = _request_wrapper(
684
+ method="HEAD",
685
+ url=url,
686
+ headers=headers,
687
+ allow_redirects=False,
688
+ follow_relative_redirects=True,
689
+ proxies=proxies,
690
+ timeout=etag_timeout,
691
+ )
692
+ headers.pop("Accept-Encoding", None)
693
+ hf_raise_for_status(r)
694
+ etag = r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag")
695
+ # We favor a custom header indicating the etag of the linked resource, and
696
+ # we fallback to the regular etag header.
697
+ # If we don't have any of those, raise an error.
698
+ if etag is None:
699
+ raise FileMetadataError(
700
+ "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
701
+ )
702
+ # We get the expected size of the file, to check the download went well.
703
+ expected_size = _int_or_none(r.headers.get("Content-Length"))
704
+ # In case of a redirect, save an extra redirect on the request.get call,
705
+ # and ensure we download the exact atomic version even if it changed
706
+ # between the HEAD and the GET (unlikely, but hey).
707
+ # Useful for lfs blobs that are stored on a CDN.
708
+ if 300 <= r.status_code <= 399:
709
+ url_to_download = r.headers["Location"]
710
+ headers.pop("authorization", None)
711
+ expected_size = None # redirected -> can't know the expected size
712
+ except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
713
+ # Actually raise for those subclasses of ConnectionError
714
+ raise
715
+ except (
716
+ requests.exceptions.ConnectionError,
717
+ requests.exceptions.Timeout,
718
+ OfflineModeIsEnabled,
719
+ ):
720
+ # Otherwise, our Internet connection is down.
721
+ # etag is None
722
+ pass
723
+
724
+ filename = force_filename if force_filename is not None else url_to_filename(url, etag)
725
+
726
+ # get cache path to put the file
727
+ cache_path = os.path.join(cache_dir, filename)
728
+
729
+ # etag is None == we don't have a connection or we passed local_files_only.
730
+ # try to get the last downloaded one
731
+ if etag is None:
732
+ if os.path.exists(cache_path) and not force_download:
733
+ return cache_path
734
+ else:
735
+ matching_files = [
736
+ file
737
+ for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
738
+ if not file.endswith(".json") and not file.endswith(".lock")
739
+ ]
740
+ if len(matching_files) > 0 and not force_download and force_filename is None:
741
+ return os.path.join(cache_dir, matching_files[-1])
742
+ else:
743
+ # If files cannot be found and local_files_only=True,
744
+ # the models might've been found if local_files_only=False
745
+ # Notify the user about that
746
+ if local_files_only:
747
+ raise LocalEntryNotFoundError(
748
+ "Cannot find the requested files in the cached path and"
749
+ " outgoing traffic has been disabled. To enable model look-ups"
750
+ " and downloads online, set 'local_files_only' to False."
751
+ )
752
+ else:
753
+ raise LocalEntryNotFoundError(
754
+ "Connection error, and we cannot find the requested files in"
755
+ " the cached path. Please try again or make sure your Internet"
756
+ " connection is on."
757
+ )
758
+
759
+ # From now on, etag is not None.
760
+ if os.path.exists(cache_path) and not force_download:
761
+ return cache_path
762
+
763
+ # Prevent parallel downloads of the same file with a lock.
764
+ lock_path = cache_path + ".lock"
765
+
766
+ # Some Windows versions do not allow for paths longer than 255 characters.
767
+ # In this case, we must specify it is an extended path by using the "\\?\" prefix.
768
+ if os.name == "nt" and len(os.path.abspath(lock_path)) > 255:
769
+ lock_path = "\\\\?\\" + os.path.abspath(lock_path)
770
+
771
+ if os.name == "nt" and len(os.path.abspath(cache_path)) > 255:
772
+ cache_path = "\\\\?\\" + os.path.abspath(cache_path)
773
+
774
+ with FileLock(lock_path):
775
+ # If the download just completed while the lock was activated.
776
+ if os.path.exists(cache_path) and not force_download:
777
+ # Even if returning early like here, the lock will be released.
778
+ return cache_path
779
+
780
+ if resume_download:
781
+ incomplete_path = cache_path + ".incomplete"
782
+
783
+ @contextmanager
784
+ def _resumable_file_manager() -> Generator[io.BufferedWriter, None, None]:
785
+ with open(incomplete_path, "ab") as f:
786
+ yield f
787
+
788
+ temp_file_manager = _resumable_file_manager
789
+ if os.path.exists(incomplete_path):
790
+ resume_size = os.stat(incomplete_path).st_size
791
+ else:
792
+ resume_size = 0
793
+ else:
794
+ temp_file_manager = partial( # type: ignore
795
+ tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False
796
+ )
797
+ resume_size = 0
798
+
799
+ # Download to temporary file, then copy to cache dir once finished.
800
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
801
+ with temp_file_manager() as temp_file:
802
+ logger.info("downloading %s to %s", url, temp_file.name)
803
+
804
+ http_get(
805
+ url_to_download,
806
+ temp_file,
807
+ proxies=proxies,
808
+ resume_size=resume_size,
809
+ headers=headers,
810
+ expected_size=expected_size,
811
+ )
812
+
813
+ logger.info("storing %s in cache at %s", url, cache_path)
814
+ _chmod_and_replace(temp_file.name, cache_path)
815
+
816
+ if force_filename is None:
817
+ logger.info("creating metadata file for %s", cache_path)
818
+ meta = {"url": url, "etag": etag}
819
+ meta_path = cache_path + ".json"
820
+ with open(meta_path, "w") as meta_file:
821
+ json.dump(meta, meta_file)
822
+
823
+ return cache_path
824
+
825
+
826
+ def _normalize_etag(etag: Optional[str]) -> Optional[str]:
827
+ """Normalize ETag HTTP header, so it can be used to create nice filepaths.
828
+
829
+ The HTTP spec allows two forms of ETag:
830
+ ETag: W/"<etag_value>"
831
+ ETag: "<etag_value>"
832
+
833
+ For now, we only expect the second form from the server, but we want to be future-proof so we support both. For
834
+ more context, see `TestNormalizeEtag` tests and https://github.com/huggingface/huggingface_hub/pull/1428.
835
+
836
+ Args:
837
+ etag (`str`, *optional*): HTTP header
838
+
839
+ Returns:
840
+ `str` or `None`: string that can be used as a nice directory name.
841
+ Returns `None` if input is None.
842
+ """
843
+ if etag is None:
844
+ return None
845
+ return etag.lstrip("W/").strip('"')
846
+
847
+
848
+ def _create_relative_symlink(src: str, dst: str, new_blob: bool = False) -> None:
849
+ """Alias method used in `transformers` conversion script."""
850
+ return _create_symlink(src=src, dst=dst, new_blob=new_blob)
851
+
852
+
853
+ def _create_symlink(src: str, dst: str, new_blob: bool = False) -> None:
854
+ """Create a symbolic link named dst pointing to src.
855
+
856
+ By default, it will try to create a symlink using a relative path. Relative paths have 2 advantages:
857
+ - If the cache_folder is moved (example: back-up on a shared drive), relative paths within the cache folder will
858
+ not brake.
859
+ - Relative paths seems to be better handled on Windows. Issue was reported 3 times in less than a week when
860
+ changing from relative to absolute paths. See https://github.com/huggingface/huggingface_hub/issues/1398,
861
+ https://github.com/huggingface/diffusers/issues/2729 and https://github.com/huggingface/transformers/pull/22228.
862
+ NOTE: The issue with absolute paths doesn't happen on admin mode.
863
+ When creating a symlink from the cache to a local folder, it is possible that a relative path cannot be created.
864
+ This happens when paths are not on the same volume. In that case, we use absolute paths.
865
+
866
+
867
+ The result layout looks something like
868
+ └── [ 128] snapshots
869
+ ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f
870
+ │ ├── [ 52] README.md -> ../../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812
871
+ │ └── [ 76] pytorch_model.bin -> ../../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
872
+
873
+ If symlinks cannot be created on this platform (most likely to be Windows), the workaround is to avoid symlinks by
874
+ having the actual file in `dst`. If it is a new file (`new_blob=True`), we move it to `dst`. If it is not a new file
875
+ (`new_blob=False`), we don't know if the blob file is already referenced elsewhere. To avoid breaking existing
876
+ cache, the file is duplicated on the disk.
877
+
878
+ In case symlinks are not supported, a warning message is displayed to the user once when loading `huggingface_hub`.
879
+ The warning message can be disable with the `DISABLE_SYMLINKS_WARNING` environment variable.
880
+ """
881
+ try:
882
+ os.remove(dst)
883
+ except OSError:
884
+ pass
885
+
886
+ abs_src = os.path.abspath(os.path.expanduser(src))
887
+ abs_dst = os.path.abspath(os.path.expanduser(dst))
888
+ abs_dst_folder = os.path.dirname(abs_dst)
889
+
890
+ # Use relative_dst in priority
891
+ try:
892
+ relative_src = os.path.relpath(abs_src, abs_dst_folder)
893
+ except ValueError:
894
+ # Raised on Windows if src and dst are not on the same volume. This is the case when creating a symlink to a
895
+ # local_dir instead of within the cache directory.
896
+ # See https://docs.python.org/3/library/os.path.html#os.path.relpath
897
+ relative_src = None
898
+
899
+ try:
900
+ commonpath = os.path.commonpath([abs_src, abs_dst])
901
+ _support_symlinks = are_symlinks_supported(commonpath)
902
+ except ValueError:
903
+ # Raised if src and dst are not on the same volume. Symlinks will still work on Linux/Macos.
904
+ # See https://docs.python.org/3/library/os.path.html#os.path.commonpath
905
+ _support_symlinks = os.name != "nt"
906
+ except PermissionError:
907
+ # Permission error means src and dst are not in the same volume (e.g. destination path has been provided
908
+ # by the user via `local_dir`. Let's test symlink support there)
909
+ _support_symlinks = are_symlinks_supported(abs_dst_folder)
910
+
911
+ # Symlinks are supported => let's create a symlink.
912
+ if _support_symlinks:
913
+ src_rel_or_abs = relative_src or abs_src
914
+ logger.debug(f"Creating pointer from {src_rel_or_abs} to {abs_dst}")
915
+ try:
916
+ os.symlink(src_rel_or_abs, abs_dst)
917
+ return
918
+ except FileExistsError:
919
+ if os.path.islink(abs_dst) and os.path.realpath(abs_dst) == os.path.realpath(abs_src):
920
+ # `abs_dst` already exists and is a symlink to the `abs_src` blob. It is most likely that the file has
921
+ # been cached twice concurrently (exactly between `os.remove` and `os.symlink`). Do nothing.
922
+ return
923
+ else:
924
+ # Very unlikely to happen. Means a file `dst` has been created exactly between `os.remove` and
925
+ # `os.symlink` and is not a symlink to the `abs_src` blob file. Raise exception.
926
+ raise
927
+ except PermissionError:
928
+ # Permission error means src and dst are not in the same volume (e.g. download to local dir) and symlink
929
+ # is supported on both volumes but not between them. Let's just make a hard copy in that case.
930
+ pass
931
+
932
+ # Symlinks are not supported => let's move or copy the file.
933
+ if new_blob:
934
+ logger.info(f"Symlink not supported. Moving file from {abs_src} to {abs_dst}")
935
+ shutil.move(abs_src, abs_dst)
936
+ else:
937
+ logger.info(f"Symlink not supported. Copying file from {abs_src} to {abs_dst}")
938
+ shutil.copyfile(abs_src, abs_dst)
939
+
940
+
941
+ def _cache_commit_hash_for_specific_revision(storage_folder: str, revision: str, commit_hash: str) -> None:
942
+ """Cache reference between a revision (tag, branch or truncated commit hash) and the corresponding commit hash.
943
+
944
+ Does nothing if `revision` is already a proper `commit_hash` or reference is already cached.
945
+ """
946
+ if revision != commit_hash:
947
+ ref_path = Path(storage_folder) / "refs" / revision
948
+ ref_path.parent.mkdir(parents=True, exist_ok=True)
949
+ if not ref_path.exists() or commit_hash != ref_path.read_text():
950
+ # Update ref only if has been updated. Could cause useless error in case
951
+ # repo is already cached and user doesn't have write access to cache folder.
952
+ # See https://github.com/huggingface/huggingface_hub/issues/1216.
953
+ ref_path.write_text(commit_hash)
954
+
955
+
956
+ @validate_hf_hub_args
957
+ def repo_folder_name(*, repo_id: str, repo_type: str) -> str:
958
+ """Return a serialized version of a hf.co repo name and type, safe for disk storage
959
+ as a single non-nested folder.
960
+
961
+ Example: models--julien-c--EsperBERTo-small
962
+ """
963
+ # remove all `/` occurrences to correctly convert repo to directory name
964
+ parts = [f"{repo_type}s", *repo_id.split("/")]
965
+ return REPO_ID_SEPARATOR.join(parts)
966
+
967
+
968
+ def _check_disk_space(expected_size: int, target_dir: Union[str, Path]) -> None:
969
+ """Check disk usage and log a warning if there is not enough disk space to download the file.
970
+
971
+ Args:
972
+ expected_size (`int`):
973
+ The expected size of the file in bytes.
974
+ target_dir (`str`):
975
+ The directory where the file will be stored after downloading.
976
+ """
977
+
978
+ target_dir = Path(target_dir) # format as `Path`
979
+ for path in [target_dir] + list(target_dir.parents): # first check target_dir, then each parents one by one
980
+ try:
981
+ target_dir_free = shutil.disk_usage(path).free
982
+ if target_dir_free < expected_size:
983
+ warnings.warn(
984
+ "Not enough free disk space to download the file. "
985
+ f"The expected file size is: {expected_size / 1e6:.2f} MB. "
986
+ f"The target location {target_dir} only has {target_dir_free / 1e6:.2f} MB free disk space."
987
+ )
988
+ return
989
+ except OSError: # raise on anything: file does not exist or space disk cannot be checked
990
+ pass
991
+
992
+
993
+ @validate_hf_hub_args
994
+ def hf_hub_download(
995
+ repo_id: str,
996
+ filename: str,
997
+ *,
998
+ subfolder: Optional[str] = None,
999
+ repo_type: Optional[str] = None,
1000
+ revision: Optional[str] = None,
1001
+ library_name: Optional[str] = None,
1002
+ library_version: Optional[str] = None,
1003
+ cache_dir: Union[str, Path, None] = None,
1004
+ local_dir: Union[str, Path, None] = None,
1005
+ local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
1006
+ user_agent: Union[Dict, str, None] = None,
1007
+ force_download: bool = False,
1008
+ force_filename: Optional[str] = None,
1009
+ proxies: Optional[Dict] = None,
1010
+ etag_timeout: float = DEFAULT_ETAG_TIMEOUT,
1011
+ resume_download: bool = False,
1012
+ token: Union[bool, str, None] = None,
1013
+ local_files_only: bool = False,
1014
+ legacy_cache_layout: bool = False,
1015
+ endpoint: Optional[str] = None,
1016
+ ) -> str:
1017
+ """Download a given file if it's not already present in the local cache.
1018
+
1019
+ The new cache file layout looks like this:
1020
+ - The cache directory contains one subfolder per repo_id (namespaced by repo type)
1021
+ - inside each repo folder:
1022
+ - refs is a list of the latest known revision => commit_hash pairs
1023
+ - blobs contains the actual file blobs (identified by their git-sha or sha256, depending on
1024
+ whether they're LFS files or not)
1025
+ - snapshots contains one subfolder per commit, each "commit" contains the subset of the files
1026
+ that have been resolved at that particular commit. Each filename is a symlink to the blob
1027
+ at that particular commit.
1028
+
1029
+ If `local_dir` is provided, the file structure from the repo will be replicated in this location. You can configure
1030
+ how you want to move those files:
1031
+ - If `local_dir_use_symlinks="auto"` (default), files are downloaded and stored in the cache directory as blob
1032
+ files. Small files (<5MB) are duplicated in `local_dir` while a symlink is created for bigger files. The goal
1033
+ is to be able to manually edit and save small files without corrupting the cache while saving disk space for
1034
+ binary files. The 5MB threshold can be configured with the `HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD`
1035
+ environment variable.
1036
+ - If `local_dir_use_symlinks=True`, files are downloaded, stored in the cache directory and symlinked in `local_dir`.
1037
+ This is optimal in term of disk usage but files must not be manually edited.
1038
+ - If `local_dir_use_symlinks=False` and the blob files exist in the cache directory, they are duplicated in the
1039
+ local dir. This means disk usage is not optimized.
1040
+ - Finally, if `local_dir_use_symlinks=False` and the blob files do not exist in the cache directory, then the
1041
+ files are downloaded and directly placed under `local_dir`. This means if you need to download them again later,
1042
+ they will be re-downloaded entirely.
1043
+
1044
+ ```
1045
+ [ 96] .
1046
+ └── [ 160] models--julien-c--EsperBERTo-small
1047
+ ├── [ 160] blobs
1048
+ │ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
1049
+ │ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e
1050
+ │ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812
1051
+ ├── [ 96] refs
1052
+ │ └── [ 40] main
1053
+ └── [ 128] snapshots
1054
+ ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f
1055
+ │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812
1056
+ │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
1057
+ └── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48
1058
+ ├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e
1059
+ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
1060
+ ```
1061
+
1062
+ Args:
1063
+ repo_id (`str`):
1064
+ A user or an organization name and a repo name separated by a `/`.
1065
+ filename (`str`):
1066
+ The name of the file in the repo.
1067
+ subfolder (`str`, *optional*):
1068
+ An optional value corresponding to a folder inside the model repo.
1069
+ repo_type (`str`, *optional*):
1070
+ Set to `"dataset"` or `"space"` if downloading from a dataset or space,
1071
+ `None` or `"model"` if downloading from a model. Default is `None`.
1072
+ revision (`str`, *optional*):
1073
+ An optional Git revision id which can be a branch name, a tag, or a
1074
+ commit hash.
1075
+ library_name (`str`, *optional*):
1076
+ The name of the library to which the object corresponds.
1077
+ library_version (`str`, *optional*):
1078
+ The version of the library.
1079
+ cache_dir (`str`, `Path`, *optional*):
1080
+ Path to the folder where cached files are stored.
1081
+ local_dir (`str` or `Path`, *optional*):
1082
+ If provided, the downloaded file will be placed under this directory, either as a symlink (default) or
1083
+ a regular file (see description for more details).
1084
+ local_dir_use_symlinks (`"auto"` or `bool`, defaults to `"auto"`):
1085
+ To be used with `local_dir`. If set to "auto", the cache directory will be used and the file will be either
1086
+ duplicated or symlinked to the local directory depending on its size. It set to `True`, a symlink will be
1087
+ created, no matter the file size. If set to `False`, the file will either be duplicated from cache (if
1088
+ already exists) or downloaded from the Hub and not cached. See description for more details.
1089
+ user_agent (`dict`, `str`, *optional*):
1090
+ The user-agent info in the form of a dictionary or a string.
1091
+ force_download (`bool`, *optional*, defaults to `False`):
1092
+ Whether the file should be downloaded even if it already exists in
1093
+ the local cache.
1094
+ proxies (`dict`, *optional*):
1095
+ Dictionary mapping protocol to the URL of the proxy passed to
1096
+ `requests.request`.
1097
+ etag_timeout (`float`, *optional*, defaults to `10`):
1098
+ When fetching ETag, how many seconds to wait for the server to send
1099
+ data before giving up which is passed to `requests.request`.
1100
+ resume_download (`bool`, *optional*, defaults to `False`):
1101
+ If `True`, resume a previously interrupted download.
1102
+ token (`str`, `bool`, *optional*):
1103
+ A token to be used for the download.
1104
+ - If `True`, the token is read from the HuggingFace config
1105
+ folder.
1106
+ - If a string, it's used as the authentication token.
1107
+ local_files_only (`bool`, *optional*, defaults to `False`):
1108
+ If `True`, avoid downloading the file and return the path to the
1109
+ local cached file if it exists.
1110
+ legacy_cache_layout (`bool`, *optional*, defaults to `False`):
1111
+ If `True`, uses the legacy file cache layout i.e. just call [`hf_hub_url`]
1112
+ then `cached_download`. This is deprecated as the new cache layout is
1113
+ more powerful.
1114
+
1115
+ Returns:
1116
+ Local path (string) of file or if networking is off, last version of
1117
+ file cached on disk.
1118
+
1119
+ <Tip>
1120
+
1121
+ Raises the following errors:
1122
+
1123
+ - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
1124
+ if `token=True` and the token cannot be found.
1125
+ - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
1126
+ if ETag cannot be determined.
1127
+ - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
1128
+ if some parameter value is invalid
1129
+ - [`~utils.RepositoryNotFoundError`]
1130
+ If the repository to download from cannot be found. This may be because it doesn't exist,
1131
+ or because it is set to `private` and you do not have access.
1132
+ - [`~utils.RevisionNotFoundError`]
1133
+ If the revision to download from cannot be found.
1134
+ - [`~utils.EntryNotFoundError`]
1135
+ If the file to download cannot be found.
1136
+ - [`~utils.LocalEntryNotFoundError`]
1137
+ If network is disabled or unavailable and file is not found in cache.
1138
+
1139
+ </Tip>
1140
+ """
1141
+ if HF_HUB_ETAG_TIMEOUT != DEFAULT_ETAG_TIMEOUT:
1142
+ # Respect environment variable above user value
1143
+ etag_timeout = HF_HUB_ETAG_TIMEOUT
1144
+
1145
+ if force_filename is not None:
1146
+ warnings.warn(
1147
+ "The `force_filename` parameter is deprecated as a new caching system, "
1148
+ "which keeps the filenames as they are on the Hub, is now in place.",
1149
+ FutureWarning,
1150
+ )
1151
+ legacy_cache_layout = True
1152
+
1153
+ if legacy_cache_layout:
1154
+ url = hf_hub_url(
1155
+ repo_id,
1156
+ filename,
1157
+ subfolder=subfolder,
1158
+ repo_type=repo_type,
1159
+ revision=revision,
1160
+ endpoint=endpoint,
1161
+ )
1162
+
1163
+ return cached_download(
1164
+ url,
1165
+ library_name=library_name,
1166
+ library_version=library_version,
1167
+ cache_dir=cache_dir,
1168
+ user_agent=user_agent,
1169
+ force_download=force_download,
1170
+ force_filename=force_filename,
1171
+ proxies=proxies,
1172
+ etag_timeout=etag_timeout,
1173
+ resume_download=resume_download,
1174
+ token=token,
1175
+ local_files_only=local_files_only,
1176
+ legacy_cache_layout=legacy_cache_layout,
1177
+ )
1178
+
1179
+ if cache_dir is None:
1180
+ cache_dir = HF_HUB_CACHE
1181
+ if revision is None:
1182
+ revision = DEFAULT_REVISION
1183
+ if isinstance(cache_dir, Path):
1184
+ cache_dir = str(cache_dir)
1185
+ if isinstance(local_dir, Path):
1186
+ local_dir = str(local_dir)
1187
+ locks_dir = os.path.join(cache_dir, ".locks")
1188
+
1189
+ if subfolder == "":
1190
+ subfolder = None
1191
+ if subfolder is not None:
1192
+ # This is used to create a URL, and not a local path, hence the forward slash.
1193
+ filename = f"{subfolder}/{filename}"
1194
+
1195
+ if repo_type is None:
1196
+ repo_type = "model"
1197
+ if repo_type not in REPO_TYPES:
1198
+ raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}")
1199
+
1200
+ storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
1201
+ os.makedirs(storage_folder, exist_ok=True)
1202
+
1203
+ # cross platform transcription of filename, to be used as a local file path.
1204
+ relative_filename = os.path.join(*filename.split("/"))
1205
+ if os.name == "nt":
1206
+ if relative_filename.startswith("..\\") or "\\..\\" in relative_filename:
1207
+ raise ValueError(
1208
+ f"Invalid filename: cannot handle filename '{relative_filename}' on Windows. Please ask the repository"
1209
+ " owner to rename this file."
1210
+ )
1211
+
1212
+ # if user provides a commit_hash and they already have the file on disk,
1213
+ # shortcut everything.
1214
+ if REGEX_COMMIT_HASH.match(revision):
1215
+ pointer_path = _get_pointer_path(storage_folder, revision, relative_filename)
1216
+ if os.path.exists(pointer_path):
1217
+ if local_dir is not None:
1218
+ return _to_local_dir(pointer_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks)
1219
+ return pointer_path
1220
+
1221
+ url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision, endpoint=endpoint)
1222
+
1223
+ headers = build_hf_headers(
1224
+ token=token,
1225
+ library_name=library_name,
1226
+ library_version=library_version,
1227
+ user_agent=user_agent,
1228
+ )
1229
+
1230
+ url_to_download = url
1231
+ etag = None
1232
+ commit_hash = None
1233
+ expected_size = None
1234
+ head_call_error: Optional[Exception] = None
1235
+ if not local_files_only:
1236
+ try:
1237
+ try:
1238
+ metadata = get_hf_file_metadata(
1239
+ url=url,
1240
+ token=token,
1241
+ proxies=proxies,
1242
+ timeout=etag_timeout,
1243
+ library_name=library_name,
1244
+ library_version=library_version,
1245
+ user_agent=user_agent,
1246
+ )
1247
+ except EntryNotFoundError as http_error:
1248
+ # Cache the non-existence of the file and raise
1249
+ commit_hash = http_error.response.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT)
1250
+ if commit_hash is not None and not legacy_cache_layout:
1251
+ no_exist_file_path = Path(storage_folder) / ".no_exist" / commit_hash / relative_filename
1252
+ no_exist_file_path.parent.mkdir(parents=True, exist_ok=True)
1253
+ no_exist_file_path.touch()
1254
+ _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash)
1255
+ raise
1256
+
1257
+ # Commit hash must exist
1258
+ commit_hash = metadata.commit_hash
1259
+ if commit_hash is None:
1260
+ raise FileMetadataError(
1261
+ "Distant resource does not seem to be on huggingface.co. It is possible that a configuration issue"
1262
+ " prevents you from downloading resources from https://huggingface.co. Please check your firewall"
1263
+ " and proxy settings and make sure your SSL certificates are updated."
1264
+ )
1265
+
1266
+ # Etag must exist
1267
+ etag = metadata.etag
1268
+ # We favor a custom header indicating the etag of the linked resource, and
1269
+ # we fallback to the regular etag header.
1270
+ # If we don't have any of those, raise an error.
1271
+ if etag is None:
1272
+ raise FileMetadataError(
1273
+ "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
1274
+ )
1275
+
1276
+ # Expected (uncompressed) size
1277
+ expected_size = metadata.size
1278
+
1279
+ # In case of a redirect, save an extra redirect on the request.get call,
1280
+ # and ensure we download the exact atomic version even if it changed
1281
+ # between the HEAD and the GET (unlikely, but hey).
1282
+ # Useful for lfs blobs that are stored on a CDN.
1283
+ if metadata.location != url:
1284
+ url_to_download = metadata.location
1285
+ # Remove authorization header when downloading a LFS blob
1286
+ headers.pop("authorization", None)
1287
+ except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
1288
+ # Actually raise for those subclasses of ConnectionError
1289
+ raise
1290
+ except (
1291
+ requests.exceptions.ConnectionError,
1292
+ requests.exceptions.Timeout,
1293
+ OfflineModeIsEnabled,
1294
+ ) as error:
1295
+ # Otherwise, our Internet connection is down.
1296
+ # etag is None
1297
+ head_call_error = error
1298
+ pass
1299
+ except (RevisionNotFoundError, EntryNotFoundError):
1300
+ # The repo was found but the revision or entry doesn't exist on the Hub (never existed or got deleted)
1301
+ raise
1302
+ except requests.HTTPError as error:
1303
+ # Multiple reasons for an http error:
1304
+ # - Repository is private and invalid/missing token sent
1305
+ # - Repository is gated and invalid/missing token sent
1306
+ # - Hub is down (error 500 or 504)
1307
+ # => let's switch to 'local_files_only=True' to check if the files are already cached.
1308
+ # (if it's not the case, the error will be re-raised)
1309
+ head_call_error = error
1310
+ pass
1311
+ except FileMetadataError as error:
1312
+ # Multiple reasons for a FileMetadataError:
1313
+ # - Wrong network configuration (proxy, firewall, SSL certificates)
1314
+ # - Inconsistency on the Hub
1315
+ # => let's switch to 'local_files_only=True' to check if the files are already cached.
1316
+ # (if it's not the case, the error will be re-raised)
1317
+ head_call_error = error
1318
+ pass
1319
+
1320
+ # etag can be None for several reasons:
1321
+ # 1. we passed local_files_only.
1322
+ # 2. we don't have a connection
1323
+ # 3. Hub is down (HTTP 500 or 504)
1324
+ # 4. repo is not found -for example private or gated- and invalid/missing token sent
1325
+ # 5. Hub is blocked by a firewall or proxy is not set correctly.
1326
+ # => Try to get the last downloaded one from the specified revision.
1327
+ #
1328
+ # If the specified revision is a commit hash, look inside "snapshots".
1329
+ # If the specified revision is a branch or tag, look inside "refs".
1330
+ if etag is None:
1331
+ # In those cases, we cannot force download.
1332
+ if force_download:
1333
+ raise ValueError(
1334
+ "We have no connection or you passed local_files_only, so force_download is not an accepted option."
1335
+ )
1336
+
1337
+ # Try to get "commit_hash" from "revision"
1338
+ commit_hash = None
1339
+ if REGEX_COMMIT_HASH.match(revision):
1340
+ commit_hash = revision
1341
+ else:
1342
+ ref_path = os.path.join(storage_folder, "refs", revision)
1343
+ if os.path.isfile(ref_path):
1344
+ with open(ref_path) as f:
1345
+ commit_hash = f.read()
1346
+
1347
+ # Return pointer file if exists
1348
+ if commit_hash is not None:
1349
+ pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)
1350
+ if os.path.exists(pointer_path):
1351
+ if local_dir is not None:
1352
+ return _to_local_dir(
1353
+ pointer_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks
1354
+ )
1355
+ return pointer_path
1356
+
1357
+ # If we couldn't find an appropriate file on disk, raise an error.
1358
+ # If files cannot be found and local_files_only=True,
1359
+ # the models might've been found if local_files_only=False
1360
+ # Notify the user about that
1361
+ if local_files_only:
1362
+ raise LocalEntryNotFoundError(
1363
+ "Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable"
1364
+ " hf.co look-ups and downloads online, set 'local_files_only' to False."
1365
+ )
1366
+ elif isinstance(head_call_error, RepositoryNotFoundError) or isinstance(head_call_error, GatedRepoError):
1367
+ # Repo not found => let's raise the actual error
1368
+ raise head_call_error
1369
+ else:
1370
+ # Otherwise: most likely a connection issue or Hub downtime => let's warn the user
1371
+ raise LocalEntryNotFoundError(
1372
+ "An error happened while trying to locate the file on the Hub and we cannot find the requested files"
1373
+ " in the local cache. Please check your connection and try again or make sure your Internet connection"
1374
+ " is on."
1375
+ ) from head_call_error
1376
+
1377
+ # From now on, etag and commit_hash are not None.
1378
+ assert etag is not None, "etag must have been retrieved from server"
1379
+ assert commit_hash is not None, "commit_hash must have been retrieved from server"
1380
+ blob_path = os.path.join(storage_folder, "blobs", etag)
1381
+ pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)
1382
+
1383
+ os.makedirs(os.path.dirname(blob_path), exist_ok=True)
1384
+ os.makedirs(os.path.dirname(pointer_path), exist_ok=True)
1385
+ # if passed revision is not identical to commit_hash
1386
+ # then revision has to be a branch name or tag name.
1387
+ # In that case store a ref.
1388
+ _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash)
1389
+
1390
+ if os.path.exists(pointer_path) and not force_download:
1391
+ if local_dir is not None:
1392
+ return _to_local_dir(pointer_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks)
1393
+ return pointer_path
1394
+
1395
+ if os.path.exists(blob_path) and not force_download:
1396
+ # we have the blob already, but not the pointer
1397
+ if local_dir is not None: # to local dir
1398
+ return _to_local_dir(blob_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks)
1399
+ else: # or in snapshot cache
1400
+ _create_symlink(blob_path, pointer_path, new_blob=False)
1401
+ return pointer_path
1402
+
1403
+ # Prevent parallel downloads of the same file with a lock.
1404
+ # etag could be duplicated across repos,
1405
+ lock_path = os.path.join(locks_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type), f"{etag}.lock")
1406
+
1407
+ # Some Windows versions do not allow for paths longer than 255 characters.
1408
+ # In this case, we must specify it is an extended path by using the "\\?\" prefix.
1409
+ if os.name == "nt" and len(os.path.abspath(lock_path)) > 255:
1410
+ lock_path = "\\\\?\\" + os.path.abspath(lock_path)
1411
+
1412
+ if os.name == "nt" and len(os.path.abspath(blob_path)) > 255:
1413
+ blob_path = "\\\\?\\" + os.path.abspath(blob_path)
1414
+
1415
+ Path(lock_path).parent.mkdir(parents=True, exist_ok=True)
1416
+ with FileLock(lock_path):
1417
+ # If the download just completed while the lock was activated.
1418
+ if os.path.exists(pointer_path) and not force_download:
1419
+ # Even if returning early like here, the lock will be released.
1420
+ if local_dir is not None:
1421
+ return _to_local_dir(pointer_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks)
1422
+ return pointer_path
1423
+
1424
+ if resume_download:
1425
+ incomplete_path = blob_path + ".incomplete"
1426
+
1427
+ @contextmanager
1428
+ def _resumable_file_manager() -> Generator[io.BufferedWriter, None, None]:
1429
+ with open(incomplete_path, "ab") as f:
1430
+ yield f
1431
+
1432
+ temp_file_manager = _resumable_file_manager
1433
+ if os.path.exists(incomplete_path):
1434
+ resume_size = os.stat(incomplete_path).st_size
1435
+ else:
1436
+ resume_size = 0
1437
+ else:
1438
+ temp_file_manager = partial( # type: ignore
1439
+ tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False
1440
+ )
1441
+ resume_size = 0
1442
+
1443
+ # Download to temporary file, then copy to cache dir once finished.
1444
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
1445
+ with temp_file_manager() as temp_file:
1446
+ logger.info("downloading %s to %s", url, temp_file.name)
1447
+
1448
+ if expected_size is not None: # might be None if HTTP header not set correctly
1449
+ # Check tmp path
1450
+ _check_disk_space(expected_size, os.path.dirname(temp_file.name))
1451
+
1452
+ # Check destination
1453
+ _check_disk_space(expected_size, os.path.dirname(blob_path))
1454
+ if local_dir is not None:
1455
+ _check_disk_space(expected_size, local_dir)
1456
+
1457
+ http_get(
1458
+ url_to_download,
1459
+ temp_file,
1460
+ proxies=proxies,
1461
+ resume_size=resume_size,
1462
+ headers=headers,
1463
+ expected_size=expected_size,
1464
+ )
1465
+
1466
+ if local_dir is None:
1467
+ logger.debug(f"Storing {url} in cache at {blob_path}")
1468
+ _chmod_and_replace(temp_file.name, blob_path)
1469
+ _create_symlink(blob_path, pointer_path, new_blob=True)
1470
+ else:
1471
+ local_dir_filepath = os.path.join(local_dir, relative_filename)
1472
+ os.makedirs(os.path.dirname(local_dir_filepath), exist_ok=True)
1473
+
1474
+ # If "auto" (default) copy-paste small files to ease manual editing but symlink big files to save disk
1475
+ # In both cases, blob file is cached.
1476
+ is_big_file = os.stat(temp_file.name).st_size > constants.HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD
1477
+ if local_dir_use_symlinks is True or (local_dir_use_symlinks == "auto" and is_big_file):
1478
+ logger.debug(f"Storing {url} in cache at {blob_path}")
1479
+ _chmod_and_replace(temp_file.name, blob_path)
1480
+ logger.debug("Create symlink to local dir")
1481
+ _create_symlink(blob_path, local_dir_filepath, new_blob=False)
1482
+ elif local_dir_use_symlinks == "auto" and not is_big_file:
1483
+ logger.debug(f"Storing {url} in cache at {blob_path}")
1484
+ _chmod_and_replace(temp_file.name, blob_path)
1485
+ logger.debug("Duplicate in local dir (small file and use_symlink set to 'auto')")
1486
+ shutil.copyfile(blob_path, local_dir_filepath)
1487
+ else:
1488
+ logger.debug(f"Storing {url} in local_dir at {local_dir_filepath} (not cached).")
1489
+ _chmod_and_replace(temp_file.name, local_dir_filepath)
1490
+ pointer_path = local_dir_filepath # for return value
1491
+
1492
+ return pointer_path
1493
+
1494
+
1495
+ @validate_hf_hub_args
1496
+ def try_to_load_from_cache(
1497
+ repo_id: str,
1498
+ filename: str,
1499
+ cache_dir: Union[str, Path, None] = None,
1500
+ revision: Optional[str] = None,
1501
+ repo_type: Optional[str] = None,
1502
+ ) -> Union[str, _CACHED_NO_EXIST_T, None]:
1503
+ """
1504
+ Explores the cache to return the latest cached file for a given revision if found.
1505
+
1506
+ This function will not raise any exception if the file in not cached.
1507
+
1508
+ Args:
1509
+ cache_dir (`str` or `os.PathLike`):
1510
+ The folder where the cached files lie.
1511
+ repo_id (`str`):
1512
+ The ID of the repo on huggingface.co.
1513
+ filename (`str`):
1514
+ The filename to look for inside `repo_id`.
1515
+ revision (`str`, *optional*):
1516
+ The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
1517
+ provided either.
1518
+ repo_type (`str`, *optional*):
1519
+ The type of the repository. Will default to `"model"`.
1520
+
1521
+ Returns:
1522
+ `Optional[str]` or `_CACHED_NO_EXIST`:
1523
+ Will return `None` if the file was not cached. Otherwise:
1524
+ - The exact path to the cached file if it's found in the cache
1525
+ - A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
1526
+ cached.
1527
+
1528
+ Example:
1529
+
1530
+ ```python
1531
+ from huggingface_hub import try_to_load_from_cache, _CACHED_NO_EXIST
1532
+
1533
+ filepath = try_to_load_from_cache()
1534
+ if isinstance(filepath, str):
1535
+ # file exists and is cached
1536
+ ...
1537
+ elif filepath is _CACHED_NO_EXIST:
1538
+ # non-existence of file is cached
1539
+ ...
1540
+ else:
1541
+ # file is not cached
1542
+ ...
1543
+ ```
1544
+ """
1545
+ if revision is None:
1546
+ revision = "main"
1547
+ if repo_type is None:
1548
+ repo_type = "model"
1549
+ if repo_type not in REPO_TYPES:
1550
+ raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}")
1551
+ if cache_dir is None:
1552
+ cache_dir = HF_HUB_CACHE
1553
+
1554
+ object_id = repo_id.replace("/", "--")
1555
+ repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
1556
+ if not os.path.isdir(repo_cache):
1557
+ # No cache for this model
1558
+ return None
1559
+
1560
+ refs_dir = os.path.join(repo_cache, "refs")
1561
+ snapshots_dir = os.path.join(repo_cache, "snapshots")
1562
+ no_exist_dir = os.path.join(repo_cache, ".no_exist")
1563
+
1564
+ # Resolve refs (for instance to convert main to the associated commit sha)
1565
+ if os.path.isdir(refs_dir):
1566
+ revision_file = os.path.join(refs_dir, revision)
1567
+ if os.path.isfile(revision_file):
1568
+ with open(revision_file) as f:
1569
+ revision = f.read()
1570
+
1571
+ # Check if file is cached as "no_exist"
1572
+ if os.path.isfile(os.path.join(no_exist_dir, revision, filename)):
1573
+ return _CACHED_NO_EXIST
1574
+
1575
+ # Check if revision folder exists
1576
+ if not os.path.exists(snapshots_dir):
1577
+ return None
1578
+ cached_shas = os.listdir(snapshots_dir)
1579
+ if revision not in cached_shas:
1580
+ # No cache for this revision and we won't try to return a random revision
1581
+ return None
1582
+
1583
+ # Check if file exists in cache
1584
+ cached_file = os.path.join(snapshots_dir, revision, filename)
1585
+ return cached_file if os.path.isfile(cached_file) else None
1586
+
1587
+
1588
+ @validate_hf_hub_args
1589
+ def get_hf_file_metadata(
1590
+ url: str,
1591
+ token: Union[bool, str, None] = None,
1592
+ proxies: Optional[Dict] = None,
1593
+ timeout: Optional[float] = DEFAULT_REQUEST_TIMEOUT,
1594
+ library_name: Optional[str] = None,
1595
+ library_version: Optional[str] = None,
1596
+ user_agent: Union[Dict, str, None] = None,
1597
+ ) -> HfFileMetadata:
1598
+ """Fetch metadata of a file versioned on the Hub for a given url.
1599
+
1600
+ Args:
1601
+ url (`str`):
1602
+ File url, for example returned by [`hf_hub_url`].
1603
+ token (`str` or `bool`, *optional*):
1604
+ A token to be used for the download.
1605
+ - If `True`, the token is read from the HuggingFace config
1606
+ folder.
1607
+ - If `False` or `None`, no token is provided.
1608
+ - If a string, it's used as the authentication token.
1609
+ proxies (`dict`, *optional*):
1610
+ Dictionary mapping protocol to the URL of the proxy passed to
1611
+ `requests.request`.
1612
+ timeout (`float`, *optional*, defaults to 10):
1613
+ How many seconds to wait for the server to send metadata before giving up.
1614
+ library_name (`str`, *optional*):
1615
+ The name of the library to which the object corresponds.
1616
+ library_version (`str`, *optional*):
1617
+ The version of the library.
1618
+ user_agent (`dict`, `str`, *optional*):
1619
+ The user-agent info in the form of a dictionary or a string.
1620
+
1621
+ Returns:
1622
+ A [`HfFileMetadata`] object containing metadata such as location, etag, size and
1623
+ commit_hash.
1624
+ """
1625
+ headers = build_hf_headers(
1626
+ token=token, library_name=library_name, library_version=library_version, user_agent=user_agent
1627
+ )
1628
+ headers["Accept-Encoding"] = "identity" # prevent any compression => we want to know the real size of the file
1629
+
1630
+ # Retrieve metadata
1631
+ r = _request_wrapper(
1632
+ method="HEAD",
1633
+ url=url,
1634
+ headers=headers,
1635
+ allow_redirects=False,
1636
+ follow_relative_redirects=True,
1637
+ proxies=proxies,
1638
+ timeout=timeout,
1639
+ )
1640
+ hf_raise_for_status(r)
1641
+
1642
+ # Return
1643
+ return HfFileMetadata(
1644
+ commit_hash=r.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT),
1645
+ # We favor a custom header indicating the etag of the linked resource, and
1646
+ # we fallback to the regular etag header.
1647
+ etag=_normalize_etag(r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag")),
1648
+ # Either from response headers (if redirected) or defaults to request url
1649
+ # Do not use directly `url`, as `_request_wrapper` might have followed relative
1650
+ # redirects.
1651
+ location=r.headers.get("Location") or r.request.url, # type: ignore
1652
+ size=_int_or_none(r.headers.get(HUGGINGFACE_HEADER_X_LINKED_SIZE) or r.headers.get("Content-Length")),
1653
+ )
1654
+
1655
+
1656
+ def _int_or_none(value: Optional[str]) -> Optional[int]:
1657
+ try:
1658
+ return int(value) # type: ignore
1659
+ except (TypeError, ValueError):
1660
+ return None
1661
+
1662
+
1663
+ def _chmod_and_replace(src: str, dst: str) -> None:
1664
+ """Set correct permission before moving a blob from tmp directory to cache dir.
1665
+
1666
+ Do not take into account the `umask` from the process as there is no convenient way
1667
+ to get it that is thread-safe.
1668
+
1669
+ See:
1670
+ - About umask: https://docs.python.org/3/library/os.html#os.umask
1671
+ - Thread-safety: https://stackoverflow.com/a/70343066
1672
+ - About solution: https://github.com/huggingface/huggingface_hub/pull/1220#issuecomment-1326211591
1673
+ - Fix issue: https://github.com/huggingface/huggingface_hub/issues/1141
1674
+ - Fix issue: https://github.com/huggingface/huggingface_hub/issues/1215
1675
+ """
1676
+ # Get umask by creating a temporary file in the cached repo folder.
1677
+ tmp_file = Path(dst).parent.parent / f"tmp_{uuid.uuid4()}"
1678
+ try:
1679
+ tmp_file.touch()
1680
+ cache_dir_mode = Path(tmp_file).stat().st_mode
1681
+ os.chmod(src, stat.S_IMODE(cache_dir_mode))
1682
+ finally:
1683
+ tmp_file.unlink()
1684
+
1685
+ shutil.move(src, dst)
1686
+
1687
+
1688
+ def _get_pointer_path(storage_folder: str, revision: str, relative_filename: str) -> str:
1689
+ # Using `os.path.abspath` instead of `Path.resolve()` to avoid resolving symlinks
1690
+ snapshot_path = os.path.join(storage_folder, "snapshots")
1691
+ pointer_path = os.path.join(snapshot_path, revision, relative_filename)
1692
+ if Path(os.path.abspath(snapshot_path)) not in Path(os.path.abspath(pointer_path)).parents:
1693
+ raise ValueError(
1694
+ "Invalid pointer path: cannot create pointer path in snapshot folder if"
1695
+ f" `storage_folder='{storage_folder}'`, `revision='{revision}'` and"
1696
+ f" `relative_filename='{relative_filename}'`."
1697
+ )
1698
+ return pointer_path
1699
+
1700
+
1701
+ def _to_local_dir(
1702
+ path: str, local_dir: str, relative_filename: str, use_symlinks: Union[bool, Literal["auto"]]
1703
+ ) -> str:
1704
+ """Place a file in a local dir (different than cache_dir).
1705
+
1706
+ Either symlink to blob file in cache or duplicate file depending on `use_symlinks` and file size.
1707
+ """
1708
+ # Using `os.path.abspath` instead of `Path.resolve()` to avoid resolving symlinks
1709
+ local_dir_filepath = os.path.join(local_dir, relative_filename)
1710
+ if Path(os.path.abspath(local_dir)) not in Path(os.path.abspath(local_dir_filepath)).parents:
1711
+ raise ValueError(
1712
+ f"Cannot copy file '{relative_filename}' to local dir '{local_dir}': file would not be in the local"
1713
+ " directory."
1714
+ )
1715
+
1716
+ os.makedirs(os.path.dirname(local_dir_filepath), exist_ok=True)
1717
+ real_blob_path = os.path.realpath(path)
1718
+
1719
+ # If "auto" (default) copy-paste small files to ease manual editing but symlink big files to save disk
1720
+ if use_symlinks == "auto":
1721
+ use_symlinks = os.stat(real_blob_path).st_size > constants.HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD
1722
+
1723
+ if use_symlinks:
1724
+ _create_symlink(real_blob_path, local_dir_filepath, new_blob=False)
1725
+ else:
1726
+ shutil.copyfile(real_blob_path, local_dir_filepath)
1727
+ return local_dir_filepath
lib/python3.11/site-packages/huggingface_hub/hf_api.py ADDED
The diff for this file is too large to render. See raw diff
 
lib/python3.11/site-packages/huggingface_hub/hf_file_system.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import re
4
+ import tempfile
5
+ from collections import deque
6
+ from dataclasses import dataclass, field
7
+ from datetime import datetime
8
+ from itertools import chain
9
+ from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
10
+ from urllib.parse import quote, unquote
11
+
12
+ import fsspec
13
+
14
+ from ._commit_api import CommitOperationCopy, CommitOperationDelete
15
+ from .constants import DEFAULT_REVISION, ENDPOINT, REPO_TYPE_MODEL, REPO_TYPES_MAPPING, REPO_TYPES_URL_PREFIXES
16
+ from .file_download import hf_hub_url
17
+ from .hf_api import HfApi, LastCommitInfo, RepoFile
18
+ from .utils import (
19
+ EntryNotFoundError,
20
+ HFValidationError,
21
+ RepositoryNotFoundError,
22
+ RevisionNotFoundError,
23
+ hf_raise_for_status,
24
+ http_backoff,
25
+ )
26
+
27
+
28
+ # Regex used to match special revisions with "/" in them (see #1710)
29
+ SPECIAL_REFS_REVISION_REGEX = re.compile(
30
+ r"""
31
+ (^refs\/convert\/\w+) # `refs/convert/parquet` revisions
32
+ |
33
+ (^refs\/pr\/\d+) # PR revisions
34
+ """,
35
+ re.VERBOSE,
36
+ )
37
+
38
+
39
+ @dataclass
40
+ class HfFileSystemResolvedPath:
41
+ """Data structure containing information about a resolved Hugging Face file system path."""
42
+
43
+ repo_type: str
44
+ repo_id: str
45
+ revision: str
46
+ path_in_repo: str
47
+ # The part placed after '@' in the initial path. It can be a quoted or unquoted refs revision.
48
+ # Used to reconstruct the unresolved path to return to the user.
49
+ _raw_revision: Optional[str] = field(default=None, repr=False)
50
+
51
+ def unresolve(self) -> str:
52
+ repo_path = REPO_TYPES_URL_PREFIXES.get(self.repo_type, "") + self.repo_id
53
+ if self._raw_revision:
54
+ return f"{repo_path}@{self._raw_revision}/{self.path_in_repo}".rstrip("/")
55
+ elif self.revision != DEFAULT_REVISION:
56
+ return f"{repo_path}@{safe_revision(self.revision)}/{self.path_in_repo}".rstrip("/")
57
+ else:
58
+ return f"{repo_path}/{self.path_in_repo}".rstrip("/")
59
+
60
+
61
+ class HfFileSystem(fsspec.AbstractFileSystem):
62
+ """
63
+ Access a remote Hugging Face Hub repository as if were a local file system.
64
+
65
+ Args:
66
+ token (`str`, *optional*):
67
+ Authentication token, obtained with [`HfApi.login`] method. Will default to the stored token.
68
+
69
+ Usage:
70
+
71
+ ```python
72
+ >>> from huggingface_hub import HfFileSystem
73
+
74
+ >>> fs = HfFileSystem()
75
+
76
+ >>> # List files
77
+ >>> fs.glob("my-username/my-model/*.bin")
78
+ ['my-username/my-model/pytorch_model.bin']
79
+ >>> fs.ls("datasets/my-username/my-dataset", detail=False)
80
+ ['datasets/my-username/my-dataset/.gitattributes', 'datasets/my-username/my-dataset/README.md', 'datasets/my-username/my-dataset/data.json']
81
+
82
+ >>> # Read/write files
83
+ >>> with fs.open("my-username/my-model/pytorch_model.bin") as f:
84
+ ... data = f.read()
85
+ >>> with fs.open("my-username/my-model/pytorch_model.bin", "wb") as f:
86
+ ... f.write(data)
87
+ ```
88
+ """
89
+
90
+ root_marker = ""
91
+ protocol = "hf"
92
+
93
+ def __init__(
94
+ self,
95
+ *args,
96
+ endpoint: Optional[str] = None,
97
+ token: Optional[str] = None,
98
+ **storage_options,
99
+ ):
100
+ super().__init__(*args, **storage_options)
101
+ self.endpoint = endpoint or ENDPOINT
102
+ self.token = token
103
+ self._api = HfApi(endpoint=endpoint, token=token)
104
+ # Maps (repo_type, repo_id, revision) to a 2-tuple with:
105
+ # * the 1st element indicating whether the repositoy and the revision exist
106
+ # * the 2nd element being the exception raised if the repository or revision doesn't exist
107
+ self._repo_and_revision_exists_cache: Dict[
108
+ Tuple[str, str, Optional[str]], Tuple[bool, Optional[Exception]]
109
+ ] = {}
110
+
111
+ def _repo_and_revision_exist(
112
+ self, repo_type: str, repo_id: str, revision: Optional[str]
113
+ ) -> Tuple[bool, Optional[Exception]]:
114
+ if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache:
115
+ try:
116
+ self._api.repo_info(repo_id, revision=revision, repo_type=repo_type)
117
+ except (RepositoryNotFoundError, HFValidationError) as e:
118
+ self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e
119
+ self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = False, e
120
+ except RevisionNotFoundError as e:
121
+ self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e
122
+ self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None
123
+ else:
124
+ self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = True, None
125
+ self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None
126
+ return self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)]
127
+
128
+ def resolve_path(self, path: str, revision: Optional[str] = None) -> HfFileSystemResolvedPath:
129
+ def _align_revision_in_path_with_revision(
130
+ revision_in_path: Optional[str], revision: Optional[str]
131
+ ) -> Optional[str]:
132
+ if revision is not None:
133
+ if revision_in_path is not None and revision_in_path != revision:
134
+ raise ValueError(
135
+ f'Revision specified in path ("{revision_in_path}") and in `revision` argument ("{revision}")'
136
+ " are not the same."
137
+ )
138
+ else:
139
+ revision = revision_in_path
140
+ return revision
141
+
142
+ path = self._strip_protocol(path)
143
+ if not path:
144
+ # can't list repositories at root
145
+ raise NotImplementedError("Access to repositories lists is not implemented.")
146
+ elif path.split("/")[0] + "/" in REPO_TYPES_URL_PREFIXES.values():
147
+ if "/" not in path:
148
+ # can't list repositories at the repository type level
149
+ raise NotImplementedError("Access to repositories lists is not implemented.")
150
+ repo_type, path = path.split("/", 1)
151
+ repo_type = REPO_TYPES_MAPPING[repo_type]
152
+ else:
153
+ repo_type = REPO_TYPE_MODEL
154
+ if path.count("/") > 0:
155
+ if "@" in path:
156
+ repo_id, revision_in_path = path.split("@", 1)
157
+ if "/" in revision_in_path:
158
+ match = SPECIAL_REFS_REVISION_REGEX.search(revision_in_path)
159
+ if match is not None and revision in (None, match.group()):
160
+ # Handle `refs/convert/parquet` and PR revisions separately
161
+ path_in_repo = SPECIAL_REFS_REVISION_REGEX.sub("", revision_in_path).lstrip("/")
162
+ revision_in_path = match.group()
163
+ else:
164
+ revision_in_path, path_in_repo = revision_in_path.split("/", 1)
165
+ else:
166
+ path_in_repo = ""
167
+ revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision)
168
+ repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision)
169
+ if not repo_and_revision_exist:
170
+ _raise_file_not_found(path, err)
171
+ else:
172
+ revision_in_path = None
173
+ repo_id_with_namespace = "/".join(path.split("/")[:2])
174
+ path_in_repo_with_namespace = "/".join(path.split("/")[2:])
175
+ repo_id_without_namespace = path.split("/")[0]
176
+ path_in_repo_without_namespace = "/".join(path.split("/")[1:])
177
+ repo_id = repo_id_with_namespace
178
+ path_in_repo = path_in_repo_with_namespace
179
+ repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision)
180
+ if not repo_and_revision_exist:
181
+ if isinstance(err, (RepositoryNotFoundError, HFValidationError)):
182
+ repo_id = repo_id_without_namespace
183
+ path_in_repo = path_in_repo_without_namespace
184
+ repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision)
185
+ if not repo_and_revision_exist:
186
+ _raise_file_not_found(path, err)
187
+ else:
188
+ _raise_file_not_found(path, err)
189
+ else:
190
+ repo_id = path
191
+ path_in_repo = ""
192
+ if "@" in path:
193
+ repo_id, revision_in_path = path.split("@", 1)
194
+ revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision)
195
+ else:
196
+ revision_in_path = None
197
+ repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision)
198
+ if not repo_and_revision_exist:
199
+ raise NotImplementedError("Access to repositories lists is not implemented.")
200
+
201
+ revision = revision if revision is not None else DEFAULT_REVISION
202
+ return HfFileSystemResolvedPath(repo_type, repo_id, revision, path_in_repo, _raw_revision=revision_in_path)
203
+
204
+ def invalidate_cache(self, path: Optional[str] = None) -> None:
205
+ if not path:
206
+ self.dircache.clear()
207
+ self._repo_and_revision_exists_cache.clear()
208
+ else:
209
+ path = self.resolve_path(path).unresolve()
210
+ while path:
211
+ self.dircache.pop(path, None)
212
+ path = self._parent(path)
213
+
214
+ def _open(
215
+ self,
216
+ path: str,
217
+ mode: str = "rb",
218
+ revision: Optional[str] = None,
219
+ **kwargs,
220
+ ) -> "HfFileSystemFile":
221
+ if "a" in mode:
222
+ raise NotImplementedError("Appending to remote files is not yet supported.")
223
+ return HfFileSystemFile(self, path, mode=mode, revision=revision, **kwargs)
224
+
225
+ def _rm(self, path: str, revision: Optional[str] = None, **kwargs) -> None:
226
+ resolved_path = self.resolve_path(path, revision=revision)
227
+ self._api.delete_file(
228
+ path_in_repo=resolved_path.path_in_repo,
229
+ repo_id=resolved_path.repo_id,
230
+ token=self.token,
231
+ repo_type=resolved_path.repo_type,
232
+ revision=resolved_path.revision,
233
+ commit_message=kwargs.get("commit_message"),
234
+ commit_description=kwargs.get("commit_description"),
235
+ )
236
+ self.invalidate_cache(path=resolved_path.unresolve())
237
+
238
+ def rm(
239
+ self,
240
+ path: str,
241
+ recursive: bool = False,
242
+ maxdepth: Optional[int] = None,
243
+ revision: Optional[str] = None,
244
+ **kwargs,
245
+ ) -> None:
246
+ resolved_path = self.resolve_path(path, revision=revision)
247
+ root_path = REPO_TYPES_URL_PREFIXES.get(resolved_path.repo_type, "") + resolved_path.repo_id
248
+ paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth, revision=revision)
249
+ paths_in_repo = [path[len(root_path) + 1 :] for path in paths if not self.isdir(path)]
250
+ operations = [CommitOperationDelete(path_in_repo=path_in_repo) for path_in_repo in paths_in_repo]
251
+ commit_message = f"Delete {path} "
252
+ commit_message += "recursively " if recursive else ""
253
+ commit_message += f"up to depth {maxdepth} " if maxdepth is not None else ""
254
+ # TODO: use `commit_description` to list all the deleted paths?
255
+ self._api.create_commit(
256
+ repo_id=resolved_path.repo_id,
257
+ repo_type=resolved_path.repo_type,
258
+ token=self.token,
259
+ operations=operations,
260
+ revision=resolved_path.revision,
261
+ commit_message=kwargs.get("commit_message", commit_message),
262
+ commit_description=kwargs.get("commit_description"),
263
+ )
264
+ self.invalidate_cache(path=resolved_path.unresolve())
265
+
266
+ def ls(
267
+ self, path: str, detail: bool = True, refresh: bool = False, revision: Optional[str] = None, **kwargs
268
+ ) -> List[Union[str, Dict[str, Any]]]:
269
+ """List the contents of a directory."""
270
+ resolved_path = self.resolve_path(path, revision=revision)
271
+ path = resolved_path.unresolve()
272
+ kwargs = {"expand_info": detail, **kwargs}
273
+ try:
274
+ out = self._ls_tree(path, refresh=refresh, revision=revision, **kwargs)
275
+ except EntryNotFoundError:
276
+ # Path could be a file
277
+ if not resolved_path.path_in_repo:
278
+ _raise_file_not_found(path, None)
279
+ out = self._ls_tree(self._parent(path), refresh=refresh, revision=revision, **kwargs)
280
+ out = [o for o in out if o["name"] == path]
281
+ if len(out) == 0:
282
+ _raise_file_not_found(path, None)
283
+ return out if detail else [o["name"] for o in out]
284
+
285
+ def _ls_tree(
286
+ self,
287
+ path: str,
288
+ recursive: bool = False,
289
+ refresh: bool = False,
290
+ revision: Optional[str] = None,
291
+ expand_info: bool = True,
292
+ ):
293
+ resolved_path = self.resolve_path(path, revision=revision)
294
+ path = resolved_path.unresolve()
295
+ root_path = HfFileSystemResolvedPath(
296
+ resolved_path.repo_type,
297
+ resolved_path.repo_id,
298
+ resolved_path.revision,
299
+ path_in_repo="",
300
+ _raw_revision=resolved_path._raw_revision,
301
+ ).unresolve()
302
+
303
+ out = []
304
+ if path in self.dircache and not refresh:
305
+ cached_path_infos = self.dircache[path]
306
+ out.extend(cached_path_infos)
307
+ dirs_not_in_dircache = []
308
+ if recursive:
309
+ # Use BFS to traverse the cache and build the "recursive "output
310
+ # (The Hub uses a so-called "tree first" strategy for the tree endpoint but we sort the output to follow the spec so the result is (eventually) the same)
311
+ dirs_to_visit = deque(
312
+ [path_info for path_info in cached_path_infos if path_info["type"] == "directory"]
313
+ )
314
+ while dirs_to_visit:
315
+ dir_info = dirs_to_visit.popleft()
316
+ if dir_info["name"] not in self.dircache:
317
+ dirs_not_in_dircache.append(dir_info["name"])
318
+ else:
319
+ cached_path_infos = self.dircache[dir_info["name"]]
320
+ out.extend(cached_path_infos)
321
+ dirs_to_visit.extend(
322
+ [path_info for path_info in cached_path_infos if path_info["type"] == "directory"]
323
+ )
324
+
325
+ dirs_not_expanded = []
326
+ if expand_info:
327
+ # Check if there are directories with non-expanded entries
328
+ dirs_not_expanded = [self._parent(o["name"]) for o in out if o["last_commit"] is None]
329
+
330
+ if (recursive and dirs_not_in_dircache) or (expand_info and dirs_not_expanded):
331
+ # If the dircache is incomplete, find the common path of the missing and non-expanded entries
332
+ # and extend the output with the result of `_ls_tree(common_path, recursive=True)`
333
+ common_prefix = os.path.commonprefix(dirs_not_in_dircache + dirs_not_expanded)
334
+ # Get the parent directory if the common prefix itself is not a directory
335
+ common_path = (
336
+ common_prefix.rstrip("/")
337
+ if common_prefix.endswith("/")
338
+ or common_prefix == root_path
339
+ or common_prefix in chain(dirs_not_in_dircache, dirs_not_expanded)
340
+ else self._parent(common_prefix)
341
+ )
342
+ out = [o for o in out if not o["name"].startswith(common_path + "/")]
343
+ for cached_path in self.dircache:
344
+ if cached_path.startswith(common_path + "/"):
345
+ self.dircache.pop(cached_path, None)
346
+ self.dircache.pop(common_path, None)
347
+ out.extend(
348
+ self._ls_tree(
349
+ common_path,
350
+ recursive=recursive,
351
+ refresh=True,
352
+ revision=revision,
353
+ expand_info=expand_info,
354
+ )
355
+ )
356
+ else:
357
+ tree = self._api.list_repo_tree(
358
+ resolved_path.repo_id,
359
+ resolved_path.path_in_repo,
360
+ recursive=recursive,
361
+ expand=expand_info,
362
+ revision=resolved_path.revision,
363
+ repo_type=resolved_path.repo_type,
364
+ )
365
+ for path_info in tree:
366
+ if isinstance(path_info, RepoFile):
367
+ cache_path_info = {
368
+ "name": root_path + "/" + path_info.path,
369
+ "size": path_info.size,
370
+ "type": "file",
371
+ "blob_id": path_info.blob_id,
372
+ "lfs": path_info.lfs,
373
+ "last_commit": path_info.last_commit,
374
+ "security": path_info.security,
375
+ }
376
+ else:
377
+ cache_path_info = {
378
+ "name": root_path + "/" + path_info.path,
379
+ "size": 0,
380
+ "type": "directory",
381
+ "tree_id": path_info.tree_id,
382
+ "last_commit": path_info.last_commit,
383
+ }
384
+ parent_path = self._parent(cache_path_info["name"])
385
+ self.dircache.setdefault(parent_path, []).append(cache_path_info)
386
+ out.append(cache_path_info)
387
+ return copy.deepcopy(out) # copy to not let users modify the dircache
388
+
389
+ def glob(self, path, **kwargs):
390
+ # Set expand_info=False by default to get a x10 speed boost
391
+ kwargs = {"expand_info": kwargs.get("detail", False), **kwargs}
392
+ path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
393
+ return super().glob(path, **kwargs)
394
+
395
+ def find(
396
+ self,
397
+ path: str,
398
+ maxdepth: Optional[int] = None,
399
+ withdirs: bool = False,
400
+ detail: bool = False,
401
+ refresh: bool = False,
402
+ revision: Optional[str] = None,
403
+ **kwargs,
404
+ ) -> Union[List[str], Dict[str, Dict[str, Any]]]:
405
+ if maxdepth:
406
+ return super().find(
407
+ path, maxdepth=maxdepth, withdirs=withdirs, detail=detail, refresh=refresh, revision=revision, **kwargs
408
+ )
409
+ resolved_path = self.resolve_path(path, revision=revision)
410
+ path = resolved_path.unresolve()
411
+ kwargs = {"expand_info": detail, **kwargs}
412
+ try:
413
+ out = self._ls_tree(path, recursive=True, refresh=refresh, revision=resolved_path.revision, **kwargs)
414
+ except EntryNotFoundError:
415
+ # Path could be a file
416
+ if self.info(path, revision=revision, **kwargs)["type"] == "file":
417
+ out = {path: {}}
418
+ else:
419
+ out = {}
420
+ else:
421
+ if not withdirs:
422
+ out = [o for o in out if o["type"] != "directory"]
423
+ else:
424
+ # If `withdirs=True`, include the directory itself to be consistent with the spec
425
+ path_info = self.info(path, revision=resolved_path.revision, **kwargs)
426
+ out = [path_info] + out if path_info["type"] == "directory" else out
427
+ out = {o["name"]: o for o in out}
428
+ names = sorted(out)
429
+ if not detail:
430
+ return names
431
+ else:
432
+ return {name: out[name] for name in names}
433
+
434
+ def cp_file(self, path1: str, path2: str, revision: Optional[str] = None, **kwargs) -> None:
435
+ resolved_path1 = self.resolve_path(path1, revision=revision)
436
+ resolved_path2 = self.resolve_path(path2, revision=revision)
437
+
438
+ same_repo = (
439
+ resolved_path1.repo_type == resolved_path2.repo_type and resolved_path1.repo_id == resolved_path2.repo_id
440
+ )
441
+
442
+ if same_repo and self.info(path1, revision=resolved_path1.revision)["lfs"] is not None:
443
+ commit_message = f"Copy {path1} to {path2}"
444
+ self._api.create_commit(
445
+ repo_id=resolved_path1.repo_id,
446
+ repo_type=resolved_path1.repo_type,
447
+ revision=resolved_path2.revision,
448
+ commit_message=kwargs.get("commit_message", commit_message),
449
+ commit_description=kwargs.get("commit_description", ""),
450
+ operations=[
451
+ CommitOperationCopy(
452
+ src_path_in_repo=resolved_path1.path_in_repo,
453
+ path_in_repo=resolved_path2.path_in_repo,
454
+ src_revision=resolved_path1.revision,
455
+ )
456
+ ],
457
+ )
458
+ else:
459
+ with self.open(path1, "rb", revision=resolved_path1.revision) as f:
460
+ content = f.read()
461
+ commit_message = f"Copy {path1} to {path2}"
462
+ self._api.upload_file(
463
+ path_or_fileobj=content,
464
+ path_in_repo=resolved_path2.path_in_repo,
465
+ repo_id=resolved_path2.repo_id,
466
+ token=self.token,
467
+ repo_type=resolved_path2.repo_type,
468
+ revision=resolved_path2.revision,
469
+ commit_message=kwargs.get("commit_message", commit_message),
470
+ commit_description=kwargs.get("commit_description"),
471
+ )
472
+ self.invalidate_cache(path=resolved_path1.unresolve())
473
+ self.invalidate_cache(path=resolved_path2.unresolve())
474
+
475
+ def modified(self, path: str, **kwargs) -> datetime:
476
+ info = self.info(path, **kwargs)
477
+ return info["last_commit"]["date"]
478
+
479
+ def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) -> Dict[str, Any]:
480
+ resolved_path = self.resolve_path(path, revision=revision)
481
+ path = resolved_path.unresolve()
482
+ expand_info = kwargs.get(
483
+ "expand_info", True
484
+ ) # don't expose it as a parameter in the public API to follow the spec
485
+ if not resolved_path.path_in_repo:
486
+ # Path is the root directory
487
+ out = {
488
+ "name": path,
489
+ "size": 0,
490
+ "type": "directory",
491
+ }
492
+ if expand_info:
493
+ last_commit = self._api.list_repo_commits(
494
+ resolved_path.repo_id, repo_type=resolved_path.repo_type, revision=resolved_path.revision
495
+ )[-1]
496
+ out = {
497
+ **out,
498
+ "tree_id": None, # TODO: tree_id of the root directory?
499
+ "last_commit": LastCommitInfo(
500
+ oid=last_commit.commit_id, title=last_commit.title, date=last_commit.created_at
501
+ ),
502
+ }
503
+ else:
504
+ out = None
505
+ parent_path = self._parent(path)
506
+ if parent_path in self.dircache:
507
+ # Check if the path is in the cache
508
+ out1 = [o for o in self.dircache[parent_path] if o["name"] == path]
509
+ if not out1:
510
+ _raise_file_not_found(path, None)
511
+ out = out1[0]
512
+ if refresh or out is None or (expand_info and out and out["last_commit"] is None):
513
+ paths_info = self._api.get_paths_info(
514
+ resolved_path.repo_id,
515
+ resolved_path.path_in_repo,
516
+ expand=expand_info,
517
+ revision=resolved_path.revision,
518
+ repo_type=resolved_path.repo_type,
519
+ )
520
+ if not paths_info:
521
+ _raise_file_not_found(path, None)
522
+ path_info = paths_info[0]
523
+ root_path = HfFileSystemResolvedPath(
524
+ resolved_path.repo_type,
525
+ resolved_path.repo_id,
526
+ resolved_path.revision,
527
+ path_in_repo="",
528
+ _raw_revision=resolved_path._raw_revision,
529
+ ).unresolve()
530
+ if isinstance(path_info, RepoFile):
531
+ out = {
532
+ "name": root_path + "/" + path_info.path,
533
+ "size": path_info.size,
534
+ "type": "file",
535
+ "blob_id": path_info.blob_id,
536
+ "lfs": path_info.lfs,
537
+ "last_commit": path_info.last_commit,
538
+ "security": path_info.security,
539
+ }
540
+ else:
541
+ out = {
542
+ "name": root_path + "/" + path_info.path,
543
+ "size": 0,
544
+ "type": "directory",
545
+ "tree_id": path_info.tree_id,
546
+ "last_commit": path_info.last_commit,
547
+ }
548
+ if not expand_info:
549
+ out = {k: out[k] for k in ["name", "size", "type"]}
550
+ assert out is not None
551
+ return copy.deepcopy(out) # copy to not let users modify the dircache
552
+
553
+ def exists(self, path, **kwargs):
554
+ """Is there a file at the given path"""
555
+ try:
556
+ self.info(path, expand_info=False, **kwargs)
557
+ return True
558
+ except: # noqa: E722
559
+ # any exception allowed bar FileNotFoundError?
560
+ return False
561
+
562
+ def isdir(self, path):
563
+ """Is this entry directory-like?"""
564
+ try:
565
+ return self.info(path, expand_info=False)["type"] == "directory"
566
+ except OSError:
567
+ return False
568
+
569
+ def isfile(self, path):
570
+ """Is this entry file-like?"""
571
+ try:
572
+ return self.info(path, expand_info=False)["type"] == "file"
573
+ except: # noqa: E722
574
+ return False
575
+
576
+ @property
577
+ def transaction(self):
578
+ """A context within which files are committed together upon exit
579
+
580
+ Requires the file class to implement `.commit()` and `.discard()`
581
+ for the normal and exception cases.
582
+ """
583
+ # Taken from https://github.com/fsspec/filesystem_spec/blob/3fbb6fee33b46cccb015607630843dea049d3243/fsspec/spec.py#L231
584
+ # See https://github.com/huggingface/huggingface_hub/issues/1733
585
+ raise NotImplementedError("Transactional commits are not supported.")
586
+
587
+ def start_transaction(self):
588
+ """Begin write transaction for deferring files, non-context version"""
589
+ # Taken from https://github.com/fsspec/filesystem_spec/blob/3fbb6fee33b46cccb015607630843dea049d3243/fsspec/spec.py#L241
590
+ # See https://github.com/huggingface/huggingface_hub/issues/1733
591
+ raise NotImplementedError("Transactional commits are not supported.")
592
+
593
+
594
+ class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
595
+ def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs):
596
+ super().__init__(fs, path, **kwargs)
597
+ self.fs: HfFileSystem
598
+
599
+ try:
600
+ self.resolved_path = fs.resolve_path(path, revision=revision)
601
+ except FileNotFoundError as e:
602
+ if "w" in kwargs.get("mode", ""):
603
+ raise FileNotFoundError(
604
+ f"{e}.\nMake sure the repository and revision exist before writing data."
605
+ ) from e
606
+
607
+ def __del__(self):
608
+ if not hasattr(self, "resolved_path"):
609
+ # Means that the constructor failed. Nothing to do.
610
+ return
611
+ return super().__del__()
612
+
613
+ def _fetch_range(self, start: int, end: int) -> bytes:
614
+ headers = {
615
+ "range": f"bytes={start}-{end - 1}",
616
+ **self.fs._api._build_hf_headers(),
617
+ }
618
+ url = hf_hub_url(
619
+ repo_id=self.resolved_path.repo_id,
620
+ revision=self.resolved_path.revision,
621
+ filename=self.resolved_path.path_in_repo,
622
+ repo_type=self.resolved_path.repo_type,
623
+ endpoint=self.fs.endpoint,
624
+ )
625
+ r = http_backoff("GET", url, headers=headers)
626
+ hf_raise_for_status(r)
627
+ return r.content
628
+
629
+ def _initiate_upload(self) -> None:
630
+ self.temp_file = tempfile.NamedTemporaryFile(prefix="hffs-", delete=False)
631
+
632
+ def _upload_chunk(self, final: bool = False) -> None:
633
+ self.buffer.seek(0)
634
+ block = self.buffer.read()
635
+ self.temp_file.write(block)
636
+ if final:
637
+ self.temp_file.close()
638
+ self.fs._api.upload_file(
639
+ path_or_fileobj=self.temp_file.name,
640
+ path_in_repo=self.resolved_path.path_in_repo,
641
+ repo_id=self.resolved_path.repo_id,
642
+ token=self.fs.token,
643
+ repo_type=self.resolved_path.repo_type,
644
+ revision=self.resolved_path.revision,
645
+ commit_message=self.kwargs.get("commit_message"),
646
+ commit_description=self.kwargs.get("commit_description"),
647
+ )
648
+ os.remove(self.temp_file.name)
649
+ self.fs.invalidate_cache(
650
+ path=self.resolved_path.unresolve(),
651
+ )
652
+
653
+
654
+ def safe_revision(revision: str) -> str:
655
+ return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision)
656
+
657
+
658
+ def safe_quote(s: str) -> str:
659
+ return quote(s, safe="")
660
+
661
+
662
+ def _raise_file_not_found(path: str, err: Optional[Exception]) -> NoReturn:
663
+ msg = path
664
+ if isinstance(err, RepositoryNotFoundError):
665
+ msg = f"{path} (repository not found)"
666
+ elif isinstance(err, RevisionNotFoundError):
667
+ msg = f"{path} (revision not found)"
668
+ elif isinstance(err, HFValidationError):
669
+ msg = f"{path} (invalid repository id)"
670
+ raise FileNotFoundError(msg) from err
lib/python3.11/site-packages/huggingface_hub/hub_mixin.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Dict, List, Optional, Type, TypeVar, Union
5
+
6
+ from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
7
+ from .file_download import hf_hub_download, is_torch_available
8
+ from .hf_api import HfApi
9
+ from .utils import HfHubHTTPError, SoftTemporaryDirectory, logging, validate_hf_hub_args
10
+
11
+
12
+ if is_torch_available():
13
+ import torch # type: ignore
14
+
15
+ logger = logging.get_logger(__name__)
16
+
17
+ # Generic variable that is either ModelHubMixin or a subclass thereof
18
+ T = TypeVar("T", bound="ModelHubMixin")
19
+
20
+
21
+ class ModelHubMixin:
22
+ """
23
+ A generic mixin to integrate ANY machine learning framework with the Hub.
24
+
25
+ To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models
26
+ have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example
27
+ of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.
28
+ """
29
+
30
+ def save_pretrained(
31
+ self,
32
+ save_directory: Union[str, Path],
33
+ *,
34
+ config: Optional[dict] = None,
35
+ repo_id: Optional[str] = None,
36
+ push_to_hub: bool = False,
37
+ **kwargs,
38
+ ) -> Optional[str]:
39
+ """
40
+ Save weights in local directory.
41
+
42
+ Args:
43
+ save_directory (`str` or `Path`):
44
+ Path to directory in which the model weights and configuration will be saved.
45
+ config (`dict`, *optional*):
46
+ Model configuration specified as a key/value dictionary.
47
+ push_to_hub (`bool`, *optional*, defaults to `False`):
48
+ Whether or not to push your model to the Huggingface Hub after saving it.
49
+ repo_id (`str`, *optional*):
50
+ ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
51
+ not provided.
52
+ kwargs:
53
+ Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
54
+ """
55
+ save_directory = Path(save_directory)
56
+ save_directory.mkdir(parents=True, exist_ok=True)
57
+
58
+ # saving model weights/files
59
+ self._save_pretrained(save_directory)
60
+
61
+ # saving config
62
+ if isinstance(config, dict):
63
+ (save_directory / CONFIG_NAME).write_text(json.dumps(config))
64
+
65
+ if push_to_hub:
66
+ kwargs = kwargs.copy() # soft-copy to avoid mutating input
67
+ if config is not None: # kwarg for `push_to_hub`
68
+ kwargs["config"] = config
69
+ if repo_id is None:
70
+ repo_id = save_directory.name # Defaults to `save_directory` name
71
+ return self.push_to_hub(repo_id=repo_id, **kwargs)
72
+ return None
73
+
74
+ def _save_pretrained(self, save_directory: Path) -> None:
75
+ """
76
+ Overwrite this method in subclass to define how to save your model.
77
+ Check out our [integration guide](../guides/integrations) for instructions.
78
+
79
+ Args:
80
+ save_directory (`str` or `Path`):
81
+ Path to directory in which the model weights and configuration will be saved.
82
+ """
83
+ raise NotImplementedError
84
+
85
+ @classmethod
86
+ @validate_hf_hub_args
87
+ def from_pretrained(
88
+ cls: Type[T],
89
+ pretrained_model_name_or_path: Union[str, Path],
90
+ *,
91
+ force_download: bool = False,
92
+ resume_download: bool = False,
93
+ proxies: Optional[Dict] = None,
94
+ token: Optional[Union[str, bool]] = None,
95
+ cache_dir: Optional[Union[str, Path]] = None,
96
+ local_files_only: bool = False,
97
+ revision: Optional[str] = None,
98
+ **model_kwargs,
99
+ ) -> T:
100
+ """
101
+ Download a model from the Huggingface Hub and instantiate it.
102
+
103
+ Args:
104
+ pretrained_model_name_or_path (`str`, `Path`):
105
+ - Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`.
106
+ - Or a path to a `directory` containing model weights saved using
107
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`.
108
+ revision (`str`, *optional*):
109
+ Revision of the model on the Hub. Can be a branch name, a git tag or any commit id.
110
+ Defaults to the latest commit on `main` branch.
111
+ force_download (`bool`, *optional*, defaults to `False`):
112
+ Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
113
+ the existing cache.
114
+ resume_download (`bool`, *optional*, defaults to `False`):
115
+ Whether to delete incompletely received files. Will attempt to resume the download if such a file exists.
116
+ proxies (`Dict[str, str]`, *optional*):
117
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
118
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
119
+ token (`str` or `bool`, *optional*):
120
+ The token to use as HTTP bearer authorization for remote files. By default, it will use the token
121
+ cached when running `huggingface-cli login`.
122
+ cache_dir (`str`, `Path`, *optional*):
123
+ Path to the folder where cached files are stored.
124
+ local_files_only (`bool`, *optional*, defaults to `False`):
125
+ If `True`, avoid downloading the file and return the path to the local cached file if it exists.
126
+ model_kwargs (`Dict`, *optional*):
127
+ Additional kwargs to pass to the model during initialization.
128
+ """
129
+ model_id = pretrained_model_name_or_path
130
+ config_file: Optional[str] = None
131
+ if os.path.isdir(model_id):
132
+ if CONFIG_NAME in os.listdir(model_id):
133
+ config_file = os.path.join(model_id, CONFIG_NAME)
134
+ else:
135
+ logger.warning(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
136
+ elif isinstance(model_id, str):
137
+ try:
138
+ config_file = hf_hub_download(
139
+ repo_id=str(model_id),
140
+ filename=CONFIG_NAME,
141
+ revision=revision,
142
+ cache_dir=cache_dir,
143
+ force_download=force_download,
144
+ proxies=proxies,
145
+ resume_download=resume_download,
146
+ token=token,
147
+ local_files_only=local_files_only,
148
+ )
149
+ except HfHubHTTPError:
150
+ logger.info(f"{CONFIG_NAME} not found in HuggingFace Hub.")
151
+
152
+ if config_file is not None:
153
+ with open(config_file, "r", encoding="utf-8") as f:
154
+ config = json.load(f)
155
+ model_kwargs.update({"config": config})
156
+
157
+ return cls._from_pretrained(
158
+ model_id=str(model_id),
159
+ revision=revision,
160
+ cache_dir=cache_dir,
161
+ force_download=force_download,
162
+ proxies=proxies,
163
+ resume_download=resume_download,
164
+ local_files_only=local_files_only,
165
+ token=token,
166
+ **model_kwargs,
167
+ )
168
+
169
+ @classmethod
170
+ def _from_pretrained(
171
+ cls: Type[T],
172
+ *,
173
+ model_id: str,
174
+ revision: Optional[str],
175
+ cache_dir: Optional[Union[str, Path]],
176
+ force_download: bool,
177
+ proxies: Optional[Dict],
178
+ resume_download: bool,
179
+ local_files_only: bool,
180
+ token: Optional[Union[str, bool]],
181
+ **model_kwargs,
182
+ ) -> T:
183
+ """Overwrite this method in subclass to define how to load your model from pretrained.
184
+
185
+ Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most
186
+ args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this
187
+ method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location`
188
+ parameter to set on which device the model should be loaded.
189
+
190
+ Check out our [integration guide](../guides/integrations) for more instructions.
191
+
192
+ Args:
193
+ model_id (`str`):
194
+ ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`).
195
+ revision (`str`, *optional*):
196
+ Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the
197
+ latest commit on `main` branch.
198
+ force_download (`bool`, *optional*, defaults to `False`):
199
+ Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
200
+ the existing cache.
201
+ resume_download (`bool`, *optional*, defaults to `False`):
202
+ Whether to delete incompletely received files. Will attempt to resume the download if such a file exists.
203
+ proxies (`Dict[str, str]`, *optional*):
204
+ A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
205
+ 'http://hostname': 'foo.bar:4012'}`).
206
+ token (`str` or `bool`, *optional*):
207
+ The token to use as HTTP bearer authorization for remote files. By default, it will use the token
208
+ cached when running `huggingface-cli login`.
209
+ cache_dir (`str`, `Path`, *optional*):
210
+ Path to the folder where cached files are stored.
211
+ local_files_only (`bool`, *optional*, defaults to `False`):
212
+ If `True`, avoid downloading the file and return the path to the local cached file if it exists.
213
+ model_kwargs:
214
+ Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method.
215
+ """
216
+ raise NotImplementedError
217
+
218
+ @validate_hf_hub_args
219
+ def push_to_hub(
220
+ self,
221
+ repo_id: str,
222
+ *,
223
+ config: Optional[dict] = None,
224
+ commit_message: str = "Push model using huggingface_hub.",
225
+ private: bool = False,
226
+ api_endpoint: Optional[str] = None,
227
+ token: Optional[str] = None,
228
+ branch: Optional[str] = None,
229
+ create_pr: Optional[bool] = None,
230
+ allow_patterns: Optional[Union[List[str], str]] = None,
231
+ ignore_patterns: Optional[Union[List[str], str]] = None,
232
+ delete_patterns: Optional[Union[List[str], str]] = None,
233
+ ) -> str:
234
+ """
235
+ Upload model checkpoint to the Hub.
236
+
237
+ Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
238
+ `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
239
+ details.
240
+
241
+
242
+ Args:
243
+ repo_id (`str`):
244
+ ID of the repository to push to (example: `"username/my-model"`).
245
+ config (`dict`, *optional*):
246
+ Configuration object to be saved alongside the model weights.
247
+ commit_message (`str`, *optional*):
248
+ Message to commit while pushing.
249
+ private (`bool`, *optional*, defaults to `False`):
250
+ Whether the repository created should be private.
251
+ api_endpoint (`str`, *optional*):
252
+ The API endpoint to use when pushing the model to the hub.
253
+ token (`str`, *optional*):
254
+ The token to use as HTTP bearer authorization for remote files. By default, it will use the token
255
+ cached when running `huggingface-cli login`.
256
+ branch (`str`, *optional*):
257
+ The git branch on which to push the model. This defaults to `"main"`.
258
+ create_pr (`boolean`, *optional*):
259
+ Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
260
+ allow_patterns (`List[str]` or `str`, *optional*):
261
+ If provided, only files matching at least one pattern are pushed.
262
+ ignore_patterns (`List[str]` or `str`, *optional*):
263
+ If provided, files matching any of the patterns are not pushed.
264
+ delete_patterns (`List[str]` or `str`, *optional*):
265
+ If provided, remote files matching any of the patterns will be deleted from the repo.
266
+
267
+ Returns:
268
+ The url of the commit of your model in the given repository.
269
+ """
270
+ api = HfApi(endpoint=api_endpoint, token=token)
271
+ repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
272
+
273
+ # Push the files to the repo in a single commit
274
+ with SoftTemporaryDirectory() as tmp:
275
+ saved_path = Path(tmp) / repo_id
276
+ self.save_pretrained(saved_path, config=config)
277
+ return api.upload_folder(
278
+ repo_id=repo_id,
279
+ repo_type="model",
280
+ folder_path=saved_path,
281
+ commit_message=commit_message,
282
+ revision=branch,
283
+ create_pr=create_pr,
284
+ allow_patterns=allow_patterns,
285
+ ignore_patterns=ignore_patterns,
286
+ delete_patterns=delete_patterns,
287
+ )
288
+
289
+
290
+ class PyTorchModelHubMixin(ModelHubMixin):
291
+ """
292
+ Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model
293
+ is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model,
294
+ you should first set it back in training mode with `model.train()`.
295
+
296
+ Example:
297
+
298
+ ```python
299
+ >>> import torch
300
+ >>> import torch.nn as nn
301
+ >>> from huggingface_hub import PyTorchModelHubMixin
302
+
303
+
304
+ >>> class MyModel(nn.Module, PyTorchModelHubMixin):
305
+ ... def __init__(self):
306
+ ... super().__init__()
307
+ ... self.param = nn.Parameter(torch.rand(3, 4))
308
+ ... self.linear = nn.Linear(4, 5)
309
+
310
+ ... def forward(self, x):
311
+ ... return self.linear(x + self.param)
312
+ >>> model = MyModel()
313
+
314
+ # Save model weights to local directory
315
+ >>> model.save_pretrained("my-awesome-model")
316
+
317
+ # Push model weights to the Hub
318
+ >>> model.push_to_hub("my-awesome-model")
319
+
320
+ # Download and initialize weights from the Hub
321
+ >>> model = MyModel.from_pretrained("username/my-awesome-model")
322
+ ```
323
+ """
324
+
325
+ def _save_pretrained(self, save_directory: Path) -> None:
326
+ """Save weights from a Pytorch model to a local directory."""
327
+ model_to_save = self.module if hasattr(self, "module") else self # type: ignore
328
+ torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME)
329
+
330
+ @classmethod
331
+ def _from_pretrained(
332
+ cls,
333
+ *,
334
+ model_id: str,
335
+ revision: Optional[str],
336
+ cache_dir: Optional[Union[str, Path]],
337
+ force_download: bool,
338
+ proxies: Optional[Dict],
339
+ resume_download: bool,
340
+ local_files_only: bool,
341
+ token: Union[str, bool, None],
342
+ map_location: str = "cpu",
343
+ strict: bool = False,
344
+ **model_kwargs,
345
+ ):
346
+ """Load Pytorch pretrained weights and return the loaded model."""
347
+ if os.path.isdir(model_id):
348
+ print("Loading weights from local directory")
349
+ model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
350
+ else:
351
+ model_file = hf_hub_download(
352
+ repo_id=model_id,
353
+ filename=PYTORCH_WEIGHTS_NAME,
354
+ revision=revision,
355
+ cache_dir=cache_dir,
356
+ force_download=force_download,
357
+ proxies=proxies,
358
+ resume_download=resume_download,
359
+ token=token,
360
+ local_files_only=local_files_only,
361
+ )
362
+ model = cls(**model_kwargs)
363
+
364
+ state_dict = torch.load(model_file, map_location=torch.device(map_location))
365
+ model.load_state_dict(state_dict, strict=strict) # type: ignore
366
+ model.eval() # type: ignore
367
+
368
+ return model
lib/python3.11/site-packages/huggingface_hub/inference/__init__.py ADDED
File without changes
lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (240 Bytes). View file
 
lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_client.cpython-311.pyc ADDED
Binary file (90.8 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_common.cpython-311.pyc ADDED
Binary file (14.6 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_text_generation.cpython-311.pyc ADDED
Binary file (23.9 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/inference/__pycache__/_types.cpython-311.pyc ADDED
Binary file (7.6 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/inference/_client.py ADDED
@@ -0,0 +1,1990 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # Related resources:
17
+ # https://huggingface.co/tasks
18
+ # https://huggingface.co/docs/huggingface.js/inference/README
19
+ # https://github.com/huggingface/huggingface.js/tree/main/packages/inference/src
20
+ # https://github.com/huggingface/text-generation-inference/tree/main/clients/python
21
+ # https://github.com/huggingface/text-generation-inference/blob/main/clients/python/text_generation/client.py
22
+ # https://huggingface.slack.com/archives/C03E4DQ9LAJ/p1680169099087869
23
+ # https://github.com/huggingface/unity-api#tasks
24
+ #
25
+ # Some TODO:
26
+ # - validate inputs/options/parameters? with Pydantic for instance? or only optionally?
27
+ # - add all tasks
28
+ #
29
+ # NOTE: the philosophy of this client is "let's make it as easy as possible to use it, even if less optimized". Some
30
+ # examples of how it translates:
31
+ # - Timeout / Server unavailable is handled by the client in a single "timeout" parameter.
32
+ # - Files can be provided as bytes, file paths, or URLs and the client will try to "guess" the type.
33
+ # - Images are parsed as PIL.Image for easier manipulation.
34
+ # - Provides a "recommended model" for each task => suboptimal but user-wise quicker to get a first script running.
35
+ # - Only the main parameters are publicly exposed. Power users can always read the docs for more options.
36
+ import logging
37
+ import time
38
+ import warnings
39
+ from dataclasses import asdict
40
+ from typing import (
41
+ TYPE_CHECKING,
42
+ Any,
43
+ Dict,
44
+ Iterable,
45
+ List,
46
+ Literal,
47
+ Optional,
48
+ Union,
49
+ overload,
50
+ )
51
+
52
+ from requests import HTTPError
53
+ from requests.structures import CaseInsensitiveDict
54
+
55
+ from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
56
+ from huggingface_hub.inference._common import (
57
+ TASKS_EXPECTING_IMAGES,
58
+ ContentT,
59
+ InferenceTimeoutError,
60
+ ModelStatus,
61
+ _b64_encode,
62
+ _b64_to_image,
63
+ _bytes_to_dict,
64
+ _bytes_to_image,
65
+ _bytes_to_list,
66
+ _fetch_recommended_models,
67
+ _import_numpy,
68
+ _is_tgi_server,
69
+ _open_as_binary,
70
+ _set_as_non_tgi,
71
+ _stream_text_generation_response,
72
+ )
73
+ from huggingface_hub.inference._text_generation import (
74
+ TextGenerationParameters,
75
+ TextGenerationRequest,
76
+ TextGenerationResponse,
77
+ TextGenerationStreamResponse,
78
+ raise_text_generation_error,
79
+ )
80
+ from huggingface_hub.inference._types import (
81
+ ClassificationOutput,
82
+ ConversationalOutput,
83
+ FillMaskOutput,
84
+ ImageSegmentationOutput,
85
+ ObjectDetectionOutput,
86
+ QuestionAnsweringOutput,
87
+ TableQuestionAnsweringOutput,
88
+ TokenClassificationOutput,
89
+ )
90
+ from huggingface_hub.utils import (
91
+ BadRequestError,
92
+ build_hf_headers,
93
+ get_session,
94
+ hf_raise_for_status,
95
+ )
96
+
97
+
98
+ if TYPE_CHECKING:
99
+ import numpy as np
100
+ from PIL import Image
101
+
102
+ logger = logging.getLogger(__name__)
103
+
104
+
105
+ class InferenceClient:
106
+ """
107
+ Initialize a new Inference Client.
108
+
109
+ [`InferenceClient`] aims to provide a unified experience to perform inference. The client can be used
110
+ seamlessly with either the (free) Inference API or self-hosted Inference Endpoints.
111
+
112
+ Args:
113
+ model (`str`, `optional`):
114
+ The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `bigcode/starcoder`
115
+ or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
116
+ automatically selected for the task.
117
+ token (`str`, *optional*):
118
+ Hugging Face token. Will default to the locally saved token. Pass `token=False` if you don't want to send
119
+ your token to the server.
120
+ timeout (`float`, `optional`):
121
+ The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
122
+ API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
123
+ headers (`Dict[str, str]`, `optional`):
124
+ Additional headers to send to the server. By default only the authorization and user-agent headers are sent.
125
+ Values in this dictionary will override the default values.
126
+ cookies (`Dict[str, str]`, `optional`):
127
+ Additional cookies to send to the server.
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ model: Optional[str] = None,
133
+ token: Union[str, bool, None] = None,
134
+ timeout: Optional[float] = None,
135
+ headers: Optional[Dict[str, str]] = None,
136
+ cookies: Optional[Dict[str, str]] = None,
137
+ ) -> None:
138
+ self.model: Optional[str] = model
139
+ self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) # contains 'authorization' + 'user-agent'
140
+ if headers is not None:
141
+ self.headers.update(headers)
142
+ self.cookies = cookies
143
+ self.timeout = timeout
144
+
145
+ def __repr__(self):
146
+ return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"
147
+
148
+ @overload
149
+ def post( # type: ignore[misc]
150
+ self,
151
+ *,
152
+ json: Optional[Union[str, Dict, List]] = None,
153
+ data: Optional[ContentT] = None,
154
+ model: Optional[str] = None,
155
+ task: Optional[str] = None,
156
+ stream: Literal[False] = ...,
157
+ ) -> bytes:
158
+ pass
159
+
160
+ @overload
161
+ def post(
162
+ self,
163
+ *,
164
+ json: Optional[Union[str, Dict, List]] = None,
165
+ data: Optional[ContentT] = None,
166
+ model: Optional[str] = None,
167
+ task: Optional[str] = None,
168
+ stream: Literal[True] = ...,
169
+ ) -> Iterable[bytes]:
170
+ pass
171
+
172
+ def post(
173
+ self,
174
+ *,
175
+ json: Optional[Union[str, Dict, List]] = None,
176
+ data: Optional[ContentT] = None,
177
+ model: Optional[str] = None,
178
+ task: Optional[str] = None,
179
+ stream: bool = False,
180
+ ) -> Union[bytes, Iterable[bytes]]:
181
+ """
182
+ Make a POST request to the inference server.
183
+
184
+ Args:
185
+ json (`Union[str, Dict, List]`, *optional*):
186
+ The JSON data to send in the request body, specific to each task. Defaults to None.
187
+ data (`Union[str, Path, bytes, BinaryIO]`, *optional*):
188
+ The content to send in the request body, specific to each task.
189
+ It can be raw bytes, a pointer to an opened file, a local file path,
190
+ or a URL to an online resource (image, audio file,...). If both `json` and `data` are passed,
191
+ `data` will take precedence. At least `json` or `data` must be provided. Defaults to None.
192
+ model (`str`, *optional*):
193
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
194
+ Inference Endpoint. Will override the model defined at the instance level. Defaults to None.
195
+ task (`str`, *optional*):
196
+ The task to perform on the inference. All available tasks can be found
197
+ [here](https://huggingface.co/tasks). Used only to default to a recommended model if `model` is not
198
+ provided. At least `model` or `task` must be provided. Defaults to None.
199
+ stream (`bool`, *optional*):
200
+ Whether to iterate over streaming APIs.
201
+
202
+ Returns:
203
+ bytes: The raw bytes returned by the server.
204
+
205
+ Raises:
206
+ [`InferenceTimeoutError`]:
207
+ If the model is unavailable or the request times out.
208
+ `HTTPError`:
209
+ If the request fails with an HTTP error status code other than HTTP 503.
210
+ """
211
+ url = self._resolve_url(model, task)
212
+
213
+ if data is not None and json is not None:
214
+ warnings.warn("Ignoring `json` as `data` is passed as binary.")
215
+
216
+ # Set Accept header if relevant
217
+ headers = self.headers.copy()
218
+ if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers:
219
+ headers["Accept"] = "image/png"
220
+
221
+ t0 = time.time()
222
+ timeout = self.timeout
223
+ while True:
224
+ with _open_as_binary(data) as data_as_binary:
225
+ try:
226
+ response = get_session().post(
227
+ url,
228
+ json=json,
229
+ data=data_as_binary,
230
+ headers=headers,
231
+ cookies=self.cookies,
232
+ timeout=self.timeout,
233
+ stream=stream,
234
+ )
235
+ except TimeoutError as error:
236
+ # Convert any `TimeoutError` to a `InferenceTimeoutError`
237
+ raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore
238
+
239
+ try:
240
+ hf_raise_for_status(response)
241
+ return response.iter_lines() if stream else response.content
242
+ except HTTPError as error:
243
+ if error.response.status_code == 422 and task is not None:
244
+ error.args = (
245
+ f"{error.args[0]}\nMake sure '{task}' task is supported by the model.",
246
+ ) + error.args[1:]
247
+ if error.response.status_code == 503:
248
+ # If Model is unavailable, either raise a TimeoutError...
249
+ if timeout is not None and time.time() - t0 > timeout:
250
+ raise InferenceTimeoutError(
251
+ f"Model not loaded on the server: {url}. Please retry with a higher timeout (current:"
252
+ f" {self.timeout}).",
253
+ request=error.request,
254
+ response=error.response,
255
+ ) from error
256
+ # ...or wait 1s and retry
257
+ logger.info(f"Waiting for model to be loaded on the server: {error}")
258
+ time.sleep(1)
259
+ if timeout is not None:
260
+ timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
261
+ continue
262
+ raise
263
+
264
+ def audio_classification(
265
+ self,
266
+ audio: ContentT,
267
+ *,
268
+ model: Optional[str] = None,
269
+ ) -> List[ClassificationOutput]:
270
+ """
271
+ Perform audio classification on the provided audio content.
272
+
273
+ Args:
274
+ audio (Union[str, Path, bytes, BinaryIO]):
275
+ The audio content to classify. It can be raw audio bytes, a local audio file, or a URL pointing to an
276
+ audio file.
277
+ model (`str`, *optional*):
278
+ The model to use for audio classification. Can be a model ID hosted on the Hugging Face Hub
279
+ or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for
280
+ audio classification will be used.
281
+
282
+ Returns:
283
+ `List[Dict]`: The classification output containing the predicted label and its confidence.
284
+
285
+ Raises:
286
+ [`InferenceTimeoutError`]:
287
+ If the model is unavailable or the request times out.
288
+ `HTTPError`:
289
+ If the request fails with an HTTP error status code other than HTTP 503.
290
+
291
+ Example:
292
+ ```py
293
+ >>> from huggingface_hub import InferenceClient
294
+ >>> client = InferenceClient()
295
+ >>> client.audio_classification("audio.flac")
296
+ [{'score': 0.4976358711719513, 'label': 'hap'}, {'score': 0.3677836060523987, 'label': 'neu'},...]
297
+ ```
298
+ """
299
+ response = self.post(data=audio, model=model, task="audio-classification")
300
+ return _bytes_to_list(response)
301
+
302
+ def automatic_speech_recognition(
303
+ self,
304
+ audio: ContentT,
305
+ *,
306
+ model: Optional[str] = None,
307
+ ) -> str:
308
+ """
309
+ Perform automatic speech recognition (ASR or audio-to-text) on the given audio content.
310
+
311
+ Args:
312
+ audio (Union[str, Path, bytes, BinaryIO]):
313
+ The content to transcribe. It can be raw audio bytes, local audio file, or a URL to an audio file.
314
+ model (`str`, *optional*):
315
+ The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
316
+ Inference Endpoint. If not provided, the default recommended model for ASR will be used.
317
+
318
+ Returns:
319
+ str: The transcribed text.
320
+
321
+ Raises:
322
+ [`InferenceTimeoutError`]:
323
+ If the model is unavailable or the request times out.
324
+ `HTTPError`:
325
+ If the request fails with an HTTP error status code other than HTTP 503.
326
+
327
+ Example:
328
+ ```py
329
+ >>> from huggingface_hub import InferenceClient
330
+ >>> client = InferenceClient()
331
+ >>> client.automatic_speech_recognition("hello_world.flac")
332
+ "hello world"
333
+ ```
334
+ """
335
+ response = self.post(data=audio, model=model, task="automatic-speech-recognition")
336
+ return _bytes_to_dict(response)["text"]
337
+
338
+ def conversational(
339
+ self,
340
+ text: str,
341
+ generated_responses: Optional[List[str]] = None,
342
+ past_user_inputs: Optional[List[str]] = None,
343
+ *,
344
+ parameters: Optional[Dict[str, Any]] = None,
345
+ model: Optional[str] = None,
346
+ ) -> ConversationalOutput:
347
+ """
348
+ Generate conversational responses based on the given input text (i.e. chat with the API).
349
+
350
+ Args:
351
+ text (`str`):
352
+ The last input from the user in the conversation.
353
+ generated_responses (`List[str]`, *optional*):
354
+ A list of strings corresponding to the earlier replies from the model. Defaults to None.
355
+ past_user_inputs (`List[str]`, *optional*):
356
+ A list of strings corresponding to the earlier replies from the user. Should be the same length as
357
+ `generated_responses`. Defaults to None.
358
+ parameters (`Dict[str, Any]`, *optional*):
359
+ Additional parameters for the conversational task. Defaults to None. For more details about the available
360
+ parameters, please refer to [this page](https://huggingface.co/docs/api-inference/detailed_parameters#conversational-task)
361
+ model (`str`, *optional*):
362
+ The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
363
+ a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
364
+ Defaults to None.
365
+
366
+ Returns:
367
+ `Dict`: The generated conversational output.
368
+
369
+ Raises:
370
+ [`InferenceTimeoutError`]:
371
+ If the model is unavailable or the request times out.
372
+ `HTTPError`:
373
+ If the request fails with an HTTP error status code other than HTTP 503.
374
+
375
+ Example:
376
+ ```py
377
+ >>> from huggingface_hub import InferenceClient
378
+ >>> client = InferenceClient()
379
+ >>> output = client.conversational("Hi, who are you?")
380
+ >>> output
381
+ {'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.']}
382
+ >>> client.conversational(
383
+ ... "Wow, that's scary!",
384
+ ... generated_responses=output["conversation"]["generated_responses"],
385
+ ... past_user_inputs=output["conversation"]["past_user_inputs"],
386
+ ... )
387
+ ```
388
+ """
389
+ payload: Dict[str, Any] = {"inputs": {"text": text}}
390
+ if generated_responses is not None:
391
+ payload["inputs"]["generated_responses"] = generated_responses
392
+ if past_user_inputs is not None:
393
+ payload["inputs"]["past_user_inputs"] = past_user_inputs
394
+ if parameters is not None:
395
+ payload["parameters"] = parameters
396
+ response = self.post(json=payload, model=model, task="conversational")
397
+ return _bytes_to_dict(response) # type: ignore
398
+
399
+ def visual_question_answering(
400
+ self,
401
+ image: ContentT,
402
+ question: str,
403
+ *,
404
+ model: Optional[str] = None,
405
+ ) -> List[str]:
406
+ """
407
+ Answering open-ended questions based on an image.
408
+
409
+ Args:
410
+ image (`Union[str, Path, bytes, BinaryIO]`):
411
+ The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
412
+ question (`str`):
413
+ Question to be answered.
414
+ model (`str`, *optional*):
415
+ The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
416
+ a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used.
417
+ Defaults to None.
418
+
419
+ Returns:
420
+ `List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
421
+
422
+ Raises:
423
+ `InferenceTimeoutError`:
424
+ If the model is unavailable or the request times out.
425
+ `HTTPError`:
426
+ If the request fails with an HTTP error status code other than HTTP 503.
427
+
428
+ Example:
429
+ ```py
430
+ >>> from huggingface_hub import InferenceClient
431
+ >>> client = InferenceClient()
432
+ >>> client.visual_question_answering(
433
+ ... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg",
434
+ ... question="What is the animal doing?"
435
+ ... )
436
+ [{'score': 0.778609573841095, 'answer': 'laying down'},{'score': 0.6957435607910156, 'answer': 'sitting'}, ...]
437
+ ```
438
+ """
439
+ payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
440
+ response = self.post(json=payload, model=model, task="visual-question-answering")
441
+ return _bytes_to_list(response)
442
+
443
+ def document_question_answering(
444
+ self,
445
+ image: ContentT,
446
+ question: str,
447
+ *,
448
+ model: Optional[str] = None,
449
+ ) -> List[QuestionAnsweringOutput]:
450
+ """
451
+ Answer questions on document images.
452
+
453
+ Args:
454
+ image (`Union[str, Path, bytes, BinaryIO]`):
455
+ The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
456
+ question (`str`):
457
+ Question to be answered.
458
+ model (`str`, *optional*):
459
+ The model to use for the document question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
460
+ a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used.
461
+ Defaults to None.
462
+
463
+ Returns:
464
+ `List[Dict]`: a list of dictionaries containing the predicted label, associated probability, word ids, and page number.
465
+
466
+ Raises:
467
+ [`InferenceTimeoutError`]:
468
+ If the model is unavailable or the request times out.
469
+ `HTTPError`:
470
+ If the request fails with an HTTP error status code other than HTTP 503.
471
+
472
+ Example:
473
+ ```py
474
+ >>> from huggingface_hub import InferenceClient
475
+ >>> client = InferenceClient()
476
+ >>> client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?")
477
+ [{'score': 0.42515629529953003, 'answer': 'us-001', 'start': 16, 'end': 16}]
478
+ ```
479
+ """
480
+ payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
481
+ response = self.post(json=payload, model=model, task="document-question-answering")
482
+ return _bytes_to_list(response)
483
+
484
+ def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray":
485
+ """
486
+ Generate embeddings for a given text.
487
+
488
+ Args:
489
+ text (`str`):
490
+ The text to embed.
491
+ model (`str`, *optional*):
492
+ The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
493
+ a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
494
+ Defaults to None.
495
+
496
+ Returns:
497
+ `np.ndarray`: The embedding representing the input text as a float32 numpy array.
498
+
499
+ Raises:
500
+ [`InferenceTimeoutError`]:
501
+ If the model is unavailable or the request times out.
502
+ `HTTPError`:
503
+ If the request fails with an HTTP error status code other than HTTP 503.
504
+
505
+ Example:
506
+ ```py
507
+ >>> from huggingface_hub import InferenceClient
508
+ >>> client = InferenceClient()
509
+ >>> client.feature_extraction("Hi, who are you?")
510
+ array([[ 2.424802 , 2.93384 , 1.1750331 , ..., 1.240499, -0.13776633, -0.7889173 ],
511
+ [-0.42943227, -0.6364878 , -1.693462 , ..., 0.41978157, -2.4336355 , 0.6162071 ],
512
+ ...,
513
+ [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
514
+ ```
515
+ """
516
+ response = self.post(json={"inputs": text}, model=model, task="feature-extraction")
517
+ np = _import_numpy()
518
+ return np.array(_bytes_to_dict(response), dtype="float32")
519
+
520
+ def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[FillMaskOutput]:
521
+ """
522
+ Fill in a hole with a missing word (token to be precise).
523
+
524
+ Args:
525
+ text (`str`):
526
+ a string to be filled from, must contain the [MASK] token (check model card for exact name of the mask).
527
+ model (`str`, *optional*):
528
+ The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to
529
+ a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used.
530
+ Defaults to None.
531
+
532
+ Returns:
533
+ `List[Dict]`: a list of fill mask output dictionaries containing the predicted label, associated
534
+ probability, token reference, and completed text.
535
+
536
+ Raises:
537
+ [`InferenceTimeoutError`]:
538
+ If the model is unavailable or the request times out.
539
+ `HTTPError`:
540
+ If the request fails with an HTTP error status code other than HTTP 503.
541
+
542
+ Example:
543
+ ```py
544
+ >>> from huggingface_hub import InferenceClient
545
+ >>> client = InferenceClient()
546
+ >>> client.fill_mask("The goal of life is <mask>.")
547
+ [{'score': 0.06897063553333282,
548
+ 'token': 11098,
549
+ 'token_str': ' happiness',
550
+ 'sequence': 'The goal of life is happiness.'},
551
+ {'score': 0.06554922461509705,
552
+ 'token': 45075,
553
+ 'token_str': ' immortality',
554
+ 'sequence': 'The goal of life is immortality.'}]
555
+ ```
556
+ """
557
+ response = self.post(json={"inputs": text}, model=model, task="fill-mask")
558
+ return _bytes_to_list(response)
559
+
560
+ def image_classification(
561
+ self,
562
+ image: ContentT,
563
+ *,
564
+ model: Optional[str] = None,
565
+ ) -> List[ClassificationOutput]:
566
+ """
567
+ Perform image classification on the given image using the specified model.
568
+
569
+ Args:
570
+ image (`Union[str, Path, bytes, BinaryIO]`):
571
+ The image to classify. It can be raw bytes, an image file, or a URL to an online image.
572
+ model (`str`, *optional*):
573
+ The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a
574
+ deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used.
575
+
576
+ Returns:
577
+ `List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
578
+
579
+ Raises:
580
+ [`InferenceTimeoutError`]:
581
+ If the model is unavailable or the request times out.
582
+ `HTTPError`:
583
+ If the request fails with an HTTP error status code other than HTTP 503.
584
+
585
+ Example:
586
+ ```py
587
+ >>> from huggingface_hub import InferenceClient
588
+ >>> client = InferenceClient()
589
+ >>> client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg")
590
+ [{'score': 0.9779096841812134, 'label': 'Blenheim spaniel'}, ...]
591
+ ```
592
+ """
593
+ response = self.post(data=image, model=model, task="image-classification")
594
+ return _bytes_to_list(response)
595
+
596
+ def image_segmentation(
597
+ self,
598
+ image: ContentT,
599
+ *,
600
+ model: Optional[str] = None,
601
+ ) -> List[ImageSegmentationOutput]:
602
+ """
603
+ Perform image segmentation on the given image using the specified model.
604
+
605
+ <Tip warning={true}>
606
+
607
+ You must have `PIL` installed if you want to work with images (`pip install Pillow`).
608
+
609
+ </Tip>
610
+
611
+ Args:
612
+ image (`Union[str, Path, bytes, BinaryIO]`):
613
+ The image to segment. It can be raw bytes, an image file, or a URL to an online image.
614
+ model (`str`, *optional*):
615
+ The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a
616
+ deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used.
617
+
618
+ Returns:
619
+ `List[Dict]`: A list of dictionaries containing the segmented masks and associated attributes.
620
+
621
+ Raises:
622
+ [`InferenceTimeoutError`]:
623
+ If the model is unavailable or the request times out.
624
+ `HTTPError`:
625
+ If the request fails with an HTTP error status code other than HTTP 503.
626
+
627
+ Example:
628
+ ```py
629
+ >>> from huggingface_hub import InferenceClient
630
+ >>> client = InferenceClient()
631
+ >>> client.image_segmentation("cat.jpg"):
632
+ [{'score': 0.989008, 'label': 'LABEL_184', 'mask': <PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>}, ...]
633
+ ```
634
+ """
635
+
636
+ # Segment
637
+ response = self.post(data=image, model=model, task="image-segmentation")
638
+ output = _bytes_to_dict(response)
639
+
640
+ # Parse masks as PIL Image
641
+ if not isinstance(output, list):
642
+ raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
643
+ for item in output:
644
+ item["mask"] = _b64_to_image(item["mask"])
645
+ return output
646
+
647
+ def image_to_image(
648
+ self,
649
+ image: ContentT,
650
+ prompt: Optional[str] = None,
651
+ *,
652
+ negative_prompt: Optional[str] = None,
653
+ height: Optional[int] = None,
654
+ width: Optional[int] = None,
655
+ num_inference_steps: Optional[int] = None,
656
+ guidance_scale: Optional[float] = None,
657
+ model: Optional[str] = None,
658
+ **kwargs,
659
+ ) -> "Image":
660
+ """
661
+ Perform image-to-image translation using a specified model.
662
+
663
+ <Tip warning={true}>
664
+
665
+ You must have `PIL` installed if you want to work with images (`pip install Pillow`).
666
+
667
+ </Tip>
668
+
669
+ Args:
670
+ image (`Union[str, Path, bytes, BinaryIO]`):
671
+ The input image for translation. It can be raw bytes, an image file, or a URL to an online image.
672
+ prompt (`str`, *optional*):
673
+ The text prompt to guide the image generation.
674
+ negative_prompt (`str`, *optional*):
675
+ A negative prompt to guide the translation process.
676
+ height (`int`, *optional*):
677
+ The height in pixels of the generated image.
678
+ width (`int`, *optional*):
679
+ The width in pixels of the generated image.
680
+ num_inference_steps (`int`, *optional*):
681
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
682
+ expense of slower inference.
683
+ guidance_scale (`float`, *optional*):
684
+ Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
685
+ usually at the expense of lower image quality.
686
+ model (`str`, *optional*):
687
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
688
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
689
+
690
+ Returns:
691
+ `Image`: The translated image.
692
+
693
+ Raises:
694
+ [`InferenceTimeoutError`]:
695
+ If the model is unavailable or the request times out.
696
+ `HTTPError`:
697
+ If the request fails with an HTTP error status code other than HTTP 503.
698
+
699
+ Example:
700
+ ```py
701
+ >>> from huggingface_hub import InferenceClient
702
+ >>> client = InferenceClient()
703
+ >>> image = client.image_to_image("cat.jpg", prompt="turn the cat into a tiger")
704
+ >>> image.save("tiger.jpg")
705
+ ```
706
+ """
707
+ parameters = {
708
+ "prompt": prompt,
709
+ "negative_prompt": negative_prompt,
710
+ "height": height,
711
+ "width": width,
712
+ "num_inference_steps": num_inference_steps,
713
+ "guidance_scale": guidance_scale,
714
+ **kwargs,
715
+ }
716
+ if all(parameter is None for parameter in parameters.values()):
717
+ # Either only an image to send => send as raw bytes
718
+ data = image
719
+ payload: Optional[Dict[str, Any]] = None
720
+ else:
721
+ # Or an image + some parameters => use base64 encoding
722
+ data = None
723
+ payload = {"inputs": _b64_encode(image)}
724
+ for key, value in parameters.items():
725
+ if value is not None:
726
+ payload.setdefault("parameters", {})[key] = value
727
+
728
+ response = self.post(json=payload, data=data, model=model, task="image-to-image")
729
+ return _bytes_to_image(response)
730
+
731
+ def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> str:
732
+ """
733
+ Takes an input image and return text.
734
+
735
+ Models can have very different outputs depending on your use case (image captioning, optical character recognition
736
+ (OCR), Pix2Struct, etc). Please have a look to the model card to learn more about a model's specificities.
737
+
738
+ Args:
739
+ image (`Union[str, Path, bytes, BinaryIO]`):
740
+ The input image to caption. It can be raw bytes, an image file, or a URL to an online image..
741
+ model (`str`, *optional*):
742
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
743
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
744
+
745
+ Returns:
746
+ `str`: The generated text.
747
+
748
+ Raises:
749
+ [`InferenceTimeoutError`]:
750
+ If the model is unavailable or the request times out.
751
+ `HTTPError`:
752
+ If the request fails with an HTTP error status code other than HTTP 503.
753
+
754
+ Example:
755
+ ```py
756
+ >>> from huggingface_hub import InferenceClient
757
+ >>> client = InferenceClient()
758
+ >>> client.image_to_text("cat.jpg")
759
+ 'a cat standing in a grassy field '
760
+ >>> client.image_to_text("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg")
761
+ 'a dog laying on the grass next to a flower pot '
762
+ ```
763
+ """
764
+ response = self.post(data=image, model=model, task="image-to-text")
765
+ return _bytes_to_dict(response)[0]["generated_text"]
766
+
767
+ def list_deployed_models(
768
+ self, frameworks: Union[None, str, Literal["all"], List[str]] = None
769
+ ) -> Dict[str, List[str]]:
770
+ """
771
+ List models currently deployed on the Inference API service.
772
+
773
+ This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
774
+ are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
775
+ specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
776
+ in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
777
+ frameworks are checked, the more time it will take.
778
+
779
+ <Tip>
780
+
781
+ This endpoint is mostly useful for discoverability. If you already know which model you want to use and want to
782
+ check its availability, you can directly use [`~InferenceClient.get_model_status`].
783
+
784
+ </Tip>
785
+
786
+ Args:
787
+ frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
788
+ The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
789
+ "all", all available frameworks will be tested. It is also possible to provide a single framework or a
790
+ custom set of frameworks to check.
791
+
792
+ Returns:
793
+ `Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
794
+
795
+ Example:
796
+ ```python
797
+ >>> from huggingface_hub import InferenceClient
798
+ >>> client = InferenceClient()
799
+
800
+ # Discover zero-shot-classification models currently deployed
801
+ >>> models = client.list_deployed_models()
802
+ >>> models["zero-shot-classification"]
803
+ ['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
804
+
805
+ # List from only 1 framework
806
+ >>> client.list_deployed_models("text-generation-inference")
807
+ {'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
808
+ ```
809
+ """
810
+ # Resolve which frameworks to check
811
+ if frameworks is None:
812
+ frameworks = MAIN_INFERENCE_API_FRAMEWORKS
813
+ elif frameworks == "all":
814
+ frameworks = ALL_INFERENCE_API_FRAMEWORKS
815
+ elif isinstance(frameworks, str):
816
+ frameworks = [frameworks]
817
+ frameworks = list(set(frameworks))
818
+
819
+ # Fetch them iteratively
820
+ models_by_task: Dict[str, List[str]] = {}
821
+
822
+ def _unpack_response(framework: str, items: List[Dict]) -> None:
823
+ for model in items:
824
+ if framework == "sentence-transformers":
825
+ # Model running with the `sentence-transformers` framework can work with both tasks even if not
826
+ # branded as such in the API response
827
+ models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
828
+ models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
829
+ else:
830
+ models_by_task.setdefault(model["task"], []).append(model["model_id"])
831
+
832
+ for framework in frameworks:
833
+ response = get_session().get(f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=self.headers)
834
+ hf_raise_for_status(response)
835
+ _unpack_response(framework, response.json())
836
+
837
+ # Sort alphabetically for discoverability and return
838
+ for task, models in models_by_task.items():
839
+ models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
840
+ return models_by_task
841
+
842
+ def object_detection(
843
+ self,
844
+ image: ContentT,
845
+ *,
846
+ model: Optional[str] = None,
847
+ ) -> List[ObjectDetectionOutput]:
848
+ """
849
+ Perform object detection on the given image using the specified model.
850
+
851
+ <Tip warning={true}>
852
+
853
+ You must have `PIL` installed if you want to work with images (`pip install Pillow`).
854
+
855
+ </Tip>
856
+
857
+ Args:
858
+ image (`Union[str, Path, bytes, BinaryIO]`):
859
+ The image to detect objects on. It can be raw bytes, an image file, or a URL to an online image.
860
+ model (`str`, *optional*):
861
+ The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a
862
+ deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used.
863
+
864
+ Returns:
865
+ `List[ObjectDetectionOutput]`: A list of dictionaries containing the bounding boxes and associated attributes.
866
+
867
+ Raises:
868
+ [`InferenceTimeoutError`]:
869
+ If the model is unavailable or the request times out.
870
+ `HTTPError`:
871
+ If the request fails with an HTTP error status code other than HTTP 503.
872
+ `ValueError`:
873
+ If the request output is not a List.
874
+
875
+ Example:
876
+ ```py
877
+ >>> from huggingface_hub import InferenceClient
878
+ >>> client = InferenceClient()
879
+ >>> client.object_detection("people.jpg"):
880
+ [{"score":0.9486683011054993,"label":"person","box":{"xmin":59,"ymin":39,"xmax":420,"ymax":510}}, ... ]
881
+ ```
882
+ """
883
+ # detect objects
884
+ response = self.post(data=image, model=model, task="object-detection")
885
+ output = _bytes_to_dict(response)
886
+ if not isinstance(output, list):
887
+ raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
888
+ return output
889
+
890
+ def question_answering(
891
+ self, question: str, context: str, *, model: Optional[str] = None
892
+ ) -> QuestionAnsweringOutput:
893
+ """
894
+ Retrieve the answer to a question from a given text.
895
+
896
+ Args:
897
+ question (`str`):
898
+ Question to be answered.
899
+ context (`str`):
900
+ The context of the question.
901
+ model (`str`):
902
+ The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
903
+ a deployed Inference Endpoint.
904
+
905
+ Returns:
906
+ `Dict`: a dictionary of question answering output containing the score, start index, end index, and answer.
907
+
908
+ Raises:
909
+ [`InferenceTimeoutError`]:
910
+ If the model is unavailable or the request times out.
911
+ `HTTPError`:
912
+ If the request fails with an HTTP error status code other than HTTP 503.
913
+
914
+ Example:
915
+ ```py
916
+ >>> from huggingface_hub import InferenceClient
917
+ >>> client = InferenceClient()
918
+ >>> client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.")
919
+ {'score': 0.9326562285423279, 'start': 11, 'end': 16, 'answer': 'Clara'}
920
+ ```
921
+ """
922
+
923
+ payload: Dict[str, Any] = {"question": question, "context": context}
924
+ response = self.post(
925
+ json=payload,
926
+ model=model,
927
+ task="question-answering",
928
+ )
929
+ return _bytes_to_dict(response) # type: ignore
930
+
931
+ def sentence_similarity(
932
+ self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None
933
+ ) -> List[float]:
934
+ """
935
+ Compute the semantic similarity between a sentence and a list of other sentences by comparing their embeddings.
936
+
937
+ Args:
938
+ sentence (`str`):
939
+ The main sentence to compare to others.
940
+ other_sentences (`List[str]`):
941
+ The list of sentences to compare to.
942
+ model (`str`, *optional*):
943
+ The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
944
+ a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
945
+ Defaults to None.
946
+
947
+ Returns:
948
+ `List[float]`: The embedding representing the input text.
949
+
950
+ Raises:
951
+ [`InferenceTimeoutError`]:
952
+ If the model is unavailable or the request times out.
953
+ `HTTPError`:
954
+ If the request fails with an HTTP error status code other than HTTP 503.
955
+
956
+ Example:
957
+ ```py
958
+ >>> from huggingface_hub import InferenceClient
959
+ >>> client = InferenceClient()
960
+ >>> client.sentence_similarity(
961
+ ... "Machine learning is so easy.",
962
+ ... other_sentences=[
963
+ ... "Deep learning is so straightforward.",
964
+ ... "This is so difficult, like rocket science.",
965
+ ... "I can't believe how much I struggled with this.",
966
+ ... ],
967
+ ... )
968
+ [0.7785726189613342, 0.45876261591911316, 0.2906220555305481]
969
+ ```
970
+ """
971
+ response = self.post(
972
+ json={"inputs": {"source_sentence": sentence, "sentences": other_sentences}},
973
+ model=model,
974
+ task="sentence-similarity",
975
+ )
976
+ return _bytes_to_list(response)
977
+
978
+ def summarization(
979
+ self,
980
+ text: str,
981
+ *,
982
+ parameters: Optional[Dict[str, Any]] = None,
983
+ model: Optional[str] = None,
984
+ ) -> str:
985
+ """
986
+ Generate a summary of a given text using a specified model.
987
+
988
+ Args:
989
+ text (`str`):
990
+ The input text to summarize.
991
+ parameters (`Dict[str, Any]`, *optional*):
992
+ Additional parameters for summarization. Check out this [page](https://huggingface.co/docs/api-inference/detailed_parameters#summarization-task)
993
+ for more details.
994
+ model (`str`, *optional*):
995
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
996
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
997
+
998
+ Returns:
999
+ `str`: The generated summary text.
1000
+
1001
+ Raises:
1002
+ [`InferenceTimeoutError`]:
1003
+ If the model is unavailable or the request times out.
1004
+ `HTTPError`:
1005
+ If the request fails with an HTTP error status code other than HTTP 503.
1006
+
1007
+ Example:
1008
+ ```py
1009
+ >>> from huggingface_hub import InferenceClient
1010
+ >>> client = InferenceClient()
1011
+ >>> client.summarization("The Eiffel tower...")
1012
+ 'The Eiffel tower is one of the most famous landmarks in the world....'
1013
+ ```
1014
+ """
1015
+ payload: Dict[str, Any] = {"inputs": text}
1016
+ if parameters is not None:
1017
+ payload["parameters"] = parameters
1018
+ response = self.post(json=payload, model=model, task="summarization")
1019
+ return _bytes_to_dict(response)[0]["summary_text"]
1020
+
1021
+ def table_question_answering(
1022
+ self, table: Dict[str, Any], query: str, *, model: Optional[str] = None
1023
+ ) -> TableQuestionAnsweringOutput:
1024
+ """
1025
+ Retrieve the answer to a question from information given in a table.
1026
+
1027
+ Args:
1028
+ table (`str`):
1029
+ A table of data represented as a dict of lists where entries are headers and the lists are all the
1030
+ values, all lists must have the same size.
1031
+ query (`str`):
1032
+ The query in plain text that you want to ask the table.
1033
+ model (`str`):
1034
+ The model to use for the table-question-answering task. Can be a model ID hosted on the Hugging Face
1035
+ Hub or a URL to a deployed Inference Endpoint.
1036
+
1037
+ Returns:
1038
+ `Dict`: a dictionary of table question answering output containing the answer, coordinates, cells and the aggregator used.
1039
+
1040
+ Raises:
1041
+ [`InferenceTimeoutError`]:
1042
+ If the model is unavailable or the request times out.
1043
+ `HTTPError`:
1044
+ If the request fails with an HTTP error status code other than HTTP 503.
1045
+
1046
+ Example:
1047
+ ```py
1048
+ >>> from huggingface_hub import InferenceClient
1049
+ >>> client = InferenceClient()
1050
+ >>> query = "How many stars does the transformers repository have?"
1051
+ >>> table = {"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"]}
1052
+ >>> client.table_question_answering(table, query, model="google/tapas-base-finetuned-wtq")
1053
+ {'answer': 'AVERAGE > 36542', 'coordinates': [[0, 1]], 'cells': ['36542'], 'aggregator': 'AVERAGE'}
1054
+ ```
1055
+ """
1056
+ response = self.post(
1057
+ json={
1058
+ "query": query,
1059
+ "table": table,
1060
+ },
1061
+ model=model,
1062
+ task="table-question-answering",
1063
+ )
1064
+ return _bytes_to_dict(response) # type: ignore
1065
+
1066
+ def tabular_classification(self, table: Dict[str, Any], *, model: str) -> List[str]:
1067
+ """
1068
+ Classifying a target category (a group) based on a set of attributes.
1069
+
1070
+ Args:
1071
+ table (`Dict[str, Any]`):
1072
+ Set of attributes to classify.
1073
+ model (`str`):
1074
+ The model to use for the tabular-classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1075
+ a deployed Inference Endpoint.
1076
+
1077
+ Returns:
1078
+ `List`: a list of labels, one per row in the initial table.
1079
+
1080
+ Raises:
1081
+ [`InferenceTimeoutError`]:
1082
+ If the model is unavailable or the request times out.
1083
+ `HTTPError`:
1084
+ If the request fails with an HTTP error status code other than HTTP 503.
1085
+
1086
+ Example:
1087
+ ```py
1088
+ >>> from huggingface_hub import InferenceClient
1089
+ >>> client = InferenceClient()
1090
+ >>> table = {
1091
+ ... "fixed_acidity": ["7.4", "7.8", "10.3"],
1092
+ ... "volatile_acidity": ["0.7", "0.88", "0.32"],
1093
+ ... "citric_acid": ["0", "0", "0.45"],
1094
+ ... "residual_sugar": ["1.9", "2.6", "6.4"],
1095
+ ... "chlorides": ["0.076", "0.098", "0.073"],
1096
+ ... "free_sulfur_dioxide": ["11", "25", "5"],
1097
+ ... "total_sulfur_dioxide": ["34", "67", "13"],
1098
+ ... "density": ["0.9978", "0.9968", "0.9976"],
1099
+ ... "pH": ["3.51", "3.2", "3.23"],
1100
+ ... "sulphates": ["0.56", "0.68", "0.82"],
1101
+ ... "alcohol": ["9.4", "9.8", "12.6"],
1102
+ ... }
1103
+ >>> client.tabular_classification(table=table, model="julien-c/wine-quality")
1104
+ ["5", "5", "5"]
1105
+ ```
1106
+ """
1107
+ response = self.post(json={"table": table}, model=model, task="tabular-classification")
1108
+ return _bytes_to_list(response)
1109
+
1110
+ def tabular_regression(self, table: Dict[str, Any], *, model: str) -> List[float]:
1111
+ """
1112
+ Predicting a numerical target value given a set of attributes/features in a table.
1113
+
1114
+ Args:
1115
+ table (`Dict[str, Any]`):
1116
+ Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical.
1117
+ model (`str`):
1118
+ The model to use for the tabular-regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1119
+ a deployed Inference Endpoint.
1120
+
1121
+ Returns:
1122
+ `List`: a list of predicted numerical target values.
1123
+
1124
+ Raises:
1125
+ [`InferenceTimeoutError`]:
1126
+ If the model is unavailable or the request times out.
1127
+ `HTTPError`:
1128
+ If the request fails with an HTTP error status code other than HTTP 503.
1129
+
1130
+ Example:
1131
+ ```py
1132
+ >>> from huggingface_hub import InferenceClient
1133
+ >>> client = InferenceClient()
1134
+ >>> table = {
1135
+ ... "Height": ["11.52", "12.48", "12.3778"],
1136
+ ... "Length1": ["23.2", "24", "23.9"],
1137
+ ... "Length2": ["25.4", "26.3", "26.5"],
1138
+ ... "Length3": ["30", "31.2", "31.1"],
1139
+ ... "Species": ["Bream", "Bream", "Bream"],
1140
+ ... "Width": ["4.02", "4.3056", "4.6961"],
1141
+ ... }
1142
+ >>> client.tabular_regression(table, model="scikit-learn/Fish-Weight")
1143
+ [110, 120, 130]
1144
+ ```
1145
+ """
1146
+ response = self.post(json={"table": table}, model=model, task="tabular-regression")
1147
+ return _bytes_to_list(response)
1148
+
1149
+ def text_classification(self, text: str, *, model: Optional[str] = None) -> List[ClassificationOutput]:
1150
+ """
1151
+ Perform text classification (e.g. sentiment-analysis) on the given text.
1152
+
1153
+ Args:
1154
+ text (`str`):
1155
+ A string to be classified.
1156
+ model (`str`, *optional*):
1157
+ The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1158
+ a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used.
1159
+ Defaults to None.
1160
+
1161
+ Returns:
1162
+ `List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
1163
+
1164
+ Raises:
1165
+ [`InferenceTimeoutError`]:
1166
+ If the model is unavailable or the request times out.
1167
+ `HTTPError`:
1168
+ If the request fails with an HTTP error status code other than HTTP 503.
1169
+
1170
+ Example:
1171
+ ```py
1172
+ >>> from huggingface_hub import InferenceClient
1173
+ >>> client = InferenceClient()
1174
+ >>> client.text_classification("I like you")
1175
+ [{'label': 'POSITIVE', 'score': 0.9998695850372314}, {'label': 'NEGATIVE', 'score': 0.0001304351753788069}]
1176
+ ```
1177
+ """
1178
+ response = self.post(json={"inputs": text}, model=model, task="text-classification")
1179
+ return _bytes_to_list(response)[0]
1180
+
1181
+ @overload
1182
+ def text_generation( # type: ignore
1183
+ self,
1184
+ prompt: str,
1185
+ *,
1186
+ details: Literal[False] = ...,
1187
+ stream: Literal[False] = ...,
1188
+ model: Optional[str] = None,
1189
+ do_sample: bool = False,
1190
+ max_new_tokens: int = 20,
1191
+ best_of: Optional[int] = None,
1192
+ repetition_penalty: Optional[float] = None,
1193
+ return_full_text: bool = False,
1194
+ seed: Optional[int] = None,
1195
+ stop_sequences: Optional[List[str]] = None,
1196
+ temperature: Optional[float] = None,
1197
+ top_k: Optional[int] = None,
1198
+ top_p: Optional[float] = None,
1199
+ truncate: Optional[int] = None,
1200
+ typical_p: Optional[float] = None,
1201
+ watermark: bool = False,
1202
+ ) -> str:
1203
+ ...
1204
+
1205
+ @overload
1206
+ def text_generation( # type: ignore
1207
+ self,
1208
+ prompt: str,
1209
+ *,
1210
+ details: Literal[True] = ...,
1211
+ stream: Literal[False] = ...,
1212
+ model: Optional[str] = None,
1213
+ do_sample: bool = False,
1214
+ max_new_tokens: int = 20,
1215
+ best_of: Optional[int] = None,
1216
+ repetition_penalty: Optional[float] = None,
1217
+ return_full_text: bool = False,
1218
+ seed: Optional[int] = None,
1219
+ stop_sequences: Optional[List[str]] = None,
1220
+ temperature: Optional[float] = None,
1221
+ top_k: Optional[int] = None,
1222
+ top_p: Optional[float] = None,
1223
+ truncate: Optional[int] = None,
1224
+ typical_p: Optional[float] = None,
1225
+ watermark: bool = False,
1226
+ ) -> TextGenerationResponse:
1227
+ ...
1228
+
1229
+ @overload
1230
+ def text_generation( # type: ignore
1231
+ self,
1232
+ prompt: str,
1233
+ *,
1234
+ details: Literal[False] = ...,
1235
+ stream: Literal[True] = ...,
1236
+ model: Optional[str] = None,
1237
+ do_sample: bool = False,
1238
+ max_new_tokens: int = 20,
1239
+ best_of: Optional[int] = None,
1240
+ repetition_penalty: Optional[float] = None,
1241
+ return_full_text: bool = False,
1242
+ seed: Optional[int] = None,
1243
+ stop_sequences: Optional[List[str]] = None,
1244
+ temperature: Optional[float] = None,
1245
+ top_k: Optional[int] = None,
1246
+ top_p: Optional[float] = None,
1247
+ truncate: Optional[int] = None,
1248
+ typical_p: Optional[float] = None,
1249
+ watermark: bool = False,
1250
+ ) -> Iterable[str]:
1251
+ ...
1252
+
1253
+ @overload
1254
+ def text_generation(
1255
+ self,
1256
+ prompt: str,
1257
+ *,
1258
+ details: Literal[True] = ...,
1259
+ stream: Literal[True] = ...,
1260
+ model: Optional[str] = None,
1261
+ do_sample: bool = False,
1262
+ max_new_tokens: int = 20,
1263
+ best_of: Optional[int] = None,
1264
+ repetition_penalty: Optional[float] = None,
1265
+ return_full_text: bool = False,
1266
+ seed: Optional[int] = None,
1267
+ stop_sequences: Optional[List[str]] = None,
1268
+ temperature: Optional[float] = None,
1269
+ top_k: Optional[int] = None,
1270
+ top_p: Optional[float] = None,
1271
+ truncate: Optional[int] = None,
1272
+ typical_p: Optional[float] = None,
1273
+ watermark: bool = False,
1274
+ ) -> Iterable[TextGenerationStreamResponse]:
1275
+ ...
1276
+
1277
+ def text_generation(
1278
+ self,
1279
+ prompt: str,
1280
+ *,
1281
+ details: bool = False,
1282
+ stream: bool = False,
1283
+ model: Optional[str] = None,
1284
+ do_sample: bool = False,
1285
+ max_new_tokens: int = 20,
1286
+ best_of: Optional[int] = None,
1287
+ repetition_penalty: Optional[float] = None,
1288
+ return_full_text: bool = False,
1289
+ seed: Optional[int] = None,
1290
+ stop_sequences: Optional[List[str]] = None,
1291
+ temperature: Optional[float] = None,
1292
+ top_k: Optional[int] = None,
1293
+ top_p: Optional[float] = None,
1294
+ truncate: Optional[int] = None,
1295
+ typical_p: Optional[float] = None,
1296
+ watermark: bool = False,
1297
+ decoder_input_details: bool = False,
1298
+ ) -> Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]:
1299
+ """
1300
+ Given a prompt, generate the following text.
1301
+
1302
+ It is recommended to have Pydantic installed in order to get inputs validated. This is preferable as it allow
1303
+ early failures.
1304
+
1305
+ API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the
1306
+ go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the
1307
+ default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but
1308
+ not exactly the same. This method is compatible with both approaches but some parameters are only available for
1309
+ `text-generation-inference`. If some parameters are ignored, a warning message is triggered but the process
1310
+ continues correctly.
1311
+
1312
+ To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
1313
+
1314
+ Args:
1315
+ prompt (`str`):
1316
+ Input text.
1317
+ details (`bool`, *optional*):
1318
+ By default, text_generation returns a string. Pass `details=True` if you want a detailed output (tokens,
1319
+ probabilities, seed, finish reason, etc.). Only available for models running on with the
1320
+ `text-generation-inference` backend.
1321
+ stream (`bool`, *optional*):
1322
+ By default, text_generation returns the full generated text. Pass `stream=True` if you want a stream of
1323
+ tokens to be returned. Only available for models running on with the `text-generation-inference`
1324
+ backend.
1325
+ model (`str`, *optional*):
1326
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1327
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1328
+ do_sample (`bool`):
1329
+ Activate logits sampling
1330
+ max_new_tokens (`int`):
1331
+ Maximum number of generated tokens
1332
+ best_of (`int`):
1333
+ Generate best_of sequences and return the one if the highest token logprobs
1334
+ repetition_penalty (`float`):
1335
+ The parameter for repetition penalty. 1.0 means no penalty. See [this
1336
+ paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
1337
+ return_full_text (`bool`):
1338
+ Whether to prepend the prompt to the generated text
1339
+ seed (`int`):
1340
+ Random sampling seed
1341
+ stop_sequences (`List[str]`):
1342
+ Stop generating tokens if a member of `stop_sequences` is generated
1343
+ temperature (`float`):
1344
+ The value used to module the logits distribution.
1345
+ top_k (`int`):
1346
+ The number of highest probability vocabulary tokens to keep for top-k-filtering.
1347
+ top_p (`float`):
1348
+ If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
1349
+ higher are kept for generation.
1350
+ truncate (`int`):
1351
+ Truncate inputs tokens to the given size
1352
+ typical_p (`float`):
1353
+ Typical Decoding mass
1354
+ See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
1355
+ watermark (`bool`):
1356
+ Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
1357
+ decoder_input_details (`bool`):
1358
+ Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken
1359
+ into account. Defaults to `False`.
1360
+
1361
+ Returns:
1362
+ `Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]`:
1363
+ Generated text returned from the server:
1364
+ - if `stream=False` and `details=False`, the generated text is returned as a `str` (default)
1365
+ - if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]`
1366
+ - if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.inference._text_generation.TextGenerationResponse`]
1367
+ - if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.inference._text_generation.TextGenerationStreamResponse`]
1368
+
1369
+ Raises:
1370
+ `ValidationError`:
1371
+ If input values are not valid. No HTTP call is made to the server.
1372
+ [`InferenceTimeoutError`]:
1373
+ If the model is unavailable or the request times out.
1374
+ `HTTPError`:
1375
+ If the request fails with an HTTP error status code other than HTTP 503.
1376
+
1377
+ Example:
1378
+ ```py
1379
+ >>> from huggingface_hub import InferenceClient
1380
+ >>> client = InferenceClient()
1381
+
1382
+ # Case 1: generate text
1383
+ >>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12)
1384
+ '100% open source and built to be easy to use.'
1385
+
1386
+ # Case 2: iterate over the generated tokens. Useful for large generation.
1387
+ >>> for token in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True):
1388
+ ... print(token)
1389
+ 100
1390
+ %
1391
+ open
1392
+ source
1393
+ and
1394
+ built
1395
+ to
1396
+ be
1397
+ easy
1398
+ to
1399
+ use
1400
+ .
1401
+
1402
+ # Case 3: get more details about the generation process.
1403
+ >>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True)
1404
+ TextGenerationResponse(
1405
+ generated_text='100% open source and built to be easy to use.',
1406
+ details=Details(
1407
+ finish_reason=<FinishReason.Length: 'length'>,
1408
+ generated_tokens=12,
1409
+ seed=None,
1410
+ prefill=[
1411
+ InputToken(id=487, text='The', logprob=None),
1412
+ InputToken(id=53789, text=' hugging', logprob=-13.171875),
1413
+ (...)
1414
+ InputToken(id=204, text=' ', logprob=-7.0390625)
1415
+ ],
1416
+ tokens=[
1417
+ Token(id=1425, text='100', logprob=-1.0175781, special=False),
1418
+ Token(id=16, text='%', logprob=-0.0463562, special=False),
1419
+ (...)
1420
+ Token(id=25, text='.', logprob=-0.5703125, special=False)
1421
+ ],
1422
+ best_of_sequences=None
1423
+ )
1424
+ )
1425
+
1426
+ # Case 4: iterate over the generated tokens with more details.
1427
+ # Last object is more complete, containing the full generated text and the finish reason.
1428
+ >>> for details in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True):
1429
+ ... print(details)
1430
+ ...
1431
+ TextGenerationStreamResponse(token=Token(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None)
1432
+ TextGenerationStreamResponse(token=Token(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None)
1433
+ TextGenerationStreamResponse(token=Token(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None)
1434
+ TextGenerationStreamResponse(token=Token(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None)
1435
+ TextGenerationStreamResponse(token=Token(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None)
1436
+ TextGenerationStreamResponse(token=Token(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None)
1437
+ TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None)
1438
+ TextGenerationStreamResponse(token=Token(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None)
1439
+ TextGenerationStreamResponse(token=Token(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None)
1440
+ TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None)
1441
+ TextGenerationStreamResponse(token=Token(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None)
1442
+ TextGenerationStreamResponse(token=Token(
1443
+ id=25,
1444
+ text='.',
1445
+ logprob=-0.5703125,
1446
+ special=False),
1447
+ generated_text='100% open source and built to be easy to use.',
1448
+ details=StreamDetails(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=12, seed=None)
1449
+ )
1450
+ ```
1451
+ """
1452
+ # NOTE: Text-generation integration is taken from the text-generation-inference project. It has more features
1453
+ # like input/output validation (if Pydantic is installed). See `_text_generation.py` header for more details.
1454
+
1455
+ if decoder_input_details and not details:
1456
+ warnings.warn(
1457
+ "`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that"
1458
+ " the output from the server will be truncated."
1459
+ )
1460
+ decoder_input_details = False
1461
+
1462
+ # Validate parameters
1463
+ parameters = TextGenerationParameters(
1464
+ best_of=best_of,
1465
+ details=details,
1466
+ do_sample=do_sample,
1467
+ max_new_tokens=max_new_tokens,
1468
+ repetition_penalty=repetition_penalty,
1469
+ return_full_text=return_full_text,
1470
+ seed=seed,
1471
+ stop=stop_sequences if stop_sequences is not None else [],
1472
+ temperature=temperature,
1473
+ top_k=top_k,
1474
+ top_p=top_p,
1475
+ truncate=truncate,
1476
+ typical_p=typical_p,
1477
+ watermark=watermark,
1478
+ decoder_input_details=decoder_input_details,
1479
+ )
1480
+ request = TextGenerationRequest(inputs=prompt, stream=stream, parameters=parameters)
1481
+ payload = asdict(request)
1482
+
1483
+ # Remove some parameters if not a TGI server
1484
+ if not _is_tgi_server(model):
1485
+ ignored_parameters = []
1486
+ for key in "watermark", "stop", "details", "decoder_input_details":
1487
+ if payload["parameters"][key] is not None:
1488
+ ignored_parameters.append(key)
1489
+ del payload["parameters"][key]
1490
+ if len(ignored_parameters) > 0:
1491
+ warnings.warn(
1492
+ "API endpoint/model for text-generation is not served via TGI. Ignoring parameters"
1493
+ f" {ignored_parameters}.",
1494
+ UserWarning,
1495
+ )
1496
+ if details:
1497
+ warnings.warn(
1498
+ "API endpoint/model for text-generation is not served via TGI. Parameter `details=True` will"
1499
+ " be ignored meaning only the generated text will be returned.",
1500
+ UserWarning,
1501
+ )
1502
+ details = False
1503
+ if stream:
1504
+ raise ValueError(
1505
+ "API endpoint/model for text-generation is not served via TGI. Cannot return output as a stream."
1506
+ " Please pass `stream=False` as input."
1507
+ )
1508
+
1509
+ # Handle errors separately for more precise error messages
1510
+ try:
1511
+ bytes_output = self.post(json=payload, model=model, task="text-generation", stream=stream) # type: ignore
1512
+ except HTTPError as e:
1513
+ if isinstance(e, BadRequestError) and "The following `model_kwargs` are not used by the model" in str(e):
1514
+ _set_as_non_tgi(model)
1515
+ return self.text_generation( # type: ignore
1516
+ prompt=prompt,
1517
+ details=details,
1518
+ stream=stream,
1519
+ model=model,
1520
+ do_sample=do_sample,
1521
+ max_new_tokens=max_new_tokens,
1522
+ best_of=best_of,
1523
+ repetition_penalty=repetition_penalty,
1524
+ return_full_text=return_full_text,
1525
+ seed=seed,
1526
+ stop_sequences=stop_sequences,
1527
+ temperature=temperature,
1528
+ top_k=top_k,
1529
+ top_p=top_p,
1530
+ truncate=truncate,
1531
+ typical_p=typical_p,
1532
+ watermark=watermark,
1533
+ decoder_input_details=decoder_input_details,
1534
+ )
1535
+ raise_text_generation_error(e)
1536
+
1537
+ # Parse output
1538
+ if stream:
1539
+ return _stream_text_generation_response(bytes_output, details) # type: ignore
1540
+
1541
+ data = _bytes_to_dict(bytes_output)[0]
1542
+ return TextGenerationResponse(**data) if details else data["generated_text"]
1543
+
1544
+ def text_to_image(
1545
+ self,
1546
+ prompt: str,
1547
+ *,
1548
+ negative_prompt: Optional[str] = None,
1549
+ height: Optional[float] = None,
1550
+ width: Optional[float] = None,
1551
+ num_inference_steps: Optional[float] = None,
1552
+ guidance_scale: Optional[float] = None,
1553
+ model: Optional[str] = None,
1554
+ **kwargs,
1555
+ ) -> "Image":
1556
+ """
1557
+ Generate an image based on a given text using a specified model.
1558
+
1559
+ <Tip warning={true}>
1560
+
1561
+ You must have `PIL` installed if you want to work with images (`pip install Pillow`).
1562
+
1563
+ </Tip>
1564
+
1565
+ Args:
1566
+ prompt (`str`):
1567
+ The prompt to generate an image from.
1568
+ negative_prompt (`str`, *optional*):
1569
+ An optional negative prompt for the image generation.
1570
+ height (`float`, *optional*):
1571
+ The height in pixels of the image to generate.
1572
+ width (`float`, *optional*):
1573
+ The width in pixels of the image to generate.
1574
+ num_inference_steps (`int`, *optional*):
1575
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1576
+ expense of slower inference.
1577
+ guidance_scale (`float`, *optional*):
1578
+ Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1579
+ usually at the expense of lower image quality.
1580
+ model (`str`, *optional*):
1581
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1582
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1583
+
1584
+ Returns:
1585
+ `Image`: The generated image.
1586
+
1587
+ Raises:
1588
+ [`InferenceTimeoutError`]:
1589
+ If the model is unavailable or the request times out.
1590
+ `HTTPError`:
1591
+ If the request fails with an HTTP error status code other than HTTP 503.
1592
+
1593
+ Example:
1594
+ ```py
1595
+ >>> from huggingface_hub import InferenceClient
1596
+ >>> client = InferenceClient()
1597
+
1598
+ >>> image = client.text_to_image("An astronaut riding a horse on the moon.")
1599
+ >>> image.save("astronaut.png")
1600
+
1601
+ >>> image = client.text_to_image(
1602
+ ... "An astronaut riding a horse on the moon.",
1603
+ ... negative_prompt="low resolution, blurry",
1604
+ ... model="stabilityai/stable-diffusion-2-1",
1605
+ ... )
1606
+ >>> image.save("better_astronaut.png")
1607
+ ```
1608
+ """
1609
+ payload = {"inputs": prompt}
1610
+ parameters = {
1611
+ "negative_prompt": negative_prompt,
1612
+ "height": height,
1613
+ "width": width,
1614
+ "num_inference_steps": num_inference_steps,
1615
+ "guidance_scale": guidance_scale,
1616
+ **kwargs,
1617
+ }
1618
+ for key, value in parameters.items():
1619
+ if value is not None:
1620
+ payload.setdefault("parameters", {})[key] = value # type: ignore
1621
+ response = self.post(json=payload, model=model, task="text-to-image")
1622
+ return _bytes_to_image(response)
1623
+
1624
+ def text_to_speech(self, text: str, *, model: Optional[str] = None) -> bytes:
1625
+ """
1626
+ Synthesize an audio of a voice pronouncing a given text.
1627
+
1628
+ Args:
1629
+ text (`str`):
1630
+ The text to synthesize.
1631
+ model (`str`, *optional*):
1632
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1633
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1634
+
1635
+ Returns:
1636
+ `bytes`: The generated audio.
1637
+
1638
+ Raises:
1639
+ [`InferenceTimeoutError`]:
1640
+ If the model is unavailable or the request times out.
1641
+ `HTTPError`:
1642
+ If the request fails with an HTTP error status code other than HTTP 503.
1643
+
1644
+ Example:
1645
+ ```py
1646
+ >>> from pathlib import Path
1647
+ >>> from huggingface_hub import InferenceClient
1648
+ >>> client = InferenceClient()
1649
+
1650
+ >>> audio = client.text_to_speech("Hello world")
1651
+ >>> Path("hello_world.flac").write_bytes(audio)
1652
+ ```
1653
+ """
1654
+ return self.post(json={"inputs": text}, model=model, task="text-to-speech")
1655
+
1656
+ def token_classification(self, text: str, *, model: Optional[str] = None) -> List[TokenClassificationOutput]:
1657
+ """
1658
+ Perform token classification on the given text.
1659
+ Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text.
1660
+
1661
+ Args:
1662
+ text (`str`):
1663
+ A string to be classified.
1664
+ model (`str`, *optional*):
1665
+ The model to use for the token classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1666
+ a deployed Inference Endpoint. If not provided, the default recommended token classification model will be used.
1667
+ Defaults to None.
1668
+
1669
+ Returns:
1670
+ `List[Dict]`: List of token classification outputs containing the entity group, confidence score, word, start and end index.
1671
+
1672
+ Raises:
1673
+ [`InferenceTimeoutError`]:
1674
+ If the model is unavailable or the request times out.
1675
+ `HTTPError`:
1676
+ If the request fails with an HTTP error status code other than HTTP 503.
1677
+
1678
+ Example:
1679
+ ```py
1680
+ >>> from huggingface_hub import InferenceClient
1681
+ >>> client = InferenceClient()
1682
+ >>> client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica")
1683
+ [{'entity_group': 'PER',
1684
+ 'score': 0.9971321225166321,
1685
+ 'word': 'Sarah Jessica Parker',
1686
+ 'start': 11,
1687
+ 'end': 31},
1688
+ {'entity_group': 'PER',
1689
+ 'score': 0.9773476123809814,
1690
+ 'word': 'Jessica',
1691
+ 'start': 52,
1692
+ 'end': 59}]
1693
+ ```
1694
+ """
1695
+ payload: Dict[str, Any] = {"inputs": text}
1696
+ response = self.post(
1697
+ json=payload,
1698
+ model=model,
1699
+ task="token-classification",
1700
+ )
1701
+ return _bytes_to_list(response)
1702
+
1703
+ def translation(
1704
+ self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None
1705
+ ) -> str:
1706
+ """
1707
+ Convert text from one language to another.
1708
+
1709
+ Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for
1710
+ your specific use case. Source and target languages usually depend on the model.
1711
+ However, it is possible to specify source and target languages for certain models. If you are working with one of these models,
1712
+ you can use `src_lang` and `tgt_lang` arguments to pass the relevant information.
1713
+ You can find this information in the model card.
1714
+
1715
+ Args:
1716
+ text (`str`):
1717
+ A string to be translated.
1718
+ model (`str`, *optional*):
1719
+ The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1720
+ a deployed Inference Endpoint. If not provided, the default recommended translation model will be used.
1721
+ Defaults to None.
1722
+ src_lang (`str`, *optional*):
1723
+ Source language of the translation task, i.e. input language. Cannot be passed without `tgt_lang`.
1724
+ tgt_lang (`str`, *optional*):
1725
+ Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`.
1726
+
1727
+ Returns:
1728
+ `str`: The generated translated text.
1729
+
1730
+ Raises:
1731
+ [`InferenceTimeoutError`]:
1732
+ If the model is unavailable or the request times out.
1733
+ `HTTPError`:
1734
+ If the request fails with an HTTP error status code other than HTTP 503.
1735
+ `ValueError`:
1736
+ If only one of the `src_lang` and `tgt_lang` arguments are provided.
1737
+
1738
+ Example:
1739
+ ```py
1740
+ >>> from huggingface_hub import InferenceClient
1741
+ >>> client = InferenceClient()
1742
+ >>> client.translation("My name is Wolfgang and I live in Berlin")
1743
+ 'Mein Name ist Wolfgang und ich lebe in Berlin.'
1744
+ >>> client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr")
1745
+ "Je m'appelle Wolfgang et je vis à Berlin."
1746
+ ```
1747
+
1748
+ Specifying languages:
1749
+ ```py
1750
+ >>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX")
1751
+ "Mon nom est Sarah Jessica Parker mais vous pouvez m\'appeler Jessica"
1752
+ ```
1753
+ """
1754
+ # Throw error if only one of `src_lang` and `tgt_lang` was given
1755
+ if src_lang is not None and tgt_lang is None:
1756
+ raise ValueError("You cannot specify `src_lang` without specifying `tgt_lang`.")
1757
+
1758
+ if src_lang is None and tgt_lang is not None:
1759
+ raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.")
1760
+
1761
+ # If both `src_lang` and `tgt_lang` are given, pass them to the request body
1762
+ payload: Dict = {"inputs": text}
1763
+ if src_lang and tgt_lang:
1764
+ payload["parameters"] = {"src_lang": src_lang, "tgt_lang": tgt_lang}
1765
+ response = self.post(json=payload, model=model, task="translation")
1766
+ return _bytes_to_dict(response)[0]["translation_text"]
1767
+
1768
+ def zero_shot_classification(
1769
+ self, text: str, labels: List[str], *, multi_label: bool = False, model: Optional[str] = None
1770
+ ) -> List[ClassificationOutput]:
1771
+ """
1772
+ Provide as input a text and a set of candidate labels to classify the input text.
1773
+
1774
+ Args:
1775
+ text (`str`):
1776
+ The input text to classify.
1777
+ labels (`List[str]`):
1778
+ List of string possible labels. There must be at least 2 labels.
1779
+ multi_label (`bool`):
1780
+ Boolean that is set to True if classes can overlap.
1781
+ model (`str`, *optional*):
1782
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1783
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1784
+
1785
+ Returns:
1786
+ `List[Dict]`: List of classification outputs containing the predicted labels and their confidence.
1787
+
1788
+ Raises:
1789
+ [`InferenceTimeoutError`]:
1790
+ If the model is unavailable or the request times out.
1791
+ `HTTPError`:
1792
+ If the request fails with an HTTP error status code other than HTTP 503.
1793
+
1794
+ Example:
1795
+ ```py
1796
+ >>> from huggingface_hub import InferenceClient
1797
+ >>> client = InferenceClient()
1798
+ >>> text = (
1799
+ ... "A new model offers an explanation for how the Galilean satellites formed around the solar system's"
1800
+ ... "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling"
1801
+ ... " mysteries when he went for a run up a hill in Nice, France."
1802
+ ... )
1803
+ >>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"]
1804
+ >>> client.zero_shot_classification(text, labels)
1805
+ [
1806
+ {"label": "scientific discovery", "score": 0.7961668968200684},
1807
+ {"label": "space & cosmos", "score": 0.18570658564567566},
1808
+ {"label": "microbiology", "score": 0.00730885099619627},
1809
+ {"label": "archeology", "score": 0.006258360575884581},
1810
+ {"label": "robots", "score": 0.004559356719255447},
1811
+ ]
1812
+ >>> client.zero_shot_classification(text, labels, multi_label=True)
1813
+ [
1814
+ {"label": "scientific discovery", "score": 0.9829297661781311},
1815
+ {"label": "space & cosmos", "score": 0.755190908908844},
1816
+ {"label": "microbiology", "score": 0.0005462635890580714},
1817
+ {"label": "archeology", "score": 0.00047131875180639327},
1818
+ {"label": "robots", "score": 0.00030448526376858354},
1819
+ ]
1820
+ ```
1821
+ """
1822
+ # Raise ValueError if input is less than 2 labels
1823
+ if len(labels) < 2:
1824
+ raise ValueError("You must specify at least 2 classes to compare.")
1825
+
1826
+ response = self.post(
1827
+ json={
1828
+ "inputs": text,
1829
+ "parameters": {
1830
+ "candidate_labels": ",".join(labels),
1831
+ "multi_label": multi_label,
1832
+ },
1833
+ },
1834
+ model=model,
1835
+ task="zero-shot-classification",
1836
+ )
1837
+ output = _bytes_to_dict(response)
1838
+ return [{"label": label, "score": score} for label, score in zip(output["labels"], output["scores"])]
1839
+
1840
+ def zero_shot_image_classification(
1841
+ self, image: ContentT, labels: List[str], *, model: Optional[str] = None
1842
+ ) -> List[ClassificationOutput]:
1843
+ """
1844
+ Provide input image and text labels to predict text labels for the image.
1845
+
1846
+ Args:
1847
+ image (`Union[str, Path, bytes, BinaryIO]`):
1848
+ The input image to caption. It can be raw bytes, an image file, or a URL to an online image.
1849
+ labels (`List[str]`):
1850
+ List of string possible labels. There must be at least 2 labels.
1851
+ model (`str`, *optional*):
1852
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1853
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1854
+
1855
+ Returns:
1856
+ `List[Dict]`: List of classification outputs containing the predicted labels and their confidence.
1857
+
1858
+ Raises:
1859
+ [`InferenceTimeoutError`]:
1860
+ If the model is unavailable or the request times out.
1861
+ `HTTPError`:
1862
+ If the request fails with an HTTP error status code other than HTTP 503.
1863
+
1864
+ Example:
1865
+ ```py
1866
+ >>> from huggingface_hub import InferenceClient
1867
+ >>> client = InferenceClient()
1868
+
1869
+ >>> client.zero_shot_image_classification(
1870
+ ... "https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg",
1871
+ ... labels=["dog", "cat", "horse"],
1872
+ ... )
1873
+ [{"label": "dog", "score": 0.956}, ...]
1874
+ ```
1875
+ """
1876
+ # Raise ValueError if input is less than 2 labels
1877
+ if len(labels) < 2:
1878
+ raise ValueError("You must specify at least 2 classes to compare.")
1879
+
1880
+ response = self.post(
1881
+ json={"image": _b64_encode(image), "parameters": {"candidate_labels": ",".join(labels)}},
1882
+ model=model,
1883
+ task="zero-shot-image-classification",
1884
+ )
1885
+ return _bytes_to_list(response)
1886
+
1887
+ def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
1888
+ model = model or self.model
1889
+
1890
+ # If model is already a URL, ignore `task` and return directly
1891
+ if model is not None and (model.startswith("http://") or model.startswith("https://")):
1892
+ return model
1893
+
1894
+ # # If no model but task is set => fetch the recommended one for this task
1895
+ if model is None:
1896
+ if task is None:
1897
+ raise ValueError(
1898
+ "You must specify at least a model (repo_id or URL) or a task, either when instantiating"
1899
+ " `InferenceClient` or when making a request."
1900
+ )
1901
+ model = self.get_recommended_model(task)
1902
+ logger.info(
1903
+ f"Using recommended model {model} for task {task}. Note that it is"
1904
+ f" encouraged to explicitly set `model='{model}'` as the recommended"
1905
+ " models list might get updated without prior notice."
1906
+ )
1907
+
1908
+ # Compute InferenceAPI url
1909
+ return (
1910
+ # Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks.
1911
+ f"{INFERENCE_ENDPOINT}/pipeline/{task}/{model}"
1912
+ if task in ("feature-extraction", "sentence-similarity")
1913
+ # Otherwise, we use the default endpoint
1914
+ else f"{INFERENCE_ENDPOINT}/models/{model}"
1915
+ )
1916
+
1917
+ @staticmethod
1918
+ def get_recommended_model(task: str) -> str:
1919
+ """
1920
+ Get the model Hugging Face recommends for the input task.
1921
+
1922
+ Args:
1923
+ task (`str`):
1924
+ The Hugging Face task to get which model Hugging Face recommends.
1925
+ All available tasks can be found [here](https://huggingface.co/tasks).
1926
+
1927
+ Returns:
1928
+ `str`: Name of the model recommended for the input task.
1929
+
1930
+ Raises:
1931
+ `ValueError`: If Hugging Face has no recommendation for the input task.
1932
+ """
1933
+ model = _fetch_recommended_models().get(task)
1934
+ if model is None:
1935
+ raise ValueError(
1936
+ f"Task {task} has no recommended model. Please specify a model"
1937
+ " explicitly. Visit https://huggingface.co/tasks for more info."
1938
+ )
1939
+ return model
1940
+
1941
+ def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
1942
+ """
1943
+ Get the status of a model hosted on the Inference API.
1944
+
1945
+ <Tip>
1946
+
1947
+ This endpoint is mostly useful when you already know which model you want to use and want to check its
1948
+ availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`].
1949
+
1950
+ </Tip>
1951
+
1952
+ Args:
1953
+ model (`str`, *optional*):
1954
+ Identifier of the model for witch the status gonna be checked. If model is not provided,
1955
+ the model associated with this instance of [`InferenceClient`] will be used. Only InferenceAPI service can be checked so the
1956
+ identifier cannot be a URL.
1957
+
1958
+
1959
+ Returns:
1960
+ [`ModelStatus`]: An instance of ModelStatus dataclass, containing information,
1961
+ about the state of the model: load, state, compute type and framework.
1962
+
1963
+ Example:
1964
+ ```py
1965
+ >>> from huggingface_hub import InferenceClient
1966
+ >>> client = InferenceClient()
1967
+ >>> client.get_model_status("bigcode/starcoder")
1968
+ ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
1969
+ ```
1970
+ """
1971
+ model = model or self.model
1972
+ if model is None:
1973
+ raise ValueError("Model id not provided.")
1974
+ if model.startswith("https://"):
1975
+ raise NotImplementedError("Model status is only available for Inference API endpoints.")
1976
+ url = f"{INFERENCE_ENDPOINT}/status/{model}"
1977
+
1978
+ response = get_session().get(url, headers=self.headers)
1979
+ hf_raise_for_status(response)
1980
+ response_data = response.json()
1981
+
1982
+ if "error" in response_data:
1983
+ raise ValueError(response_data["error"])
1984
+
1985
+ return ModelStatus(
1986
+ loaded=response_data["loaded"],
1987
+ state=response_data["state"],
1988
+ compute_type=response_data["compute_type"],
1989
+ framework=response_data["framework"],
1990
+ )
lib/python3.11/site-packages/huggingface_hub/inference/_common.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Contains utilities used by both the sync and async inference clients."""
16
+ import base64
17
+ import io
18
+ import json
19
+ import logging
20
+ from contextlib import contextmanager
21
+ from dataclasses import dataclass
22
+ from pathlib import Path
23
+ from typing import (
24
+ TYPE_CHECKING,
25
+ Any,
26
+ AsyncIterable,
27
+ BinaryIO,
28
+ ContextManager,
29
+ Dict,
30
+ Generator,
31
+ Iterable,
32
+ List,
33
+ Literal,
34
+ Optional,
35
+ Set,
36
+ Union,
37
+ overload,
38
+ )
39
+
40
+ from requests import HTTPError
41
+
42
+ from ..constants import ENDPOINT
43
+ from ..utils import (
44
+ build_hf_headers,
45
+ get_session,
46
+ hf_raise_for_status,
47
+ is_aiohttp_available,
48
+ is_numpy_available,
49
+ is_pillow_available,
50
+ )
51
+ from ._text_generation import TextGenerationStreamResponse, _parse_text_generation_error
52
+
53
+
54
+ if TYPE_CHECKING:
55
+ from aiohttp import ClientResponse, ClientSession
56
+ from PIL import Image
57
+
58
+ # TYPES
59
+ UrlT = str
60
+ PathT = Union[str, Path]
61
+ BinaryT = Union[bytes, BinaryIO]
62
+ ContentT = Union[BinaryT, PathT, UrlT]
63
+
64
+ # Use to set a Accept: image/png header
65
+ TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"}
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+
70
+ # Add dataclass for ModelStatus. We use this dataclass in get_model_status function.
71
+ @dataclass
72
+ class ModelStatus:
73
+ """
74
+ This Dataclass represents the the model status in the Hugging Face Inference API.
75
+
76
+ Args:
77
+ loaded (`bool`):
78
+ If the model is currently loaded into Hugging Face's InferenceAPI. Models
79
+ are loaded on-demand, leading to the user's first request taking longer.
80
+ If a model is loaded, you can be assured that it is in a healthy state.
81
+ state (`str`):
82
+ The current state of the model. This can be 'Loaded', 'Loadable', 'TooBig'.
83
+ If a model's state is 'Loadable', it's not too big and has a supported
84
+ backend. Loadable models are automatically loaded when the user first
85
+ requests inference on the endpoint. This means it is transparent for the
86
+ user to load a model, except that the first call takes longer to complete.
87
+ compute_type (`str`):
88
+ The type of compute resource the model is using or will use, such as 'gpu' or 'cpu'.
89
+ framework (`str`):
90
+ The name of the framework that the model was built with, such as 'transformers'
91
+ or 'text-generation-inference'.
92
+ """
93
+
94
+ loaded: bool
95
+ state: str
96
+ compute_type: str
97
+ framework: str
98
+
99
+
100
+ class InferenceTimeoutError(HTTPError, TimeoutError):
101
+ """Error raised when a model is unavailable or the request times out."""
102
+
103
+
104
+ ## IMPORT UTILS
105
+
106
+
107
+ def _import_aiohttp():
108
+ # Make sure `aiohttp` is installed on the machine.
109
+ if not is_aiohttp_available():
110
+ raise ImportError("Please install aiohttp to use `AsyncInferenceClient` (`pip install aiohttp`).")
111
+ import aiohttp
112
+
113
+ return aiohttp
114
+
115
+
116
+ def _import_numpy():
117
+ """Make sure `numpy` is installed on the machine."""
118
+ if not is_numpy_available():
119
+ raise ImportError("Please install numpy to use deal with embeddings (`pip install numpy`).")
120
+ import numpy
121
+
122
+ return numpy
123
+
124
+
125
+ def _import_pil_image():
126
+ """Make sure `PIL` is installed on the machine."""
127
+ if not is_pillow_available():
128
+ raise ImportError(
129
+ "Please install Pillow to use deal with images (`pip install Pillow`). If you don't want the image to be"
130
+ " post-processed, use `client.post(...)` and get the raw response from the server."
131
+ )
132
+ from PIL import Image
133
+
134
+ return Image
135
+
136
+
137
+ ## RECOMMENDED MODELS
138
+
139
+ # Will be globally fetched only once (see '_fetch_recommended_models')
140
+ _RECOMMENDED_MODELS: Optional[Dict[str, Optional[str]]] = None
141
+
142
+
143
+ def _fetch_recommended_models() -> Dict[str, Optional[str]]:
144
+ global _RECOMMENDED_MODELS
145
+ if _RECOMMENDED_MODELS is None:
146
+ response = get_session().get(f"{ENDPOINT}/api/tasks", headers=build_hf_headers())
147
+ hf_raise_for_status(response)
148
+ _RECOMMENDED_MODELS = {
149
+ task: _first_or_none(details["widgetModels"]) for task, details in response.json().items()
150
+ }
151
+ return _RECOMMENDED_MODELS
152
+
153
+
154
+ def _first_or_none(items: List[Any]) -> Optional[Any]:
155
+ try:
156
+ return items[0] or None
157
+ except IndexError:
158
+ return None
159
+
160
+
161
+ ## ENCODING / DECODING UTILS
162
+
163
+
164
+ @overload
165
+ def _open_as_binary(content: ContentT) -> ContextManager[BinaryT]:
166
+ ... # means "if input is not None, output is not None"
167
+
168
+
169
+ @overload
170
+ def _open_as_binary(content: Literal[None]) -> ContextManager[Literal[None]]:
171
+ ... # means "if input is None, output is None"
172
+
173
+
174
+ @contextmanager # type: ignore
175
+ def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT], None, None]:
176
+ """Open `content` as a binary file, either from a URL, a local path, or raw bytes.
177
+
178
+ Do nothing if `content` is None,
179
+
180
+ TODO: handle a PIL.Image as input
181
+ TODO: handle base64 as input
182
+ """
183
+ # If content is a string => must be either a URL or a path
184
+ if isinstance(content, str):
185
+ if content.startswith("https://") or content.startswith("http://"):
186
+ logger.debug(f"Downloading content from {content}")
187
+ yield get_session().get(content).content # TODO: retrieve as stream and pipe to post request ?
188
+ return
189
+ content = Path(content)
190
+ if not content.exists():
191
+ raise FileNotFoundError(
192
+ f"File not found at {content}. If `data` is a string, it must either be a URL or a path to a local"
193
+ " file. To pass raw content, please encode it as bytes first."
194
+ )
195
+
196
+ # If content is a Path => open it
197
+ if isinstance(content, Path):
198
+ logger.debug(f"Opening content from {content}")
199
+ with content.open("rb") as f:
200
+ yield f
201
+ else:
202
+ # Otherwise: already a file-like object or None
203
+ yield content
204
+
205
+
206
+ def _b64_encode(content: ContentT) -> str:
207
+ """Encode a raw file (image, audio) into base64. Can be byes, an opened file, a path or a URL."""
208
+ with _open_as_binary(content) as data:
209
+ data_as_bytes = data if isinstance(data, bytes) else data.read()
210
+ return base64.b64encode(data_as_bytes).decode()
211
+
212
+
213
+ def _b64_to_image(encoded_image: str) -> "Image":
214
+ """Parse a base64-encoded string into a PIL Image."""
215
+ Image = _import_pil_image()
216
+ return Image.open(io.BytesIO(base64.b64decode(encoded_image)))
217
+
218
+
219
+ def _bytes_to_list(content: bytes) -> List:
220
+ """Parse bytes from a Response object into a Python list.
221
+
222
+ Expects the response body to be JSON-encoded data.
223
+
224
+ NOTE: This is exactly the same implementation as `_bytes_to_dict` and will not complain if the returned data is a
225
+ dictionary. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect.
226
+ """
227
+ return json.loads(content.decode())
228
+
229
+
230
+ def _bytes_to_dict(content: bytes) -> Dict:
231
+ """Parse bytes from a Response object into a Python dictionary.
232
+
233
+ Expects the response body to be JSON-encoded data.
234
+
235
+ NOTE: This is exactly the same implementation as `_bytes_to_list` and will not complain if the returned data is a
236
+ list. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect.
237
+ """
238
+ return json.loads(content.decode())
239
+
240
+
241
+ def _bytes_to_image(content: bytes) -> "Image":
242
+ """Parse bytes from a Response object into a PIL Image.
243
+
244
+ Expects the response body to be raw bytes. To deal with b64 encoded images, use `_b64_to_image` instead.
245
+ """
246
+ Image = _import_pil_image()
247
+ return Image.open(io.BytesIO(content))
248
+
249
+
250
+ ## STREAMING UTILS
251
+
252
+
253
+ def _stream_text_generation_response(
254
+ bytes_output_as_lines: Iterable[bytes], details: bool
255
+ ) -> Union[Iterable[str], Iterable[TextGenerationStreamResponse]]:
256
+ # Parse ServerSentEvents
257
+ for byte_payload in bytes_output_as_lines:
258
+ # Skip line
259
+ if byte_payload == b"\n":
260
+ continue
261
+
262
+ payload = byte_payload.decode("utf-8")
263
+
264
+ # Event data
265
+ if payload.startswith("data:"):
266
+ # Decode payload
267
+ json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
268
+ # Either an error as being returned
269
+ if json_payload.get("error") is not None:
270
+ raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
271
+ # Or parse token payload
272
+ output = TextGenerationStreamResponse(**json_payload)
273
+ yield output.token.text if not details else output
274
+
275
+
276
+ async def _async_stream_text_generation_response(
277
+ bytes_output_as_lines: AsyncIterable[bytes], details: bool
278
+ ) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamResponse]]:
279
+ # Parse ServerSentEvents
280
+ async for byte_payload in bytes_output_as_lines:
281
+ # Skip line
282
+ if byte_payload == b"\n":
283
+ continue
284
+
285
+ payload = byte_payload.decode("utf-8")
286
+
287
+ # Event data
288
+ if payload.startswith("data:"):
289
+ # Decode payload
290
+ json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
291
+ # Either an error as being returned
292
+ if json_payload.get("error") is not None:
293
+ raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
294
+ # Or parse token payload
295
+ output = TextGenerationStreamResponse(**json_payload)
296
+ yield output.token.text if not details else output
297
+
298
+
299
+ async def _async_yield_from(client: "ClientSession", response: "ClientResponse") -> AsyncIterable[bytes]:
300
+ async for byte_payload in response.content:
301
+ yield byte_payload
302
+ await client.close()
303
+
304
+
305
+ # "TGI servers" are servers running with the `text-generation-inference` backend.
306
+ # This backend is the go-to solution to run large language models at scale. However,
307
+ # for some smaller models (e.g. "gpt2") the default `transformers` + `api-inference`
308
+ # solution is still in use.
309
+ #
310
+ # Both approaches have very similar APIs, but not exactly the same. What we do first in
311
+ # the `text_generation` method is to assume the model is served via TGI. If we realize
312
+ # it's not the case (i.e. we receive an HTTP 400 Bad Request), we fallback to the
313
+ # default API with a warning message. We remember for each model if it's a TGI server
314
+ # or not using `_NON_TGI_SERVERS` global variable.
315
+ #
316
+ # For more details, see https://github.com/huggingface/text-generation-inference and
317
+ # https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task.
318
+
319
+ _NON_TGI_SERVERS: Set[Optional[str]] = set()
320
+
321
+
322
+ def _set_as_non_tgi(model: Optional[str]) -> None:
323
+ _NON_TGI_SERVERS.add(model)
324
+
325
+
326
+ def _is_tgi_server(model: Optional[str]) -> bool:
327
+ return model not in _NON_TGI_SERVERS
lib/python3.11/site-packages/huggingface_hub/inference/_generated/__init__.py ADDED
File without changes
lib/python3.11/site-packages/huggingface_hub/inference/_generated/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (251 Bytes). View file
 
lib/python3.11/site-packages/huggingface_hub/inference/_generated/__pycache__/_async_client.cpython-311.pyc ADDED
Binary file (96.9 kB). View file
 
lib/python3.11/site-packages/huggingface_hub/inference/_generated/_async_client.py ADDED
@@ -0,0 +1,2020 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # WARNING
17
+ # This entire file has been adapted from the sync-client code in `src/huggingface_hub/inference/_client.py`.
18
+ # Any change in InferenceClient will be automatically reflected in AsyncInferenceClient.
19
+ # To re-generate the code, run `make style` or `python ./utils/generate_async_inference_client.py --update`.
20
+ # WARNING
21
+ import asyncio
22
+ import logging
23
+ import time
24
+ import warnings
25
+ from dataclasses import asdict
26
+ from typing import (
27
+ TYPE_CHECKING,
28
+ Any,
29
+ AsyncIterable,
30
+ Dict,
31
+ List,
32
+ Literal,
33
+ Optional,
34
+ Union,
35
+ overload,
36
+ )
37
+
38
+ from requests.structures import CaseInsensitiveDict
39
+
40
+ from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
41
+ from huggingface_hub.inference._common import (
42
+ TASKS_EXPECTING_IMAGES,
43
+ ContentT,
44
+ InferenceTimeoutError,
45
+ ModelStatus,
46
+ _async_stream_text_generation_response,
47
+ _b64_encode,
48
+ _b64_to_image,
49
+ _bytes_to_dict,
50
+ _bytes_to_image,
51
+ _bytes_to_list,
52
+ _fetch_recommended_models,
53
+ _import_numpy,
54
+ _is_tgi_server,
55
+ _open_as_binary,
56
+ _set_as_non_tgi,
57
+ )
58
+ from huggingface_hub.inference._text_generation import (
59
+ TextGenerationParameters,
60
+ TextGenerationRequest,
61
+ TextGenerationResponse,
62
+ TextGenerationStreamResponse,
63
+ raise_text_generation_error,
64
+ )
65
+ from huggingface_hub.inference._types import (
66
+ ClassificationOutput,
67
+ ConversationalOutput,
68
+ FillMaskOutput,
69
+ ImageSegmentationOutput,
70
+ ObjectDetectionOutput,
71
+ QuestionAnsweringOutput,
72
+ TableQuestionAnsweringOutput,
73
+ TokenClassificationOutput,
74
+ )
75
+ from huggingface_hub.utils import (
76
+ build_hf_headers,
77
+ )
78
+
79
+ from .._common import _async_yield_from, _import_aiohttp
80
+
81
+
82
+ if TYPE_CHECKING:
83
+ import numpy as np
84
+ from PIL import Image
85
+
86
+ logger = logging.getLogger(__name__)
87
+
88
+
89
+ class AsyncInferenceClient:
90
+ """
91
+ Initialize a new Inference Client.
92
+
93
+ [`InferenceClient`] aims to provide a unified experience to perform inference. The client can be used
94
+ seamlessly with either the (free) Inference API or self-hosted Inference Endpoints.
95
+
96
+ Args:
97
+ model (`str`, `optional`):
98
+ The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `bigcode/starcoder`
99
+ or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
100
+ automatically selected for the task.
101
+ token (`str`, *optional*):
102
+ Hugging Face token. Will default to the locally saved token. Pass `token=False` if you don't want to send
103
+ your token to the server.
104
+ timeout (`float`, `optional`):
105
+ The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
106
+ API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
107
+ headers (`Dict[str, str]`, `optional`):
108
+ Additional headers to send to the server. By default only the authorization and user-agent headers are sent.
109
+ Values in this dictionary will override the default values.
110
+ cookies (`Dict[str, str]`, `optional`):
111
+ Additional cookies to send to the server.
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ model: Optional[str] = None,
117
+ token: Union[str, bool, None] = None,
118
+ timeout: Optional[float] = None,
119
+ headers: Optional[Dict[str, str]] = None,
120
+ cookies: Optional[Dict[str, str]] = None,
121
+ ) -> None:
122
+ self.model: Optional[str] = model
123
+ self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) # contains 'authorization' + 'user-agent'
124
+ if headers is not None:
125
+ self.headers.update(headers)
126
+ self.cookies = cookies
127
+ self.timeout = timeout
128
+
129
+ def __repr__(self):
130
+ return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"
131
+
132
+ @overload
133
+ async def post( # type: ignore[misc]
134
+ self,
135
+ *,
136
+ json: Optional[Union[str, Dict, List]] = None,
137
+ data: Optional[ContentT] = None,
138
+ model: Optional[str] = None,
139
+ task: Optional[str] = None,
140
+ stream: Literal[False] = ...,
141
+ ) -> bytes:
142
+ pass
143
+
144
+ @overload
145
+ async def post(
146
+ self,
147
+ *,
148
+ json: Optional[Union[str, Dict, List]] = None,
149
+ data: Optional[ContentT] = None,
150
+ model: Optional[str] = None,
151
+ task: Optional[str] = None,
152
+ stream: Literal[True] = ...,
153
+ ) -> AsyncIterable[bytes]:
154
+ pass
155
+
156
+ async def post(
157
+ self,
158
+ *,
159
+ json: Optional[Union[str, Dict, List]] = None,
160
+ data: Optional[ContentT] = None,
161
+ model: Optional[str] = None,
162
+ task: Optional[str] = None,
163
+ stream: bool = False,
164
+ ) -> Union[bytes, AsyncIterable[bytes]]:
165
+ """
166
+ Make a POST request to the inference server.
167
+
168
+ Args:
169
+ json (`Union[str, Dict, List]`, *optional*):
170
+ The JSON data to send in the request body, specific to each task. Defaults to None.
171
+ data (`Union[str, Path, bytes, BinaryIO]`, *optional*):
172
+ The content to send in the request body, specific to each task.
173
+ It can be raw bytes, a pointer to an opened file, a local file path,
174
+ or a URL to an online resource (image, audio file,...). If both `json` and `data` are passed,
175
+ `data` will take precedence. At least `json` or `data` must be provided. Defaults to None.
176
+ model (`str`, *optional*):
177
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
178
+ Inference Endpoint. Will override the model defined at the instance level. Defaults to None.
179
+ task (`str`, *optional*):
180
+ The task to perform on the inference. All available tasks can be found
181
+ [here](https://huggingface.co/tasks). Used only to default to a recommended model if `model` is not
182
+ provided. At least `model` or `task` must be provided. Defaults to None.
183
+ stream (`bool`, *optional*):
184
+ Whether to iterate over streaming APIs.
185
+
186
+ Returns:
187
+ bytes: The raw bytes returned by the server.
188
+
189
+ Raises:
190
+ [`InferenceTimeoutError`]:
191
+ If the model is unavailable or the request times out.
192
+ `aiohttp.ClientResponseError`:
193
+ If the request fails with an HTTP error status code other than HTTP 503.
194
+ """
195
+
196
+ aiohttp = _import_aiohttp()
197
+
198
+ url = self._resolve_url(model, task)
199
+
200
+ if data is not None and json is not None:
201
+ warnings.warn("Ignoring `json` as `data` is passed as binary.")
202
+
203
+ # Set Accept header if relevant
204
+ headers = self.headers.copy()
205
+ if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers:
206
+ headers["Accept"] = "image/png"
207
+
208
+ t0 = time.time()
209
+ timeout = self.timeout
210
+ while True:
211
+ with _open_as_binary(data) as data_as_binary:
212
+ # Do not use context manager as we don't want to close the connection immediately when returning
213
+ # a stream
214
+ client = aiohttp.ClientSession(
215
+ headers=headers, cookies=self.cookies, timeout=aiohttp.ClientTimeout(self.timeout)
216
+ )
217
+
218
+ try:
219
+ response = await client.post(url, json=json, data=data_as_binary)
220
+ response_error_payload = None
221
+ if response.status != 200:
222
+ try:
223
+ response_error_payload = await response.json() # get payload before connection closed
224
+ except Exception:
225
+ pass
226
+ response.raise_for_status()
227
+ if stream:
228
+ return _async_yield_from(client, response)
229
+ else:
230
+ content = await response.read()
231
+ await client.close()
232
+ return content
233
+ except asyncio.TimeoutError as error:
234
+ await client.close()
235
+ # Convert any `TimeoutError` to a `InferenceTimeoutError`
236
+ raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore
237
+ except aiohttp.ClientResponseError as error:
238
+ error.response_error_payload = response_error_payload
239
+ await client.close()
240
+ if response.status == 422 and task is not None:
241
+ error.message += f". Make sure '{task}' task is supported by the model."
242
+ if response.status == 503:
243
+ # If Model is unavailable, either raise a TimeoutError...
244
+ if timeout is not None and time.time() - t0 > timeout:
245
+ raise InferenceTimeoutError(
246
+ f"Model not loaded on the server: {url}. Please retry with a higher timeout"
247
+ f" (current: {self.timeout}).",
248
+ request=error.request,
249
+ response=error.response,
250
+ ) from error
251
+ # ...or wait 1s and retry
252
+ logger.info(f"Waiting for model to be loaded on the server: {error}")
253
+ time.sleep(1)
254
+ if timeout is not None:
255
+ timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
256
+ continue
257
+ raise error
258
+
259
+ async def audio_classification(
260
+ self,
261
+ audio: ContentT,
262
+ *,
263
+ model: Optional[str] = None,
264
+ ) -> List[ClassificationOutput]:
265
+ """
266
+ Perform audio classification on the provided audio content.
267
+
268
+ Args:
269
+ audio (Union[str, Path, bytes, BinaryIO]):
270
+ The audio content to classify. It can be raw audio bytes, a local audio file, or a URL pointing to an
271
+ audio file.
272
+ model (`str`, *optional*):
273
+ The model to use for audio classification. Can be a model ID hosted on the Hugging Face Hub
274
+ or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for
275
+ audio classification will be used.
276
+
277
+ Returns:
278
+ `List[Dict]`: The classification output containing the predicted label and its confidence.
279
+
280
+ Raises:
281
+ [`InferenceTimeoutError`]:
282
+ If the model is unavailable or the request times out.
283
+ `aiohttp.ClientResponseError`:
284
+ If the request fails with an HTTP error status code other than HTTP 503.
285
+
286
+ Example:
287
+ ```py
288
+ # Must be run in an async context
289
+ >>> from huggingface_hub import AsyncInferenceClient
290
+ >>> client = AsyncInferenceClient()
291
+ >>> await client.audio_classification("audio.flac")
292
+ [{'score': 0.4976358711719513, 'label': 'hap'}, {'score': 0.3677836060523987, 'label': 'neu'},...]
293
+ ```
294
+ """
295
+ response = await self.post(data=audio, model=model, task="audio-classification")
296
+ return _bytes_to_list(response)
297
+
298
+ async def automatic_speech_recognition(
299
+ self,
300
+ audio: ContentT,
301
+ *,
302
+ model: Optional[str] = None,
303
+ ) -> str:
304
+ """
305
+ Perform automatic speech recognition (ASR or audio-to-text) on the given audio content.
306
+
307
+ Args:
308
+ audio (Union[str, Path, bytes, BinaryIO]):
309
+ The content to transcribe. It can be raw audio bytes, local audio file, or a URL to an audio file.
310
+ model (`str`, *optional*):
311
+ The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
312
+ Inference Endpoint. If not provided, the default recommended model for ASR will be used.
313
+
314
+ Returns:
315
+ str: The transcribed text.
316
+
317
+ Raises:
318
+ [`InferenceTimeoutError`]:
319
+ If the model is unavailable or the request times out.
320
+ `aiohttp.ClientResponseError`:
321
+ If the request fails with an HTTP error status code other than HTTP 503.
322
+
323
+ Example:
324
+ ```py
325
+ # Must be run in an async context
326
+ >>> from huggingface_hub import AsyncInferenceClient
327
+ >>> client = AsyncInferenceClient()
328
+ >>> await client.automatic_speech_recognition("hello_world.flac")
329
+ "hello world"
330
+ ```
331
+ """
332
+ response = await self.post(data=audio, model=model, task="automatic-speech-recognition")
333
+ return _bytes_to_dict(response)["text"]
334
+
335
+ async def conversational(
336
+ self,
337
+ text: str,
338
+ generated_responses: Optional[List[str]] = None,
339
+ past_user_inputs: Optional[List[str]] = None,
340
+ *,
341
+ parameters: Optional[Dict[str, Any]] = None,
342
+ model: Optional[str] = None,
343
+ ) -> ConversationalOutput:
344
+ """
345
+ Generate conversational responses based on the given input text (i.e. chat with the API).
346
+
347
+ Args:
348
+ text (`str`):
349
+ The last input from the user in the conversation.
350
+ generated_responses (`List[str]`, *optional*):
351
+ A list of strings corresponding to the earlier replies from the model. Defaults to None.
352
+ past_user_inputs (`List[str]`, *optional*):
353
+ A list of strings corresponding to the earlier replies from the user. Should be the same length as
354
+ `generated_responses`. Defaults to None.
355
+ parameters (`Dict[str, Any]`, *optional*):
356
+ Additional parameters for the conversational task. Defaults to None. For more details about the available
357
+ parameters, please refer to [this page](https://huggingface.co/docs/api-inference/detailed_parameters#conversational-task)
358
+ model (`str`, *optional*):
359
+ The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
360
+ a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
361
+ Defaults to None.
362
+
363
+ Returns:
364
+ `Dict`: The generated conversational output.
365
+
366
+ Raises:
367
+ [`InferenceTimeoutError`]:
368
+ If the model is unavailable or the request times out.
369
+ `aiohttp.ClientResponseError`:
370
+ If the request fails with an HTTP error status code other than HTTP 503.
371
+
372
+ Example:
373
+ ```py
374
+ # Must be run in an async context
375
+ >>> from huggingface_hub import AsyncInferenceClient
376
+ >>> client = AsyncInferenceClient()
377
+ >>> output = await client.conversational("Hi, who are you?")
378
+ >>> output
379
+ {'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256 async for open-end generation.']}
380
+ >>> await client.conversational(
381
+ ... "Wow, that's scary!",
382
+ ... generated_responses=output["conversation"]["generated_responses"],
383
+ ... past_user_inputs=output["conversation"]["past_user_inputs"],
384
+ ... )
385
+ ```
386
+ """
387
+ payload: Dict[str, Any] = {"inputs": {"text": text}}
388
+ if generated_responses is not None:
389
+ payload["inputs"]["generated_responses"] = generated_responses
390
+ if past_user_inputs is not None:
391
+ payload["inputs"]["past_user_inputs"] = past_user_inputs
392
+ if parameters is not None:
393
+ payload["parameters"] = parameters
394
+ response = await self.post(json=payload, model=model, task="conversational")
395
+ return _bytes_to_dict(response) # type: ignore
396
+
397
+ async def visual_question_answering(
398
+ self,
399
+ image: ContentT,
400
+ question: str,
401
+ *,
402
+ model: Optional[str] = None,
403
+ ) -> List[str]:
404
+ """
405
+ Answering open-ended questions based on an image.
406
+
407
+ Args:
408
+ image (`Union[str, Path, bytes, BinaryIO]`):
409
+ The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
410
+ question (`str`):
411
+ Question to be answered.
412
+ model (`str`, *optional*):
413
+ The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
414
+ a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used.
415
+ Defaults to None.
416
+
417
+ Returns:
418
+ `List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
419
+
420
+ Raises:
421
+ `InferenceTimeoutError`:
422
+ If the model is unavailable or the request times out.
423
+ `aiohttp.ClientResponseError`:
424
+ If the request fails with an HTTP error status code other than HTTP 503.
425
+
426
+ Example:
427
+ ```py
428
+ # Must be run in an async context
429
+ >>> from huggingface_hub import AsyncInferenceClient
430
+ >>> client = AsyncInferenceClient()
431
+ >>> await client.visual_question_answering(
432
+ ... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg",
433
+ ... question="What is the animal doing?"
434
+ ... )
435
+ [{'score': 0.778609573841095, 'answer': 'laying down'},{'score': 0.6957435607910156, 'answer': 'sitting'}, ...]
436
+ ```
437
+ """
438
+ payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
439
+ response = await self.post(json=payload, model=model, task="visual-question-answering")
440
+ return _bytes_to_list(response)
441
+
442
+ async def document_question_answering(
443
+ self,
444
+ image: ContentT,
445
+ question: str,
446
+ *,
447
+ model: Optional[str] = None,
448
+ ) -> List[QuestionAnsweringOutput]:
449
+ """
450
+ Answer questions on document images.
451
+
452
+ Args:
453
+ image (`Union[str, Path, bytes, BinaryIO]`):
454
+ The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
455
+ question (`str`):
456
+ Question to be answered.
457
+ model (`str`, *optional*):
458
+ The model to use for the document question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
459
+ a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used.
460
+ Defaults to None.
461
+
462
+ Returns:
463
+ `List[Dict]`: a list of dictionaries containing the predicted label, associated probability, word ids, and page number.
464
+
465
+ Raises:
466
+ [`InferenceTimeoutError`]:
467
+ If the model is unavailable or the request times out.
468
+ `aiohttp.ClientResponseError`:
469
+ If the request fails with an HTTP error status code other than HTTP 503.
470
+
471
+ Example:
472
+ ```py
473
+ # Must be run in an async context
474
+ >>> from huggingface_hub import AsyncInferenceClient
475
+ >>> client = AsyncInferenceClient()
476
+ >>> await client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?")
477
+ [{'score': 0.42515629529953003, 'answer': 'us-001', 'start': 16, 'end': 16}]
478
+ ```
479
+ """
480
+ payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
481
+ response = await self.post(json=payload, model=model, task="document-question-answering")
482
+ return _bytes_to_list(response)
483
+
484
+ async def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray":
485
+ """
486
+ Generate embeddings for a given text.
487
+
488
+ Args:
489
+ text (`str`):
490
+ The text to embed.
491
+ model (`str`, *optional*):
492
+ The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
493
+ a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
494
+ Defaults to None.
495
+
496
+ Returns:
497
+ `np.ndarray`: The embedding representing the input text as a float32 numpy array.
498
+
499
+ Raises:
500
+ [`InferenceTimeoutError`]:
501
+ If the model is unavailable or the request times out.
502
+ `aiohttp.ClientResponseError`:
503
+ If the request fails with an HTTP error status code other than HTTP 503.
504
+
505
+ Example:
506
+ ```py
507
+ # Must be run in an async context
508
+ >>> from huggingface_hub import AsyncInferenceClient
509
+ >>> client = AsyncInferenceClient()
510
+ >>> await client.feature_extraction("Hi, who are you?")
511
+ array([[ 2.424802 , 2.93384 , 1.1750331 , ..., 1.240499, -0.13776633, -0.7889173 ],
512
+ [-0.42943227, -0.6364878 , -1.693462 , ..., 0.41978157, -2.4336355 , 0.6162071 ],
513
+ ...,
514
+ [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
515
+ ```
516
+ """
517
+ response = await self.post(json={"inputs": text}, model=model, task="feature-extraction")
518
+ np = _import_numpy()
519
+ return np.array(_bytes_to_dict(response), dtype="float32")
520
+
521
+ async def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[FillMaskOutput]:
522
+ """
523
+ Fill in a hole with a missing word (token to be precise).
524
+
525
+ Args:
526
+ text (`str`):
527
+ a string to be filled from, must contain the [MASK] token (check model card for exact name of the mask).
528
+ model (`str`, *optional*):
529
+ The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to
530
+ a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used.
531
+ Defaults to None.
532
+
533
+ Returns:
534
+ `List[Dict]`: a list of fill mask output dictionaries containing the predicted label, associated
535
+ probability, token reference, and completed text.
536
+
537
+ Raises:
538
+ [`InferenceTimeoutError`]:
539
+ If the model is unavailable or the request times out.
540
+ `aiohttp.ClientResponseError`:
541
+ If the request fails with an HTTP error status code other than HTTP 503.
542
+
543
+ Example:
544
+ ```py
545
+ # Must be run in an async context
546
+ >>> from huggingface_hub import AsyncInferenceClient
547
+ >>> client = AsyncInferenceClient()
548
+ >>> await client.fill_mask("The goal of life is <mask>.")
549
+ [{'score': 0.06897063553333282,
550
+ 'token': 11098,
551
+ 'token_str': ' happiness',
552
+ 'sequence': 'The goal of life is happiness.'},
553
+ {'score': 0.06554922461509705,
554
+ 'token': 45075,
555
+ 'token_str': ' immortality',
556
+ 'sequence': 'The goal of life is immortality.'}]
557
+ ```
558
+ """
559
+ response = await self.post(json={"inputs": text}, model=model, task="fill-mask")
560
+ return _bytes_to_list(response)
561
+
562
+ async def image_classification(
563
+ self,
564
+ image: ContentT,
565
+ *,
566
+ model: Optional[str] = None,
567
+ ) -> List[ClassificationOutput]:
568
+ """
569
+ Perform image classification on the given image using the specified model.
570
+
571
+ Args:
572
+ image (`Union[str, Path, bytes, BinaryIO]`):
573
+ The image to classify. It can be raw bytes, an image file, or a URL to an online image.
574
+ model (`str`, *optional*):
575
+ The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a
576
+ deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used.
577
+
578
+ Returns:
579
+ `List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
580
+
581
+ Raises:
582
+ [`InferenceTimeoutError`]:
583
+ If the model is unavailable or the request times out.
584
+ `aiohttp.ClientResponseError`:
585
+ If the request fails with an HTTP error status code other than HTTP 503.
586
+
587
+ Example:
588
+ ```py
589
+ # Must be run in an async context
590
+ >>> from huggingface_hub import AsyncInferenceClient
591
+ >>> client = AsyncInferenceClient()
592
+ >>> await client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg")
593
+ [{'score': 0.9779096841812134, 'label': 'Blenheim spaniel'}, ...]
594
+ ```
595
+ """
596
+ response = await self.post(data=image, model=model, task="image-classification")
597
+ return _bytes_to_list(response)
598
+
599
+ async def image_segmentation(
600
+ self,
601
+ image: ContentT,
602
+ *,
603
+ model: Optional[str] = None,
604
+ ) -> List[ImageSegmentationOutput]:
605
+ """
606
+ Perform image segmentation on the given image using the specified model.
607
+
608
+ <Tip warning={true}>
609
+
610
+ You must have `PIL` installed if you want to work with images (`pip install Pillow`).
611
+
612
+ </Tip>
613
+
614
+ Args:
615
+ image (`Union[str, Path, bytes, BinaryIO]`):
616
+ The image to segment. It can be raw bytes, an image file, or a URL to an online image.
617
+ model (`str`, *optional*):
618
+ The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a
619
+ deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used.
620
+
621
+ Returns:
622
+ `List[Dict]`: A list of dictionaries containing the segmented masks and associated attributes.
623
+
624
+ Raises:
625
+ [`InferenceTimeoutError`]:
626
+ If the model is unavailable or the request times out.
627
+ `aiohttp.ClientResponseError`:
628
+ If the request fails with an HTTP error status code other than HTTP 503.
629
+
630
+ Example:
631
+ ```py
632
+ # Must be run in an async context
633
+ >>> from huggingface_hub import AsyncInferenceClient
634
+ >>> client = AsyncInferenceClient()
635
+ >>> await client.image_segmentation("cat.jpg"):
636
+ [{'score': 0.989008, 'label': 'LABEL_184', 'mask': <PIL.PngImagePlugin.PngImageFile image mode=L size=400x300 at 0x7FDD2B129CC0>}, ...]
637
+ ```
638
+ """
639
+
640
+ # Segment
641
+ response = await self.post(data=image, model=model, task="image-segmentation")
642
+ output = _bytes_to_dict(response)
643
+
644
+ # Parse masks as PIL Image
645
+ if not isinstance(output, list):
646
+ raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
647
+ for item in output:
648
+ item["mask"] = _b64_to_image(item["mask"])
649
+ return output
650
+
651
+ async def image_to_image(
652
+ self,
653
+ image: ContentT,
654
+ prompt: Optional[str] = None,
655
+ *,
656
+ negative_prompt: Optional[str] = None,
657
+ height: Optional[int] = None,
658
+ width: Optional[int] = None,
659
+ num_inference_steps: Optional[int] = None,
660
+ guidance_scale: Optional[float] = None,
661
+ model: Optional[str] = None,
662
+ **kwargs,
663
+ ) -> "Image":
664
+ """
665
+ Perform image-to-image translation using a specified model.
666
+
667
+ <Tip warning={true}>
668
+
669
+ You must have `PIL` installed if you want to work with images (`pip install Pillow`).
670
+
671
+ </Tip>
672
+
673
+ Args:
674
+ image (`Union[str, Path, bytes, BinaryIO]`):
675
+ The input image for translation. It can be raw bytes, an image file, or a URL to an online image.
676
+ prompt (`str`, *optional*):
677
+ The text prompt to guide the image generation.
678
+ negative_prompt (`str`, *optional*):
679
+ A negative prompt to guide the translation process.
680
+ height (`int`, *optional*):
681
+ The height in pixels of the generated image.
682
+ width (`int`, *optional*):
683
+ The width in pixels of the generated image.
684
+ num_inference_steps (`int`, *optional*):
685
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
686
+ expense of slower inference.
687
+ guidance_scale (`float`, *optional*):
688
+ Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
689
+ usually at the expense of lower image quality.
690
+ model (`str`, *optional*):
691
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
692
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
693
+
694
+ Returns:
695
+ `Image`: The translated image.
696
+
697
+ Raises:
698
+ [`InferenceTimeoutError`]:
699
+ If the model is unavailable or the request times out.
700
+ `aiohttp.ClientResponseError`:
701
+ If the request fails with an HTTP error status code other than HTTP 503.
702
+
703
+ Example:
704
+ ```py
705
+ # Must be run in an async context
706
+ >>> from huggingface_hub import AsyncInferenceClient
707
+ >>> client = AsyncInferenceClient()
708
+ >>> image = await client.image_to_image("cat.jpg", prompt="turn the cat into a tiger")
709
+ >>> image.save("tiger.jpg")
710
+ ```
711
+ """
712
+ parameters = {
713
+ "prompt": prompt,
714
+ "negative_prompt": negative_prompt,
715
+ "height": height,
716
+ "width": width,
717
+ "num_inference_steps": num_inference_steps,
718
+ "guidance_scale": guidance_scale,
719
+ **kwargs,
720
+ }
721
+ if all(parameter is None for parameter in parameters.values()):
722
+ # Either only an image to send => send as raw bytes
723
+ data = image
724
+ payload: Optional[Dict[str, Any]] = None
725
+ else:
726
+ # Or an image + some parameters => use base64 encoding
727
+ data = None
728
+ payload = {"inputs": _b64_encode(image)}
729
+ for key, value in parameters.items():
730
+ if value is not None:
731
+ payload.setdefault("parameters", {})[key] = value
732
+
733
+ response = await self.post(json=payload, data=data, model=model, task="image-to-image")
734
+ return _bytes_to_image(response)
735
+
736
+ async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> str:
737
+ """
738
+ Takes an input image and return text.
739
+
740
+ Models can have very different outputs depending on your use case (image captioning, optical character recognition
741
+ (OCR), Pix2Struct, etc). Please have a look to the model card to learn more about a model's specificities.
742
+
743
+ Args:
744
+ image (`Union[str, Path, bytes, BinaryIO]`):
745
+ The input image to caption. It can be raw bytes, an image file, or a URL to an online image..
746
+ model (`str`, *optional*):
747
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
748
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
749
+
750
+ Returns:
751
+ `str`: The generated text.
752
+
753
+ Raises:
754
+ [`InferenceTimeoutError`]:
755
+ If the model is unavailable or the request times out.
756
+ `aiohttp.ClientResponseError`:
757
+ If the request fails with an HTTP error status code other than HTTP 503.
758
+
759
+ Example:
760
+ ```py
761
+ # Must be run in an async context
762
+ >>> from huggingface_hub import AsyncInferenceClient
763
+ >>> client = AsyncInferenceClient()
764
+ >>> await client.image_to_text("cat.jpg")
765
+ 'a cat standing in a grassy field '
766
+ >>> await client.image_to_text("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg")
767
+ 'a dog laying on the grass next to a flower pot '
768
+ ```
769
+ """
770
+ response = await self.post(data=image, model=model, task="image-to-text")
771
+ return _bytes_to_dict(response)[0]["generated_text"]
772
+
773
+ async def list_deployed_models(
774
+ self, frameworks: Union[None, str, Literal["all"], List[str]] = None
775
+ ) -> Dict[str, List[str]]:
776
+ """
777
+ List models currently deployed on the Inference API service.
778
+
779
+ This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
780
+ are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
781
+ specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
782
+ in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
783
+ frameworks are checked, the more time it will take.
784
+
785
+ <Tip>
786
+
787
+ This endpoint is mostly useful for discoverability. If you already know which model you want to use and want to
788
+ check its availability, you can directly use [`~InferenceClient.get_model_status`].
789
+
790
+ </Tip>
791
+
792
+ Args:
793
+ frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
794
+ The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
795
+ "all", all available frameworks will be tested. It is also possible to provide a single framework or a
796
+ custom set of frameworks to check.
797
+
798
+ Returns:
799
+ `Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.
800
+
801
+ Example:
802
+ ```py
803
+ # Must be run in an async contextthon
804
+ >>> from huggingface_hub import AsyncInferenceClient
805
+ >>> client = AsyncInferenceClient()
806
+
807
+ # Discover zero-shot-classification models currently deployed
808
+ >>> models = await client.list_deployed_models()
809
+ >>> models["zero-shot-classification"]
810
+ ['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]
811
+
812
+ # List from only 1 framework
813
+ >>> await client.list_deployed_models("text-generation-inference")
814
+ {'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
815
+ ```
816
+ """
817
+ # Resolve which frameworks to check
818
+ if frameworks is None:
819
+ frameworks = MAIN_INFERENCE_API_FRAMEWORKS
820
+ elif frameworks == "all":
821
+ frameworks = ALL_INFERENCE_API_FRAMEWORKS
822
+ elif isinstance(frameworks, str):
823
+ frameworks = [frameworks]
824
+ frameworks = list(set(frameworks))
825
+
826
+ # Fetch them iteratively
827
+ models_by_task: Dict[str, List[str]] = {}
828
+
829
+ def _unpack_response(framework: str, items: List[Dict]) -> None:
830
+ for model in items:
831
+ if framework == "sentence-transformers":
832
+ # Model running with the `sentence-transformers` framework can work with both tasks even if not
833
+ # branded as such in the API response
834
+ models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
835
+ models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
836
+ else:
837
+ models_by_task.setdefault(model["task"], []).append(model["model_id"])
838
+
839
+ async def _fetch_framework(framework: str) -> None:
840
+ async with _import_aiohttp().ClientSession(headers=self.headers) as client:
841
+ response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}")
842
+ response.raise_for_status()
843
+ _unpack_response(framework, await response.json())
844
+
845
+ import asyncio
846
+
847
+ await asyncio.gather(*[_fetch_framework(framework) for framework in frameworks])
848
+
849
+ # Sort alphabetically for discoverability and return
850
+ for task, models in models_by_task.items():
851
+ models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
852
+ return models_by_task
853
+
854
+ async def object_detection(
855
+ self,
856
+ image: ContentT,
857
+ *,
858
+ model: Optional[str] = None,
859
+ ) -> List[ObjectDetectionOutput]:
860
+ """
861
+ Perform object detection on the given image using the specified model.
862
+
863
+ <Tip warning={true}>
864
+
865
+ You must have `PIL` installed if you want to work with images (`pip install Pillow`).
866
+
867
+ </Tip>
868
+
869
+ Args:
870
+ image (`Union[str, Path, bytes, BinaryIO]`):
871
+ The image to detect objects on. It can be raw bytes, an image file, or a URL to an online image.
872
+ model (`str`, *optional*):
873
+ The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a
874
+ deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used.
875
+
876
+ Returns:
877
+ `List[ObjectDetectionOutput]`: A list of dictionaries containing the bounding boxes and associated attributes.
878
+
879
+ Raises:
880
+ [`InferenceTimeoutError`]:
881
+ If the model is unavailable or the request times out.
882
+ `aiohttp.ClientResponseError`:
883
+ If the request fails with an HTTP error status code other than HTTP 503.
884
+ `ValueError`:
885
+ If the request output is not a List.
886
+
887
+ Example:
888
+ ```py
889
+ # Must be run in an async context
890
+ >>> from huggingface_hub import AsyncInferenceClient
891
+ >>> client = AsyncInferenceClient()
892
+ >>> await client.object_detection("people.jpg"):
893
+ [{"score":0.9486683011054993,"label":"person","box":{"xmin":59,"ymin":39,"xmax":420,"ymax":510}}, ... ]
894
+ ```
895
+ """
896
+ # detect objects
897
+ response = await self.post(data=image, model=model, task="object-detection")
898
+ output = _bytes_to_dict(response)
899
+ if not isinstance(output, list):
900
+ raise ValueError(f"Server output must be a list. Got {type(output)}: {str(output)[:200]}...")
901
+ return output
902
+
903
+ async def question_answering(
904
+ self, question: str, context: str, *, model: Optional[str] = None
905
+ ) -> QuestionAnsweringOutput:
906
+ """
907
+ Retrieve the answer to a question from a given text.
908
+
909
+ Args:
910
+ question (`str`):
911
+ Question to be answered.
912
+ context (`str`):
913
+ The context of the question.
914
+ model (`str`):
915
+ The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
916
+ a deployed Inference Endpoint.
917
+
918
+ Returns:
919
+ `Dict`: a dictionary of question answering output containing the score, start index, end index, and answer.
920
+
921
+ Raises:
922
+ [`InferenceTimeoutError`]:
923
+ If the model is unavailable or the request times out.
924
+ `aiohttp.ClientResponseError`:
925
+ If the request fails with an HTTP error status code other than HTTP 503.
926
+
927
+ Example:
928
+ ```py
929
+ # Must be run in an async context
930
+ >>> from huggingface_hub import AsyncInferenceClient
931
+ >>> client = AsyncInferenceClient()
932
+ >>> await client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.")
933
+ {'score': 0.9326562285423279, 'start': 11, 'end': 16, 'answer': 'Clara'}
934
+ ```
935
+ """
936
+
937
+ payload: Dict[str, Any] = {"question": question, "context": context}
938
+ response = await self.post(
939
+ json=payload,
940
+ model=model,
941
+ task="question-answering",
942
+ )
943
+ return _bytes_to_dict(response) # type: ignore
944
+
945
+ async def sentence_similarity(
946
+ self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None
947
+ ) -> List[float]:
948
+ """
949
+ Compute the semantic similarity between a sentence and a list of other sentences by comparing their embeddings.
950
+
951
+ Args:
952
+ sentence (`str`):
953
+ The main sentence to compare to others.
954
+ other_sentences (`List[str]`):
955
+ The list of sentences to compare to.
956
+ model (`str`, *optional*):
957
+ The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
958
+ a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
959
+ Defaults to None.
960
+
961
+ Returns:
962
+ `List[float]`: The embedding representing the input text.
963
+
964
+ Raises:
965
+ [`InferenceTimeoutError`]:
966
+ If the model is unavailable or the request times out.
967
+ `aiohttp.ClientResponseError`:
968
+ If the request fails with an HTTP error status code other than HTTP 503.
969
+
970
+ Example:
971
+ ```py
972
+ # Must be run in an async context
973
+ >>> from huggingface_hub import AsyncInferenceClient
974
+ >>> client = AsyncInferenceClient()
975
+ >>> await client.sentence_similarity(
976
+ ... "Machine learning is so easy.",
977
+ ... other_sentences=[
978
+ ... "Deep learning is so straightforward.",
979
+ ... "This is so difficult, like rocket science.",
980
+ ... "I can't believe how much I struggled with this.",
981
+ ... ],
982
+ ... )
983
+ [0.7785726189613342, 0.45876261591911316, 0.2906220555305481]
984
+ ```
985
+ """
986
+ response = await self.post(
987
+ json={"inputs": {"source_sentence": sentence, "sentences": other_sentences}},
988
+ model=model,
989
+ task="sentence-similarity",
990
+ )
991
+ return _bytes_to_list(response)
992
+
993
+ async def summarization(
994
+ self,
995
+ text: str,
996
+ *,
997
+ parameters: Optional[Dict[str, Any]] = None,
998
+ model: Optional[str] = None,
999
+ ) -> str:
1000
+ """
1001
+ Generate a summary of a given text using a specified model.
1002
+
1003
+ Args:
1004
+ text (`str`):
1005
+ The input text to summarize.
1006
+ parameters (`Dict[str, Any]`, *optional*):
1007
+ Additional parameters for summarization. Check out this [page](https://huggingface.co/docs/api-inference/detailed_parameters#summarization-task)
1008
+ for more details.
1009
+ model (`str`, *optional*):
1010
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1011
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1012
+
1013
+ Returns:
1014
+ `str`: The generated summary text.
1015
+
1016
+ Raises:
1017
+ [`InferenceTimeoutError`]:
1018
+ If the model is unavailable or the request times out.
1019
+ `aiohttp.ClientResponseError`:
1020
+ If the request fails with an HTTP error status code other than HTTP 503.
1021
+
1022
+ Example:
1023
+ ```py
1024
+ # Must be run in an async context
1025
+ >>> from huggingface_hub import AsyncInferenceClient
1026
+ >>> client = AsyncInferenceClient()
1027
+ >>> await client.summarization("The Eiffel tower...")
1028
+ 'The Eiffel tower is one of the most famous landmarks in the world....'
1029
+ ```
1030
+ """
1031
+ payload: Dict[str, Any] = {"inputs": text}
1032
+ if parameters is not None:
1033
+ payload["parameters"] = parameters
1034
+ response = await self.post(json=payload, model=model, task="summarization")
1035
+ return _bytes_to_dict(response)[0]["summary_text"]
1036
+
1037
+ async def table_question_answering(
1038
+ self, table: Dict[str, Any], query: str, *, model: Optional[str] = None
1039
+ ) -> TableQuestionAnsweringOutput:
1040
+ """
1041
+ Retrieve the answer to a question from information given in a table.
1042
+
1043
+ Args:
1044
+ table (`str`):
1045
+ A table of data represented as a dict of lists where entries are headers and the lists are all the
1046
+ values, all lists must have the same size.
1047
+ query (`str`):
1048
+ The query in plain text that you want to ask the table.
1049
+ model (`str`):
1050
+ The model to use for the table-question-answering task. Can be a model ID hosted on the Hugging Face
1051
+ Hub or a URL to a deployed Inference Endpoint.
1052
+
1053
+ Returns:
1054
+ `Dict`: a dictionary of table question answering output containing the answer, coordinates, cells and the aggregator used.
1055
+
1056
+ Raises:
1057
+ [`InferenceTimeoutError`]:
1058
+ If the model is unavailable or the request times out.
1059
+ `aiohttp.ClientResponseError`:
1060
+ If the request fails with an HTTP error status code other than HTTP 503.
1061
+
1062
+ Example:
1063
+ ```py
1064
+ # Must be run in an async context
1065
+ >>> from huggingface_hub import AsyncInferenceClient
1066
+ >>> client = AsyncInferenceClient()
1067
+ >>> query = "How many stars does the transformers repository have?"
1068
+ >>> table = {"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"]}
1069
+ >>> await client.table_question_answering(table, query, model="google/tapas-base-finetuned-wtq")
1070
+ {'answer': 'AVERAGE > 36542', 'coordinates': [[0, 1]], 'cells': ['36542'], 'aggregator': 'AVERAGE'}
1071
+ ```
1072
+ """
1073
+ response = await self.post(
1074
+ json={
1075
+ "query": query,
1076
+ "table": table,
1077
+ },
1078
+ model=model,
1079
+ task="table-question-answering",
1080
+ )
1081
+ return _bytes_to_dict(response) # type: ignore
1082
+
1083
+ async def tabular_classification(self, table: Dict[str, Any], *, model: str) -> List[str]:
1084
+ """
1085
+ Classifying a target category (a group) based on a set of attributes.
1086
+
1087
+ Args:
1088
+ table (`Dict[str, Any]`):
1089
+ Set of attributes to classify.
1090
+ model (`str`):
1091
+ The model to use for the tabular-classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1092
+ a deployed Inference Endpoint.
1093
+
1094
+ Returns:
1095
+ `List`: a list of labels, one per row in the initial table.
1096
+
1097
+ Raises:
1098
+ [`InferenceTimeoutError`]:
1099
+ If the model is unavailable or the request times out.
1100
+ `aiohttp.ClientResponseError`:
1101
+ If the request fails with an HTTP error status code other than HTTP 503.
1102
+
1103
+ Example:
1104
+ ```py
1105
+ # Must be run in an async context
1106
+ >>> from huggingface_hub import AsyncInferenceClient
1107
+ >>> client = AsyncInferenceClient()
1108
+ >>> table = {
1109
+ ... "fixed_acidity": ["7.4", "7.8", "10.3"],
1110
+ ... "volatile_acidity": ["0.7", "0.88", "0.32"],
1111
+ ... "citric_acid": ["0", "0", "0.45"],
1112
+ ... "residual_sugar": ["1.9", "2.6", "6.4"],
1113
+ ... "chlorides": ["0.076", "0.098", "0.073"],
1114
+ ... "free_sulfur_dioxide": ["11", "25", "5"],
1115
+ ... "total_sulfur_dioxide": ["34", "67", "13"],
1116
+ ... "density": ["0.9978", "0.9968", "0.9976"],
1117
+ ... "pH": ["3.51", "3.2", "3.23"],
1118
+ ... "sulphates": ["0.56", "0.68", "0.82"],
1119
+ ... "alcohol": ["9.4", "9.8", "12.6"],
1120
+ ... }
1121
+ >>> await client.tabular_classification(table=table, model="julien-c/wine-quality")
1122
+ ["5", "5", "5"]
1123
+ ```
1124
+ """
1125
+ response = await self.post(json={"table": table}, model=model, task="tabular-classification")
1126
+ return _bytes_to_list(response)
1127
+
1128
+ async def tabular_regression(self, table: Dict[str, Any], *, model: str) -> List[float]:
1129
+ """
1130
+ Predicting a numerical target value given a set of attributes/features in a table.
1131
+
1132
+ Args:
1133
+ table (`Dict[str, Any]`):
1134
+ Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical.
1135
+ model (`str`):
1136
+ The model to use for the tabular-regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1137
+ a deployed Inference Endpoint.
1138
+
1139
+ Returns:
1140
+ `List`: a list of predicted numerical target values.
1141
+
1142
+ Raises:
1143
+ [`InferenceTimeoutError`]:
1144
+ If the model is unavailable or the request times out.
1145
+ `aiohttp.ClientResponseError`:
1146
+ If the request fails with an HTTP error status code other than HTTP 503.
1147
+
1148
+ Example:
1149
+ ```py
1150
+ # Must be run in an async context
1151
+ >>> from huggingface_hub import AsyncInferenceClient
1152
+ >>> client = AsyncInferenceClient()
1153
+ >>> table = {
1154
+ ... "Height": ["11.52", "12.48", "12.3778"],
1155
+ ... "Length1": ["23.2", "24", "23.9"],
1156
+ ... "Length2": ["25.4", "26.3", "26.5"],
1157
+ ... "Length3": ["30", "31.2", "31.1"],
1158
+ ... "Species": ["Bream", "Bream", "Bream"],
1159
+ ... "Width": ["4.02", "4.3056", "4.6961"],
1160
+ ... }
1161
+ >>> await client.tabular_regression(table, model="scikit-learn/Fish-Weight")
1162
+ [110, 120, 130]
1163
+ ```
1164
+ """
1165
+ response = await self.post(json={"table": table}, model=model, task="tabular-regression")
1166
+ return _bytes_to_list(response)
1167
+
1168
+ async def text_classification(self, text: str, *, model: Optional[str] = None) -> List[ClassificationOutput]:
1169
+ """
1170
+ Perform text classification (e.g. sentiment-analysis) on the given text.
1171
+
1172
+ Args:
1173
+ text (`str`):
1174
+ A string to be classified.
1175
+ model (`str`, *optional*):
1176
+ The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1177
+ a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used.
1178
+ Defaults to None.
1179
+
1180
+ Returns:
1181
+ `List[Dict]`: a list of dictionaries containing the predicted label and associated probability.
1182
+
1183
+ Raises:
1184
+ [`InferenceTimeoutError`]:
1185
+ If the model is unavailable or the request times out.
1186
+ `aiohttp.ClientResponseError`:
1187
+ If the request fails with an HTTP error status code other than HTTP 503.
1188
+
1189
+ Example:
1190
+ ```py
1191
+ # Must be run in an async context
1192
+ >>> from huggingface_hub import AsyncInferenceClient
1193
+ >>> client = AsyncInferenceClient()
1194
+ >>> await client.text_classification("I like you")
1195
+ [{'label': 'POSITIVE', 'score': 0.9998695850372314}, {'label': 'NEGATIVE', 'score': 0.0001304351753788069}]
1196
+ ```
1197
+ """
1198
+ response = await self.post(json={"inputs": text}, model=model, task="text-classification")
1199
+ return _bytes_to_list(response)[0]
1200
+
1201
+ @overload
1202
+ async def text_generation( # type: ignore
1203
+ self,
1204
+ prompt: str,
1205
+ *,
1206
+ details: Literal[False] = ...,
1207
+ stream: Literal[False] = ...,
1208
+ model: Optional[str] = None,
1209
+ do_sample: bool = False,
1210
+ max_new_tokens: int = 20,
1211
+ best_of: Optional[int] = None,
1212
+ repetition_penalty: Optional[float] = None,
1213
+ return_full_text: bool = False,
1214
+ seed: Optional[int] = None,
1215
+ stop_sequences: Optional[List[str]] = None,
1216
+ temperature: Optional[float] = None,
1217
+ top_k: Optional[int] = None,
1218
+ top_p: Optional[float] = None,
1219
+ truncate: Optional[int] = None,
1220
+ typical_p: Optional[float] = None,
1221
+ watermark: bool = False,
1222
+ ) -> str:
1223
+ ...
1224
+
1225
+ @overload
1226
+ async def text_generation( # type: ignore
1227
+ self,
1228
+ prompt: str,
1229
+ *,
1230
+ details: Literal[True] = ...,
1231
+ stream: Literal[False] = ...,
1232
+ model: Optional[str] = None,
1233
+ do_sample: bool = False,
1234
+ max_new_tokens: int = 20,
1235
+ best_of: Optional[int] = None,
1236
+ repetition_penalty: Optional[float] = None,
1237
+ return_full_text: bool = False,
1238
+ seed: Optional[int] = None,
1239
+ stop_sequences: Optional[List[str]] = None,
1240
+ temperature: Optional[float] = None,
1241
+ top_k: Optional[int] = None,
1242
+ top_p: Optional[float] = None,
1243
+ truncate: Optional[int] = None,
1244
+ typical_p: Optional[float] = None,
1245
+ watermark: bool = False,
1246
+ ) -> TextGenerationResponse:
1247
+ ...
1248
+
1249
+ @overload
1250
+ async def text_generation( # type: ignore
1251
+ self,
1252
+ prompt: str,
1253
+ *,
1254
+ details: Literal[False] = ...,
1255
+ stream: Literal[True] = ...,
1256
+ model: Optional[str] = None,
1257
+ do_sample: bool = False,
1258
+ max_new_tokens: int = 20,
1259
+ best_of: Optional[int] = None,
1260
+ repetition_penalty: Optional[float] = None,
1261
+ return_full_text: bool = False,
1262
+ seed: Optional[int] = None,
1263
+ stop_sequences: Optional[List[str]] = None,
1264
+ temperature: Optional[float] = None,
1265
+ top_k: Optional[int] = None,
1266
+ top_p: Optional[float] = None,
1267
+ truncate: Optional[int] = None,
1268
+ typical_p: Optional[float] = None,
1269
+ watermark: bool = False,
1270
+ ) -> AsyncIterable[str]:
1271
+ ...
1272
+
1273
+ @overload
1274
+ async def text_generation(
1275
+ self,
1276
+ prompt: str,
1277
+ *,
1278
+ details: Literal[True] = ...,
1279
+ stream: Literal[True] = ...,
1280
+ model: Optional[str] = None,
1281
+ do_sample: bool = False,
1282
+ max_new_tokens: int = 20,
1283
+ best_of: Optional[int] = None,
1284
+ repetition_penalty: Optional[float] = None,
1285
+ return_full_text: bool = False,
1286
+ seed: Optional[int] = None,
1287
+ stop_sequences: Optional[List[str]] = None,
1288
+ temperature: Optional[float] = None,
1289
+ top_k: Optional[int] = None,
1290
+ top_p: Optional[float] = None,
1291
+ truncate: Optional[int] = None,
1292
+ typical_p: Optional[float] = None,
1293
+ watermark: bool = False,
1294
+ ) -> AsyncIterable[TextGenerationStreamResponse]:
1295
+ ...
1296
+
1297
+ async def text_generation(
1298
+ self,
1299
+ prompt: str,
1300
+ *,
1301
+ details: bool = False,
1302
+ stream: bool = False,
1303
+ model: Optional[str] = None,
1304
+ do_sample: bool = False,
1305
+ max_new_tokens: int = 20,
1306
+ best_of: Optional[int] = None,
1307
+ repetition_penalty: Optional[float] = None,
1308
+ return_full_text: bool = False,
1309
+ seed: Optional[int] = None,
1310
+ stop_sequences: Optional[List[str]] = None,
1311
+ temperature: Optional[float] = None,
1312
+ top_k: Optional[int] = None,
1313
+ top_p: Optional[float] = None,
1314
+ truncate: Optional[int] = None,
1315
+ typical_p: Optional[float] = None,
1316
+ watermark: bool = False,
1317
+ decoder_input_details: bool = False,
1318
+ ) -> Union[str, TextGenerationResponse, AsyncIterable[str], AsyncIterable[TextGenerationStreamResponse]]:
1319
+ """
1320
+ Given a prompt, generate the following text.
1321
+
1322
+ It is recommended to have Pydantic installed in order to get inputs validated. This is preferable as it allow
1323
+ early failures.
1324
+
1325
+ API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the
1326
+ go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the
1327
+ default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but
1328
+ not exactly the same. This method is compatible with both approaches but some parameters are only available for
1329
+ `text-generation-inference`. If some parameters are ignored, a warning message is triggered but the process
1330
+ continues correctly.
1331
+
1332
+ To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.
1333
+
1334
+ Args:
1335
+ prompt (`str`):
1336
+ Input text.
1337
+ details (`bool`, *optional*):
1338
+ By default, text_generation returns a string. Pass `details=True` if you want a detailed output (tokens,
1339
+ probabilities, seed, finish reason, etc.). Only available for models running on with the
1340
+ `text-generation-inference` backend.
1341
+ stream (`bool`, *optional*):
1342
+ By default, text_generation returns the full generated text. Pass `stream=True` if you want a stream of
1343
+ tokens to be returned. Only available for models running on with the `text-generation-inference`
1344
+ backend.
1345
+ model (`str`, *optional*):
1346
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1347
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1348
+ do_sample (`bool`):
1349
+ Activate logits sampling
1350
+ max_new_tokens (`int`):
1351
+ Maximum number of generated tokens
1352
+ best_of (`int`):
1353
+ Generate best_of sequences and return the one if the highest token logprobs
1354
+ repetition_penalty (`float`):
1355
+ The parameter for repetition penalty. 1.0 means no penalty. See [this
1356
+ paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
1357
+ return_full_text (`bool`):
1358
+ Whether to prepend the prompt to the generated text
1359
+ seed (`int`):
1360
+ Random sampling seed
1361
+ stop_sequences (`List[str]`):
1362
+ Stop generating tokens if a member of `stop_sequences` is generated
1363
+ temperature (`float`):
1364
+ The value used to module the logits distribution.
1365
+ top_k (`int`):
1366
+ The number of highest probability vocabulary tokens to keep for top-k-filtering.
1367
+ top_p (`float`):
1368
+ If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
1369
+ higher are kept for generation.
1370
+ truncate (`int`):
1371
+ Truncate inputs tokens to the given size
1372
+ typical_p (`float`):
1373
+ Typical Decoding mass
1374
+ See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
1375
+ watermark (`bool`):
1376
+ Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
1377
+ decoder_input_details (`bool`):
1378
+ Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken
1379
+ into account. Defaults to `False`.
1380
+
1381
+ Returns:
1382
+ `Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]`:
1383
+ Generated text returned from the server:
1384
+ - if `stream=False` and `details=False`, the generated text is returned as a `str` (default)
1385
+ - if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]`
1386
+ - if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.inference._text_generation.TextGenerationResponse`]
1387
+ - if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.inference._text_generation.TextGenerationStreamResponse`]
1388
+
1389
+ Raises:
1390
+ `ValidationError`:
1391
+ If input values are not valid. No HTTP call is made to the server.
1392
+ [`InferenceTimeoutError`]:
1393
+ If the model is unavailable or the request times out.
1394
+ `aiohttp.ClientResponseError`:
1395
+ If the request fails with an HTTP error status code other than HTTP 503.
1396
+
1397
+ Example:
1398
+ ```py
1399
+ # Must be run in an async context
1400
+ >>> from huggingface_hub import AsyncInferenceClient
1401
+ >>> client = AsyncInferenceClient()
1402
+
1403
+ # Case 1: generate text
1404
+ >>> await client.text_generation("The huggingface_hub library is ", max_new_tokens=12)
1405
+ '100% open source and built to be easy to use.'
1406
+
1407
+ # Case 2: iterate over the generated tokens. Useful async for large generation.
1408
+ >>> async for token in await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True):
1409
+ ... print(token)
1410
+ 100
1411
+ %
1412
+ open
1413
+ source
1414
+ and
1415
+ built
1416
+ to
1417
+ be
1418
+ easy
1419
+ to
1420
+ use
1421
+ .
1422
+
1423
+ # Case 3: get more details about the generation process.
1424
+ >>> await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True)
1425
+ TextGenerationResponse(
1426
+ generated_text='100% open source and built to be easy to use.',
1427
+ details=Details(
1428
+ finish_reason=<FinishReason.Length: 'length'>,
1429
+ generated_tokens=12,
1430
+ seed=None,
1431
+ prefill=[
1432
+ InputToken(id=487, text='The', logprob=None),
1433
+ InputToken(id=53789, text=' hugging', logprob=-13.171875),
1434
+ (...)
1435
+ InputToken(id=204, text=' ', logprob=-7.0390625)
1436
+ ],
1437
+ tokens=[
1438
+ Token(id=1425, text='100', logprob=-1.0175781, special=False),
1439
+ Token(id=16, text='%', logprob=-0.0463562, special=False),
1440
+ (...)
1441
+ Token(id=25, text='.', logprob=-0.5703125, special=False)
1442
+ ],
1443
+ best_of_sequences=None
1444
+ )
1445
+ )
1446
+
1447
+ # Case 4: iterate over the generated tokens with more details.
1448
+ # Last object is more complete, containing the full generated text and the finish reason.
1449
+ >>> async for details in await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True):
1450
+ ... print(details)
1451
+ ...
1452
+ TextGenerationStreamResponse(token=Token(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None)
1453
+ TextGenerationStreamResponse(token=Token(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None)
1454
+ TextGenerationStreamResponse(token=Token(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None)
1455
+ TextGenerationStreamResponse(token=Token(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None)
1456
+ TextGenerationStreamResponse(token=Token(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None)
1457
+ TextGenerationStreamResponse(token=Token(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None)
1458
+ TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None)
1459
+ TextGenerationStreamResponse(token=Token(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None)
1460
+ TextGenerationStreamResponse(token=Token(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None)
1461
+ TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None)
1462
+ TextGenerationStreamResponse(token=Token(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None)
1463
+ TextGenerationStreamResponse(token=Token(
1464
+ id=25,
1465
+ text='.',
1466
+ logprob=-0.5703125,
1467
+ special=False),
1468
+ generated_text='100% open source and built to be easy to use.',
1469
+ details=StreamDetails(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=12, seed=None)
1470
+ )
1471
+ ```
1472
+ """
1473
+ # NOTE: Text-generation integration is taken from the text-generation-inference project. It has more features
1474
+ # like input/output validation (if Pydantic is installed). See `_text_generation.py` header for more details.
1475
+
1476
+ if decoder_input_details and not details:
1477
+ warnings.warn(
1478
+ "`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that"
1479
+ " the output from the server will be truncated."
1480
+ )
1481
+ decoder_input_details = False
1482
+
1483
+ # Validate parameters
1484
+ parameters = TextGenerationParameters(
1485
+ best_of=best_of,
1486
+ details=details,
1487
+ do_sample=do_sample,
1488
+ max_new_tokens=max_new_tokens,
1489
+ repetition_penalty=repetition_penalty,
1490
+ return_full_text=return_full_text,
1491
+ seed=seed,
1492
+ stop=stop_sequences if stop_sequences is not None else [],
1493
+ temperature=temperature,
1494
+ top_k=top_k,
1495
+ top_p=top_p,
1496
+ truncate=truncate,
1497
+ typical_p=typical_p,
1498
+ watermark=watermark,
1499
+ decoder_input_details=decoder_input_details,
1500
+ )
1501
+ request = TextGenerationRequest(inputs=prompt, stream=stream, parameters=parameters)
1502
+ payload = asdict(request)
1503
+
1504
+ # Remove some parameters if not a TGI server
1505
+ if not _is_tgi_server(model):
1506
+ ignored_parameters = []
1507
+ for key in "watermark", "stop", "details", "decoder_input_details":
1508
+ if payload["parameters"][key] is not None:
1509
+ ignored_parameters.append(key)
1510
+ del payload["parameters"][key]
1511
+ if len(ignored_parameters) > 0:
1512
+ warnings.warn(
1513
+ "API endpoint/model for text-generation is not served via TGI. Ignoring parameters"
1514
+ f" {ignored_parameters}.",
1515
+ UserWarning,
1516
+ )
1517
+ if details:
1518
+ warnings.warn(
1519
+ "API endpoint/model for text-generation is not served via TGI. Parameter `details=True` will"
1520
+ " be ignored meaning only the generated text will be returned.",
1521
+ UserWarning,
1522
+ )
1523
+ details = False
1524
+ if stream:
1525
+ raise ValueError(
1526
+ "API endpoint/model for text-generation is not served via TGI. Cannot return output as a stream."
1527
+ " Please pass `stream=False` as input."
1528
+ )
1529
+
1530
+ # Handle errors separately for more precise error messages
1531
+ try:
1532
+ bytes_output = await self.post(json=payload, model=model, task="text-generation", stream=stream) # type: ignore
1533
+ except _import_aiohttp().ClientResponseError as e:
1534
+ error_message = getattr(e, "response_error_payload", {}).get("error", "")
1535
+ if e.code == 400 and "The following `model_kwargs` are not used by the model" in error_message:
1536
+ _set_as_non_tgi(model)
1537
+ return await self.text_generation( # type: ignore
1538
+ prompt=prompt,
1539
+ details=details,
1540
+ stream=stream,
1541
+ model=model,
1542
+ do_sample=do_sample,
1543
+ max_new_tokens=max_new_tokens,
1544
+ best_of=best_of,
1545
+ repetition_penalty=repetition_penalty,
1546
+ return_full_text=return_full_text,
1547
+ seed=seed,
1548
+ stop_sequences=stop_sequences,
1549
+ temperature=temperature,
1550
+ top_k=top_k,
1551
+ top_p=top_p,
1552
+ truncate=truncate,
1553
+ typical_p=typical_p,
1554
+ watermark=watermark,
1555
+ decoder_input_details=decoder_input_details,
1556
+ )
1557
+ raise_text_generation_error(e)
1558
+
1559
+ # Parse output
1560
+ if stream:
1561
+ return _async_stream_text_generation_response(bytes_output, details) # type: ignore
1562
+
1563
+ data = _bytes_to_dict(bytes_output)[0]
1564
+ return TextGenerationResponse(**data) if details else data["generated_text"]
1565
+
1566
+ async def text_to_image(
1567
+ self,
1568
+ prompt: str,
1569
+ *,
1570
+ negative_prompt: Optional[str] = None,
1571
+ height: Optional[float] = None,
1572
+ width: Optional[float] = None,
1573
+ num_inference_steps: Optional[float] = None,
1574
+ guidance_scale: Optional[float] = None,
1575
+ model: Optional[str] = None,
1576
+ **kwargs,
1577
+ ) -> "Image":
1578
+ """
1579
+ Generate an image based on a given text using a specified model.
1580
+
1581
+ <Tip warning={true}>
1582
+
1583
+ You must have `PIL` installed if you want to work with images (`pip install Pillow`).
1584
+
1585
+ </Tip>
1586
+
1587
+ Args:
1588
+ prompt (`str`):
1589
+ The prompt to generate an image from.
1590
+ negative_prompt (`str`, *optional*):
1591
+ An optional negative prompt for the image generation.
1592
+ height (`float`, *optional*):
1593
+ The height in pixels of the image to generate.
1594
+ width (`float`, *optional*):
1595
+ The width in pixels of the image to generate.
1596
+ num_inference_steps (`int`, *optional*):
1597
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1598
+ expense of slower inference.
1599
+ guidance_scale (`float`, *optional*):
1600
+ Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1601
+ usually at the expense of lower image quality.
1602
+ model (`str`, *optional*):
1603
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1604
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1605
+
1606
+ Returns:
1607
+ `Image`: The generated image.
1608
+
1609
+ Raises:
1610
+ [`InferenceTimeoutError`]:
1611
+ If the model is unavailable or the request times out.
1612
+ `aiohttp.ClientResponseError`:
1613
+ If the request fails with an HTTP error status code other than HTTP 503.
1614
+
1615
+ Example:
1616
+ ```py
1617
+ # Must be run in an async context
1618
+ >>> from huggingface_hub import AsyncInferenceClient
1619
+ >>> client = AsyncInferenceClient()
1620
+
1621
+ >>> image = await client.text_to_image("An astronaut riding a horse on the moon.")
1622
+ >>> image.save("astronaut.png")
1623
+
1624
+ >>> image = await client.text_to_image(
1625
+ ... "An astronaut riding a horse on the moon.",
1626
+ ... negative_prompt="low resolution, blurry",
1627
+ ... model="stabilityai/stable-diffusion-2-1",
1628
+ ... )
1629
+ >>> image.save("better_astronaut.png")
1630
+ ```
1631
+ """
1632
+ payload = {"inputs": prompt}
1633
+ parameters = {
1634
+ "negative_prompt": negative_prompt,
1635
+ "height": height,
1636
+ "width": width,
1637
+ "num_inference_steps": num_inference_steps,
1638
+ "guidance_scale": guidance_scale,
1639
+ **kwargs,
1640
+ }
1641
+ for key, value in parameters.items():
1642
+ if value is not None:
1643
+ payload.setdefault("parameters", {})[key] = value # type: ignore
1644
+ response = await self.post(json=payload, model=model, task="text-to-image")
1645
+ return _bytes_to_image(response)
1646
+
1647
+ async def text_to_speech(self, text: str, *, model: Optional[str] = None) -> bytes:
1648
+ """
1649
+ Synthesize an audio of a voice pronouncing a given text.
1650
+
1651
+ Args:
1652
+ text (`str`):
1653
+ The text to synthesize.
1654
+ model (`str`, *optional*):
1655
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1656
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1657
+
1658
+ Returns:
1659
+ `bytes`: The generated audio.
1660
+
1661
+ Raises:
1662
+ [`InferenceTimeoutError`]:
1663
+ If the model is unavailable or the request times out.
1664
+ `aiohttp.ClientResponseError`:
1665
+ If the request fails with an HTTP error status code other than HTTP 503.
1666
+
1667
+ Example:
1668
+ ```py
1669
+ # Must be run in an async context
1670
+ >>> from pathlib import Path
1671
+ >>> from huggingface_hub import AsyncInferenceClient
1672
+ >>> client = AsyncInferenceClient()
1673
+
1674
+ >>> audio = await client.text_to_speech("Hello world")
1675
+ >>> Path("hello_world.flac").write_bytes(audio)
1676
+ ```
1677
+ """
1678
+ return await self.post(json={"inputs": text}, model=model, task="text-to-speech")
1679
+
1680
+ async def token_classification(self, text: str, *, model: Optional[str] = None) -> List[TokenClassificationOutput]:
1681
+ """
1682
+ Perform token classification on the given text.
1683
+ Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text.
1684
+
1685
+ Args:
1686
+ text (`str`):
1687
+ A string to be classified.
1688
+ model (`str`, *optional*):
1689
+ The model to use for the token classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1690
+ a deployed Inference Endpoint. If not provided, the default recommended token classification model will be used.
1691
+ Defaults to None.
1692
+
1693
+ Returns:
1694
+ `List[Dict]`: List of token classification outputs containing the entity group, confidence score, word, start and end index.
1695
+
1696
+ Raises:
1697
+ [`InferenceTimeoutError`]:
1698
+ If the model is unavailable or the request times out.
1699
+ `aiohttp.ClientResponseError`:
1700
+ If the request fails with an HTTP error status code other than HTTP 503.
1701
+
1702
+ Example:
1703
+ ```py
1704
+ # Must be run in an async context
1705
+ >>> from huggingface_hub import AsyncInferenceClient
1706
+ >>> client = AsyncInferenceClient()
1707
+ >>> await client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica")
1708
+ [{'entity_group': 'PER',
1709
+ 'score': 0.9971321225166321,
1710
+ 'word': 'Sarah Jessica Parker',
1711
+ 'start': 11,
1712
+ 'end': 31},
1713
+ {'entity_group': 'PER',
1714
+ 'score': 0.9773476123809814,
1715
+ 'word': 'Jessica',
1716
+ 'start': 52,
1717
+ 'end': 59}]
1718
+ ```
1719
+ """
1720
+ payload: Dict[str, Any] = {"inputs": text}
1721
+ response = await self.post(
1722
+ json=payload,
1723
+ model=model,
1724
+ task="token-classification",
1725
+ )
1726
+ return _bytes_to_list(response)
1727
+
1728
+ async def translation(
1729
+ self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None
1730
+ ) -> str:
1731
+ """
1732
+ Convert text from one language to another.
1733
+
1734
+ Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for
1735
+ your specific use case. Source and target languages usually depend on the model.
1736
+ However, it is possible to specify source and target languages for certain models. If you are working with one of these models,
1737
+ you can use `src_lang` and `tgt_lang` arguments to pass the relevant information.
1738
+ You can find this information in the model card.
1739
+
1740
+ Args:
1741
+ text (`str`):
1742
+ A string to be translated.
1743
+ model (`str`, *optional*):
1744
+ The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1745
+ a deployed Inference Endpoint. If not provided, the default recommended translation model will be used.
1746
+ Defaults to None.
1747
+ src_lang (`str`, *optional*):
1748
+ Source language of the translation task, i.e. input language. Cannot be passed without `tgt_lang`.
1749
+ tgt_lang (`str`, *optional*):
1750
+ Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`.
1751
+
1752
+ Returns:
1753
+ `str`: The generated translated text.
1754
+
1755
+ Raises:
1756
+ [`InferenceTimeoutError`]:
1757
+ If the model is unavailable or the request times out.
1758
+ `aiohttp.ClientResponseError`:
1759
+ If the request fails with an HTTP error status code other than HTTP 503.
1760
+ `ValueError`:
1761
+ If only one of the `src_lang` and `tgt_lang` arguments are provided.
1762
+
1763
+ Example:
1764
+ ```py
1765
+ # Must be run in an async context
1766
+ >>> from huggingface_hub import AsyncInferenceClient
1767
+ >>> client = AsyncInferenceClient()
1768
+ >>> await client.translation("My name is Wolfgang and I live in Berlin")
1769
+ 'Mein Name ist Wolfgang und ich lebe in Berlin.'
1770
+ >>> await client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr")
1771
+ "Je m'appelle Wolfgang et je vis à Berlin."
1772
+ ```
1773
+
1774
+ Specifying languages:
1775
+ ```py
1776
+ >>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX")
1777
+ "Mon nom est Sarah Jessica Parker mais vous pouvez m\'appeler Jessica"
1778
+ ```
1779
+ """
1780
+ # Throw error if only one of `src_lang` and `tgt_lang` was given
1781
+ if src_lang is not None and tgt_lang is None:
1782
+ raise ValueError("You cannot specify `src_lang` without specifying `tgt_lang`.")
1783
+
1784
+ if src_lang is None and tgt_lang is not None:
1785
+ raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.")
1786
+
1787
+ # If both `src_lang` and `tgt_lang` are given, pass them to the request body
1788
+ payload: Dict = {"inputs": text}
1789
+ if src_lang and tgt_lang:
1790
+ payload["parameters"] = {"src_lang": src_lang, "tgt_lang": tgt_lang}
1791
+ response = await self.post(json=payload, model=model, task="translation")
1792
+ return _bytes_to_dict(response)[0]["translation_text"]
1793
+
1794
+ async def zero_shot_classification(
1795
+ self, text: str, labels: List[str], *, multi_label: bool = False, model: Optional[str] = None
1796
+ ) -> List[ClassificationOutput]:
1797
+ """
1798
+ Provide as input a text and a set of candidate labels to classify the input text.
1799
+
1800
+ Args:
1801
+ text (`str`):
1802
+ The input text to classify.
1803
+ labels (`List[str]`):
1804
+ List of string possible labels. There must be at least 2 labels.
1805
+ multi_label (`bool`):
1806
+ Boolean that is set to True if classes can overlap.
1807
+ model (`str`, *optional*):
1808
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1809
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1810
+
1811
+ Returns:
1812
+ `List[Dict]`: List of classification outputs containing the predicted labels and their confidence.
1813
+
1814
+ Raises:
1815
+ [`InferenceTimeoutError`]:
1816
+ If the model is unavailable or the request times out.
1817
+ `aiohttp.ClientResponseError`:
1818
+ If the request fails with an HTTP error status code other than HTTP 503.
1819
+
1820
+ Example:
1821
+ ```py
1822
+ # Must be run in an async context
1823
+ >>> from huggingface_hub import AsyncInferenceClient
1824
+ >>> client = AsyncInferenceClient()
1825
+ >>> text = (
1826
+ ... "A new model offers an explanation async for how the Galilean satellites formed around the solar system's"
1827
+ ... "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling"
1828
+ ... " mysteries when he went async for a run up a hill in Nice, France."
1829
+ ... )
1830
+ >>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"]
1831
+ >>> await client.zero_shot_classification(text, labels)
1832
+ [
1833
+ {"label": "scientific discovery", "score": 0.7961668968200684},
1834
+ {"label": "space & cosmos", "score": 0.18570658564567566},
1835
+ {"label": "microbiology", "score": 0.00730885099619627},
1836
+ {"label": "archeology", "score": 0.006258360575884581},
1837
+ {"label": "robots", "score": 0.004559356719255447},
1838
+ ]
1839
+ >>> await client.zero_shot_classification(text, labels, multi_label=True)
1840
+ [
1841
+ {"label": "scientific discovery", "score": 0.9829297661781311},
1842
+ {"label": "space & cosmos", "score": 0.755190908908844},
1843
+ {"label": "microbiology", "score": 0.0005462635890580714},
1844
+ {"label": "archeology", "score": 0.00047131875180639327},
1845
+ {"label": "robots", "score": 0.00030448526376858354},
1846
+ ]
1847
+ ```
1848
+ """
1849
+ # Raise ValueError if input is less than 2 labels
1850
+ if len(labels) < 2:
1851
+ raise ValueError("You must specify at least 2 classes to compare.")
1852
+
1853
+ response = await self.post(
1854
+ json={
1855
+ "inputs": text,
1856
+ "parameters": {
1857
+ "candidate_labels": ",".join(labels),
1858
+ "multi_label": multi_label,
1859
+ },
1860
+ },
1861
+ model=model,
1862
+ task="zero-shot-classification",
1863
+ )
1864
+ output = _bytes_to_dict(response)
1865
+ return [{"label": label, "score": score} for label, score in zip(output["labels"], output["scores"])]
1866
+
1867
+ async def zero_shot_image_classification(
1868
+ self, image: ContentT, labels: List[str], *, model: Optional[str] = None
1869
+ ) -> List[ClassificationOutput]:
1870
+ """
1871
+ Provide input image and text labels to predict text labels for the image.
1872
+
1873
+ Args:
1874
+ image (`Union[str, Path, bytes, BinaryIO]`):
1875
+ The input image to caption. It can be raw bytes, an image file, or a URL to an online image.
1876
+ labels (`List[str]`):
1877
+ List of string possible labels. There must be at least 2 labels.
1878
+ model (`str`, *optional*):
1879
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
1880
+ Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
1881
+
1882
+ Returns:
1883
+ `List[Dict]`: List of classification outputs containing the predicted labels and their confidence.
1884
+
1885
+ Raises:
1886
+ [`InferenceTimeoutError`]:
1887
+ If the model is unavailable or the request times out.
1888
+ `aiohttp.ClientResponseError`:
1889
+ If the request fails with an HTTP error status code other than HTTP 503.
1890
+
1891
+ Example:
1892
+ ```py
1893
+ # Must be run in an async context
1894
+ >>> from huggingface_hub import AsyncInferenceClient
1895
+ >>> client = AsyncInferenceClient()
1896
+
1897
+ >>> await client.zero_shot_image_classification(
1898
+ ... "https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg",
1899
+ ... labels=["dog", "cat", "horse"],
1900
+ ... )
1901
+ [{"label": "dog", "score": 0.956}, ...]
1902
+ ```
1903
+ """
1904
+ # Raise ValueError if input is less than 2 labels
1905
+ if len(labels) < 2:
1906
+ raise ValueError("You must specify at least 2 classes to compare.")
1907
+
1908
+ response = await self.post(
1909
+ json={"image": _b64_encode(image), "parameters": {"candidate_labels": ",".join(labels)}},
1910
+ model=model,
1911
+ task="zero-shot-image-classification",
1912
+ )
1913
+ return _bytes_to_list(response)
1914
+
1915
+ def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
1916
+ model = model or self.model
1917
+
1918
+ # If model is already a URL, ignore `task` and return directly
1919
+ if model is not None and (model.startswith("http://") or model.startswith("https://")):
1920
+ return model
1921
+
1922
+ # # If no model but task is set => fetch the recommended one for this task
1923
+ if model is None:
1924
+ if task is None:
1925
+ raise ValueError(
1926
+ "You must specify at least a model (repo_id or URL) or a task, either when instantiating"
1927
+ " `InferenceClient` or when making a request."
1928
+ )
1929
+ model = self.get_recommended_model(task)
1930
+ logger.info(
1931
+ f"Using recommended model {model} for task {task}. Note that it is"
1932
+ f" encouraged to explicitly set `model='{model}'` as the recommended"
1933
+ " models list might get updated without prior notice."
1934
+ )
1935
+
1936
+ # Compute InferenceAPI url
1937
+ return (
1938
+ # Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks.
1939
+ f"{INFERENCE_ENDPOINT}/pipeline/{task}/{model}"
1940
+ if task in ("feature-extraction", "sentence-similarity")
1941
+ # Otherwise, we use the default endpoint
1942
+ else f"{INFERENCE_ENDPOINT}/models/{model}"
1943
+ )
1944
+
1945
+ @staticmethod
1946
+ def get_recommended_model(task: str) -> str:
1947
+ """
1948
+ Get the model Hugging Face recommends for the input task.
1949
+
1950
+ Args:
1951
+ task (`str`):
1952
+ The Hugging Face task to get which model Hugging Face recommends.
1953
+ All available tasks can be found [here](https://huggingface.co/tasks).
1954
+
1955
+ Returns:
1956
+ `str`: Name of the model recommended for the input task.
1957
+
1958
+ Raises:
1959
+ `ValueError`: If Hugging Face has no recommendation for the input task.
1960
+ """
1961
+ model = _fetch_recommended_models().get(task)
1962
+ if model is None:
1963
+ raise ValueError(
1964
+ f"Task {task} has no recommended model. Please specify a model"
1965
+ " explicitly. Visit https://huggingface.co/tasks for more info."
1966
+ )
1967
+ return model
1968
+
1969
+ async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
1970
+ """
1971
+ Get the status of a model hosted on the Inference API.
1972
+
1973
+ <Tip>
1974
+
1975
+ This endpoint is mostly useful when you already know which model you want to use and want to check its
1976
+ availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`].
1977
+
1978
+ </Tip>
1979
+
1980
+ Args:
1981
+ model (`str`, *optional*):
1982
+ Identifier of the model for witch the status gonna be checked. If model is not provided,
1983
+ the model associated with this instance of [`InferenceClient`] will be used. Only InferenceAPI service can be checked so the
1984
+ identifier cannot be a URL.
1985
+
1986
+
1987
+ Returns:
1988
+ [`ModelStatus`]: An instance of ModelStatus dataclass, containing information,
1989
+ about the state of the model: load, state, compute type and framework.
1990
+
1991
+ Example:
1992
+ ```py
1993
+ # Must be run in an async context
1994
+ >>> from huggingface_hub import AsyncInferenceClient
1995
+ >>> client = AsyncInferenceClient()
1996
+ >>> await client.get_model_status("bigcode/starcoder")
1997
+ ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference')
1998
+ ```
1999
+ """
2000
+ model = model or self.model
2001
+ if model is None:
2002
+ raise ValueError("Model id not provided.")
2003
+ if model.startswith("https://"):
2004
+ raise NotImplementedError("Model status is only available for Inference API endpoints.")
2005
+ url = f"{INFERENCE_ENDPOINT}/status/{model}"
2006
+
2007
+ async with _import_aiohttp().ClientSession(headers=self.headers) as client:
2008
+ response = await client.get(url)
2009
+ response.raise_for_status()
2010
+ response_data = await response.json()
2011
+
2012
+ if "error" in response_data:
2013
+ raise ValueError(response_data["error"])
2014
+
2015
+ return ModelStatus(
2016
+ loaded=response_data["loaded"],
2017
+ state=response_data["state"],
2018
+ compute_type=response_data["compute_type"],
2019
+ framework=response_data["framework"],
2020
+ )
lib/python3.11/site-packages/huggingface_hub/inference/_text_generation.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # Original implementation taken from the `text-generation` Python client (see https://pypi.org/project/text-generation/
17
+ # and https://github.com/huggingface/text-generation-inference/tree/main/clients/python)
18
+ #
19
+ # Changes compared to original implementation:
20
+ # - use pydantic.dataclasses instead of BaseModel
21
+ # - default to Python's dataclasses if Pydantic is not installed (same implementation but no validation)
22
+ # - added default values for all parameters (not needed in BaseModel but dataclasses yes)
23
+ # - integrated in `huggingface_hub.InferenceClient``
24
+ # - added `stream: bool` and `details: bool` in the `text_generation` method instead of having different methods for each use case
25
+ import warnings
26
+ from dataclasses import field
27
+ from enum import Enum
28
+ from typing import List, NoReturn, Optional
29
+
30
+ from requests import HTTPError
31
+
32
+ from ..utils import is_pydantic_available
33
+
34
+
35
+ if is_pydantic_available():
36
+ from pydantic import validator as pydantic_validator
37
+ from pydantic.dataclasses import dataclass
38
+
39
+ def validator(*args, **kwargs):
40
+ # Pydantic v1's `@validator` is deprecated in favor of `@field_validator`. In order to support both pydantic v1
41
+ # and v2 without changing the logic, we catch the warning message in pydantic v2 and ignore it. If we want to
42
+ # support pydantic v3 in the future, we will drop support for pydantic v1 and use `pydantic.field_validator`
43
+ # correctly.
44
+ #
45
+ # Related:
46
+ # - https://docs.pydantic.dev/latest/migration/#changes-to-validators
47
+ # - https://github.com/huggingface/huggingface_hub/pull/1837
48
+ with warnings.catch_warnings():
49
+ warnings.filterwarnings("ignore", message="Pydantic V1 style `@validator` validators are deprecated.")
50
+ return pydantic_validator(*args, **kwargs)
51
+ else:
52
+ # No validation if Pydantic is not installed
53
+ from dataclasses import dataclass # type: ignore
54
+
55
+ def validator(x): # type: ignore
56
+ return lambda y: y
57
+
58
+
59
+ @dataclass
60
+ class TextGenerationParameters:
61
+ """
62
+ Parameters for text generation.
63
+
64
+ Args:
65
+ do_sample (`bool`, *optional*):
66
+ Activate logits sampling. Defaults to False.
67
+ max_new_tokens (`int`, *optional*):
68
+ Maximum number of generated tokens. Defaults to 20.
69
+ repetition_penalty (`Optional[float]`, *optional*):
70
+ The parameter for repetition penalty. A value of 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf)
71
+ for more details. Defaults to None.
72
+ return_full_text (`bool`, *optional*):
73
+ Whether to prepend the prompt to the generated text. Defaults to False.
74
+ stop (`List[str]`, *optional*):
75
+ Stop generating tokens if a member of `stop_sequences` is generated. Defaults to an empty list.
76
+ seed (`Optional[int]`, *optional*):
77
+ Random sampling seed. Defaults to None.
78
+ temperature (`Optional[float]`, *optional*):
79
+ The value used to modulate the logits distribution. Defaults to None.
80
+ top_k (`Optional[int]`, *optional*):
81
+ The number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
82
+ top_p (`Optional[float]`, *optional*):
83
+ If set to a value less than 1, only the smallest set of most probable tokens with probabilities that add up
84
+ to `top_p` or higher are kept for generation. Defaults to None.
85
+ truncate (`Optional[int]`, *optional*):
86
+ Truncate input tokens to the given size. Defaults to None.
87
+ typical_p (`Optional[float]`, *optional*):
88
+ Typical Decoding mass. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666)
89
+ for more information. Defaults to None.
90
+ best_of (`Optional[int]`, *optional*):
91
+ Generate `best_of` sequences and return the one with the highest token logprobs. Defaults to None.
92
+ watermark (`bool`, *optional*):
93
+ Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226). Defaults to False.
94
+ details (`bool`, *optional*):
95
+ Get generation details. Defaults to False.
96
+ decoder_input_details (`bool`, *optional*):
97
+ Get decoder input token logprobs and ids. Defaults to False.
98
+ """
99
+
100
+ # Activate logits sampling
101
+ do_sample: bool = False
102
+ # Maximum number of generated tokens
103
+ max_new_tokens: int = 20
104
+ # The parameter for repetition penalty. 1.0 means no penalty.
105
+ # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
106
+ repetition_penalty: Optional[float] = None
107
+ # Whether to prepend the prompt to the generated text
108
+ return_full_text: bool = False
109
+ # Stop generating tokens if a member of `stop_sequences` is generated
110
+ stop: List[str] = field(default_factory=lambda: [])
111
+ # Random sampling seed
112
+ seed: Optional[int] = None
113
+ # The value used to module the logits distribution.
114
+ temperature: Optional[float] = None
115
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering.
116
+ top_k: Optional[int] = None
117
+ # If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
118
+ # higher are kept for generation.
119
+ top_p: Optional[float] = None
120
+ # truncate inputs tokens to the given size
121
+ truncate: Optional[int] = None
122
+ # Typical Decoding mass
123
+ # See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
124
+ typical_p: Optional[float] = None
125
+ # Generate best_of sequences and return the one if the highest token logprobs
126
+ best_of: Optional[int] = None
127
+ # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
128
+ watermark: bool = False
129
+ # Get generation details
130
+ details: bool = False
131
+ # Get decoder input token logprobs and ids
132
+ decoder_input_details: bool = False
133
+
134
+ @validator("best_of")
135
+ def valid_best_of(cls, field_value, values):
136
+ if field_value is not None:
137
+ if field_value <= 0:
138
+ raise ValueError("`best_of` must be strictly positive")
139
+ if field_value > 1 and values["seed"] is not None:
140
+ raise ValueError("`seed` must not be set when `best_of` is > 1")
141
+ sampling = (
142
+ values["do_sample"]
143
+ | (values["temperature"] is not None)
144
+ | (values["top_k"] is not None)
145
+ | (values["top_p"] is not None)
146
+ | (values["typical_p"] is not None)
147
+ )
148
+ if field_value > 1 and not sampling:
149
+ raise ValueError("you must use sampling when `best_of` is > 1")
150
+
151
+ return field_value
152
+
153
+ @validator("repetition_penalty")
154
+ def valid_repetition_penalty(cls, v):
155
+ if v is not None and v <= 0:
156
+ raise ValueError("`repetition_penalty` must be strictly positive")
157
+ return v
158
+
159
+ @validator("seed")
160
+ def valid_seed(cls, v):
161
+ if v is not None and v < 0:
162
+ raise ValueError("`seed` must be positive")
163
+ return v
164
+
165
+ @validator("temperature")
166
+ def valid_temp(cls, v):
167
+ if v is not None and v <= 0:
168
+ raise ValueError("`temperature` must be strictly positive")
169
+ return v
170
+
171
+ @validator("top_k")
172
+ def valid_top_k(cls, v):
173
+ if v is not None and v <= 0:
174
+ raise ValueError("`top_k` must be strictly positive")
175
+ return v
176
+
177
+ @validator("top_p")
178
+ def valid_top_p(cls, v):
179
+ if v is not None and (v <= 0 or v >= 1.0):
180
+ raise ValueError("`top_p` must be > 0.0 and < 1.0")
181
+ return v
182
+
183
+ @validator("truncate")
184
+ def valid_truncate(cls, v):
185
+ if v is not None and v <= 0:
186
+ raise ValueError("`truncate` must be strictly positive")
187
+ return v
188
+
189
+ @validator("typical_p")
190
+ def valid_typical_p(cls, v):
191
+ if v is not None and (v <= 0 or v >= 1.0):
192
+ raise ValueError("`typical_p` must be > 0.0 and < 1.0")
193
+ return v
194
+
195
+
196
+ @dataclass
197
+ class TextGenerationRequest:
198
+ """
199
+ Request object for text generation (only for internal use).
200
+
201
+ Args:
202
+ inputs (`str`):
203
+ The prompt for text generation.
204
+ parameters (`Optional[TextGenerationParameters]`, *optional*):
205
+ Generation parameters.
206
+ stream (`bool`, *optional*):
207
+ Whether to stream output tokens. Defaults to False.
208
+ """
209
+
210
+ # Prompt
211
+ inputs: str
212
+ # Generation parameters
213
+ parameters: Optional[TextGenerationParameters] = None
214
+ # Whether to stream output tokens
215
+ stream: bool = False
216
+
217
+ @validator("inputs")
218
+ def valid_input(cls, v):
219
+ if not v:
220
+ raise ValueError("`inputs` cannot be empty")
221
+ return v
222
+
223
+ @validator("stream")
224
+ def valid_best_of_stream(cls, field_value, values):
225
+ parameters = values["parameters"]
226
+ if parameters is not None and parameters.best_of is not None and parameters.best_of > 1 and field_value:
227
+ raise ValueError("`best_of` != 1 is not supported when `stream` == True")
228
+ return field_value
229
+
230
+ def __post_init__(self):
231
+ if not is_pydantic_available():
232
+ # If pydantic is not installed, we need to instantiate the nested dataclasses manually
233
+ if self.parameters is not None and isinstance(self.parameters, dict):
234
+ self.parameters = TextGenerationParameters(**self.parameters)
235
+
236
+
237
+ # Decoder input tokens
238
+ @dataclass
239
+ class InputToken:
240
+ """
241
+ Represents an input token.
242
+
243
+ Args:
244
+ id (`int`):
245
+ Token ID from the model tokenizer.
246
+ text (`str`):
247
+ Token text.
248
+ logprob (`float` or `None`):
249
+ Log probability of the token. Optional since the logprob of the first token cannot be computed.
250
+ """
251
+
252
+ # Token ID from the model tokenizer
253
+ id: int
254
+ # Token text
255
+ text: str
256
+ # Logprob
257
+ # Optional since the logprob of the first token cannot be computed
258
+ logprob: Optional[float] = None
259
+
260
+
261
+ # Generated tokens
262
+ @dataclass
263
+ class Token:
264
+ """
265
+ Represents a token.
266
+
267
+ Args:
268
+ id (`int`):
269
+ Token ID from the model tokenizer.
270
+ text (`str`):
271
+ Token text.
272
+ logprob (`float`):
273
+ Log probability of the token.
274
+ special (`bool`):
275
+ Indicates whether the token is a special token. It can be used to ignore
276
+ tokens when concatenating.
277
+ """
278
+
279
+ # Token ID from the model tokenizer
280
+ id: int
281
+ # Token text
282
+ text: str
283
+ # Logprob
284
+ logprob: float
285
+ # Is the token a special token
286
+ # Can be used to ignore tokens when concatenating
287
+ special: bool
288
+
289
+
290
+ # Generation finish reason
291
+ class FinishReason(str, Enum):
292
+ # number of generated tokens == `max_new_tokens`
293
+ Length = "length"
294
+ # the model generated its end of sequence token
295
+ EndOfSequenceToken = "eos_token"
296
+ # the model generated a text included in `stop_sequences`
297
+ StopSequence = "stop_sequence"
298
+
299
+
300
+ # Additional sequences when using the `best_of` parameter
301
+ @dataclass
302
+ class BestOfSequence:
303
+ """
304
+ Represents a best-of sequence generated during text generation.
305
+
306
+ Args:
307
+ generated_text (`str`):
308
+ The generated text.
309
+ finish_reason (`FinishReason`):
310
+ The reason for the generation to finish, represented by a `FinishReason` value.
311
+ generated_tokens (`int`):
312
+ The number of generated tokens in the sequence.
313
+ seed (`Optional[int]`):
314
+ The sampling seed if sampling was activated.
315
+ prefill (`List[InputToken]`):
316
+ The decoder input tokens. Empty if `decoder_input_details` is False. Defaults to an empty list.
317
+ tokens (`List[Token]`):
318
+ The generated tokens. Defaults to an empty list.
319
+ """
320
+
321
+ # Generated text
322
+ generated_text: str
323
+ # Generation finish reason
324
+ finish_reason: FinishReason
325
+ # Number of generated tokens
326
+ generated_tokens: int
327
+ # Sampling seed if sampling was activated
328
+ seed: Optional[int] = None
329
+ # Decoder input tokens, empty if decoder_input_details is False
330
+ prefill: List[InputToken] = field(default_factory=lambda: [])
331
+ # Generated tokens
332
+ tokens: List[Token] = field(default_factory=lambda: [])
333
+
334
+ def __post_init__(self):
335
+ if not is_pydantic_available():
336
+ # If pydantic is not installed, we need to instantiate the nested dataclasses manually
337
+ self.prefill = [
338
+ InputToken(**input_token) if isinstance(input_token, dict) else input_token
339
+ for input_token in self.prefill
340
+ ]
341
+ self.tokens = [Token(**token) if isinstance(token, dict) else token for token in self.tokens]
342
+
343
+
344
+ # `generate` details
345
+ @dataclass
346
+ class Details:
347
+ """
348
+ Represents details of a text generation.
349
+
350
+ Args:
351
+ finish_reason (`FinishReason`):
352
+ The reason for the generation to finish, represented by a `FinishReason` value.
353
+ generated_tokens (`int`):
354
+ The number of generated tokens.
355
+ seed (`Optional[int]`):
356
+ The sampling seed if sampling was activated.
357
+ prefill (`List[InputToken]`, *optional*):
358
+ The decoder input tokens. Empty if `decoder_input_details` is False. Defaults to an empty list.
359
+ tokens (`List[Token]`):
360
+ The generated tokens. Defaults to an empty list.
361
+ best_of_sequences (`Optional[List[BestOfSequence]]`):
362
+ Additional sequences when using the `best_of` parameter.
363
+ """
364
+
365
+ # Generation finish reason
366
+ finish_reason: FinishReason
367
+ # Number of generated tokens
368
+ generated_tokens: int
369
+ # Sampling seed if sampling was activated
370
+ seed: Optional[int] = None
371
+ # Decoder input tokens, empty if decoder_input_details is False
372
+ prefill: List[InputToken] = field(default_factory=lambda: [])
373
+ # Generated tokens
374
+ tokens: List[Token] = field(default_factory=lambda: [])
375
+ # Additional sequences when using the `best_of` parameter
376
+ best_of_sequences: Optional[List[BestOfSequence]] = None
377
+
378
+ def __post_init__(self):
379
+ if not is_pydantic_available():
380
+ # If pydantic is not installed, we need to instantiate the nested dataclasses manually
381
+ self.prefill = [
382
+ InputToken(**input_token) if isinstance(input_token, dict) else input_token
383
+ for input_token in self.prefill
384
+ ]
385
+ self.tokens = [Token(**token) if isinstance(token, dict) else token for token in self.tokens]
386
+ if self.best_of_sequences is not None:
387
+ self.best_of_sequences = [
388
+ BestOfSequence(**best_of_sequence) if isinstance(best_of_sequence, dict) else best_of_sequence
389
+ for best_of_sequence in self.best_of_sequences
390
+ ]
391
+
392
+
393
+ # `generate` return value
394
+ @dataclass
395
+ class TextGenerationResponse:
396
+ """
397
+ Represents a response for text generation.
398
+
399
+ Only returned when `details=True`, otherwise a string is returned.
400
+
401
+ Args:
402
+ generated_text (`str`):
403
+ The generated text.
404
+ details (`Optional[Details]`):
405
+ Generation details. Returned only if `details=True` is sent to the server.
406
+ """
407
+
408
+ # Generated text
409
+ generated_text: str
410
+ # Generation details
411
+ details: Optional[Details] = None
412
+
413
+ def __post_init__(self):
414
+ if not is_pydantic_available():
415
+ # If pydantic is not installed, we need to instantiate the nested dataclasses manually
416
+ if self.details is not None and isinstance(self.details, dict):
417
+ self.details = Details(**self.details)
418
+
419
+
420
+ # `generate_stream` details
421
+ @dataclass
422
+ class StreamDetails:
423
+ """
424
+ Represents details of a text generation stream.
425
+
426
+ Args:
427
+ finish_reason (`FinishReason`):
428
+ The reason for the generation to finish, represented by a `FinishReason` value.
429
+ generated_tokens (`int`):
430
+ The number of generated tokens.
431
+ seed (`Optional[int]`):
432
+ The sampling seed if sampling was activated.
433
+ """
434
+
435
+ # Generation finish reason
436
+ finish_reason: FinishReason
437
+ # Number of generated tokens
438
+ generated_tokens: int
439
+ # Sampling seed if sampling was activated
440
+ seed: Optional[int] = None
441
+
442
+
443
+ # `generate_stream` return value
444
+ @dataclass
445
+ class TextGenerationStreamResponse:
446
+ """
447
+ Represents a response for streaming text generation.
448
+
449
+ Only returned when `details=True` and `stream=True`.
450
+
451
+ Args:
452
+ token (`Token`):
453
+ The generated token.
454
+ generated_text (`Optional[str]`, *optional*):
455
+ The complete generated text. Only available when the generation is finished.
456
+ details (`Optional[StreamDetails]`, *optional*):
457
+ Generation details. Only available when the generation is finished.
458
+ """
459
+
460
+ # Generated token
461
+ token: Token
462
+ # Complete generated text
463
+ # Only available when the generation is finished
464
+ generated_text: Optional[str] = None
465
+ # Generation details
466
+ # Only available when the generation is finished
467
+ details: Optional[StreamDetails] = None
468
+
469
+ def __post_init__(self):
470
+ if not is_pydantic_available():
471
+ # If pydantic is not installed, we need to instantiate the nested dataclasses manually
472
+ if isinstance(self.token, dict):
473
+ self.token = Token(**self.token)
474
+ if self.details is not None and isinstance(self.details, dict):
475
+ self.details = StreamDetails(**self.details)
476
+
477
+
478
+ # TEXT GENERATION ERRORS
479
+ # ----------------------
480
+ # Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation
481
+ # inference project (https://github.com/huggingface/text-generation-inference).
482
+ # ----------------------
483
+
484
+
485
+ class TextGenerationError(HTTPError):
486
+ """Generic error raised if text-generation went wrong."""
487
+
488
+
489
+ # Text Generation Inference Errors
490
+ class ValidationError(TextGenerationError):
491
+ """Server-side validation error."""
492
+
493
+
494
+ class GenerationError(TextGenerationError):
495
+ pass
496
+
497
+
498
+ class OverloadedError(TextGenerationError):
499
+ pass
500
+
501
+
502
+ class IncompleteGenerationError(TextGenerationError):
503
+ pass
504
+
505
+
506
+ class UnknownError(TextGenerationError):
507
+ pass
508
+
509
+
510
+ def raise_text_generation_error(http_error: HTTPError) -> NoReturn:
511
+ """
512
+ Try to parse text-generation-inference error message and raise HTTPError in any case.
513
+
514
+ Args:
515
+ error (`HTTPError`):
516
+ The HTTPError that have been raised.
517
+ """
518
+ # Try to parse a Text Generation Inference error
519
+
520
+ try:
521
+ # Hacky way to retrieve payload in case of aiohttp error
522
+ payload = getattr(http_error, "response_error_payload", None) or http_error.response.json()
523
+ error = payload.get("error")
524
+ error_type = payload.get("error_type")
525
+ except Exception: # no payload
526
+ raise http_error
527
+
528
+ # If error_type => more information than `hf_raise_for_status`
529
+ if error_type is not None:
530
+ exception = _parse_text_generation_error(error, error_type)
531
+ raise exception from http_error
532
+
533
+ # Otherwise, fallback to default error
534
+ raise http_error
535
+
536
+
537
+ def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError:
538
+ if error_type == "generation":
539
+ return GenerationError(error) # type: ignore
540
+ if error_type == "incomplete_generation":
541
+ return IncompleteGenerationError(error) # type: ignore
542
+ if error_type == "overloaded":
543
+ return OverloadedError(error) # type: ignore
544
+ if error_type == "validation":
545
+ return ValidationError(error) # type: ignore
546
+ return UnknownError(error) # type: ignore
lib/python3.11/site-packages/huggingface_hub/inference/_types.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import TYPE_CHECKING, List, TypedDict
16
+
17
+
18
+ if TYPE_CHECKING:
19
+ from PIL import Image
20
+
21
+
22
+ class ClassificationOutput(TypedDict):
23
+ """Dictionary containing the output of a [`~InferenceClient.audio_classification`] and [`~InferenceClient.image_classification`] task.
24
+
25
+ Args:
26
+ label (`str`):
27
+ The label predicted by the model.
28
+ score (`float`):
29
+ The score of the label predicted by the model.
30
+ """
31
+
32
+ label: str
33
+ score: float
34
+
35
+
36
+ class ConversationalOutputConversation(TypedDict):
37
+ """Dictionary containing the "conversation" part of a [`~InferenceClient.conversational`] task.
38
+
39
+ Args:
40
+ generated_responses (`List[str]`):
41
+ A list of the responses from the model.
42
+ past_user_inputs (`List[str]`):
43
+ A list of the inputs from the user. Must be the same length as `generated_responses`.
44
+ """
45
+
46
+ generated_responses: List[str]
47
+ past_user_inputs: List[str]
48
+
49
+
50
+ class ConversationalOutput(TypedDict):
51
+ """Dictionary containing the output of a [`~InferenceClient.conversational`] task.
52
+
53
+ Args:
54
+ generated_text (`str`):
55
+ The last response from the model.
56
+ conversation (`ConversationalOutputConversation`):
57
+ The past conversation.
58
+ warnings (`List[str]`):
59
+ A list of warnings associated with the process.
60
+ """
61
+
62
+ conversation: ConversationalOutputConversation
63
+ generated_text: str
64
+ warnings: List[str]
65
+
66
+
67
+ class FillMaskOutput(TypedDict):
68
+ """Dictionary containing information about a [`~InferenceClient.fill_mask`] task.
69
+
70
+ Args:
71
+ score (`float`):
72
+ The probability of the token.
73
+ token (`int`):
74
+ The id of the token.
75
+ token_str (`str`):
76
+ The string representation of the token.
77
+ sequence (`str`):
78
+ The actual sequence of tokens that ran against the model (may contain special tokens).
79
+ """
80
+
81
+ score: float
82
+ token: int
83
+ token_str: str
84
+ sequence: str
85
+
86
+
87
+ class ImageSegmentationOutput(TypedDict):
88
+ """Dictionary containing information about a [`~InferenceClient.image_segmentation`] task. In practice, image segmentation returns a
89
+ list of `ImageSegmentationOutput` with 1 item per mask.
90
+
91
+ Args:
92
+ label (`str`):
93
+ The label corresponding to the mask.
94
+ mask (`Image`):
95
+ An Image object representing the mask predicted by the model.
96
+ score (`float`):
97
+ The score associated with the label for this mask.
98
+ """
99
+
100
+ label: str
101
+ mask: "Image"
102
+ score: float
103
+
104
+
105
+ class ObjectDetectionOutput(TypedDict):
106
+ """Dictionary containing information about a [`~InferenceClient.object_detection`] task.
107
+
108
+ Args:
109
+ label (`str`):
110
+ The label corresponding to the detected object.
111
+ box (`dict`):
112
+ A dict response of bounding box coordinates of
113
+ the detected object: xmin, ymin, xmax, ymax
114
+ score (`float`):
115
+ The score corresponding to the detected object.
116
+ """
117
+
118
+ label: str
119
+ box: dict
120
+ score: float
121
+
122
+
123
+ class QuestionAnsweringOutput(TypedDict):
124
+ """Dictionary containing information about a [`~InferenceClient.question_answering`] task.
125
+
126
+ Args:
127
+ score (`float`):
128
+ A float that represents how likely that the answer is correct.
129
+ start (`int`):
130
+ The index (string wise) of the start of the answer within context.
131
+ end (`int`):
132
+ The index (string wise) of the end of the answer within context.
133
+ answer (`str`):
134
+ A string that is the answer within the text.
135
+ """
136
+
137
+ score: float
138
+ start: int
139
+ end: int
140
+ answer: str
141
+
142
+
143
+ class TableQuestionAnsweringOutput(TypedDict):
144
+ """Dictionary containing information about a [`~InferenceClient.table_question_answering`] task.
145
+
146
+ Args:
147
+ answer (`str`):
148
+ The plaintext answer.
149
+ coordinates (`List[List[int]]`):
150
+ A list of coordinates of the cells referenced in the answer.
151
+ cells (`List[int]`):
152
+ A list of coordinates of the cells contents.
153
+ aggregator (`str`):
154
+ The aggregator used to get the answer.
155
+ """
156
+
157
+ answer: str
158
+ coordinates: List[List[int]]
159
+ cells: List[List[int]]
160
+ aggregator: str
161
+
162
+
163
+ class TokenClassificationOutput(TypedDict):
164
+ """Dictionary containing the output of a [`~InferenceClient.token_classification`] task.
165
+
166
+ Args:
167
+ entity_group (`str`):
168
+ The type for the entity being recognized (model specific).
169
+ score (`float`):
170
+ The score of the label predicted by the model.
171
+ word (`str`):
172
+ The string that was captured.
173
+ start (`int`):
174
+ The offset stringwise where the answer is located. Useful to disambiguate if word occurs multiple times.
175
+ end (`int`):
176
+ The offset stringwise where the answer is located. Useful to disambiguate if word occurs multiple times.
177
+ """
178
+
179
+ entity_group: str
180
+ score: float
181
+ word: str
182
+ start: int
183
+ end: int
lib/python3.11/site-packages/huggingface_hub/inference_api.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import Any, Dict, List, Optional, Union
3
+
4
+ from .constants import INFERENCE_ENDPOINT
5
+ from .hf_api import HfApi
6
+ from .utils import build_hf_headers, get_session, is_pillow_available, logging, validate_hf_hub_args
7
+ from .utils._deprecation import _deprecate_method
8
+
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ ALL_TASKS = [
14
+ # NLP
15
+ "text-classification",
16
+ "token-classification",
17
+ "table-question-answering",
18
+ "question-answering",
19
+ "zero-shot-classification",
20
+ "translation",
21
+ "summarization",
22
+ "conversational",
23
+ "feature-extraction",
24
+ "text-generation",
25
+ "text2text-generation",
26
+ "fill-mask",
27
+ "sentence-similarity",
28
+ # Audio
29
+ "text-to-speech",
30
+ "automatic-speech-recognition",
31
+ "audio-to-audio",
32
+ "audio-classification",
33
+ "voice-activity-detection",
34
+ # Computer vision
35
+ "image-classification",
36
+ "object-detection",
37
+ "image-segmentation",
38
+ "text-to-image",
39
+ "image-to-image",
40
+ # Others
41
+ "tabular-classification",
42
+ "tabular-regression",
43
+ ]
44
+
45
+
46
+ class InferenceApi:
47
+ """Client to configure requests and make calls to the HuggingFace Inference API.
48
+
49
+ Example:
50
+
51
+ ```python
52
+ >>> from huggingface_hub.inference_api import InferenceApi
53
+
54
+ >>> # Mask-fill example
55
+ >>> inference = InferenceApi("bert-base-uncased")
56
+ >>> inference(inputs="The goal of life is [MASK].")
57
+ [{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}]
58
+
59
+ >>> # Question Answering example
60
+ >>> inference = InferenceApi("deepset/roberta-base-squad2")
61
+ >>> inputs = {
62
+ ... "question": "What's my name?",
63
+ ... "context": "My name is Clara and I live in Berkeley.",
64
+ ... }
65
+ >>> inference(inputs)
66
+ {'score': 0.9326569437980652, 'start': 11, 'end': 16, 'answer': 'Clara'}
67
+
68
+ >>> # Zero-shot example
69
+ >>> inference = InferenceApi("typeform/distilbert-base-uncased-mnli")
70
+ >>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!"
71
+ >>> params = {"candidate_labels": ["refund", "legal", "faq"]}
72
+ >>> inference(inputs, params)
73
+ {'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]}
74
+
75
+ >>> # Overriding configured task
76
+ >>> inference = InferenceApi("bert-base-uncased", task="feature-extraction")
77
+
78
+ >>> # Text-to-image
79
+ >>> inference = InferenceApi("stabilityai/stable-diffusion-2-1")
80
+ >>> inference("cat")
81
+ <PIL.PngImagePlugin.PngImageFile image (...)>
82
+
83
+ >>> # Return as raw response to parse the output yourself
84
+ >>> inference = InferenceApi("mio/amadeus")
85
+ >>> response = inference("hello world", raw_response=True)
86
+ >>> response.headers
87
+ {"Content-Type": "audio/flac", ...}
88
+ >>> response.content # raw bytes from server
89
+ b'(...)'
90
+ ```
91
+ """
92
+
93
+ @validate_hf_hub_args
94
+ @_deprecate_method(
95
+ version="1.0",
96
+ message=(
97
+ "`InferenceApi` client is deprecated in favor of the more feature-complete `InferenceClient`. Check out"
98
+ " this guide to learn how to convert your script to use it:"
99
+ " https://huggingface.co/docs/huggingface_hub/guides/inference#legacy-inferenceapi-client."
100
+ ),
101
+ )
102
+ def __init__(
103
+ self,
104
+ repo_id: str,
105
+ task: Optional[str] = None,
106
+ token: Optional[str] = None,
107
+ gpu: bool = False,
108
+ ):
109
+ """Inits headers and API call information.
110
+
111
+ Args:
112
+ repo_id (``str``):
113
+ Id of repository (e.g. `user/bert-base-uncased`).
114
+ task (``str``, `optional`, defaults ``None``):
115
+ Whether to force a task instead of using task specified in the
116
+ repository.
117
+ token (`str`, `optional`):
118
+ The API token to use as HTTP bearer authorization. This is not
119
+ the authentication token. You can find the token in
120
+ https://huggingface.co/settings/token. Alternatively, you can
121
+ find both your organizations and personal API tokens using
122
+ `HfApi().whoami(token)`.
123
+ gpu (`bool`, `optional`, defaults `False`):
124
+ Whether to use GPU instead of CPU for inference(requires Startup
125
+ plan at least).
126
+ """
127
+ self.options = {"wait_for_model": True, "use_gpu": gpu}
128
+ self.headers = build_hf_headers(token=token)
129
+
130
+ # Configure task
131
+ model_info = HfApi(token=token).model_info(repo_id=repo_id)
132
+ if not model_info.pipeline_tag and not task:
133
+ raise ValueError(
134
+ "Task not specified in the repository. Please add it to the model card"
135
+ " using pipeline_tag"
136
+ " (https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined)"
137
+ )
138
+
139
+ if task and task != model_info.pipeline_tag:
140
+ if task not in ALL_TASKS:
141
+ raise ValueError(f"Invalid task {task}. Make sure it's valid.")
142
+
143
+ logger.warning(
144
+ "You're using a different task than the one specified in the"
145
+ " repository. Be sure to know what you're doing :)"
146
+ )
147
+ self.task = task
148
+ else:
149
+ assert model_info.pipeline_tag is not None, "Pipeline tag cannot be None"
150
+ self.task = model_info.pipeline_tag
151
+
152
+ self.api_url = f"{INFERENCE_ENDPOINT}/pipeline/{self.task}/{repo_id}"
153
+
154
+ def __repr__(self):
155
+ # Do not add headers to repr to avoid leaking token.
156
+ return f"InferenceAPI(api_url='{self.api_url}', task='{self.task}', options={self.options})"
157
+
158
+ def __call__(
159
+ self,
160
+ inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
161
+ params: Optional[Dict] = None,
162
+ data: Optional[bytes] = None,
163
+ raw_response: bool = False,
164
+ ) -> Any:
165
+ """Make a call to the Inference API.
166
+
167
+ Args:
168
+ inputs (`str` or `Dict` or `List[str]` or `List[List[str]]`, *optional*):
169
+ Inputs for the prediction.
170
+ params (`Dict`, *optional*):
171
+ Additional parameters for the models. Will be sent as `parameters` in the
172
+ payload.
173
+ data (`bytes`, *optional*):
174
+ Bytes content of the request. In this case, leave `inputs` and `params` empty.
175
+ raw_response (`bool`, defaults to `False`):
176
+ If `True`, the raw `Response` object is returned. You can parse its content
177
+ as preferred. By default, the content is parsed into a more practical format
178
+ (json dictionary or PIL Image for example).
179
+ """
180
+ # Build payload
181
+ payload: Dict[str, Any] = {
182
+ "options": self.options,
183
+ }
184
+ if inputs:
185
+ payload["inputs"] = inputs
186
+ if params:
187
+ payload["parameters"] = params
188
+
189
+ # Make API call
190
+ response = get_session().post(self.api_url, headers=self.headers, json=payload, data=data)
191
+
192
+ # Let the user handle the response
193
+ if raw_response:
194
+ return response
195
+
196
+ # By default, parse the response for the user.
197
+ content_type = response.headers.get("Content-Type") or ""
198
+ if content_type.startswith("image"):
199
+ if not is_pillow_available():
200
+ raise ImportError(
201
+ f"Task '{self.task}' returned as image but Pillow is not installed."
202
+ " Please install it (`pip install Pillow`) or pass"
203
+ " `raw_response=True` to get the raw `Response` object and parse"
204
+ " the image by yourself."
205
+ )
206
+
207
+ from PIL import Image
208
+
209
+ return Image.open(io.BytesIO(response.content))
210
+ elif content_type == "application/json":
211
+ return response.json()
212
+ else:
213
+ raise NotImplementedError(
214
+ f"{content_type} output type is not implemented yet. You can pass"
215
+ " `raw_response=True` to get the raw `Response` object and parse the"
216
+ " output by yourself."
217
+ )
lib/python3.11/site-packages/huggingface_hub/keras_mixin.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc as collections
2
+ import json
3
+ import os
4
+ import warnings
5
+ from pathlib import Path
6
+ from shutil import copytree
7
+ from typing import Any, Dict, List, Optional, Union
8
+
9
+ from huggingface_hub import ModelHubMixin, snapshot_download
10
+ from huggingface_hub.utils import (
11
+ get_tf_version,
12
+ is_graphviz_available,
13
+ is_pydot_available,
14
+ is_tf_available,
15
+ yaml_dump,
16
+ )
17
+
18
+ from .constants import CONFIG_NAME
19
+ from .hf_api import HfApi
20
+ from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ if is_tf_available():
26
+ import tensorflow as tf # type: ignore
27
+
28
+
29
+ def _flatten_dict(dictionary, parent_key=""):
30
+ """Flatten a nested dictionary.
31
+ Reference: https://stackoverflow.com/a/6027615/10319735
32
+
33
+ Args:
34
+ dictionary (`dict`):
35
+ The nested dictionary to be flattened.
36
+ parent_key (`str`):
37
+ The parent key to be prefixed to the children keys.
38
+ Necessary for recursing over the nested dictionary.
39
+
40
+ Returns:
41
+ The flattened dictionary.
42
+ """
43
+ items = []
44
+ for key, value in dictionary.items():
45
+ new_key = f"{parent_key}.{key}" if parent_key else key
46
+ if isinstance(value, collections.MutableMapping):
47
+ items.extend(
48
+ _flatten_dict(
49
+ value,
50
+ new_key,
51
+ ).items()
52
+ )
53
+ else:
54
+ items.append((new_key, value))
55
+ return dict(items)
56
+
57
+
58
+ def _create_hyperparameter_table(model):
59
+ """Parse hyperparameter dictionary into a markdown table."""
60
+ if model.optimizer is not None:
61
+ optimizer_params = model.optimizer.get_config()
62
+ # flatten the configuration
63
+ optimizer_params = _flatten_dict(optimizer_params)
64
+ optimizer_params["training_precision"] = tf.keras.mixed_precision.global_policy().name
65
+ table = "| Hyperparameters | Value |\n| :-- | :-- |\n"
66
+ for key, value in optimizer_params.items():
67
+ table += f"| {key} | {value} |\n"
68
+ else:
69
+ table = None
70
+ return table
71
+
72
+
73
+ def _plot_network(model, save_directory):
74
+ tf.keras.utils.plot_model(
75
+ model,
76
+ to_file=f"{save_directory}/model.png",
77
+ show_shapes=False,
78
+ show_dtype=False,
79
+ show_layer_names=True,
80
+ rankdir="TB",
81
+ expand_nested=False,
82
+ dpi=96,
83
+ layer_range=None,
84
+ )
85
+
86
+
87
+ def _create_model_card(
88
+ model,
89
+ repo_dir: Path,
90
+ plot_model: bool = True,
91
+ metadata: Optional[dict] = None,
92
+ ):
93
+ """
94
+ Creates a model card for the repository.
95
+
96
+ Do not overwrite an existing README.md file.
97
+ """
98
+ readme_path = repo_dir / "README.md"
99
+ if readme_path.exists():
100
+ return
101
+
102
+ hyperparameters = _create_hyperparameter_table(model)
103
+ if plot_model and is_graphviz_available() and is_pydot_available():
104
+ _plot_network(model, repo_dir)
105
+ if metadata is None:
106
+ metadata = {}
107
+ metadata["library_name"] = "keras"
108
+ model_card: str = "---\n"
109
+ model_card += yaml_dump(metadata, default_flow_style=False)
110
+ model_card += "---\n"
111
+ model_card += "\n## Model description\n\nMore information needed\n"
112
+ model_card += "\n## Intended uses & limitations\n\nMore information needed\n"
113
+ model_card += "\n## Training and evaluation data\n\nMore information needed\n"
114
+ if hyperparameters is not None:
115
+ model_card += "\n## Training procedure\n"
116
+ model_card += "\n### Training hyperparameters\n"
117
+ model_card += "\nThe following hyperparameters were used during training:\n\n"
118
+ model_card += hyperparameters
119
+ model_card += "\n"
120
+ if plot_model and os.path.exists(f"{repo_dir}/model.png"):
121
+ model_card += "\n ## Model Plot\n"
122
+ model_card += "\n<details>"
123
+ model_card += "\n<summary>View Model Plot</summary>\n"
124
+ path_to_plot = "./model.png"
125
+ model_card += f"\n![Model Image]({path_to_plot})\n"
126
+ model_card += "\n</details>"
127
+
128
+ readme_path.write_text(model_card)
129
+
130
+
131
+ def save_pretrained_keras(
132
+ model,
133
+ save_directory: Union[str, Path],
134
+ config: Optional[Dict[str, Any]] = None,
135
+ include_optimizer: bool = False,
136
+ plot_model: bool = True,
137
+ tags: Optional[Union[list, str]] = None,
138
+ **model_save_kwargs,
139
+ ):
140
+ """
141
+ Saves a Keras model to save_directory in SavedModel format. Use this if
142
+ you're using the Functional or Sequential APIs.
143
+
144
+ Args:
145
+ model (`Keras.Model`):
146
+ The [Keras
147
+ model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)
148
+ you'd like to save. The model must be compiled and built.
149
+ save_directory (`str` or `Path`):
150
+ Specify directory in which you want to save the Keras model.
151
+ config (`dict`, *optional*):
152
+ Configuration object to be saved alongside the model weights.
153
+ include_optimizer(`bool`, *optional*, defaults to `False`):
154
+ Whether or not to include optimizer in serialization.
155
+ plot_model (`bool`, *optional*, defaults to `True`):
156
+ Setting this to `True` will plot the model and put it in the model
157
+ card. Requires graphviz and pydot to be installed.
158
+ tags (Union[`str`,`list`], *optional*):
159
+ List of tags that are related to model or string of a single tag. See example tags
160
+ [here](https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1).
161
+ model_save_kwargs(`dict`, *optional*):
162
+ model_save_kwargs will be passed to
163
+ [`tf.keras.models.save_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model).
164
+ """
165
+ if is_tf_available():
166
+ import tensorflow as tf
167
+ else:
168
+ raise ImportError("Called a Tensorflow-specific function but could not import it.")
169
+
170
+ if not model.built:
171
+ raise ValueError("Model should be built before trying to save")
172
+
173
+ save_directory = Path(save_directory)
174
+ save_directory.mkdir(parents=True, exist_ok=True)
175
+
176
+ # saving config
177
+ if config:
178
+ if not isinstance(config, dict):
179
+ raise RuntimeError(f"Provided config to save_pretrained_keras should be a dict. Got: '{type(config)}'")
180
+
181
+ with (save_directory / CONFIG_NAME).open("w") as f:
182
+ json.dump(config, f)
183
+
184
+ metadata = {}
185
+ if isinstance(tags, list):
186
+ metadata["tags"] = tags
187
+ elif isinstance(tags, str):
188
+ metadata["tags"] = [tags]
189
+
190
+ task_name = model_save_kwargs.pop("task_name", None)
191
+ if task_name is not None:
192
+ warnings.warn(
193
+ "`task_name` input argument is deprecated. Pass `tags` instead.",
194
+ FutureWarning,
195
+ )
196
+ if "tags" in metadata:
197
+ metadata["tags"].append(task_name)
198
+ else:
199
+ metadata["tags"] = [task_name]
200
+
201
+ if model.history is not None:
202
+ if model.history.history != {}:
203
+ path = save_directory / "history.json"
204
+ if path.exists():
205
+ warnings.warn(
206
+ "`history.json` file already exists, it will be overwritten by the history of this version.",
207
+ UserWarning,
208
+ )
209
+ with path.open("w", encoding="utf-8") as f:
210
+ json.dump(model.history.history, f, indent=2, sort_keys=True)
211
+
212
+ _create_model_card(model, save_directory, plot_model, metadata)
213
+ tf.keras.models.save_model(model, save_directory, include_optimizer=include_optimizer, **model_save_kwargs)
214
+
215
+
216
+ def from_pretrained_keras(*args, **kwargs) -> "KerasModelHubMixin":
217
+ r"""
218
+ Instantiate a pretrained Keras model from a pre-trained model from the Hub.
219
+ The model is expected to be in `SavedModel` format.
220
+
221
+ Args:
222
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
223
+ Can be either:
224
+ - A string, the `model id` of a pretrained model hosted inside a
225
+ model repo on huggingface.co. Valid model ids can be located
226
+ at the root-level, like `bert-base-uncased`, or namespaced
227
+ under a user or organization name, like
228
+ `dbmdz/bert-base-german-cased`.
229
+ - You can add `revision` by appending `@` at the end of model_id
230
+ simply like this: `dbmdz/bert-base-german-cased@main` Revision
231
+ is the specific model version to use. It can be a branch name,
232
+ a tag name, or a commit id, since we use a git-based system
233
+ for storing models and other artifacts on huggingface.co, so
234
+ `revision` can be any identifier allowed by git.
235
+ - A path to a `directory` containing model weights saved using
236
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g.,
237
+ `./my_model_directory/`.
238
+ - `None` if you are both providing the configuration and state
239
+ dictionary (resp. with keyword arguments `config` and
240
+ `state_dict`).
241
+ force_download (`bool`, *optional*, defaults to `False`):
242
+ Whether to force the (re-)download of the model weights and
243
+ configuration files, overriding the cached versions if they exist.
244
+ resume_download (`bool`, *optional*, defaults to `False`):
245
+ Whether to delete incompletely received files. Will attempt to
246
+ resume the download if such a file exists.
247
+ proxies (`Dict[str, str]`, *optional*):
248
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.,
249
+ `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The
250
+ proxies are used on each request.
251
+ token (`str` or `bool`, *optional*):
252
+ The token to use as HTTP bearer authorization for remote files. If
253
+ `True`, will use the token generated when running `transformers-cli
254
+ login` (stored in `~/.huggingface`).
255
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
256
+ Path to a directory in which a downloaded pretrained model
257
+ configuration should be cached if the standard cache should not be
258
+ used.
259
+ local_files_only(`bool`, *optional*, defaults to `False`):
260
+ Whether to only look at local files (i.e., do not try to download
261
+ the model).
262
+ model_kwargs (`Dict`, *optional*):
263
+ model_kwargs will be passed to the model during initialization
264
+
265
+ <Tip>
266
+
267
+ Passing `token=True` is required when you want to use a private
268
+ model.
269
+
270
+ </Tip>
271
+ """
272
+ return KerasModelHubMixin.from_pretrained(*args, **kwargs)
273
+
274
+
275
+ @validate_hf_hub_args
276
+ def push_to_hub_keras(
277
+ model,
278
+ repo_id: str,
279
+ *,
280
+ config: Optional[dict] = None,
281
+ commit_message: str = "Push Keras model using huggingface_hub.",
282
+ private: bool = False,
283
+ api_endpoint: Optional[str] = None,
284
+ token: Optional[str] = None,
285
+ branch: Optional[str] = None,
286
+ create_pr: Optional[bool] = None,
287
+ allow_patterns: Optional[Union[List[str], str]] = None,
288
+ ignore_patterns: Optional[Union[List[str], str]] = None,
289
+ delete_patterns: Optional[Union[List[str], str]] = None,
290
+ log_dir: Optional[str] = None,
291
+ include_optimizer: bool = False,
292
+ tags: Optional[Union[list, str]] = None,
293
+ plot_model: bool = True,
294
+ **model_save_kwargs,
295
+ ):
296
+ """
297
+ Upload model checkpoint to the Hub.
298
+
299
+ Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
300
+ `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
301
+ details.
302
+
303
+ Args:
304
+ model (`Keras.Model`):
305
+ The [Keras model](`https://www.tensorflow.org/api_docs/python/tf/keras/Model`) you'd like to push to the
306
+ Hub. The model must be compiled and built.
307
+ repo_id (`str`):
308
+ ID of the repository to push to (example: `"username/my-model"`).
309
+ commit_message (`str`, *optional*, defaults to "Add Keras model"):
310
+ Message to commit while pushing.
311
+ private (`bool`, *optional*, defaults to `False`):
312
+ Whether the repository created should be private.
313
+ api_endpoint (`str`, *optional*):
314
+ The API endpoint to use when pushing the model to the hub.
315
+ token (`str`, *optional*):
316
+ The token to use as HTTP bearer authorization for remote files. If
317
+ not set, will use the token set when logging in with
318
+ `huggingface-cli login` (stored in `~/.huggingface`).
319
+ branch (`str`, *optional*):
320
+ The git branch on which to push the model. This defaults to
321
+ the default branch as specified in your repository, which
322
+ defaults to `"main"`.
323
+ create_pr (`boolean`, *optional*):
324
+ Whether or not to create a Pull Request from `branch` with that commit.
325
+ Defaults to `False`.
326
+ config (`dict`, *optional*):
327
+ Configuration object to be saved alongside the model weights.
328
+ allow_patterns (`List[str]` or `str`, *optional*):
329
+ If provided, only files matching at least one pattern are pushed.
330
+ ignore_patterns (`List[str]` or `str`, *optional*):
331
+ If provided, files matching any of the patterns are not pushed.
332
+ delete_patterns (`List[str]` or `str`, *optional*):
333
+ If provided, remote files matching any of the patterns will be deleted from the repo.
334
+ log_dir (`str`, *optional*):
335
+ TensorBoard logging directory to be pushed. The Hub automatically
336
+ hosts and displays a TensorBoard instance if log files are included
337
+ in the repository.
338
+ include_optimizer (`bool`, *optional*, defaults to `False`):
339
+ Whether or not to include optimizer during serialization.
340
+ tags (Union[`list`, `str`], *optional*):
341
+ List of tags that are related to model or string of a single tag. See example tags
342
+ [here](https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1).
343
+ plot_model (`bool`, *optional*, defaults to `True`):
344
+ Setting this to `True` will plot the model and put it in the model
345
+ card. Requires graphviz and pydot to be installed.
346
+ model_save_kwargs(`dict`, *optional*):
347
+ model_save_kwargs will be passed to
348
+ [`tf.keras.models.save_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model).
349
+
350
+ Returns:
351
+ The url of the commit of your model in the given repository.
352
+ """
353
+ api = HfApi(endpoint=api_endpoint)
354
+ repo_id = api.create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True).repo_id
355
+
356
+ # Push the files to the repo in a single commit
357
+ with SoftTemporaryDirectory() as tmp:
358
+ saved_path = Path(tmp) / repo_id
359
+ save_pretrained_keras(
360
+ model,
361
+ saved_path,
362
+ config=config,
363
+ include_optimizer=include_optimizer,
364
+ tags=tags,
365
+ plot_model=plot_model,
366
+ **model_save_kwargs,
367
+ )
368
+
369
+ # If `log_dir` provided, delete remote logs and upload new ones
370
+ if log_dir is not None:
371
+ delete_patterns = (
372
+ []
373
+ if delete_patterns is None
374
+ else (
375
+ [delete_patterns] # convert `delete_patterns` to a list
376
+ if isinstance(delete_patterns, str)
377
+ else delete_patterns
378
+ )
379
+ )
380
+ delete_patterns.append("logs/*")
381
+ copytree(log_dir, saved_path / "logs")
382
+
383
+ return api.upload_folder(
384
+ repo_type="model",
385
+ repo_id=repo_id,
386
+ folder_path=saved_path,
387
+ commit_message=commit_message,
388
+ token=token,
389
+ revision=branch,
390
+ create_pr=create_pr,
391
+ allow_patterns=allow_patterns,
392
+ ignore_patterns=ignore_patterns,
393
+ delete_patterns=delete_patterns,
394
+ )
395
+
396
+
397
+ class KerasModelHubMixin(ModelHubMixin):
398
+ """
399
+ Implementation of [`ModelHubMixin`] to provide model Hub upload/download
400
+ capabilities to Keras models.
401
+
402
+
403
+ ```python
404
+ >>> import tensorflow as tf
405
+ >>> from huggingface_hub import KerasModelHubMixin
406
+
407
+
408
+ >>> class MyModel(tf.keras.Model, KerasModelHubMixin):
409
+ ... def __init__(self, **kwargs):
410
+ ... super().__init__()
411
+ ... self.config = kwargs.pop("config", None)
412
+ ... self.dummy_inputs = ...
413
+ ... self.layer = ...
414
+
415
+ ... def call(self, *args):
416
+ ... return ...
417
+
418
+
419
+ >>> # Initialize and compile the model as you normally would
420
+ >>> model = MyModel()
421
+ >>> model.compile(...)
422
+ >>> # Build the graph by training it or passing dummy inputs
423
+ >>> _ = model(model.dummy_inputs)
424
+ >>> # Save model weights to local directory
425
+ >>> model.save_pretrained("my-awesome-model")
426
+ >>> # Push model weights to the Hub
427
+ >>> model.push_to_hub("my-awesome-model")
428
+ >>> # Download and initialize weights from the Hub
429
+ >>> model = MyModel.from_pretrained("username/super-cool-model")
430
+ ```
431
+ """
432
+
433
+ def _save_pretrained(self, save_directory):
434
+ save_pretrained_keras(self, save_directory)
435
+
436
+ @classmethod
437
+ def _from_pretrained(
438
+ cls,
439
+ model_id,
440
+ revision,
441
+ cache_dir,
442
+ force_download,
443
+ proxies,
444
+ resume_download,
445
+ local_files_only,
446
+ token,
447
+ **model_kwargs,
448
+ ):
449
+ """Here we just call [`from_pretrained_keras`] function so both the mixin and
450
+ functional APIs stay in sync.
451
+
452
+ TODO - Some args above aren't used since we are calling
453
+ snapshot_download instead of hf_hub_download.
454
+ """
455
+ if is_tf_available():
456
+ import tensorflow as tf
457
+ else:
458
+ raise ImportError("Called a TensorFlow-specific function but could not import it.")
459
+
460
+ # TODO - Figure out what to do about these config values. Config is not going to be needed to load model
461
+ cfg = model_kwargs.pop("config", None)
462
+
463
+ # Root is either a local filepath matching model_id or a cached snapshot
464
+ if not os.path.isdir(model_id):
465
+ storage_folder = snapshot_download(
466
+ repo_id=model_id,
467
+ revision=revision,
468
+ cache_dir=cache_dir,
469
+ library_name="keras",
470
+ library_version=get_tf_version(),
471
+ )
472
+ else:
473
+ storage_folder = model_id
474
+
475
+ model = tf.keras.models.load_model(storage_folder, **model_kwargs)
476
+
477
+ # For now, we add a new attribute, config, to store the config loaded from the hub/a local dir.
478
+ model.config = cfg
479
+
480
+ return model
lib/python3.11/site-packages/huggingface_hub/lfs.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2019-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Git LFS related type definitions and utilities"""
16
+ import inspect
17
+ import io
18
+ import os
19
+ import re
20
+ import warnings
21
+ from contextlib import AbstractContextManager
22
+ from dataclasses import dataclass
23
+ from math import ceil
24
+ from os.path import getsize
25
+ from pathlib import Path
26
+ from typing import TYPE_CHECKING, BinaryIO, Dict, Iterable, List, Optional, Tuple, TypedDict
27
+ from urllib.parse import unquote
28
+
29
+ from huggingface_hub.constants import ENDPOINT, HF_HUB_ENABLE_HF_TRANSFER, REPO_TYPES_URL_PREFIXES
30
+ from huggingface_hub.utils import get_session
31
+
32
+ from .utils import (
33
+ build_hf_headers,
34
+ hf_raise_for_status,
35
+ http_backoff,
36
+ logging,
37
+ tqdm,
38
+ validate_hf_hub_args,
39
+ )
40
+ from .utils.sha import sha256, sha_fileobj
41
+
42
+
43
+ if TYPE_CHECKING:
44
+ from ._commit_api import CommitOperationAdd
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ OID_REGEX = re.compile(r"^[0-9a-f]{40}$")
49
+
50
+ LFS_MULTIPART_UPLOAD_COMMAND = "lfs-multipart-upload"
51
+
52
+ LFS_HEADERS = {
53
+ "Accept": "application/vnd.git-lfs+json",
54
+ "Content-Type": "application/vnd.git-lfs+json",
55
+ }
56
+
57
+
58
+ @dataclass
59
+ class UploadInfo:
60
+ """
61
+ Dataclass holding required information to determine whether a blob
62
+ should be uploaded to the hub using the LFS protocol or the regular protocol
63
+
64
+ Args:
65
+ sha256 (`bytes`):
66
+ SHA256 hash of the blob
67
+ size (`int`):
68
+ Size in bytes of the blob
69
+ sample (`bytes`):
70
+ First 512 bytes of the blob
71
+ """
72
+
73
+ sha256: bytes
74
+ size: int
75
+ sample: bytes
76
+
77
+ @classmethod
78
+ def from_path(cls, path: str):
79
+ size = getsize(path)
80
+ with io.open(path, "rb") as file:
81
+ sample = file.peek(512)[:512]
82
+ sha = sha_fileobj(file)
83
+ return cls(size=size, sha256=sha, sample=sample)
84
+
85
+ @classmethod
86
+ def from_bytes(cls, data: bytes):
87
+ sha = sha256(data).digest()
88
+ return cls(size=len(data), sample=data[:512], sha256=sha)
89
+
90
+ @classmethod
91
+ def from_fileobj(cls, fileobj: BinaryIO):
92
+ sample = fileobj.read(512)
93
+ fileobj.seek(0, io.SEEK_SET)
94
+ sha = sha_fileobj(fileobj)
95
+ size = fileobj.tell()
96
+ fileobj.seek(0, io.SEEK_SET)
97
+ return cls(size=size, sha256=sha, sample=sample)
98
+
99
+
100
+ @validate_hf_hub_args
101
+ def post_lfs_batch_info(
102
+ upload_infos: Iterable[UploadInfo],
103
+ token: Optional[str],
104
+ repo_type: str,
105
+ repo_id: str,
106
+ revision: Optional[str] = None,
107
+ endpoint: Optional[str] = None,
108
+ ) -> Tuple[List[dict], List[dict]]:
109
+ """
110
+ Requests the LFS batch endpoint to retrieve upload instructions
111
+
112
+ Learn more: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md
113
+
114
+ Args:
115
+ upload_infos (`Iterable` of `UploadInfo`):
116
+ `UploadInfo` for the files that are being uploaded, typically obtained
117
+ from `CommitOperationAdd.upload_info`
118
+ repo_type (`str`):
119
+ Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
120
+ repo_id (`str`):
121
+ A namespace (user or an organization) and a repo name separated
122
+ by a `/`.
123
+ token (`str`, *optional*):
124
+ An authentication token ( See https://huggingface.co/settings/tokens )
125
+ revision (`str`, *optional*):
126
+ The git revision to upload to.
127
+
128
+ Returns:
129
+ `LfsBatchInfo`: 2-tuple:
130
+ - First element is the list of upload instructions from the server
131
+ - Second element is an list of errors, if any
132
+
133
+ Raises:
134
+ `ValueError`: If an argument is invalid or the server response is malformed
135
+
136
+ `HTTPError`: If the server returned an error
137
+ """
138
+ endpoint = endpoint if endpoint is not None else ENDPOINT
139
+ url_prefix = ""
140
+ if repo_type in REPO_TYPES_URL_PREFIXES:
141
+ url_prefix = REPO_TYPES_URL_PREFIXES[repo_type]
142
+ batch_url = f"{endpoint}/{url_prefix}{repo_id}.git/info/lfs/objects/batch"
143
+ payload: Dict = {
144
+ "operation": "upload",
145
+ "transfers": ["basic", "multipart"],
146
+ "objects": [
147
+ {
148
+ "oid": upload.sha256.hex(),
149
+ "size": upload.size,
150
+ }
151
+ for upload in upload_infos
152
+ ],
153
+ "hash_algo": "sha256",
154
+ }
155
+ if revision is not None:
156
+ payload["ref"] = {"name": unquote(revision)} # revision has been previously 'quoted'
157
+ headers = {**LFS_HEADERS, **build_hf_headers(token=token or True)} # Token must be provided or retrieved
158
+ resp = get_session().post(batch_url, headers=headers, json=payload)
159
+ hf_raise_for_status(resp)
160
+ batch_info = resp.json()
161
+
162
+ objects = batch_info.get("objects", None)
163
+ if not isinstance(objects, list):
164
+ raise ValueError("Malformed response from server")
165
+
166
+ return (
167
+ [_validate_batch_actions(obj) for obj in objects if "error" not in obj],
168
+ [_validate_batch_error(obj) for obj in objects if "error" in obj],
169
+ )
170
+
171
+
172
+ class PayloadPartT(TypedDict):
173
+ partNumber: int
174
+ etag: str
175
+
176
+
177
+ class CompletionPayloadT(TypedDict):
178
+ """Payload that will be sent to the Hub when uploading multi-part."""
179
+
180
+ oid: str
181
+ parts: List[PayloadPartT]
182
+
183
+
184
+ def lfs_upload(operation: "CommitOperationAdd", lfs_batch_action: Dict, token: Optional[str]) -> None:
185
+ """
186
+ Handles uploading a given object to the Hub with the LFS protocol.
187
+
188
+ Can be a No-op if the content of the file is already present on the hub large file storage.
189
+
190
+ Args:
191
+ operation (`CommitOperationAdd`):
192
+ The add operation triggering this upload.
193
+ lfs_batch_action (`dict`):
194
+ Upload instructions from the LFS batch endpoint for this object. See [`~utils.lfs.post_lfs_batch_info`] for
195
+ more details.
196
+ token (`str`, *optional*):
197
+ A [user access token](https://hf.co/settings/tokens) to authenticate requests against the Hub
198
+
199
+ Raises:
200
+ - `ValueError` if `lfs_batch_action` is improperly formatted
201
+ - `HTTPError` if the upload resulted in an error
202
+ """
203
+ # 0. If LFS file is already present, skip upload
204
+ _validate_batch_actions(lfs_batch_action)
205
+ actions = lfs_batch_action.get("actions")
206
+ if actions is None:
207
+ # The file was already uploaded
208
+ logger.debug(f"Content of file {operation.path_in_repo} is already present upstream - skipping upload")
209
+ return
210
+
211
+ # 1. Validate server response (check required keys in dict)
212
+ upload_action = lfs_batch_action["actions"]["upload"]
213
+ _validate_lfs_action(upload_action)
214
+ verify_action = lfs_batch_action["actions"].get("verify")
215
+ if verify_action is not None:
216
+ _validate_lfs_action(verify_action)
217
+
218
+ # 2. Upload file (either single part or multi-part)
219
+ header = upload_action.get("header", {})
220
+ chunk_size = header.get("chunk_size")
221
+ if chunk_size is not None:
222
+ try:
223
+ chunk_size = int(chunk_size)
224
+ except (ValueError, TypeError):
225
+ raise ValueError(
226
+ f"Malformed response from LFS batch endpoint: `chunk_size` should be an integer. Got '{chunk_size}'."
227
+ )
228
+ _upload_multi_part(operation=operation, header=header, chunk_size=chunk_size, upload_url=upload_action["href"])
229
+ else:
230
+ _upload_single_part(operation=operation, upload_url=upload_action["href"])
231
+
232
+ # 3. Verify upload went well
233
+ if verify_action is not None:
234
+ _validate_lfs_action(verify_action)
235
+ verify_resp = get_session().post(
236
+ verify_action["href"],
237
+ headers=build_hf_headers(token=token or True),
238
+ json={"oid": operation.upload_info.sha256.hex(), "size": operation.upload_info.size},
239
+ )
240
+ hf_raise_for_status(verify_resp)
241
+ logger.debug(f"{operation.path_in_repo}: Upload successful")
242
+
243
+
244
+ def _validate_lfs_action(lfs_action: dict):
245
+ """validates response from the LFS batch endpoint"""
246
+ if not (
247
+ isinstance(lfs_action.get("href"), str)
248
+ and (lfs_action.get("header") is None or isinstance(lfs_action.get("header"), dict))
249
+ ):
250
+ raise ValueError("lfs_action is improperly formatted")
251
+ return lfs_action
252
+
253
+
254
+ def _validate_batch_actions(lfs_batch_actions: dict):
255
+ """validates response from the LFS batch endpoint"""
256
+ if not (isinstance(lfs_batch_actions.get("oid"), str) and isinstance(lfs_batch_actions.get("size"), int)):
257
+ raise ValueError("lfs_batch_actions is improperly formatted")
258
+
259
+ upload_action = lfs_batch_actions.get("actions", {}).get("upload")
260
+ verify_action = lfs_batch_actions.get("actions", {}).get("verify")
261
+ if upload_action is not None:
262
+ _validate_lfs_action(upload_action)
263
+ if verify_action is not None:
264
+ _validate_lfs_action(verify_action)
265
+ return lfs_batch_actions
266
+
267
+
268
+ def _validate_batch_error(lfs_batch_error: dict):
269
+ """validates response from the LFS batch endpoint"""
270
+ if not (isinstance(lfs_batch_error.get("oid"), str) and isinstance(lfs_batch_error.get("size"), int)):
271
+ raise ValueError("lfs_batch_error is improperly formatted")
272
+ error_info = lfs_batch_error.get("error")
273
+ if not (
274
+ isinstance(error_info, dict)
275
+ and isinstance(error_info.get("message"), str)
276
+ and isinstance(error_info.get("code"), int)
277
+ ):
278
+ raise ValueError("lfs_batch_error is improperly formatted")
279
+ return lfs_batch_error
280
+
281
+
282
+ def _upload_single_part(operation: "CommitOperationAdd", upload_url: str) -> None:
283
+ """
284
+ Uploads `fileobj` as a single PUT HTTP request (basic LFS transfer protocol)
285
+
286
+ Args:
287
+ upload_url (`str`):
288
+ The URL to PUT the file to.
289
+ fileobj:
290
+ The file-like object holding the data to upload.
291
+
292
+ Returns: `requests.Response`
293
+
294
+ Raises: `requests.HTTPError` if the upload resulted in an error
295
+ """
296
+ with operation.as_file(with_tqdm=True) as fileobj:
297
+ # S3 might raise a transient 500 error -> let's retry if that happens
298
+ response = http_backoff("PUT", upload_url, data=fileobj, retry_on_status_codes=(500, 503))
299
+ hf_raise_for_status(response)
300
+
301
+
302
+ def _upload_multi_part(operation: "CommitOperationAdd", header: Dict, chunk_size: int, upload_url: str) -> None:
303
+ """
304
+ Uploads file using HF multipart LFS transfer protocol.
305
+ """
306
+ # 1. Get upload URLs for each part
307
+ sorted_parts_urls = _get_sorted_parts_urls(header=header, upload_info=operation.upload_info, chunk_size=chunk_size)
308
+
309
+ # 2. Upload parts (either with hf_transfer or in pure Python)
310
+ use_hf_transfer = HF_HUB_ENABLE_HF_TRANSFER
311
+ if (
312
+ HF_HUB_ENABLE_HF_TRANSFER
313
+ and not isinstance(operation.path_or_fileobj, str)
314
+ and not isinstance(operation.path_or_fileobj, Path)
315
+ ):
316
+ warnings.warn(
317
+ "hf_transfer is enabled but does not support uploading from bytes or BinaryIO, falling back to regular"
318
+ " upload"
319
+ )
320
+ use_hf_transfer = False
321
+
322
+ response_headers = (
323
+ _upload_parts_hf_transfer(operation=operation, sorted_parts_urls=sorted_parts_urls, chunk_size=chunk_size)
324
+ if use_hf_transfer
325
+ else _upload_parts_iteratively(operation=operation, sorted_parts_urls=sorted_parts_urls, chunk_size=chunk_size)
326
+ )
327
+
328
+ # 3. Send completion request
329
+ completion_res = get_session().post(
330
+ upload_url,
331
+ json=_get_completion_payload(response_headers, operation.upload_info.sha256.hex()),
332
+ headers=LFS_HEADERS,
333
+ )
334
+ hf_raise_for_status(completion_res)
335
+
336
+
337
+ def _get_sorted_parts_urls(header: Dict, upload_info: UploadInfo, chunk_size: int) -> List[str]:
338
+ sorted_part_upload_urls = [
339
+ upload_url
340
+ for _, upload_url in sorted(
341
+ [
342
+ (int(part_num, 10), upload_url)
343
+ for part_num, upload_url in header.items()
344
+ if part_num.isdigit() and len(part_num) > 0
345
+ ],
346
+ key=lambda t: t[0],
347
+ )
348
+ ]
349
+ num_parts = len(sorted_part_upload_urls)
350
+ if num_parts != ceil(upload_info.size / chunk_size):
351
+ raise ValueError("Invalid server response to upload large LFS file")
352
+ return sorted_part_upload_urls
353
+
354
+
355
+ def _get_completion_payload(response_headers: List[Dict], oid: str) -> CompletionPayloadT:
356
+ parts: List[PayloadPartT] = []
357
+ for part_number, header in enumerate(response_headers):
358
+ etag = header.get("etag")
359
+ if etag is None or etag == "":
360
+ raise ValueError(f"Invalid etag (`{etag}`) returned for part {part_number + 1}")
361
+ parts.append(
362
+ {
363
+ "partNumber": part_number + 1,
364
+ "etag": etag,
365
+ }
366
+ )
367
+ return {"oid": oid, "parts": parts}
368
+
369
+
370
+ def _upload_parts_iteratively(
371
+ operation: "CommitOperationAdd", sorted_parts_urls: List[str], chunk_size: int
372
+ ) -> List[Dict]:
373
+ headers = []
374
+ with operation.as_file(with_tqdm=True) as fileobj:
375
+ for part_idx, part_upload_url in enumerate(sorted_parts_urls):
376
+ with SliceFileObj(
377
+ fileobj,
378
+ seek_from=chunk_size * part_idx,
379
+ read_limit=chunk_size,
380
+ ) as fileobj_slice:
381
+ # S3 might raise a transient 500 error -> let's retry if that happens
382
+ part_upload_res = http_backoff(
383
+ "PUT", part_upload_url, data=fileobj_slice, retry_on_status_codes=(500, 503)
384
+ )
385
+ hf_raise_for_status(part_upload_res)
386
+ headers.append(part_upload_res.headers)
387
+ return headers # type: ignore
388
+
389
+
390
+ def _upload_parts_hf_transfer(
391
+ operation: "CommitOperationAdd", sorted_parts_urls: List[str], chunk_size: int
392
+ ) -> List[Dict]:
393
+ # Upload file using an external Rust-based package. Upload is faster but support less features (no progress bars).
394
+ try:
395
+ from hf_transfer import multipart_upload
396
+ except ImportError:
397
+ raise ValueError(
398
+ "Fast uploading using 'hf_transfer' is enabled (HF_HUB_ENABLE_HF_TRANSFER=1) but 'hf_transfer' package is"
399
+ " not available in your environment. Try `pip install hf_transfer`."
400
+ )
401
+
402
+ supports_callback = "callback" in inspect.signature(multipart_upload).parameters
403
+ if not supports_callback:
404
+ warnings.warn(
405
+ "You are using an outdated version of `hf_transfer`. Consider upgrading to latest version to enable progress bars using `pip install -U hf_transfer`."
406
+ )
407
+
408
+ total = operation.upload_info.size
409
+ desc = operation.path_in_repo
410
+ if len(desc) > 40:
411
+ desc = f"(…){desc[-40:]}"
412
+ disable = bool(logger.getEffectiveLevel() == logging.NOTSET)
413
+
414
+ with tqdm(unit="B", unit_scale=True, total=total, initial=0, desc=desc, disable=disable) as progress:
415
+ try:
416
+ output = multipart_upload(
417
+ file_path=operation.path_or_fileobj,
418
+ parts_urls=sorted_parts_urls,
419
+ chunk_size=chunk_size,
420
+ max_files=128,
421
+ parallel_failures=127, # could be removed
422
+ max_retries=5,
423
+ **({"callback": progress.update} if supports_callback else {}),
424
+ )
425
+ except Exception as e:
426
+ raise RuntimeError(
427
+ "An error occurred while uploading using `hf_transfer`. Consider disabling HF_HUB_ENABLE_HF_TRANSFER for"
428
+ " better error handling."
429
+ ) from e
430
+ if not supports_callback:
431
+ progress.update(total)
432
+ return output
433
+
434
+
435
+ class SliceFileObj(AbstractContextManager):
436
+ """
437
+ Utility context manager to read a *slice* of a seekable file-like object as a seekable, file-like object.
438
+
439
+ This is NOT thread safe
440
+
441
+ Inspired by stackoverflow.com/a/29838711/593036
442
+
443
+ Credits to @julien-c
444
+
445
+ Args:
446
+ fileobj (`BinaryIO`):
447
+ A file-like object to slice. MUST implement `tell()` and `seek()` (and `read()` of course).
448
+ `fileobj` will be reset to its original position when exiting the context manager.
449
+ seek_from (`int`):
450
+ The start of the slice (offset from position 0 in bytes).
451
+ read_limit (`int`):
452
+ The maximum number of bytes to read from the slice.
453
+
454
+ Attributes:
455
+ previous_position (`int`):
456
+ The previous position
457
+
458
+ Examples:
459
+
460
+ Reading 200 bytes with an offset of 128 bytes from a file (ie bytes 128 to 327):
461
+ ```python
462
+ >>> with open("path/to/file", "rb") as file:
463
+ ... with SliceFileObj(file, seek_from=128, read_limit=200) as fslice:
464
+ ... fslice.read(...)
465
+ ```
466
+
467
+ Reading a file in chunks of 512 bytes
468
+ ```python
469
+ >>> import os
470
+ >>> chunk_size = 512
471
+ >>> file_size = os.getsize("path/to/file")
472
+ >>> with open("path/to/file", "rb") as file:
473
+ ... for chunk_idx in range(ceil(file_size / chunk_size)):
474
+ ... with SliceFileObj(file, seek_from=chunk_idx * chunk_size, read_limit=chunk_size) as fslice:
475
+ ... chunk = fslice.read(...)
476
+
477
+ ```
478
+ """
479
+
480
+ def __init__(self, fileobj: BinaryIO, seek_from: int, read_limit: int):
481
+ self.fileobj = fileobj
482
+ self.seek_from = seek_from
483
+ self.read_limit = read_limit
484
+
485
+ def __enter__(self):
486
+ self._previous_position = self.fileobj.tell()
487
+ end_of_stream = self.fileobj.seek(0, os.SEEK_END)
488
+ self._len = min(self.read_limit, end_of_stream - self.seek_from)
489
+ # ^^ The actual number of bytes that can be read from the slice
490
+ self.fileobj.seek(self.seek_from, io.SEEK_SET)
491
+ return self
492
+
493
+ def __exit__(self, exc_type, exc_value, traceback):
494
+ self.fileobj.seek(self._previous_position, io.SEEK_SET)
495
+
496
+ def read(self, n: int = -1):
497
+ pos = self.tell()
498
+ if pos >= self._len:
499
+ return b""
500
+ remaining_amount = self._len - pos
501
+ data = self.fileobj.read(remaining_amount if n < 0 else min(n, remaining_amount))
502
+ return data
503
+
504
+ def tell(self) -> int:
505
+ return self.fileobj.tell() - self.seek_from
506
+
507
+ def seek(self, offset: int, whence: int = os.SEEK_SET) -> int:
508
+ start = self.seek_from
509
+ end = start + self._len
510
+ if whence in (os.SEEK_SET, os.SEEK_END):
511
+ offset = start + offset if whence == os.SEEK_SET else end + offset
512
+ offset = max(start, min(offset, end))
513
+ whence = os.SEEK_SET
514
+ elif whence == os.SEEK_CUR:
515
+ cur_pos = self.fileobj.tell()
516
+ offset = max(start - cur_pos, min(offset, end - cur_pos))
517
+ else:
518
+ raise ValueError(f"whence value {whence} is not supported")
519
+ return self.fileobj.seek(offset, whence) - self.seek_from
520
+
521
+ def __iter__(self):
522
+ yield self.read(n=4 * 1024 * 1024)
lib/python3.11/site-packages/huggingface_hub/repocard.py ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Any, Dict, Literal, Optional, Type, Union
6
+
7
+ import requests
8
+ import yaml
9
+
10
+ from huggingface_hub.file_download import hf_hub_download
11
+ from huggingface_hub.hf_api import upload_file
12
+ from huggingface_hub.repocard_data import (
13
+ CardData,
14
+ DatasetCardData,
15
+ EvalResult,
16
+ ModelCardData,
17
+ SpaceCardData,
18
+ eval_results_to_model_index,
19
+ model_index_to_eval_results,
20
+ )
21
+ from huggingface_hub.utils import get_session, is_jinja_available, yaml_dump
22
+
23
+ from .constants import REPOCARD_NAME
24
+ from .utils import EntryNotFoundError, SoftTemporaryDirectory, validate_hf_hub_args
25
+
26
+
27
+ TEMPLATE_MODELCARD_PATH = Path(__file__).parent / "templates" / "modelcard_template.md"
28
+ TEMPLATE_DATASETCARD_PATH = Path(__file__).parent / "templates" / "datasetcard_template.md"
29
+
30
+ # exact same regex as in the Hub server. Please keep in sync.
31
+ # See https://github.com/huggingface/moon-landing/blob/main/server/lib/ViewMarkdown.ts#L18
32
+ REGEX_YAML_BLOCK = re.compile(r"^(\s*---[\r\n]+)([\S\s]*?)([\r\n]+---(\r\n|\n|$))")
33
+
34
+
35
+ class RepoCard:
36
+ card_data_class = CardData
37
+ default_template_path = TEMPLATE_MODELCARD_PATH
38
+ repo_type = "model"
39
+
40
+ def __init__(self, content: str, ignore_metadata_errors: bool = False):
41
+ """Initialize a RepoCard from string content. The content should be a
42
+ Markdown file with a YAML block at the beginning and a Markdown body.
43
+
44
+ Args:
45
+ content (`str`): The content of the Markdown file.
46
+
47
+ Example:
48
+ ```python
49
+ >>> from huggingface_hub.repocard import RepoCard
50
+ >>> text = '''
51
+ ... ---
52
+ ... language: en
53
+ ... license: mit
54
+ ... ---
55
+ ...
56
+ ... # My repo
57
+ ... '''
58
+ >>> card = RepoCard(text)
59
+ >>> card.data.to_dict()
60
+ {'language': 'en', 'license': 'mit'}
61
+ >>> card.text
62
+ '\\n# My repo\\n'
63
+
64
+ ```
65
+ <Tip>
66
+ Raises the following error:
67
+
68
+ - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
69
+ when the content of the repo card metadata is not a dictionary.
70
+
71
+ </Tip>
72
+ """
73
+
74
+ # Set the content of the RepoCard, as well as underlying .data and .text attributes.
75
+ # See the `content` property setter for more details.
76
+ self.ignore_metadata_errors = ignore_metadata_errors
77
+ self.content = content
78
+
79
+ @property
80
+ def content(self):
81
+ """The content of the RepoCard, including the YAML block and the Markdown body."""
82
+ line_break = _detect_line_ending(self._content) or "\n"
83
+ return f"---{line_break}{self.data.to_yaml(line_break=line_break)}{line_break}---{line_break}{self.text}"
84
+
85
+ @content.setter
86
+ def content(self, content: str):
87
+ """Set the content of the RepoCard."""
88
+ self._content = content
89
+
90
+ match = REGEX_YAML_BLOCK.search(content)
91
+ if match:
92
+ # Metadata found in the YAML block
93
+ yaml_block = match.group(2)
94
+ self.text = content[match.end() :]
95
+ data_dict = yaml.safe_load(yaml_block)
96
+
97
+ if data_dict is None:
98
+ data_dict = {}
99
+
100
+ # The YAML block's data should be a dictionary
101
+ if not isinstance(data_dict, dict):
102
+ raise ValueError("repo card metadata block should be a dict")
103
+ else:
104
+ # Model card without metadata... create empty metadata
105
+ warnings.warn("Repo card metadata block was not found. Setting CardData to empty.")
106
+ data_dict = {}
107
+ self.text = content
108
+
109
+ self.data = self.card_data_class(**data_dict, ignore_metadata_errors=self.ignore_metadata_errors)
110
+
111
+ def __str__(self):
112
+ return self.content
113
+
114
+ def save(self, filepath: Union[Path, str]):
115
+ r"""Save a RepoCard to a file.
116
+
117
+ Args:
118
+ filepath (`Union[Path, str]`): Filepath to the markdown file to save.
119
+
120
+ Example:
121
+ ```python
122
+ >>> from huggingface_hub.repocard import RepoCard
123
+ >>> card = RepoCard("---\nlanguage: en\n---\n# This is a test repo card")
124
+ >>> card.save("/tmp/test.md")
125
+
126
+ ```
127
+ """
128
+ filepath = Path(filepath)
129
+ filepath.parent.mkdir(parents=True, exist_ok=True)
130
+ # Preserve newlines as in the existing file.
131
+ with open(filepath, mode="w", newline="", encoding="utf-8") as f:
132
+ f.write(str(self))
133
+
134
+ @classmethod
135
+ def load(
136
+ cls,
137
+ repo_id_or_path: Union[str, Path],
138
+ repo_type: Optional[str] = None,
139
+ token: Optional[str] = None,
140
+ ignore_metadata_errors: bool = False,
141
+ ):
142
+ """Initialize a RepoCard from a Hugging Face Hub repo's README.md or a local filepath.
143
+
144
+ Args:
145
+ repo_id_or_path (`Union[str, Path]`):
146
+ The repo ID associated with a Hugging Face Hub repo or a local filepath.
147
+ repo_type (`str`, *optional*):
148
+ The type of Hugging Face repo to push to. Defaults to None, which will use use "model". Other options
149
+ are "dataset" and "space". Not used when loading from a local filepath. If this is called from a child
150
+ class, the default value will be the child class's `repo_type`.
151
+ token (`str`, *optional*):
152
+ Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token.
153
+ ignore_metadata_errors (`str`):
154
+ If True, errors while parsing the metadata section will be ignored. Some information might be lost during
155
+ the process. Use it at your own risk.
156
+
157
+ Returns:
158
+ [`huggingface_hub.repocard.RepoCard`]: The RepoCard (or subclass) initialized from the repo's
159
+ README.md file or filepath.
160
+
161
+ Example:
162
+ ```python
163
+ >>> from huggingface_hub.repocard import RepoCard
164
+ >>> card = RepoCard.load("nateraw/food")
165
+ >>> assert card.data.tags == ["generated_from_trainer", "image-classification", "pytorch"]
166
+
167
+ ```
168
+ """
169
+
170
+ if Path(repo_id_or_path).exists():
171
+ card_path = Path(repo_id_or_path)
172
+ elif isinstance(repo_id_or_path, str):
173
+ card_path = Path(
174
+ hf_hub_download(
175
+ repo_id_or_path,
176
+ REPOCARD_NAME,
177
+ repo_type=repo_type or cls.repo_type,
178
+ token=token,
179
+ )
180
+ )
181
+ else:
182
+ raise ValueError(f"Cannot load RepoCard: path not found on disk ({repo_id_or_path}).")
183
+
184
+ # Preserve newlines in the existing file.
185
+ with card_path.open(mode="r", newline="", encoding="utf-8") as f:
186
+ return cls(f.read(), ignore_metadata_errors=ignore_metadata_errors)
187
+
188
+ def validate(self, repo_type: Optional[str] = None):
189
+ """Validates card against Hugging Face Hub's card validation logic.
190
+ Using this function requires access to the internet, so it is only called
191
+ internally by [`huggingface_hub.repocard.RepoCard.push_to_hub`].
192
+
193
+ Args:
194
+ repo_type (`str`, *optional*, defaults to "model"):
195
+ The type of Hugging Face repo to push to. Options are "model", "dataset", and "space".
196
+ If this function is called from a child class, the default will be the child class's `repo_type`.
197
+
198
+ <Tip>
199
+ Raises the following errors:
200
+
201
+ - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
202
+ if the card fails validation checks.
203
+ - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError)
204
+ if the request to the Hub API fails for any other reason.
205
+
206
+ </Tip>
207
+ """
208
+
209
+ # If repo type is provided, otherwise, use the repo type of the card.
210
+ repo_type = repo_type or self.repo_type
211
+
212
+ body = {
213
+ "repoType": repo_type,
214
+ "content": str(self),
215
+ }
216
+ headers = {"Accept": "text/plain"}
217
+
218
+ try:
219
+ r = get_session().post("https://huggingface.co/api/validate-yaml", body, headers=headers)
220
+ r.raise_for_status()
221
+ except requests.exceptions.HTTPError as exc:
222
+ if r.status_code == 400:
223
+ raise ValueError(r.text)
224
+ else:
225
+ raise exc
226
+
227
+ def push_to_hub(
228
+ self,
229
+ repo_id: str,
230
+ token: Optional[str] = None,
231
+ repo_type: Optional[str] = None,
232
+ commit_message: Optional[str] = None,
233
+ commit_description: Optional[str] = None,
234
+ revision: Optional[str] = None,
235
+ create_pr: Optional[bool] = None,
236
+ parent_commit: Optional[str] = None,
237
+ ):
238
+ """Push a RepoCard to a Hugging Face Hub repo.
239
+
240
+ Args:
241
+ repo_id (`str`):
242
+ The repo ID of the Hugging Face Hub repo to push to. Example: "nateraw/food".
243
+ token (`str`, *optional*):
244
+ Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to
245
+ the stored token.
246
+ repo_type (`str`, *optional*, defaults to "model"):
247
+ The type of Hugging Face repo to push to. Options are "model", "dataset", and "space". If this
248
+ function is called by a child class, it will default to the child class's `repo_type`.
249
+ commit_message (`str`, *optional*):
250
+ The summary / title / first line of the generated commit.
251
+ commit_description (`str`, *optional*)
252
+ The description of the generated commit.
253
+ revision (`str`, *optional*):
254
+ The git revision to commit from. Defaults to the head of the `"main"` branch.
255
+ create_pr (`bool`, *optional*):
256
+ Whether or not to create a Pull Request with this commit. Defaults to `False`.
257
+ parent_commit (`str`, *optional*):
258
+ The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported.
259
+ If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`.
260
+ If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`.
261
+ Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be
262
+ especially useful if the repo is updated / committed to concurrently.
263
+ Returns:
264
+ `str`: URL of the commit which updated the card metadata.
265
+ """
266
+
267
+ # If repo type is provided, otherwise, use the repo type of the card.
268
+ repo_type = repo_type or self.repo_type
269
+
270
+ # Validate card before pushing to hub
271
+ self.validate(repo_type=repo_type)
272
+
273
+ with SoftTemporaryDirectory() as tmpdir:
274
+ tmp_path = Path(tmpdir) / REPOCARD_NAME
275
+ tmp_path.write_text(str(self))
276
+ url = upload_file(
277
+ path_or_fileobj=str(tmp_path),
278
+ path_in_repo=REPOCARD_NAME,
279
+ repo_id=repo_id,
280
+ token=token,
281
+ repo_type=repo_type,
282
+ commit_message=commit_message,
283
+ commit_description=commit_description,
284
+ create_pr=create_pr,
285
+ revision=revision,
286
+ parent_commit=parent_commit,
287
+ )
288
+ return url
289
+
290
+ @classmethod
291
+ def from_template(
292
+ cls,
293
+ card_data: CardData,
294
+ template_path: Optional[str] = None,
295
+ **template_kwargs,
296
+ ):
297
+ """Initialize a RepoCard from a template. By default, it uses the default template.
298
+
299
+ Templates are Jinja2 templates that can be customized by passing keyword arguments.
300
+
301
+ Args:
302
+ card_data (`huggingface_hub.CardData`):
303
+ A huggingface_hub.CardData instance containing the metadata you want to include in the YAML
304
+ header of the repo card on the Hugging Face Hub.
305
+ template_path (`str`, *optional*):
306
+ A path to a markdown file with optional Jinja template variables that can be filled
307
+ in with `template_kwargs`. Defaults to the default template.
308
+
309
+ Returns:
310
+ [`huggingface_hub.repocard.RepoCard`]: A RepoCard instance with the specified card data and content from the
311
+ template.
312
+ """
313
+ if is_jinja_available():
314
+ import jinja2
315
+ else:
316
+ raise ImportError(
317
+ "Using RepoCard.from_template requires Jinja2 to be installed. Please"
318
+ " install it with `pip install Jinja2`."
319
+ )
320
+
321
+ kwargs = card_data.to_dict().copy()
322
+ kwargs.update(template_kwargs) # Template_kwargs have priority
323
+ template = jinja2.Template(Path(template_path or cls.default_template_path).read_text())
324
+ content = template.render(card_data=card_data.to_yaml(), **kwargs)
325
+ return cls(content)
326
+
327
+
328
+ class ModelCard(RepoCard):
329
+ card_data_class = ModelCardData
330
+ default_template_path = TEMPLATE_MODELCARD_PATH
331
+ repo_type = "model"
332
+
333
+ @classmethod
334
+ def from_template( # type: ignore # violates Liskov property but easier to use
335
+ cls,
336
+ card_data: ModelCardData,
337
+ template_path: Optional[str] = None,
338
+ **template_kwargs,
339
+ ):
340
+ """Initialize a ModelCard from a template. By default, it uses the default template, which can be found here:
341
+ https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md
342
+
343
+ Templates are Jinja2 templates that can be customized by passing keyword arguments.
344
+
345
+ Args:
346
+ card_data (`huggingface_hub.ModelCardData`):
347
+ A huggingface_hub.ModelCardData instance containing the metadata you want to include in the YAML
348
+ header of the model card on the Hugging Face Hub.
349
+ template_path (`str`, *optional*):
350
+ A path to a markdown file with optional Jinja template variables that can be filled
351
+ in with `template_kwargs`. Defaults to the default template.
352
+
353
+ Returns:
354
+ [`huggingface_hub.ModelCard`]: A ModelCard instance with the specified card data and content from the
355
+ template.
356
+
357
+ Example:
358
+ ```python
359
+ >>> from huggingface_hub import ModelCard, ModelCardData, EvalResult
360
+
361
+ >>> # Using the Default Template
362
+ >>> card_data = ModelCardData(
363
+ ... language='en',
364
+ ... license='mit',
365
+ ... library_name='timm',
366
+ ... tags=['image-classification', 'resnet'],
367
+ ... datasets=['beans'],
368
+ ... metrics=['accuracy'],
369
+ ... )
370
+ >>> card = ModelCard.from_template(
371
+ ... card_data,
372
+ ... model_description='This model does x + y...'
373
+ ... )
374
+
375
+ >>> # Including Evaluation Results
376
+ >>> card_data = ModelCardData(
377
+ ... language='en',
378
+ ... tags=['image-classification', 'resnet'],
379
+ ... eval_results=[
380
+ ... EvalResult(
381
+ ... task_type='image-classification',
382
+ ... dataset_type='beans',
383
+ ... dataset_name='Beans',
384
+ ... metric_type='accuracy',
385
+ ... metric_value=0.9,
386
+ ... ),
387
+ ... ],
388
+ ... model_name='my-cool-model',
389
+ ... )
390
+ >>> card = ModelCard.from_template(card_data)
391
+
392
+ >>> # Using a Custom Template
393
+ >>> card_data = ModelCardData(
394
+ ... language='en',
395
+ ... tags=['image-classification', 'resnet']
396
+ ... )
397
+ >>> card = ModelCard.from_template(
398
+ ... card_data=card_data,
399
+ ... template_path='./src/huggingface_hub/templates/modelcard_template.md',
400
+ ... custom_template_var='custom value', # will be replaced in template if it exists
401
+ ... )
402
+
403
+ ```
404
+ """
405
+ return super().from_template(card_data, template_path, **template_kwargs)
406
+
407
+
408
+ class DatasetCard(RepoCard):
409
+ card_data_class = DatasetCardData
410
+ default_template_path = TEMPLATE_DATASETCARD_PATH
411
+ repo_type = "dataset"
412
+
413
+ @classmethod
414
+ def from_template( # type: ignore # violates Liskov property but easier to use
415
+ cls,
416
+ card_data: DatasetCardData,
417
+ template_path: Optional[str] = None,
418
+ **template_kwargs,
419
+ ):
420
+ """Initialize a DatasetCard from a template. By default, it uses the default template, which can be found here:
421
+ https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/datasetcard_template.md
422
+
423
+ Templates are Jinja2 templates that can be customized by passing keyword arguments.
424
+
425
+ Args:
426
+ card_data (`huggingface_hub.DatasetCardData`):
427
+ A huggingface_hub.DatasetCardData instance containing the metadata you want to include in the YAML
428
+ header of the dataset card on the Hugging Face Hub.
429
+ template_path (`str`, *optional*):
430
+ A path to a markdown file with optional Jinja template variables that can be filled
431
+ in with `template_kwargs`. Defaults to the default template.
432
+
433
+ Returns:
434
+ [`huggingface_hub.DatasetCard`]: A DatasetCard instance with the specified card data and content from the
435
+ template.
436
+
437
+ Example:
438
+ ```python
439
+ >>> from huggingface_hub import DatasetCard, DatasetCardData
440
+
441
+ >>> # Using the Default Template
442
+ >>> card_data = DatasetCardData(
443
+ ... language='en',
444
+ ... license='mit',
445
+ ... annotations_creators='crowdsourced',
446
+ ... task_categories=['text-classification'],
447
+ ... task_ids=['sentiment-classification', 'text-scoring'],
448
+ ... multilinguality='monolingual',
449
+ ... pretty_name='My Text Classification Dataset',
450
+ ... )
451
+ >>> card = DatasetCard.from_template(
452
+ ... card_data,
453
+ ... pretty_name=card_data.pretty_name,
454
+ ... )
455
+
456
+ >>> # Using a Custom Template
457
+ >>> card_data = DatasetCardData(
458
+ ... language='en',
459
+ ... license='mit',
460
+ ... )
461
+ >>> card = DatasetCard.from_template(
462
+ ... card_data=card_data,
463
+ ... template_path='./src/huggingface_hub/templates/datasetcard_template.md',
464
+ ... custom_template_var='custom value', # will be replaced in template if it exists
465
+ ... )
466
+
467
+ ```
468
+ """
469
+ return super().from_template(card_data, template_path, **template_kwargs)
470
+
471
+
472
+ class SpaceCard(RepoCard):
473
+ card_data_class = SpaceCardData
474
+ default_template_path = TEMPLATE_MODELCARD_PATH
475
+ repo_type = "space"
476
+
477
+
478
+ def _detect_line_ending(content: str) -> Literal["\r", "\n", "\r\n", None]: # noqa: F722
479
+ """Detect the line ending of a string. Used by RepoCard to avoid making huge diff on newlines.
480
+
481
+ Uses same implementation as in Hub server, keep it in sync.
482
+
483
+ Returns:
484
+ str: The detected line ending of the string.
485
+ """
486
+ cr = content.count("\r")
487
+ lf = content.count("\n")
488
+ crlf = content.count("\r\n")
489
+ if cr + lf == 0:
490
+ return None
491
+ if crlf == cr and crlf == lf:
492
+ return "\r\n"
493
+ if cr > lf:
494
+ return "\r"
495
+ else:
496
+ return "\n"
497
+
498
+
499
+ def metadata_load(local_path: Union[str, Path]) -> Optional[Dict]:
500
+ content = Path(local_path).read_text()
501
+ match = REGEX_YAML_BLOCK.search(content)
502
+ if match:
503
+ yaml_block = match.group(2)
504
+ data = yaml.safe_load(yaml_block)
505
+ if data is None or isinstance(data, dict):
506
+ return data
507
+ raise ValueError("repo card metadata block should be a dict")
508
+ else:
509
+ return None
510
+
511
+
512
+ def metadata_save(local_path: Union[str, Path], data: Dict) -> None:
513
+ """
514
+ Save the metadata dict in the upper YAML part Trying to preserve newlines as
515
+ in the existing file. Docs about open() with newline="" parameter:
516
+ https://docs.python.org/3/library/functions.html?highlight=open#open Does
517
+ not work with "^M" linebreaks, which are replaced by \n
518
+ """
519
+ line_break = "\n"
520
+ content = ""
521
+ # try to detect existing newline character
522
+ if os.path.exists(local_path):
523
+ with open(local_path, "r", newline="", encoding="utf8") as readme:
524
+ content = readme.read()
525
+ if isinstance(readme.newlines, tuple):
526
+ line_break = readme.newlines[0]
527
+ elif isinstance(readme.newlines, str):
528
+ line_break = readme.newlines
529
+
530
+ # creates a new file if it not
531
+ with open(local_path, "w", newline="", encoding="utf8") as readme:
532
+ data_yaml = yaml_dump(data, sort_keys=False, line_break=line_break)
533
+ # sort_keys: keep dict order
534
+ match = REGEX_YAML_BLOCK.search(content)
535
+ if match:
536
+ output = content[: match.start()] + f"---{line_break}{data_yaml}---{line_break}" + content[match.end() :]
537
+ else:
538
+ output = f"---{line_break}{data_yaml}---{line_break}{content}"
539
+
540
+ readme.write(output)
541
+ readme.close()
542
+
543
+
544
+ def metadata_eval_result(
545
+ *,
546
+ model_pretty_name: str,
547
+ task_pretty_name: str,
548
+ task_id: str,
549
+ metrics_pretty_name: str,
550
+ metrics_id: str,
551
+ metrics_value: Any,
552
+ dataset_pretty_name: str,
553
+ dataset_id: str,
554
+ metrics_config: Optional[str] = None,
555
+ metrics_verified: bool = False,
556
+ dataset_config: Optional[str] = None,
557
+ dataset_split: Optional[str] = None,
558
+ dataset_revision: Optional[str] = None,
559
+ metrics_verification_token: Optional[str] = None,
560
+ ) -> Dict:
561
+ """
562
+ Creates a metadata dict with the result from a model evaluated on a dataset.
563
+
564
+ Args:
565
+ model_pretty_name (`str`):
566
+ The name of the model in natural language.
567
+ task_pretty_name (`str`):
568
+ The name of a task in natural language.
569
+ task_id (`str`):
570
+ Example: automatic-speech-recognition. A task id.
571
+ metrics_pretty_name (`str`):
572
+ A name for the metric in natural language. Example: Test WER.
573
+ metrics_id (`str`):
574
+ Example: wer. A metric id from https://hf.co/metrics.
575
+ metrics_value (`Any`):
576
+ The value from the metric. Example: 20.0 or "20.0 ± 1.2".
577
+ dataset_pretty_name (`str`):
578
+ The name of the dataset in natural language.
579
+ dataset_id (`str`):
580
+ Example: common_voice. A dataset id from https://hf.co/datasets.
581
+ metrics_config (`str`, *optional*):
582
+ The name of the metric configuration used in `load_metric()`.
583
+ Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`.
584
+ metrics_verified (`bool`, *optional*, defaults to `False`):
585
+ Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set.
586
+ dataset_config (`str`, *optional*):
587
+ Example: fr. The name of the dataset configuration used in `load_dataset()`.
588
+ dataset_split (`str`, *optional*):
589
+ Example: test. The name of the dataset split used in `load_dataset()`.
590
+ dataset_revision (`str`, *optional*):
591
+ Example: 5503434ddd753f426f4b38109466949a1217c2bb. The name of the dataset dataset revision
592
+ used in `load_dataset()`.
593
+ metrics_verification_token (`bool`, *optional*):
594
+ A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not.
595
+
596
+ Returns:
597
+ `dict`: a metadata dict with the result from a model evaluated on a dataset.
598
+
599
+ Example:
600
+ ```python
601
+ >>> from huggingface_hub import metadata_eval_result
602
+ >>> results = metadata_eval_result(
603
+ ... model_pretty_name="RoBERTa fine-tuned on ReactionGIF",
604
+ ... task_pretty_name="Text Classification",
605
+ ... task_id="text-classification",
606
+ ... metrics_pretty_name="Accuracy",
607
+ ... metrics_id="accuracy",
608
+ ... metrics_value=0.2662102282047272,
609
+ ... dataset_pretty_name="ReactionJPEG",
610
+ ... dataset_id="julien-c/reactionjpeg",
611
+ ... dataset_config="default",
612
+ ... dataset_split="test",
613
+ ... )
614
+ >>> results == {
615
+ ... 'model-index': [
616
+ ... {
617
+ ... 'name': 'RoBERTa fine-tuned on ReactionGIF',
618
+ ... 'results': [
619
+ ... {
620
+ ... 'task': {
621
+ ... 'type': 'text-classification',
622
+ ... 'name': 'Text Classification'
623
+ ... },
624
+ ... 'dataset': {
625
+ ... 'name': 'ReactionJPEG',
626
+ ... 'type': 'julien-c/reactionjpeg',
627
+ ... 'config': 'default',
628
+ ... 'split': 'test'
629
+ ... },
630
+ ... 'metrics': [
631
+ ... {
632
+ ... 'type': 'accuracy',
633
+ ... 'value': 0.2662102282047272,
634
+ ... 'name': 'Accuracy',
635
+ ... 'verified': False
636
+ ... }
637
+ ... ]
638
+ ... }
639
+ ... ]
640
+ ... }
641
+ ... ]
642
+ ... }
643
+ True
644
+
645
+ ```
646
+ """
647
+
648
+ return {
649
+ "model-index": eval_results_to_model_index(
650
+ model_name=model_pretty_name,
651
+ eval_results=[
652
+ EvalResult(
653
+ task_name=task_pretty_name,
654
+ task_type=task_id,
655
+ metric_name=metrics_pretty_name,
656
+ metric_type=metrics_id,
657
+ metric_value=metrics_value,
658
+ dataset_name=dataset_pretty_name,
659
+ dataset_type=dataset_id,
660
+ metric_config=metrics_config,
661
+ verified=metrics_verified,
662
+ verify_token=metrics_verification_token,
663
+ dataset_config=dataset_config,
664
+ dataset_split=dataset_split,
665
+ dataset_revision=dataset_revision,
666
+ )
667
+ ],
668
+ )
669
+ }
670
+
671
+
672
+ @validate_hf_hub_args
673
+ def metadata_update(
674
+ repo_id: str,
675
+ metadata: Dict,
676
+ *,
677
+ repo_type: Optional[str] = None,
678
+ overwrite: bool = False,
679
+ token: Optional[str] = None,
680
+ commit_message: Optional[str] = None,
681
+ commit_description: Optional[str] = None,
682
+ revision: Optional[str] = None,
683
+ create_pr: bool = False,
684
+ parent_commit: Optional[str] = None,
685
+ ) -> str:
686
+ """
687
+ Updates the metadata in the README.md of a repository on the Hugging Face Hub.
688
+ If the README.md file doesn't exist yet, a new one is created with metadata and an
689
+ the default ModelCard or DatasetCard template. For `space` repo, an error is thrown
690
+ as a Space cannot exist without a `README.md` file.
691
+
692
+ Args:
693
+ repo_id (`str`):
694
+ The name of the repository.
695
+ metadata (`dict`):
696
+ A dictionary containing the metadata to be updated.
697
+ repo_type (`str`, *optional*):
698
+ Set to `"dataset"` or `"space"` if updating to a dataset or space,
699
+ `None` or `"model"` if updating to a model. Default is `None`.
700
+ overwrite (`bool`, *optional*, defaults to `False`):
701
+ If set to `True` an existing field can be overwritten, otherwise
702
+ attempting to overwrite an existing field will cause an error.
703
+ token (`str`, *optional*):
704
+ The Hugging Face authentication token.
705
+ commit_message (`str`, *optional*):
706
+ The summary / title / first line of the generated commit. Defaults to
707
+ `f"Update metadata with huggingface_hub"`
708
+ commit_description (`str` *optional*)
709
+ The description of the generated commit
710
+ revision (`str`, *optional*):
711
+ The git revision to commit from. Defaults to the head of the
712
+ `"main"` branch.
713
+ create_pr (`boolean`, *optional*):
714
+ Whether or not to create a Pull Request from `revision` with that commit.
715
+ Defaults to `False`.
716
+ parent_commit (`str`, *optional*):
717
+ The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported.
718
+ If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`.
719
+ If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`.
720
+ Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be
721
+ especially useful if the repo is updated / committed to concurrently.
722
+ Returns:
723
+ `str`: URL of the commit which updated the card metadata.
724
+
725
+ Example:
726
+ ```python
727
+ >>> from huggingface_hub import metadata_update
728
+ >>> metadata = {'model-index': [{'name': 'RoBERTa fine-tuned on ReactionGIF',
729
+ ... 'results': [{'dataset': {'name': 'ReactionGIF',
730
+ ... 'type': 'julien-c/reactiongif'},
731
+ ... 'metrics': [{'name': 'Recall',
732
+ ... 'type': 'recall',
733
+ ... 'value': 0.7762102282047272}],
734
+ ... 'task': {'name': 'Text Classification',
735
+ ... 'type': 'text-classification'}}]}]}
736
+ >>> url = metadata_update("hf-internal-testing/reactiongif-roberta-card", metadata)
737
+
738
+ ```
739
+ """
740
+ commit_message = commit_message if commit_message is not None else "Update metadata with huggingface_hub"
741
+
742
+ # Card class given repo_type
743
+ card_class: Type[RepoCard]
744
+ if repo_type is None or repo_type == "model":
745
+ card_class = ModelCard
746
+ elif repo_type == "dataset":
747
+ card_class = DatasetCard
748
+ elif repo_type == "space":
749
+ card_class = RepoCard
750
+ else:
751
+ raise ValueError(f"Unknown repo_type: {repo_type}")
752
+
753
+ # Either load repo_card from the Hub or create an empty one.
754
+ # NOTE: Will not create the repo if it doesn't exist.
755
+ try:
756
+ card = card_class.load(repo_id, token=token, repo_type=repo_type)
757
+ except EntryNotFoundError:
758
+ if repo_type == "space":
759
+ raise ValueError("Cannot update metadata on a Space that doesn't contain a `README.md` file.")
760
+
761
+ # Initialize a ModelCard or DatasetCard from default template and no data.
762
+ card = card_class.from_template(CardData())
763
+
764
+ for key, value in metadata.items():
765
+ if key == "model-index":
766
+ # if the new metadata doesn't include a name, either use existing one or repo name
767
+ if "name" not in value[0]:
768
+ value[0]["name"] = getattr(card, "model_name", repo_id)
769
+ model_name, new_results = model_index_to_eval_results(value)
770
+ if card.data.eval_results is None:
771
+ card.data.eval_results = new_results
772
+ card.data.model_name = model_name
773
+ else:
774
+ existing_results = card.data.eval_results
775
+
776
+ # Iterate over new results
777
+ # Iterate over existing results
778
+ # If both results describe the same metric but value is different:
779
+ # If overwrite=True: overwrite the metric value
780
+ # Else: raise ValueError
781
+ # Else: append new result to existing ones.
782
+ for new_result in new_results:
783
+ result_found = False
784
+ for existing_result in existing_results:
785
+ if new_result.is_equal_except_value(existing_result):
786
+ if new_result != existing_result and not overwrite:
787
+ raise ValueError(
788
+ "You passed a new value for the existing metric"
789
+ f" 'name: {new_result.metric_name}, type: "
790
+ f"{new_result.metric_type}'. Set `overwrite=True`"
791
+ " to overwrite existing metrics."
792
+ )
793
+ result_found = True
794
+ existing_result.metric_value = new_result.metric_value
795
+ if existing_result.verified is True:
796
+ existing_result.verify_token = new_result.verify_token
797
+ if not result_found:
798
+ card.data.eval_results.append(new_result)
799
+ else:
800
+ # Any metadata that is not a result metric
801
+ if card.data.get(key) is not None and not overwrite and card.data.get(key) != value:
802
+ raise ValueError(
803
+ f"You passed a new value for the existing meta data field '{key}'."
804
+ " Set `overwrite=True` to overwrite existing metadata."
805
+ )
806
+ else:
807
+ card.data[key] = value
808
+
809
+ return card.push_to_hub(
810
+ repo_id,
811
+ token=token,
812
+ repo_type=repo_type,
813
+ commit_message=commit_message,
814
+ commit_description=commit_description,
815
+ create_pr=create_pr,
816
+ revision=revision,
817
+ parent_commit=parent_commit,
818
+ )
lib/python3.11/site-packages/huggingface_hub/repocard_data.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import warnings
3
+ from collections import defaultdict
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+ from huggingface_hub.utils import yaml_dump
8
+
9
+
10
+ @dataclass
11
+ class EvalResult:
12
+ """
13
+ Flattened representation of individual evaluation results found in model-index of Model Cards.
14
+
15
+ For more information on the model-index spec, see https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1.
16
+
17
+ Args:
18
+ task_type (`str`):
19
+ The task identifier. Example: "image-classification".
20
+ dataset_type (`str`):
21
+ The dataset identifier. Example: "common_voice". Use dataset id from https://hf.co/datasets.
22
+ dataset_name (`str`):
23
+ A pretty name for the dataset. Example: "Common Voice (French)".
24
+ metric_type (`str`):
25
+ The metric identifier. Example: "wer". Use metric id from https://hf.co/metrics.
26
+ metric_value (`Any`):
27
+ The metric value. Example: 0.9 or "20.0 ± 1.2".
28
+ task_name (`str`, *optional*):
29
+ A pretty name for the task. Example: "Speech Recognition".
30
+ dataset_config (`str`, *optional*):
31
+ The name of the dataset configuration used in `load_dataset()`.
32
+ Example: fr in `load_dataset("common_voice", "fr")`. See the `datasets` docs for more info:
33
+ https://hf.co/docs/datasets/package_reference/loading_methods#datasets.load_dataset.name
34
+ dataset_split (`str`, *optional*):
35
+ The split used in `load_dataset()`. Example: "test".
36
+ dataset_revision (`str`, *optional*):
37
+ The revision (AKA Git Sha) of the dataset used in `load_dataset()`.
38
+ Example: 5503434ddd753f426f4b38109466949a1217c2bb
39
+ dataset_args (`Dict[str, Any]`, *optional*):
40
+ The arguments passed during `Metric.compute()`. Example for `bleu`: `{"max_order": 4}`
41
+ metric_name (`str`, *optional*):
42
+ A pretty name for the metric. Example: "Test WER".
43
+ metric_config (`str`, *optional*):
44
+ The name of the metric configuration used in `load_metric()`.
45
+ Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`.
46
+ See the `datasets` docs for more info: https://huggingface.co/docs/datasets/v2.1.0/en/loading#load-configurations
47
+ metric_args (`Dict[str, Any]`, *optional*):
48
+ The arguments passed during `Metric.compute()`. Example for `bleu`: max_order: 4
49
+ verified (`bool`, *optional*):
50
+ Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set.
51
+ verify_token (`str`, *optional*):
52
+ A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not.
53
+ source_name (`str`, *optional*):
54
+ The name of the source of the evaluation result. Example: "Open LLM Leaderboard".
55
+ source_url (`str`, *optional*):
56
+ The URL of the source of the evaluation result. Example: "https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard".
57
+ """
58
+
59
+ # Required
60
+
61
+ # The task identifier
62
+ # Example: automatic-speech-recognition
63
+ task_type: str
64
+
65
+ # The dataset identifier
66
+ # Example: common_voice. Use dataset id from https://hf.co/datasets
67
+ dataset_type: str
68
+
69
+ # A pretty name for the dataset.
70
+ # Example: Common Voice (French)
71
+ dataset_name: str
72
+
73
+ # The metric identifier
74
+ # Example: wer. Use metric id from https://hf.co/metrics
75
+ metric_type: str
76
+
77
+ # Value of the metric.
78
+ # Example: 20.0 or "20.0 ± 1.2"
79
+ metric_value: Any
80
+
81
+ # Optional
82
+
83
+ # A pretty name for the task.
84
+ # Example: Speech Recognition
85
+ task_name: Optional[str] = None
86
+
87
+ # The name of the dataset configuration used in `load_dataset()`.
88
+ # Example: fr in `load_dataset("common_voice", "fr")`.
89
+ # See the `datasets` docs for more info:
90
+ # https://huggingface.co/docs/datasets/package_reference/loading_methods#datasets.load_dataset.name
91
+ dataset_config: Optional[str] = None
92
+
93
+ # The split used in `load_dataset()`.
94
+ # Example: test
95
+ dataset_split: Optional[str] = None
96
+
97
+ # The revision (AKA Git Sha) of the dataset used in `load_dataset()`.
98
+ # Example: 5503434ddd753f426f4b38109466949a1217c2bb
99
+ dataset_revision: Optional[str] = None
100
+
101
+ # The arguments passed during `Metric.compute()`.
102
+ # Example for `bleu`: max_order: 4
103
+ dataset_args: Optional[Dict[str, Any]] = None
104
+
105
+ # A pretty name for the metric.
106
+ # Example: Test WER
107
+ metric_name: Optional[str] = None
108
+
109
+ # The name of the metric configuration used in `load_metric()`.
110
+ # Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`.
111
+ # See the `datasets` docs for more info: https://huggingface.co/docs/datasets/v2.1.0/en/loading#load-configurations
112
+ metric_config: Optional[str] = None
113
+
114
+ # The arguments passed during `Metric.compute()`.
115
+ # Example for `bleu`: max_order: 4
116
+ metric_args: Optional[Dict[str, Any]] = None
117
+
118
+ # Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set.
119
+ verified: Optional[bool] = None
120
+
121
+ # A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not.
122
+ verify_token: Optional[str] = None
123
+
124
+ # The name of the source of the evaluation result.
125
+ # Example: Open LLM Leaderboard
126
+ source_name: Optional[str] = None
127
+
128
+ # The URL of the source of the evaluation result.
129
+ # Example: https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard
130
+ source_url: Optional[str] = None
131
+
132
+ @property
133
+ def unique_identifier(self) -> tuple:
134
+ """Returns a tuple that uniquely identifies this evaluation."""
135
+ return (
136
+ self.task_type,
137
+ self.dataset_type,
138
+ self.dataset_config,
139
+ self.dataset_split,
140
+ self.dataset_revision,
141
+ )
142
+
143
+ def is_equal_except_value(self, other: "EvalResult") -> bool:
144
+ """
145
+ Return True if `self` and `other` describe exactly the same metric but with a
146
+ different value.
147
+ """
148
+ for key, _ in self.__dict__.items():
149
+ if key == "metric_value":
150
+ continue
151
+ # For metrics computed by Hugging Face's evaluation service, `verify_token` is derived from `metric_value`,
152
+ # so we exclude it here in the comparison.
153
+ if key != "verify_token" and getattr(self, key) != getattr(other, key):
154
+ return False
155
+ return True
156
+
157
+ def __post_init__(self) -> None:
158
+ if self.source_name is not None and self.source_url is None:
159
+ raise ValueError("If `source_name` is provided, `source_url` must also be provided.")
160
+
161
+
162
+ @dataclass
163
+ class CardData:
164
+ """Structure containing metadata from a RepoCard.
165
+
166
+ [`CardData`] is the parent class of [`ModelCardData`] and [`DatasetCardData`].
167
+
168
+ Metadata can be exported as a dictionary or YAML. Export can be customized to alter the representation of the data
169
+ (example: flatten evaluation results). `CardData` behaves as a dictionary (can get, pop, set values) but do not
170
+ inherit from `dict` to allow this export step.
171
+ """
172
+
173
+ def __init__(self, ignore_metadata_errors: bool = False, **kwargs):
174
+ self.__dict__.update(kwargs)
175
+
176
+ def to_dict(self) -> Dict[str, Any]:
177
+ """Converts CardData to a dict.
178
+
179
+ Returns:
180
+ `dict`: CardData represented as a dictionary ready to be dumped to a YAML
181
+ block for inclusion in a README.md file.
182
+ """
183
+
184
+ data_dict = copy.deepcopy(self.__dict__)
185
+ self._to_dict(data_dict)
186
+ return _remove_none(data_dict)
187
+
188
+ def _to_dict(self, data_dict):
189
+ """Use this method in child classes to alter the dict representation of the data. Alter the dict in-place.
190
+
191
+ Args:
192
+ data_dict (`dict`): The raw dict representation of the card data.
193
+ """
194
+ pass
195
+
196
+ def to_yaml(self, line_break=None) -> str:
197
+ """Dumps CardData to a YAML block for inclusion in a README.md file.
198
+
199
+ Args:
200
+ line_break (str, *optional*):
201
+ The line break to use when dumping to yaml.
202
+
203
+ Returns:
204
+ `str`: CardData represented as a YAML block.
205
+ """
206
+ return yaml_dump(self.to_dict(), sort_keys=False, line_break=line_break).strip()
207
+
208
+ def __repr__(self):
209
+ return repr(self.__dict__)
210
+
211
+ def __str__(self):
212
+ return self.to_yaml()
213
+
214
+ def get(self, key: str, default: Any = None) -> Any:
215
+ """Get value for a given metadata key."""
216
+ return self.__dict__.get(key, default)
217
+
218
+ def pop(self, key: str, default: Any = None) -> Any:
219
+ """Pop value for a given metadata key."""
220
+ return self.__dict__.pop(key, default)
221
+
222
+ def __getitem__(self, key: str) -> Any:
223
+ """Get value for a given metadata key."""
224
+ return self.__dict__[key]
225
+
226
+ def __setitem__(self, key: str, value: Any) -> None:
227
+ """Set value for a given metadata key."""
228
+ self.__dict__[key] = value
229
+
230
+ def __contains__(self, key: str) -> bool:
231
+ """Check if a given metadata key is set."""
232
+ return key in self.__dict__
233
+
234
+ def __len__(self) -> int:
235
+ """Return the number of metadata keys set."""
236
+ return len(self.__dict__)
237
+
238
+
239
+ class ModelCardData(CardData):
240
+ """Model Card Metadata that is used by Hugging Face Hub when included at the top of your README.md
241
+
242
+ Args:
243
+ language (`Union[str, List[str]]`, *optional*):
244
+ Language of model's training data or metadata. It must be an ISO 639-1, 639-2 or
245
+ 639-3 code (two/three letters), or a special value like "code", "multilingual". Defaults to `None`.
246
+ license (`str`, *optional*):
247
+ License of this model. Example: apache-2.0 or any license from
248
+ https://huggingface.co/docs/hub/repositories-licenses. Defaults to None.
249
+ library_name (`str`, *optional*):
250
+ Name of library used by this model. Example: keras or any library from
251
+ https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/model-libraries.ts.
252
+ Defaults to None.
253
+ tags (`List[str]`, *optional*):
254
+ List of tags to add to your model that can be used when filtering on the Hugging
255
+ Face Hub. Defaults to None.
256
+ datasets (`List[str]`, *optional*):
257
+ List of datasets that were used to train this model. Should be a dataset ID
258
+ found on https://hf.co/datasets. Defaults to None.
259
+ metrics (`List[str]`, *optional*):
260
+ List of metrics used to evaluate this model. Should be a metric name that can be found
261
+ at https://hf.co/metrics. Example: 'accuracy'. Defaults to None.
262
+ eval_results (`Union[List[EvalResult], EvalResult]`, *optional*):
263
+ List of `huggingface_hub.EvalResult` that define evaluation results of the model. If provided,
264
+ `model_name` is used to as a name on PapersWithCode's leaderboards. Defaults to `None`.
265
+ model_name (`str`, *optional*):
266
+ A name for this model. It is used along with
267
+ `eval_results` to construct the `model-index` within the card's metadata. The name
268
+ you supply here is what will be used on PapersWithCode's leaderboards. If None is provided
269
+ then the repo name is used as a default. Defaults to None.
270
+ ignore_metadata_errors (`str`):
271
+ If True, errors while parsing the metadata section will be ignored. Some information might be lost during
272
+ the process. Use it at your own risk.
273
+ kwargs (`dict`, *optional*):
274
+ Additional metadata that will be added to the model card. Defaults to None.
275
+
276
+ Example:
277
+ ```python
278
+ >>> from huggingface_hub import ModelCardData
279
+ >>> card_data = ModelCardData(
280
+ ... language="en",
281
+ ... license="mit",
282
+ ... library_name="timm",
283
+ ... tags=['image-classification', 'resnet'],
284
+ ... )
285
+ >>> card_data.to_dict()
286
+ {'language': 'en', 'license': 'mit', 'library_name': 'timm', 'tags': ['image-classification', 'resnet']}
287
+
288
+ ```
289
+ """
290
+
291
+ def __init__(
292
+ self,
293
+ *,
294
+ language: Optional[Union[str, List[str]]] = None,
295
+ license: Optional[str] = None,
296
+ library_name: Optional[str] = None,
297
+ tags: Optional[List[str]] = None,
298
+ datasets: Optional[List[str]] = None,
299
+ metrics: Optional[List[str]] = None,
300
+ eval_results: Optional[List[EvalResult]] = None,
301
+ model_name: Optional[str] = None,
302
+ ignore_metadata_errors: bool = False,
303
+ **kwargs,
304
+ ):
305
+ self.language = language
306
+ self.license = license
307
+ self.library_name = library_name
308
+ self.tags = tags
309
+ self.datasets = datasets
310
+ self.metrics = metrics
311
+ self.eval_results = eval_results
312
+ self.model_name = model_name
313
+
314
+ model_index = kwargs.pop("model-index", None)
315
+ if model_index:
316
+ try:
317
+ model_name, eval_results = model_index_to_eval_results(model_index)
318
+ self.model_name = model_name
319
+ self.eval_results = eval_results
320
+ except (KeyError, TypeError) as error:
321
+ if ignore_metadata_errors:
322
+ warnings.warn("Invalid model-index. Not loading eval results into CardData.")
323
+ else:
324
+ raise ValueError(
325
+ f"Invalid `model_index` in metadata cannot be parsed: {error.__class__} {error}. Pass"
326
+ " `ignore_metadata_errors=True` to ignore this error while loading a Model Card. Warning:"
327
+ " some information will be lost. Use it at your own risk."
328
+ )
329
+
330
+ super().__init__(**kwargs)
331
+
332
+ if self.eval_results:
333
+ if type(self.eval_results) == EvalResult:
334
+ self.eval_results = [self.eval_results]
335
+ if self.model_name is None:
336
+ raise ValueError("Passing `eval_results` requires `model_name` to be set.")
337
+
338
+ def _to_dict(self, data_dict):
339
+ """Format the internal data dict. In this case, we convert eval results to a valid model index"""
340
+ if self.eval_results is not None:
341
+ data_dict["model-index"] = eval_results_to_model_index(self.model_name, self.eval_results)
342
+ del data_dict["eval_results"], data_dict["model_name"]
343
+
344
+
345
+ class DatasetCardData(CardData):
346
+ """Dataset Card Metadata that is used by Hugging Face Hub when included at the top of your README.md
347
+
348
+ Args:
349
+ language (`List[str]`, *optional*):
350
+ Language of dataset's data or metadata. It must be an ISO 639-1, 639-2 or
351
+ 639-3 code (two/three letters), or a special value like "code", "multilingual".
352
+ license (`Union[str, List[str]]`, *optional*):
353
+ License(s) of this dataset. Example: apache-2.0 or any license from
354
+ https://huggingface.co/docs/hub/repositories-licenses.
355
+ annotations_creators (`Union[str, List[str]]`, *optional*):
356
+ How the annotations for the dataset were created.
357
+ Options are: 'found', 'crowdsourced', 'expert-generated', 'machine-generated', 'no-annotation', 'other'.
358
+ language_creators (`Union[str, List[str]]`, *optional*):
359
+ How the text-based data in the dataset was created.
360
+ Options are: 'found', 'crowdsourced', 'expert-generated', 'machine-generated', 'other'
361
+ multilinguality (`Union[str, List[str]]`, *optional*):
362
+ Whether the dataset is multilingual.
363
+ Options are: 'monolingual', 'multilingual', 'translation', 'other'.
364
+ size_categories (`Union[str, List[str]]`, *optional*):
365
+ The number of examples in the dataset. Options are: 'n<1K', '1K<n<10K', '10K<n<100K',
366
+ '100K<n<1M', '1M<n<10M', '10M<n<100M', '100M<n<1B', '1B<n<10B', '10B<n<100B', '100B<n<1T', 'n>1T', and 'other'.
367
+ source_datasets (`List[str]]`, *optional*):
368
+ Indicates whether the dataset is an original dataset or extended from another existing dataset.
369
+ Options are: 'original' and 'extended'.
370
+ task_categories (`Union[str, List[str]]`, *optional*):
371
+ What categories of task does the dataset support?
372
+ task_ids (`Union[str, List[str]]`, *optional*):
373
+ What specific tasks does the dataset support?
374
+ paperswithcode_id (`str`, *optional*):
375
+ ID of the dataset on PapersWithCode.
376
+ pretty_name (`str`, *optional*):
377
+ A more human-readable name for the dataset. (ex. "Cats vs. Dogs")
378
+ train_eval_index (`Dict`, *optional*):
379
+ A dictionary that describes the necessary spec for doing evaluation on the Hub.
380
+ If not provided, it will be gathered from the 'train-eval-index' key of the kwargs.
381
+ config_names (`Union[str, List[str]]`, *optional*):
382
+ A list of the available dataset configs for the dataset.
383
+ """
384
+
385
+ def __init__(
386
+ self,
387
+ *,
388
+ language: Optional[Union[str, List[str]]] = None,
389
+ license: Optional[Union[str, List[str]]] = None,
390
+ annotations_creators: Optional[Union[str, List[str]]] = None,
391
+ language_creators: Optional[Union[str, List[str]]] = None,
392
+ multilinguality: Optional[Union[str, List[str]]] = None,
393
+ size_categories: Optional[Union[str, List[str]]] = None,
394
+ source_datasets: Optional[List[str]] = None,
395
+ task_categories: Optional[Union[str, List[str]]] = None,
396
+ task_ids: Optional[Union[str, List[str]]] = None,
397
+ paperswithcode_id: Optional[str] = None,
398
+ pretty_name: Optional[str] = None,
399
+ train_eval_index: Optional[Dict] = None,
400
+ config_names: Optional[Union[str, List[str]]] = None,
401
+ ignore_metadata_errors: bool = False,
402
+ **kwargs,
403
+ ):
404
+ self.annotations_creators = annotations_creators
405
+ self.language_creators = language_creators
406
+ self.language = language
407
+ self.license = license
408
+ self.multilinguality = multilinguality
409
+ self.size_categories = size_categories
410
+ self.source_datasets = source_datasets
411
+ self.task_categories = task_categories
412
+ self.task_ids = task_ids
413
+ self.paperswithcode_id = paperswithcode_id
414
+ self.pretty_name = pretty_name
415
+ self.config_names = config_names
416
+
417
+ # TODO - maybe handle this similarly to EvalResult?
418
+ self.train_eval_index = train_eval_index or kwargs.pop("train-eval-index", None)
419
+ super().__init__(**kwargs)
420
+
421
+ def _to_dict(self, data_dict):
422
+ data_dict["train-eval-index"] = data_dict.pop("train_eval_index")
423
+
424
+
425
+ class SpaceCardData(CardData):
426
+ """Space Card Metadata that is used by Hugging Face Hub when included at the top of your README.md
427
+
428
+ To get an exhaustive reference of Spaces configuration, please visit https://huggingface.co/docs/hub/spaces-config-reference#spaces-configuration-reference.
429
+
430
+ Args:
431
+ title (`str`, *optional*)
432
+ Title of the Space.
433
+ sdk (`str`, *optional*)
434
+ SDK of the Space (one of `gradio`, `streamlit`, `docker`, or `static`).
435
+ sdk_version (`str`, *optional*)
436
+ Version of the used SDK (if Gradio/Streamlit sdk).
437
+ python_version (`str`, *optional*)
438
+ Python version used in the Space (if Gradio/Streamlit sdk).
439
+ app_file (`str`, *optional*)
440
+ Path to your main application file (which contains either gradio or streamlit Python code, or static html code).
441
+ Path is relative to the root of the repository.
442
+ app_port (`str`, *optional*)
443
+ Port on which your application is running. Used only if sdk is `docker`.
444
+ license (`str`, *optional*)
445
+ License of this model. Example: apache-2.0 or any license from
446
+ https://huggingface.co/docs/hub/repositories-licenses.
447
+ duplicated_from (`str`, *optional*)
448
+ ID of the original Space if this is a duplicated Space.
449
+ models (List[`str`], *optional*)
450
+ List of models related to this Space. Should be a dataset ID found on https://hf.co/models.
451
+ datasets (`List[str]`, *optional*)
452
+ List of datasets related to this Space. Should be a dataset ID found on https://hf.co/datasets.
453
+ tags (`List[str]`, *optional*)
454
+ List of tags to add to your Space that can be used when filtering on the Hub.
455
+ ignore_metadata_errors (`str`):
456
+ If True, errors while parsing the metadata section will be ignored. Some information might be lost during
457
+ the process. Use it at your own risk.
458
+ kwargs (`dict`, *optional*):
459
+ Additional metadata that will be added to the space card.
460
+
461
+ Example:
462
+ ```python
463
+ >>> from huggingface_hub import SpaceCardData
464
+ >>> card_data = SpaceCardData(
465
+ ... title="Dreambooth Training",
466
+ ... license="mit",
467
+ ... sdk="gradio",
468
+ ... duplicated_from="multimodalart/dreambooth-training"
469
+ ... )
470
+ >>> card_data.to_dict()
471
+ {'title': 'Dreambooth Training', 'sdk': 'gradio', 'license': 'mit', 'duplicated_from': 'multimodalart/dreambooth-training'}
472
+ ```
473
+ """
474
+
475
+ def __init__(
476
+ self,
477
+ *,
478
+ title: Optional[str] = None,
479
+ sdk: Optional[str] = None,
480
+ sdk_version: Optional[str] = None,
481
+ python_version: Optional[str] = None,
482
+ app_file: Optional[str] = None,
483
+ app_port: Optional[int] = None,
484
+ license: Optional[str] = None,
485
+ duplicated_from: Optional[str] = None,
486
+ models: Optional[List[str]] = None,
487
+ datasets: Optional[List[str]] = None,
488
+ tags: Optional[List[str]] = None,
489
+ ignore_metadata_errors: bool = False,
490
+ **kwargs,
491
+ ):
492
+ self.title = title
493
+ self.sdk = sdk
494
+ self.sdk_version = sdk_version
495
+ self.python_version = python_version
496
+ self.app_file = app_file
497
+ self.app_port = app_port
498
+ self.license = license
499
+ self.duplicated_from = duplicated_from
500
+ self.models = models
501
+ self.datasets = datasets
502
+ self.tags = tags
503
+ super().__init__(**kwargs)
504
+
505
+
506
+ def model_index_to_eval_results(model_index: List[Dict[str, Any]]) -> Tuple[str, List[EvalResult]]:
507
+ """Takes in a model index and returns the model name and a list of `huggingface_hub.EvalResult` objects.
508
+
509
+ A detailed spec of the model index can be found here:
510
+ https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
511
+
512
+ Args:
513
+ model_index (`List[Dict[str, Any]]`):
514
+ A model index data structure, likely coming from a README.md file on the
515
+ Hugging Face Hub.
516
+
517
+ Returns:
518
+ model_name (`str`):
519
+ The name of the model as found in the model index. This is used as the
520
+ identifier for the model on leaderboards like PapersWithCode.
521
+ eval_results (`List[EvalResult]`):
522
+ A list of `huggingface_hub.EvalResult` objects containing the metrics
523
+ reported in the provided model_index.
524
+
525
+ Example:
526
+ ```python
527
+ >>> from huggingface_hub.repocard_data import model_index_to_eval_results
528
+ >>> # Define a minimal model index
529
+ >>> model_index = [
530
+ ... {
531
+ ... "name": "my-cool-model",
532
+ ... "results": [
533
+ ... {
534
+ ... "task": {
535
+ ... "type": "image-classification"
536
+ ... },
537
+ ... "dataset": {
538
+ ... "type": "beans",
539
+ ... "name": "Beans"
540
+ ... },
541
+ ... "metrics": [
542
+ ... {
543
+ ... "type": "accuracy",
544
+ ... "value": 0.9
545
+ ... }
546
+ ... ]
547
+ ... }
548
+ ... ]
549
+ ... }
550
+ ... ]
551
+ >>> model_name, eval_results = model_index_to_eval_results(model_index)
552
+ >>> model_name
553
+ 'my-cool-model'
554
+ >>> eval_results[0].task_type
555
+ 'image-classification'
556
+ >>> eval_results[0].metric_type
557
+ 'accuracy'
558
+
559
+ ```
560
+ """
561
+
562
+ eval_results = []
563
+ for elem in model_index:
564
+ name = elem["name"]
565
+ results = elem["results"]
566
+ for result in results:
567
+ task_type = result["task"]["type"]
568
+ task_name = result["task"].get("name")
569
+ dataset_type = result["dataset"]["type"]
570
+ dataset_name = result["dataset"]["name"]
571
+ dataset_config = result["dataset"].get("config")
572
+ dataset_split = result["dataset"].get("split")
573
+ dataset_revision = result["dataset"].get("revision")
574
+ dataset_args = result["dataset"].get("args")
575
+ source_name = result.get("source", {}).get("name")
576
+ source_url = result.get("source", {}).get("url")
577
+
578
+ for metric in result["metrics"]:
579
+ metric_type = metric["type"]
580
+ metric_value = metric["value"]
581
+ metric_name = metric.get("name")
582
+ metric_args = metric.get("args")
583
+ metric_config = metric.get("config")
584
+ verified = metric.get("verified")
585
+ verify_token = metric.get("verifyToken")
586
+
587
+ eval_result = EvalResult(
588
+ task_type=task_type, # Required
589
+ dataset_type=dataset_type, # Required
590
+ dataset_name=dataset_name, # Required
591
+ metric_type=metric_type, # Required
592
+ metric_value=metric_value, # Required
593
+ task_name=task_name,
594
+ dataset_config=dataset_config,
595
+ dataset_split=dataset_split,
596
+ dataset_revision=dataset_revision,
597
+ dataset_args=dataset_args,
598
+ metric_name=metric_name,
599
+ metric_args=metric_args,
600
+ metric_config=metric_config,
601
+ verified=verified,
602
+ verify_token=verify_token,
603
+ source_name=source_name,
604
+ source_url=source_url,
605
+ )
606
+ eval_results.append(eval_result)
607
+ return name, eval_results
608
+
609
+
610
+ def _remove_none(obj):
611
+ """
612
+ Recursively remove `None` values from a dict. Borrowed from: https://stackoverflow.com/a/20558778
613
+ """
614
+ if isinstance(obj, (list, tuple, set)):
615
+ return type(obj)(_remove_none(x) for x in obj if x is not None)
616
+ elif isinstance(obj, dict):
617
+ return type(obj)((_remove_none(k), _remove_none(v)) for k, v in obj.items() if k is not None and v is not None)
618
+ else:
619
+ return obj
620
+
621
+
622
+ def eval_results_to_model_index(model_name: str, eval_results: List[EvalResult]) -> List[Dict[str, Any]]:
623
+ """Takes in given model name and list of `huggingface_hub.EvalResult` and returns a
624
+ valid model-index that will be compatible with the format expected by the
625
+ Hugging Face Hub.
626
+
627
+ Args:
628
+ model_name (`str`):
629
+ Name of the model (ex. "my-cool-model"). This is used as the identifier
630
+ for the model on leaderboards like PapersWithCode.
631
+ eval_results (`List[EvalResult]`):
632
+ List of `huggingface_hub.EvalResult` objects containing the metrics to be
633
+ reported in the model-index.
634
+
635
+ Returns:
636
+ model_index (`List[Dict[str, Any]]`): The eval_results converted to a model-index.
637
+
638
+ Example:
639
+ ```python
640
+ >>> from huggingface_hub.repocard_data import eval_results_to_model_index, EvalResult
641
+ >>> # Define minimal eval_results
642
+ >>> eval_results = [
643
+ ... EvalResult(
644
+ ... task_type="image-classification", # Required
645
+ ... dataset_type="beans", # Required
646
+ ... dataset_name="Beans", # Required
647
+ ... metric_type="accuracy", # Required
648
+ ... metric_value=0.9, # Required
649
+ ... )
650
+ ... ]
651
+ >>> eval_results_to_model_index("my-cool-model", eval_results)
652
+ [{'name': 'my-cool-model', 'results': [{'task': {'type': 'image-classification'}, 'dataset': {'name': 'Beans', 'type': 'beans'}, 'metrics': [{'type': 'accuracy', 'value': 0.9}]}]}]
653
+
654
+ ```
655
+ """
656
+
657
+ # Metrics are reported on a unique task-and-dataset basis.
658
+ # Here, we make a map of those pairs and the associated EvalResults.
659
+ task_and_ds_types_map: Dict[Any, List[EvalResult]] = defaultdict(list)
660
+ for eval_result in eval_results:
661
+ task_and_ds_types_map[eval_result.unique_identifier].append(eval_result)
662
+
663
+ # Use the map from above to generate the model index data.
664
+ model_index_data = []
665
+ for results in task_and_ds_types_map.values():
666
+ # All items from `results` share same metadata
667
+ sample_result = results[0]
668
+ data = {
669
+ "task": {
670
+ "type": sample_result.task_type,
671
+ "name": sample_result.task_name,
672
+ },
673
+ "dataset": {
674
+ "name": sample_result.dataset_name,
675
+ "type": sample_result.dataset_type,
676
+ "config": sample_result.dataset_config,
677
+ "split": sample_result.dataset_split,
678
+ "revision": sample_result.dataset_revision,
679
+ "args": sample_result.dataset_args,
680
+ },
681
+ "metrics": [
682
+ {
683
+ "type": result.metric_type,
684
+ "value": result.metric_value,
685
+ "name": result.metric_name,
686
+ "config": result.metric_config,
687
+ "args": result.metric_args,
688
+ "verified": result.verified,
689
+ "verifyToken": result.verify_token,
690
+ }
691
+ for result in results
692
+ ],
693
+ }
694
+ if sample_result.source_url is not None:
695
+ source = {
696
+ "url": sample_result.source_url,
697
+ }
698
+ if sample_result.source_name is not None:
699
+ source["name"] = sample_result.source_name
700
+ data["source"] = source
701
+ model_index_data.append(data)
702
+
703
+ # TODO - Check if there cases where this list is longer than one?
704
+ # Finally, the model index itself is list of dicts.
705
+ model_index = [
706
+ {
707
+ "name": model_name,
708
+ "results": model_index_data,
709
+ }
710
+ ]
711
+ return _remove_none(model_index)
lib/python3.11/site-packages/huggingface_hub/repository.py ADDED
@@ -0,0 +1,1476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import os
3
+ import re
4
+ import subprocess
5
+ import threading
6
+ import time
7
+ from contextlib import contextmanager
8
+ from pathlib import Path
9
+ from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypedDict, Union
10
+ from urllib.parse import urlparse
11
+
12
+ from huggingface_hub.constants import REPO_TYPES_URL_PREFIXES, REPOCARD_NAME
13
+ from huggingface_hub.repocard import metadata_load, metadata_save
14
+
15
+ from .hf_api import HfApi, repo_type_and_id_from_hf_id
16
+ from .lfs import LFS_MULTIPART_UPLOAD_COMMAND
17
+ from .utils import (
18
+ SoftTemporaryDirectory,
19
+ get_token,
20
+ logging,
21
+ run_subprocess,
22
+ tqdm,
23
+ validate_hf_hub_args,
24
+ )
25
+ from .utils._deprecation import _deprecate_method
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class CommandInProgress:
32
+ """
33
+ Utility to follow commands launched asynchronously.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ title: str,
39
+ is_done_method: Callable,
40
+ status_method: Callable,
41
+ process: subprocess.Popen,
42
+ post_method: Optional[Callable] = None,
43
+ ):
44
+ self.title = title
45
+ self._is_done = is_done_method
46
+ self._status = status_method
47
+ self._process = process
48
+ self._stderr = ""
49
+ self._stdout = ""
50
+ self._post_method = post_method
51
+
52
+ @property
53
+ def is_done(self) -> bool:
54
+ """
55
+ Whether the process is done.
56
+ """
57
+ result = self._is_done()
58
+
59
+ if result and self._post_method is not None:
60
+ self._post_method()
61
+ self._post_method = None
62
+
63
+ return result
64
+
65
+ @property
66
+ def status(self) -> int:
67
+ """
68
+ The exit code/status of the current action. Will return `0` if the
69
+ command has completed successfully, and a number between 1 and 255 if
70
+ the process errored-out.
71
+
72
+ Will return -1 if the command is still ongoing.
73
+ """
74
+ return self._status()
75
+
76
+ @property
77
+ def failed(self) -> bool:
78
+ """
79
+ Whether the process errored-out.
80
+ """
81
+ return self.status > 0
82
+
83
+ @property
84
+ def stderr(self) -> str:
85
+ """
86
+ The current output message on the standard error.
87
+ """
88
+ if self._process.stderr is not None:
89
+ self._stderr += self._process.stderr.read()
90
+ return self._stderr
91
+
92
+ @property
93
+ def stdout(self) -> str:
94
+ """
95
+ The current output message on the standard output.
96
+ """
97
+ if self._process.stdout is not None:
98
+ self._stdout += self._process.stdout.read()
99
+ return self._stdout
100
+
101
+ def __repr__(self):
102
+ status = self.status
103
+
104
+ if status == -1:
105
+ status = "running"
106
+
107
+ return (
108
+ f"[{self.title} command, status code: {status},"
109
+ f" {'in progress.' if not self.is_done else 'finished.'} PID:"
110
+ f" {self._process.pid}]"
111
+ )
112
+
113
+
114
+ def is_git_repo(folder: Union[str, Path]) -> bool:
115
+ """
116
+ Check if the folder is the root or part of a git repository
117
+
118
+ Args:
119
+ folder (`str`):
120
+ The folder in which to run the command.
121
+
122
+ Returns:
123
+ `bool`: `True` if the repository is part of a repository, `False`
124
+ otherwise.
125
+ """
126
+ folder_exists = os.path.exists(os.path.join(folder, ".git"))
127
+ git_branch = subprocess.run("git branch".split(), cwd=folder, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
128
+ return folder_exists and git_branch.returncode == 0
129
+
130
+
131
+ def is_local_clone(folder: Union[str, Path], remote_url: str) -> bool:
132
+ """
133
+ Check if the folder is a local clone of the remote_url
134
+
135
+ Args:
136
+ folder (`str` or `Path`):
137
+ The folder in which to run the command.
138
+ remote_url (`str`):
139
+ The url of a git repository.
140
+
141
+ Returns:
142
+ `bool`: `True` if the repository is a local clone of the remote
143
+ repository specified, `False` otherwise.
144
+ """
145
+ if not is_git_repo(folder):
146
+ return False
147
+
148
+ remotes = run_subprocess("git remote -v", folder).stdout
149
+
150
+ # Remove token for the test with remotes.
151
+ remote_url = re.sub(r"https://.*@", "https://", remote_url)
152
+ remotes = [re.sub(r"https://.*@", "https://", remote) for remote in remotes.split()]
153
+ return remote_url in remotes
154
+
155
+
156
+ def is_tracked_with_lfs(filename: Union[str, Path]) -> bool:
157
+ """
158
+ Check if the file passed is tracked with git-lfs.
159
+
160
+ Args:
161
+ filename (`str` or `Path`):
162
+ The filename to check.
163
+
164
+ Returns:
165
+ `bool`: `True` if the file passed is tracked with git-lfs, `False`
166
+ otherwise.
167
+ """
168
+ folder = Path(filename).parent
169
+ filename = Path(filename).name
170
+
171
+ try:
172
+ p = run_subprocess("git check-attr -a".split() + [filename], folder)
173
+ attributes = p.stdout.strip()
174
+ except subprocess.CalledProcessError as exc:
175
+ if not is_git_repo(folder):
176
+ return False
177
+ else:
178
+ raise OSError(exc.stderr)
179
+
180
+ if len(attributes) == 0:
181
+ return False
182
+
183
+ found_lfs_tag = {"diff": False, "merge": False, "filter": False}
184
+
185
+ for attribute in attributes.split("\n"):
186
+ for tag in found_lfs_tag.keys():
187
+ if tag in attribute and "lfs" in attribute:
188
+ found_lfs_tag[tag] = True
189
+
190
+ return all(found_lfs_tag.values())
191
+
192
+
193
+ def is_git_ignored(filename: Union[str, Path]) -> bool:
194
+ """
195
+ Check if file is git-ignored. Supports nested .gitignore files.
196
+
197
+ Args:
198
+ filename (`str` or `Path`):
199
+ The filename to check.
200
+
201
+ Returns:
202
+ `bool`: `True` if the file passed is ignored by `git`, `False`
203
+ otherwise.
204
+ """
205
+ folder = Path(filename).parent
206
+ filename = Path(filename).name
207
+
208
+ try:
209
+ p = run_subprocess("git check-ignore".split() + [filename], folder, check=False)
210
+ # Will return exit code 1 if not gitignored
211
+ is_ignored = not bool(p.returncode)
212
+ except subprocess.CalledProcessError as exc:
213
+ raise OSError(exc.stderr)
214
+
215
+ return is_ignored
216
+
217
+
218
+ def is_binary_file(filename: Union[str, Path]) -> bool:
219
+ """
220
+ Check if file is a binary file.
221
+
222
+ Args:
223
+ filename (`str` or `Path`):
224
+ The filename to check.
225
+
226
+ Returns:
227
+ `bool`: `True` if the file passed is a binary file, `False` otherwise.
228
+ """
229
+ try:
230
+ with open(filename, "rb") as f:
231
+ content = f.read(10 * (1024**2)) # Read a maximum of 10MB
232
+
233
+ # Code sample taken from the following stack overflow thread
234
+ # https://stackoverflow.com/questions/898669/how-can-i-detect-if-a-file-is-binary-non-text-in-python/7392391#7392391
235
+ text_chars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F})
236
+ return bool(content.translate(None, text_chars))
237
+ except UnicodeDecodeError:
238
+ return True
239
+
240
+
241
+ def files_to_be_staged(pattern: str = ".", folder: Union[str, Path, None] = None) -> List[str]:
242
+ """
243
+ Returns a list of filenames that are to be staged.
244
+
245
+ Args:
246
+ pattern (`str` or `Path`):
247
+ The pattern of filenames to check. Put `.` to get all files.
248
+ folder (`str` or `Path`):
249
+ The folder in which to run the command.
250
+
251
+ Returns:
252
+ `List[str]`: List of files that are to be staged.
253
+ """
254
+ try:
255
+ p = run_subprocess("git ls-files --exclude-standard -mo".split() + [pattern], folder)
256
+ if len(p.stdout.strip()):
257
+ files = p.stdout.strip().split("\n")
258
+ else:
259
+ files = []
260
+ except subprocess.CalledProcessError as exc:
261
+ raise EnvironmentError(exc.stderr)
262
+
263
+ return files
264
+
265
+
266
+ def is_tracked_upstream(folder: Union[str, Path]) -> bool:
267
+ """
268
+ Check if the current checked-out branch is tracked upstream.
269
+
270
+ Args:
271
+ folder (`str` or `Path`):
272
+ The folder in which to run the command.
273
+
274
+ Returns:
275
+ `bool`: `True` if the current checked-out branch is tracked upstream,
276
+ `False` otherwise.
277
+ """
278
+ try:
279
+ run_subprocess("git rev-parse --symbolic-full-name --abbrev-ref @{u}", folder)
280
+ return True
281
+ except subprocess.CalledProcessError as exc:
282
+ if "HEAD" in exc.stderr:
283
+ raise OSError("No branch checked out")
284
+
285
+ return False
286
+
287
+
288
+ def commits_to_push(folder: Union[str, Path], upstream: Optional[str] = None) -> int:
289
+ """
290
+ Check the number of commits that would be pushed upstream
291
+
292
+ Args:
293
+ folder (`str` or `Path`):
294
+ The folder in which to run the command.
295
+ upstream (`str`, *optional*):
296
+ The name of the upstream repository with which the comparison should be
297
+ made.
298
+
299
+ Returns:
300
+ `int`: Number of commits that would be pushed upstream were a `git
301
+ push` to proceed.
302
+ """
303
+ try:
304
+ result = run_subprocess(f"git cherry -v {upstream or ''}", folder)
305
+ return len(result.stdout.split("\n")) - 1
306
+ except subprocess.CalledProcessError as exc:
307
+ raise EnvironmentError(exc.stderr)
308
+
309
+
310
+ class PbarT(TypedDict):
311
+ # Used to store an opened progress bar in `_lfs_log_progress`
312
+ bar: tqdm
313
+ past_bytes: int
314
+
315
+
316
+ @contextmanager
317
+ def _lfs_log_progress():
318
+ """
319
+ This is a context manager that will log the Git LFS progress of cleaning,
320
+ smudging, pulling and pushing.
321
+ """
322
+
323
+ if logger.getEffectiveLevel() >= logging.ERROR:
324
+ try:
325
+ yield
326
+ except Exception:
327
+ pass
328
+ return
329
+
330
+ def output_progress(stopping_event: threading.Event):
331
+ """
332
+ To be launched as a separate thread with an event meaning it should stop
333
+ the tail.
334
+ """
335
+ # Key is tuple(state, filename), value is a dict(tqdm bar and a previous value)
336
+ pbars: Dict[Tuple[str, str], PbarT] = {}
337
+
338
+ def close_pbars():
339
+ for pbar in pbars.values():
340
+ pbar["bar"].update(pbar["bar"].total - pbar["past_bytes"])
341
+ pbar["bar"].refresh()
342
+ pbar["bar"].close()
343
+
344
+ def tail_file(filename) -> Iterator[str]:
345
+ """
346
+ Creates a generator to be iterated through, which will return each
347
+ line one by one. Will stop tailing the file if the stopping_event is
348
+ set.
349
+ """
350
+ with open(filename, "r") as file:
351
+ current_line = ""
352
+ while True:
353
+ if stopping_event.is_set():
354
+ close_pbars()
355
+ break
356
+
357
+ line_bit = file.readline()
358
+ if line_bit is not None and not len(line_bit.strip()) == 0:
359
+ current_line += line_bit
360
+ if current_line.endswith("\n"):
361
+ yield current_line
362
+ current_line = ""
363
+ else:
364
+ time.sleep(1)
365
+
366
+ # If the file isn't created yet, wait for a few seconds before trying again.
367
+ # Can be interrupted with the stopping_event.
368
+ while not os.path.exists(os.environ["GIT_LFS_PROGRESS"]):
369
+ if stopping_event.is_set():
370
+ close_pbars()
371
+ return
372
+
373
+ time.sleep(2)
374
+
375
+ for line in tail_file(os.environ["GIT_LFS_PROGRESS"]):
376
+ try:
377
+ state, file_progress, byte_progress, filename = line.split()
378
+ except ValueError as error:
379
+ # Try/except to ease debugging. See https://github.com/huggingface/huggingface_hub/issues/1373.
380
+ raise ValueError(f"Cannot unpack LFS progress line:\n{line}") from error
381
+ description = f"{state.capitalize()} file {filename}"
382
+
383
+ current_bytes, total_bytes = byte_progress.split("/")
384
+ current_bytes_int = int(current_bytes)
385
+ total_bytes_int = int(total_bytes)
386
+
387
+ pbar = pbars.get((state, filename))
388
+ if pbar is None:
389
+ # Initialize progress bar
390
+ pbars[(state, filename)] = {
391
+ "bar": tqdm(
392
+ desc=description,
393
+ initial=current_bytes_int,
394
+ total=total_bytes_int,
395
+ unit="B",
396
+ unit_scale=True,
397
+ unit_divisor=1024,
398
+ ),
399
+ "past_bytes": int(current_bytes),
400
+ }
401
+ else:
402
+ # Update progress bar
403
+ pbar["bar"].update(current_bytes_int - pbar["past_bytes"])
404
+ pbar["past_bytes"] = current_bytes_int
405
+
406
+ current_lfs_progress_value = os.environ.get("GIT_LFS_PROGRESS", "")
407
+
408
+ with SoftTemporaryDirectory() as tmpdir:
409
+ os.environ["GIT_LFS_PROGRESS"] = os.path.join(tmpdir, "lfs_progress")
410
+ logger.debug(f"Following progress in {os.environ['GIT_LFS_PROGRESS']}")
411
+
412
+ exit_event = threading.Event()
413
+ x = threading.Thread(target=output_progress, args=(exit_event,), daemon=True)
414
+ x.start()
415
+
416
+ try:
417
+ yield
418
+ finally:
419
+ exit_event.set()
420
+ x.join()
421
+
422
+ os.environ["GIT_LFS_PROGRESS"] = current_lfs_progress_value
423
+
424
+
425
+ class Repository:
426
+ """
427
+ Helper class to wrap the git and git-lfs commands.
428
+
429
+ The aim is to facilitate interacting with huggingface.co hosted model or
430
+ dataset repos, though not a lot here (if any) is actually specific to
431
+ huggingface.co.
432
+
433
+ <Tip warning={true}>
434
+
435
+ [`Repository`] is deprecated in favor of the http-based alternatives implemented in
436
+ [`HfApi`]. Given its large adoption in legacy code, the complete removal of
437
+ [`Repository`] will only happen in release `v1.0`. For more details, please read
438
+ https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http.
439
+
440
+ </Tip>
441
+ """
442
+
443
+ command_queue: List[CommandInProgress]
444
+
445
+ @validate_hf_hub_args
446
+ @_deprecate_method(
447
+ version="1.0",
448
+ message=(
449
+ "Please prefer the http-based alternatives instead. Given its large adoption in legacy code, the complete"
450
+ " removal is only planned on next major release.\nFor more details, please read"
451
+ " https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http."
452
+ ),
453
+ )
454
+ def __init__(
455
+ self,
456
+ local_dir: Union[str, Path],
457
+ clone_from: Optional[str] = None,
458
+ repo_type: Optional[str] = None,
459
+ token: Union[bool, str] = True,
460
+ git_user: Optional[str] = None,
461
+ git_email: Optional[str] = None,
462
+ revision: Optional[str] = None,
463
+ skip_lfs_files: bool = False,
464
+ client: Optional[HfApi] = None,
465
+ ):
466
+ """
467
+ Instantiate a local clone of a git repo.
468
+
469
+ If `clone_from` is set, the repo will be cloned from an existing remote repository.
470
+ If the remote repo does not exist, a `EnvironmentError` exception will be thrown.
471
+ Please create the remote repo first using [`create_repo`].
472
+
473
+ `Repository` uses the local git credentials by default. If explicitly set, the `token`
474
+ or the `git_user`/`git_email` pair will be used instead.
475
+
476
+ Args:
477
+ local_dir (`str` or `Path`):
478
+ path (e.g. `'my_trained_model/'`) to the local directory, where
479
+ the `Repository` will be initialized.
480
+ clone_from (`str`, *optional*):
481
+ Either a repository url or `repo_id`.
482
+ Example:
483
+ - `"https://huggingface.co/philschmid/playground-tests"`
484
+ - `"philschmid/playground-tests"`
485
+ repo_type (`str`, *optional*):
486
+ To set when cloning a repo from a repo_id. Default is model.
487
+ token (`bool` or `str`, *optional*):
488
+ A valid authentication token (see https://huggingface.co/settings/token).
489
+ If `None` or `True` and machine is logged in (through `huggingface-cli login`
490
+ or [`~huggingface_hub.login`]), token will be retrieved from the cache.
491
+ If `False`, token is not sent in the request header.
492
+ git_user (`str`, *optional*):
493
+ will override the `git config user.name` for committing and
494
+ pushing files to the hub.
495
+ git_email (`str`, *optional*):
496
+ will override the `git config user.email` for committing and
497
+ pushing files to the hub.
498
+ revision (`str`, *optional*):
499
+ Revision to checkout after initializing the repository. If the
500
+ revision doesn't exist, a branch will be created with that
501
+ revision name from the default branch's current HEAD.
502
+ skip_lfs_files (`bool`, *optional*, defaults to `False`):
503
+ whether to skip git-LFS files or not.
504
+ client (`HfApi`, *optional*):
505
+ Instance of [`HfApi`] to use when calling the HF Hub API. A new
506
+ instance will be created if this is left to `None`.
507
+
508
+ Raises:
509
+ - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
510
+ if the remote repository set in `clone_from` does not exist.
511
+ """
512
+ if isinstance(local_dir, Path):
513
+ local_dir = str(local_dir)
514
+ os.makedirs(local_dir, exist_ok=True)
515
+ self.local_dir = os.path.join(os.getcwd(), local_dir)
516
+ self._repo_type = repo_type
517
+ self.command_queue = []
518
+ self.skip_lfs_files = skip_lfs_files
519
+ self.client = client if client is not None else HfApi()
520
+
521
+ self.check_git_versions()
522
+
523
+ if isinstance(token, str):
524
+ self.huggingface_token: Optional[str] = token
525
+ elif token is False:
526
+ self.huggingface_token = None
527
+ else:
528
+ # if `True` -> explicit use of the cached token
529
+ # if `None` -> implicit use of the cached token
530
+ self.huggingface_token = get_token()
531
+
532
+ if clone_from is not None:
533
+ self.clone_from(repo_url=clone_from)
534
+ else:
535
+ if is_git_repo(self.local_dir):
536
+ logger.debug("[Repository] is a valid git repo")
537
+ else:
538
+ raise ValueError("If not specifying `clone_from`, you need to pass Repository a valid git clone.")
539
+
540
+ if self.huggingface_token is not None and (git_email is None or git_user is None):
541
+ user = self.client.whoami(self.huggingface_token)
542
+
543
+ if git_email is None:
544
+ git_email = user["email"]
545
+
546
+ if git_user is None:
547
+ git_user = user["fullname"]
548
+
549
+ if git_user is not None or git_email is not None:
550
+ self.git_config_username_and_email(git_user, git_email)
551
+
552
+ self.lfs_enable_largefiles()
553
+ self.git_credential_helper_store()
554
+
555
+ if revision is not None:
556
+ self.git_checkout(revision, create_branch_ok=True)
557
+
558
+ # This ensures that all commands exit before exiting the Python runtime.
559
+ # This will ensure all pushes register on the hub, even if other errors happen in subsequent operations.
560
+ atexit.register(self.wait_for_commands)
561
+
562
+ @property
563
+ def current_branch(self) -> str:
564
+ """
565
+ Returns the current checked out branch.
566
+
567
+ Returns:
568
+ `str`: Current checked out branch.
569
+ """
570
+ try:
571
+ result = run_subprocess("git rev-parse --abbrev-ref HEAD", self.local_dir).stdout.strip()
572
+ except subprocess.CalledProcessError as exc:
573
+ raise EnvironmentError(exc.stderr)
574
+
575
+ return result
576
+
577
+ def check_git_versions(self):
578
+ """
579
+ Checks that `git` and `git-lfs` can be run.
580
+
581
+ Raises:
582
+ - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
583
+ if `git` or `git-lfs` are not installed.
584
+ """
585
+ try:
586
+ git_version = run_subprocess("git --version", self.local_dir).stdout.strip()
587
+ except FileNotFoundError:
588
+ raise EnvironmentError("Looks like you do not have git installed, please install.")
589
+
590
+ try:
591
+ lfs_version = run_subprocess("git-lfs --version", self.local_dir).stdout.strip()
592
+ except FileNotFoundError:
593
+ raise EnvironmentError(
594
+ "Looks like you do not have git-lfs installed, please install."
595
+ " You can install from https://git-lfs.github.com/."
596
+ " Then run `git lfs install` (you only have to do this once)."
597
+ )
598
+ logger.info(git_version + "\n" + lfs_version)
599
+
600
+ @validate_hf_hub_args
601
+ def clone_from(self, repo_url: str, token: Union[bool, str, None] = None):
602
+ """
603
+ Clone from a remote. If the folder already exists, will try to clone the
604
+ repository within it.
605
+
606
+ If this folder is a git repository with linked history, will try to
607
+ update the repository.
608
+
609
+ Args:
610
+ repo_url (`str`):
611
+ The URL from which to clone the repository
612
+ token (`Union[str, bool]`, *optional*):
613
+ Whether to use the authentication token. It can be:
614
+ - a string which is the token itself
615
+ - `False`, which would not use the authentication token
616
+ - `True`, which would fetch the authentication token from the
617
+ local folder and use it (you should be logged in for this to
618
+ work).
619
+ - `None`, which would retrieve the value of
620
+ `self.huggingface_token`.
621
+
622
+ <Tip>
623
+
624
+ Raises the following error:
625
+
626
+ - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
627
+ if an organization token (starts with "api_org") is passed. Use must use
628
+ your own personal access token (see https://hf.co/settings/tokens).
629
+
630
+ - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
631
+ if you are trying to clone the repository in a non-empty folder, or if the
632
+ `git` operations raise errors.
633
+
634
+ </Tip>
635
+ """
636
+ token = (
637
+ token # str -> use it
638
+ if isinstance(token, str)
639
+ else (
640
+ None # `False` -> explicit no token
641
+ if token is False
642
+ else self.huggingface_token # `None` or `True` -> use default
643
+ )
644
+ )
645
+ if token is not None and token.startswith("api_org"):
646
+ raise ValueError(
647
+ "You must use your personal access token, not an Organization token"
648
+ " (see https://hf.co/settings/tokens)."
649
+ )
650
+
651
+ hub_url = self.client.endpoint
652
+ if hub_url in repo_url or ("http" not in repo_url and len(repo_url.split("/")) <= 2):
653
+ repo_type, namespace, repo_name = repo_type_and_id_from_hf_id(repo_url, hub_url=hub_url)
654
+ repo_id = f"{namespace}/{repo_name}" if namespace is not None else repo_name
655
+
656
+ if repo_type is not None:
657
+ self._repo_type = repo_type
658
+
659
+ repo_url = hub_url + "/"
660
+
661
+ if self._repo_type in REPO_TYPES_URL_PREFIXES:
662
+ repo_url += REPO_TYPES_URL_PREFIXES[self._repo_type]
663
+
664
+ if token is not None:
665
+ # Add token in git url when provided
666
+ scheme = urlparse(repo_url).scheme
667
+ repo_url = repo_url.replace(f"{scheme}://", f"{scheme}://user:{token}@")
668
+
669
+ repo_url += repo_id
670
+
671
+ # For error messages, it's cleaner to show the repo url without the token.
672
+ clean_repo_url = re.sub(r"(https?)://.*@", r"\1://", repo_url)
673
+ try:
674
+ run_subprocess("git lfs install", self.local_dir)
675
+
676
+ # checks if repository is initialized in a empty repository or in one with files
677
+ if len(os.listdir(self.local_dir)) == 0:
678
+ logger.warning(f"Cloning {clean_repo_url} into local empty directory.")
679
+
680
+ with _lfs_log_progress():
681
+ env = os.environ.copy()
682
+
683
+ if self.skip_lfs_files:
684
+ env.update({"GIT_LFS_SKIP_SMUDGE": "1"})
685
+
686
+ run_subprocess(
687
+ # 'git lfs clone' is deprecated (will display a warning in the terminal)
688
+ # but we still use it as it provides a nicer UX when downloading large
689
+ # files (shows progress).
690
+ f"{'git clone' if self.skip_lfs_files else 'git lfs clone'} {repo_url} .",
691
+ self.local_dir,
692
+ env=env,
693
+ )
694
+ else:
695
+ # Check if the folder is the root of a git repository
696
+ if not is_git_repo(self.local_dir):
697
+ raise EnvironmentError(
698
+ "Tried to clone a repository in a non-empty folder that isn't"
699
+ f" a git repository ('{self.local_dir}'). If you really want to"
700
+ f" do this, do it manually:\n cd {self.local_dir} && git init"
701
+ " && git remote add origin && git pull origin main\n or clone"
702
+ " repo to a new folder and move your existing files there"
703
+ " afterwards."
704
+ )
705
+
706
+ if is_local_clone(self.local_dir, repo_url):
707
+ logger.warning(
708
+ f"{self.local_dir} is already a clone of {clean_repo_url}."
709
+ " Make sure you pull the latest changes with"
710
+ " `repo.git_pull()`."
711
+ )
712
+ else:
713
+ output = run_subprocess("git remote get-url origin", self.local_dir, check=False)
714
+
715
+ error_msg = (
716
+ f"Tried to clone {clean_repo_url} in an unrelated git"
717
+ " repository.\nIf you believe this is an error, please add"
718
+ f" a remote with the following URL: {clean_repo_url}."
719
+ )
720
+ if output.returncode == 0:
721
+ clean_local_remote_url = re.sub(r"https://.*@", "https://", output.stdout)
722
+ error_msg += f"\nLocal path has its origin defined as: {clean_local_remote_url}"
723
+ raise EnvironmentError(error_msg)
724
+
725
+ except subprocess.CalledProcessError as exc:
726
+ raise EnvironmentError(exc.stderr)
727
+
728
+ def git_config_username_and_email(self, git_user: Optional[str] = None, git_email: Optional[str] = None):
729
+ """
730
+ Sets git username and email (only in the current repo).
731
+
732
+ Args:
733
+ git_user (`str`, *optional*):
734
+ The username to register through `git`.
735
+ git_email (`str`, *optional*):
736
+ The email to register through `git`.
737
+ """
738
+ try:
739
+ if git_user is not None:
740
+ run_subprocess("git config user.name".split() + [git_user], self.local_dir)
741
+
742
+ if git_email is not None:
743
+ run_subprocess(f"git config user.email {git_email}".split(), self.local_dir)
744
+ except subprocess.CalledProcessError as exc:
745
+ raise EnvironmentError(exc.stderr)
746
+
747
+ def git_credential_helper_store(self):
748
+ """
749
+ Sets the git credential helper to `store`
750
+ """
751
+ try:
752
+ run_subprocess("git config credential.helper store", self.local_dir)
753
+ except subprocess.CalledProcessError as exc:
754
+ raise EnvironmentError(exc.stderr)
755
+
756
+ def git_head_hash(self) -> str:
757
+ """
758
+ Get commit sha on top of HEAD.
759
+
760
+ Returns:
761
+ `str`: The current checked out commit SHA.
762
+ """
763
+ try:
764
+ p = run_subprocess("git rev-parse HEAD", self.local_dir)
765
+ return p.stdout.strip()
766
+ except subprocess.CalledProcessError as exc:
767
+ raise EnvironmentError(exc.stderr)
768
+
769
+ def git_remote_url(self) -> str:
770
+ """
771
+ Get URL to origin remote.
772
+
773
+ Returns:
774
+ `str`: The URL of the `origin` remote.
775
+ """
776
+ try:
777
+ p = run_subprocess("git config --get remote.origin.url", self.local_dir)
778
+ url = p.stdout.strip()
779
+ # Strip basic auth info.
780
+ return re.sub(r"https://.*@", "https://", url)
781
+ except subprocess.CalledProcessError as exc:
782
+ raise EnvironmentError(exc.stderr)
783
+
784
+ def git_head_commit_url(self) -> str:
785
+ """
786
+ Get URL to last commit on HEAD. We assume it's been pushed, and the url
787
+ scheme is the same one as for GitHub or HuggingFace.
788
+
789
+ Returns:
790
+ `str`: The URL to the current checked-out commit.
791
+ """
792
+ sha = self.git_head_hash()
793
+ url = self.git_remote_url()
794
+ if url.endswith("/"):
795
+ url = url[:-1]
796
+ return f"{url}/commit/{sha}"
797
+
798
+ def list_deleted_files(self) -> List[str]:
799
+ """
800
+ Returns a list of the files that are deleted in the working directory or
801
+ index.
802
+
803
+ Returns:
804
+ `List[str]`: A list of files that have been deleted in the working
805
+ directory or index.
806
+ """
807
+ try:
808
+ git_status = run_subprocess("git status -s", self.local_dir).stdout.strip()
809
+ except subprocess.CalledProcessError as exc:
810
+ raise EnvironmentError(exc.stderr)
811
+
812
+ if len(git_status) == 0:
813
+ return []
814
+
815
+ # Receives a status like the following
816
+ # D .gitignore
817
+ # D new_file.json
818
+ # AD new_file1.json
819
+ # ?? new_file2.json
820
+ # ?? new_file4.json
821
+
822
+ # Strip each line of whitespaces
823
+ modified_files_statuses = [status.strip() for status in git_status.split("\n")]
824
+
825
+ # Only keep files that are deleted using the D prefix
826
+ deleted_files_statuses = [status for status in modified_files_statuses if "D" in status.split()[0]]
827
+
828
+ # Remove the D prefix and strip to keep only the relevant filename
829
+ deleted_files = [status.split()[-1].strip() for status in deleted_files_statuses]
830
+
831
+ return deleted_files
832
+
833
+ def lfs_track(self, patterns: Union[str, List[str]], filename: bool = False):
834
+ """
835
+ Tell git-lfs to track files according to a pattern.
836
+
837
+ Setting the `filename` argument to `True` will treat the arguments as
838
+ literal filenames, not as patterns. Any special glob characters in the
839
+ filename will be escaped when writing to the `.gitattributes` file.
840
+
841
+ Args:
842
+ patterns (`Union[str, List[str]]`):
843
+ The pattern, or list of patterns, to track with git-lfs.
844
+ filename (`bool`, *optional*, defaults to `False`):
845
+ Whether to use the patterns as literal filenames.
846
+ """
847
+ if isinstance(patterns, str):
848
+ patterns = [patterns]
849
+ try:
850
+ for pattern in patterns:
851
+ run_subprocess(
852
+ f"git lfs track {'--filename' if filename else ''} {pattern}",
853
+ self.local_dir,
854
+ )
855
+ except subprocess.CalledProcessError as exc:
856
+ raise EnvironmentError(exc.stderr)
857
+
858
+ def lfs_untrack(self, patterns: Union[str, List[str]]):
859
+ """
860
+ Tell git-lfs to untrack those files.
861
+
862
+ Args:
863
+ patterns (`Union[str, List[str]]`):
864
+ The pattern, or list of patterns, to untrack with git-lfs.
865
+ """
866
+ if isinstance(patterns, str):
867
+ patterns = [patterns]
868
+ try:
869
+ for pattern in patterns:
870
+ run_subprocess("git lfs untrack".split() + [pattern], self.local_dir)
871
+ except subprocess.CalledProcessError as exc:
872
+ raise EnvironmentError(exc.stderr)
873
+
874
+ def lfs_enable_largefiles(self):
875
+ """
876
+ HF-specific. This enables upload support of files >5GB.
877
+ """
878
+ try:
879
+ lfs_config = "git config lfs.customtransfer.multipart"
880
+ run_subprocess(f"{lfs_config}.path huggingface-cli", self.local_dir)
881
+ run_subprocess(
882
+ f"{lfs_config}.args {LFS_MULTIPART_UPLOAD_COMMAND}",
883
+ self.local_dir,
884
+ )
885
+ except subprocess.CalledProcessError as exc:
886
+ raise EnvironmentError(exc.stderr)
887
+
888
+ def auto_track_binary_files(self, pattern: str = ".") -> List[str]:
889
+ """
890
+ Automatically track binary files with git-lfs.
891
+
892
+ Args:
893
+ pattern (`str`, *optional*, defaults to "."):
894
+ The pattern with which to track files that are binary.
895
+
896
+ Returns:
897
+ `List[str]`: List of filenames that are now tracked due to being
898
+ binary files
899
+ """
900
+ files_to_be_tracked_with_lfs = []
901
+
902
+ deleted_files = self.list_deleted_files()
903
+
904
+ for filename in files_to_be_staged(pattern, folder=self.local_dir):
905
+ if filename in deleted_files:
906
+ continue
907
+
908
+ path_to_file = os.path.join(os.getcwd(), self.local_dir, filename)
909
+
910
+ if not (is_tracked_with_lfs(path_to_file) or is_git_ignored(path_to_file)):
911
+ size_in_mb = os.path.getsize(path_to_file) / (1024 * 1024)
912
+
913
+ if size_in_mb >= 10:
914
+ logger.warning(
915
+ "Parsing a large file to check if binary or not. Tracking large"
916
+ " files using `repository.auto_track_large_files` is"
917
+ " recommended so as to not load the full file in memory."
918
+ )
919
+
920
+ is_binary = is_binary_file(path_to_file)
921
+
922
+ if is_binary:
923
+ self.lfs_track(filename)
924
+ files_to_be_tracked_with_lfs.append(filename)
925
+
926
+ # Cleanup the .gitattributes if files were deleted
927
+ self.lfs_untrack(deleted_files)
928
+
929
+ return files_to_be_tracked_with_lfs
930
+
931
+ def auto_track_large_files(self, pattern: str = ".") -> List[str]:
932
+ """
933
+ Automatically track large files (files that weigh more than 10MBs) with
934
+ git-lfs.
935
+
936
+ Args:
937
+ pattern (`str`, *optional*, defaults to "."):
938
+ The pattern with which to track files that are above 10MBs.
939
+
940
+ Returns:
941
+ `List[str]`: List of filenames that are now tracked due to their
942
+ size.
943
+ """
944
+ files_to_be_tracked_with_lfs = []
945
+
946
+ deleted_files = self.list_deleted_files()
947
+
948
+ for filename in files_to_be_staged(pattern, folder=self.local_dir):
949
+ if filename in deleted_files:
950
+ continue
951
+
952
+ path_to_file = os.path.join(os.getcwd(), self.local_dir, filename)
953
+ size_in_mb = os.path.getsize(path_to_file) / (1024 * 1024)
954
+
955
+ if size_in_mb >= 10 and not is_tracked_with_lfs(path_to_file) and not is_git_ignored(path_to_file):
956
+ self.lfs_track(filename)
957
+ files_to_be_tracked_with_lfs.append(filename)
958
+
959
+ # Cleanup the .gitattributes if files were deleted
960
+ self.lfs_untrack(deleted_files)
961
+
962
+ return files_to_be_tracked_with_lfs
963
+
964
+ def lfs_prune(self, recent=False):
965
+ """
966
+ git lfs prune
967
+
968
+ Args:
969
+ recent (`bool`, *optional*, defaults to `False`):
970
+ Whether to prune files even if they were referenced by recent
971
+ commits. See the following
972
+ [link](https://github.com/git-lfs/git-lfs/blob/f3d43f0428a84fc4f1e5405b76b5a73ec2437e65/docs/man/git-lfs-prune.1.ronn#recent-files)
973
+ for more information.
974
+ """
975
+ try:
976
+ with _lfs_log_progress():
977
+ result = run_subprocess(f"git lfs prune {'--recent' if recent else ''}", self.local_dir)
978
+ logger.info(result.stdout)
979
+ except subprocess.CalledProcessError as exc:
980
+ raise EnvironmentError(exc.stderr)
981
+
982
+ def git_pull(self, rebase: bool = False, lfs: bool = False):
983
+ """
984
+ git pull
985
+
986
+ Args:
987
+ rebase (`bool`, *optional*, defaults to `False`):
988
+ Whether to rebase the current branch on top of the upstream
989
+ branch after fetching.
990
+ lfs (`bool`, *optional*, defaults to `False`):
991
+ Whether to fetch the LFS files too. This option only changes the
992
+ behavior when a repository was cloned without fetching the LFS
993
+ files; calling `repo.git_pull(lfs=True)` will then fetch the LFS
994
+ file from the remote repository.
995
+ """
996
+ command = "git pull" if not lfs else "git lfs pull"
997
+ if rebase:
998
+ command += " --rebase"
999
+ try:
1000
+ with _lfs_log_progress():
1001
+ result = run_subprocess(command, self.local_dir)
1002
+ logger.info(result.stdout)
1003
+ except subprocess.CalledProcessError as exc:
1004
+ raise EnvironmentError(exc.stderr)
1005
+
1006
+ def git_add(self, pattern: str = ".", auto_lfs_track: bool = False):
1007
+ """
1008
+ git add
1009
+
1010
+ Setting the `auto_lfs_track` parameter to `True` will automatically
1011
+ track files that are larger than 10MB with `git-lfs`.
1012
+
1013
+ Args:
1014
+ pattern (`str`, *optional*, defaults to "."):
1015
+ The pattern with which to add files to staging.
1016
+ auto_lfs_track (`bool`, *optional*, defaults to `False`):
1017
+ Whether to automatically track large and binary files with
1018
+ git-lfs. Any file over 10MB in size, or in binary format, will
1019
+ be automatically tracked.
1020
+ """
1021
+ if auto_lfs_track:
1022
+ # Track files according to their size (>=10MB)
1023
+ tracked_files = self.auto_track_large_files(pattern)
1024
+
1025
+ # Read the remaining files and track them if they're binary
1026
+ tracked_files.extend(self.auto_track_binary_files(pattern))
1027
+
1028
+ if tracked_files:
1029
+ logger.warning(
1030
+ f"Adding files tracked by Git LFS: {tracked_files}. This may take a"
1031
+ " bit of time if the files are large."
1032
+ )
1033
+
1034
+ try:
1035
+ result = run_subprocess("git add -v".split() + [pattern], self.local_dir)
1036
+ logger.info(f"Adding to index:\n{result.stdout}\n")
1037
+ except subprocess.CalledProcessError as exc:
1038
+ raise EnvironmentError(exc.stderr)
1039
+
1040
+ def git_commit(self, commit_message: str = "commit files to HF hub"):
1041
+ """
1042
+ git commit
1043
+
1044
+ Args:
1045
+ commit_message (`str`, *optional*, defaults to "commit files to HF hub"):
1046
+ The message attributed to the commit.
1047
+ """
1048
+ try:
1049
+ result = run_subprocess("git commit -v -m".split() + [commit_message], self.local_dir)
1050
+ logger.info(f"Committed:\n{result.stdout}\n")
1051
+ except subprocess.CalledProcessError as exc:
1052
+ if len(exc.stderr) > 0:
1053
+ raise EnvironmentError(exc.stderr)
1054
+ else:
1055
+ raise EnvironmentError(exc.stdout)
1056
+
1057
+ def git_push(
1058
+ self,
1059
+ upstream: Optional[str] = None,
1060
+ blocking: bool = True,
1061
+ auto_lfs_prune: bool = False,
1062
+ ) -> Union[str, Tuple[str, CommandInProgress]]:
1063
+ """
1064
+ git push
1065
+
1066
+ If used without setting `blocking`, will return url to commit on remote
1067
+ repo. If used with `blocking=True`, will return a tuple containing the
1068
+ url to commit and the command object to follow for information about the
1069
+ process.
1070
+
1071
+ Args:
1072
+ upstream (`str`, *optional*):
1073
+ Upstream to which this should push. If not specified, will push
1074
+ to the lastly defined upstream or to the default one (`origin
1075
+ main`).
1076
+ blocking (`bool`, *optional*, defaults to `True`):
1077
+ Whether the function should return only when the push has
1078
+ finished. Setting this to `False` will return an
1079
+ `CommandInProgress` object which has an `is_done` property. This
1080
+ property will be set to `True` when the push is finished.
1081
+ auto_lfs_prune (`bool`, *optional*, defaults to `False`):
1082
+ Whether to automatically prune files once they have been pushed
1083
+ to the remote.
1084
+ """
1085
+ command = "git push"
1086
+
1087
+ if upstream:
1088
+ command += f" --set-upstream {upstream}"
1089
+
1090
+ number_of_commits = commits_to_push(self.local_dir, upstream)
1091
+
1092
+ if number_of_commits > 1:
1093
+ logger.warning(f"Several commits ({number_of_commits}) will be pushed upstream.")
1094
+ if blocking:
1095
+ logger.warning("The progress bars may be unreliable.")
1096
+
1097
+ try:
1098
+ with _lfs_log_progress():
1099
+ process = subprocess.Popen(
1100
+ command.split(),
1101
+ stderr=subprocess.PIPE,
1102
+ stdout=subprocess.PIPE,
1103
+ encoding="utf-8",
1104
+ cwd=self.local_dir,
1105
+ )
1106
+
1107
+ if blocking:
1108
+ stdout, stderr = process.communicate()
1109
+ return_code = process.poll()
1110
+ process.kill()
1111
+
1112
+ if len(stderr):
1113
+ logger.warning(stderr)
1114
+
1115
+ if return_code:
1116
+ raise subprocess.CalledProcessError(return_code, process.args, output=stdout, stderr=stderr)
1117
+
1118
+ except subprocess.CalledProcessError as exc:
1119
+ raise EnvironmentError(exc.stderr)
1120
+
1121
+ if not blocking:
1122
+
1123
+ def status_method():
1124
+ status = process.poll()
1125
+ if status is None:
1126
+ return -1
1127
+ else:
1128
+ return status
1129
+
1130
+ command_in_progress = CommandInProgress(
1131
+ "push",
1132
+ is_done_method=lambda: process.poll() is not None,
1133
+ status_method=status_method,
1134
+ process=process,
1135
+ post_method=self.lfs_prune if auto_lfs_prune else None,
1136
+ )
1137
+
1138
+ self.command_queue.append(command_in_progress)
1139
+
1140
+ return self.git_head_commit_url(), command_in_progress
1141
+
1142
+ if auto_lfs_prune:
1143
+ self.lfs_prune()
1144
+
1145
+ return self.git_head_commit_url()
1146
+
1147
+ def git_checkout(self, revision: str, create_branch_ok: bool = False):
1148
+ """
1149
+ git checkout a given revision
1150
+
1151
+ Specifying `create_branch_ok` to `True` will create the branch to the
1152
+ given revision if that revision doesn't exist.
1153
+
1154
+ Args:
1155
+ revision (`str`):
1156
+ The revision to checkout.
1157
+ create_branch_ok (`str`, *optional*, defaults to `False`):
1158
+ Whether creating a branch named with the `revision` passed at
1159
+ the current checked-out reference if `revision` isn't an
1160
+ existing revision is allowed.
1161
+ """
1162
+ try:
1163
+ result = run_subprocess(f"git checkout {revision}", self.local_dir)
1164
+ logger.warning(f"Checked out {revision} from {self.current_branch}.")
1165
+ logger.warning(result.stdout)
1166
+ except subprocess.CalledProcessError as exc:
1167
+ if not create_branch_ok:
1168
+ raise EnvironmentError(exc.stderr)
1169
+ else:
1170
+ try:
1171
+ result = run_subprocess(f"git checkout -b {revision}", self.local_dir)
1172
+ logger.warning(
1173
+ f"Revision `{revision}` does not exist. Created and checked out branch `{revision}`."
1174
+ )
1175
+ logger.warning(result.stdout)
1176
+ except subprocess.CalledProcessError as exc:
1177
+ raise EnvironmentError(exc.stderr)
1178
+
1179
+ def tag_exists(self, tag_name: str, remote: Optional[str] = None) -> bool:
1180
+ """
1181
+ Check if a tag exists or not.
1182
+
1183
+ Args:
1184
+ tag_name (`str`):
1185
+ The name of the tag to check.
1186
+ remote (`str`, *optional*):
1187
+ Whether to check if the tag exists on a remote. This parameter
1188
+ should be the identifier of the remote.
1189
+
1190
+ Returns:
1191
+ `bool`: Whether the tag exists.
1192
+ """
1193
+ if remote:
1194
+ try:
1195
+ result = run_subprocess(f"git ls-remote origin refs/tags/{tag_name}", self.local_dir).stdout.strip()
1196
+ except subprocess.CalledProcessError as exc:
1197
+ raise EnvironmentError(exc.stderr)
1198
+
1199
+ return len(result) != 0
1200
+ else:
1201
+ try:
1202
+ git_tags = run_subprocess("git tag", self.local_dir).stdout.strip()
1203
+ except subprocess.CalledProcessError as exc:
1204
+ raise EnvironmentError(exc.stderr)
1205
+
1206
+ git_tags = git_tags.split("\n")
1207
+ return tag_name in git_tags
1208
+
1209
+ def delete_tag(self, tag_name: str, remote: Optional[str] = None) -> bool:
1210
+ """
1211
+ Delete a tag, both local and remote, if it exists
1212
+
1213
+ Args:
1214
+ tag_name (`str`):
1215
+ The tag name to delete.
1216
+ remote (`str`, *optional*):
1217
+ The remote on which to delete the tag.
1218
+
1219
+ Returns:
1220
+ `bool`: `True` if deleted, `False` if the tag didn't exist.
1221
+ If remote is not passed, will just be updated locally
1222
+ """
1223
+ delete_locally = True
1224
+ delete_remotely = True
1225
+
1226
+ if not self.tag_exists(tag_name):
1227
+ delete_locally = False
1228
+
1229
+ if not self.tag_exists(tag_name, remote=remote):
1230
+ delete_remotely = False
1231
+
1232
+ if delete_locally:
1233
+ try:
1234
+ run_subprocess(["git", "tag", "-d", tag_name], self.local_dir).stdout.strip()
1235
+ except subprocess.CalledProcessError as exc:
1236
+ raise EnvironmentError(exc.stderr)
1237
+
1238
+ if remote and delete_remotely:
1239
+ try:
1240
+ run_subprocess(f"git push {remote} --delete {tag_name}", self.local_dir).stdout.strip()
1241
+ except subprocess.CalledProcessError as exc:
1242
+ raise EnvironmentError(exc.stderr)
1243
+
1244
+ return True
1245
+
1246
+ def add_tag(self, tag_name: str, message: Optional[str] = None, remote: Optional[str] = None):
1247
+ """
1248
+ Add a tag at the current head and push it
1249
+
1250
+ If remote is None, will just be updated locally
1251
+
1252
+ If no message is provided, the tag will be lightweight. if a message is
1253
+ provided, the tag will be annotated.
1254
+
1255
+ Args:
1256
+ tag_name (`str`):
1257
+ The name of the tag to be added.
1258
+ message (`str`, *optional*):
1259
+ The message that accompanies the tag. The tag will turn into an
1260
+ annotated tag if a message is passed.
1261
+ remote (`str`, *optional*):
1262
+ The remote on which to add the tag.
1263
+ """
1264
+ if message:
1265
+ tag_args = ["git", "tag", "-a", tag_name, "-m", message]
1266
+ else:
1267
+ tag_args = ["git", "tag", tag_name]
1268
+
1269
+ try:
1270
+ run_subprocess(tag_args, self.local_dir).stdout.strip()
1271
+ except subprocess.CalledProcessError as exc:
1272
+ raise EnvironmentError(exc.stderr)
1273
+
1274
+ if remote:
1275
+ try:
1276
+ run_subprocess(f"git push {remote} {tag_name}", self.local_dir).stdout.strip()
1277
+ except subprocess.CalledProcessError as exc:
1278
+ raise EnvironmentError(exc.stderr)
1279
+
1280
+ def is_repo_clean(self) -> bool:
1281
+ """
1282
+ Return whether or not the git status is clean or not
1283
+
1284
+ Returns:
1285
+ `bool`: `True` if the git status is clean, `False` otherwise.
1286
+ """
1287
+ try:
1288
+ git_status = run_subprocess("git status --porcelain", self.local_dir).stdout.strip()
1289
+ except subprocess.CalledProcessError as exc:
1290
+ raise EnvironmentError(exc.stderr)
1291
+
1292
+ return len(git_status) == 0
1293
+
1294
+ def push_to_hub(
1295
+ self,
1296
+ commit_message: str = "commit files to HF hub",
1297
+ blocking: bool = True,
1298
+ clean_ok: bool = True,
1299
+ auto_lfs_prune: bool = False,
1300
+ ) -> Union[None, str, Tuple[str, CommandInProgress]]:
1301
+ """
1302
+ Helper to add, commit, and push files to remote repository on the
1303
+ HuggingFace Hub. Will automatically track large files (>10MB).
1304
+
1305
+ Args:
1306
+ commit_message (`str`):
1307
+ Message to use for the commit.
1308
+ blocking (`bool`, *optional*, defaults to `True`):
1309
+ Whether the function should return only when the `git push` has
1310
+ finished.
1311
+ clean_ok (`bool`, *optional*, defaults to `True`):
1312
+ If True, this function will return None if the repo is
1313
+ untouched. Default behavior is to fail because the git command
1314
+ fails.
1315
+ auto_lfs_prune (`bool`, *optional*, defaults to `False`):
1316
+ Whether to automatically prune files once they have been pushed
1317
+ to the remote.
1318
+ """
1319
+ if clean_ok and self.is_repo_clean():
1320
+ logger.info("Repo currently clean. Ignoring push_to_hub")
1321
+ return None
1322
+ self.git_add(auto_lfs_track=True)
1323
+ self.git_commit(commit_message)
1324
+ return self.git_push(
1325
+ upstream=f"origin {self.current_branch}",
1326
+ blocking=blocking,
1327
+ auto_lfs_prune=auto_lfs_prune,
1328
+ )
1329
+
1330
+ @contextmanager
1331
+ def commit(
1332
+ self,
1333
+ commit_message: str,
1334
+ branch: Optional[str] = None,
1335
+ track_large_files: bool = True,
1336
+ blocking: bool = True,
1337
+ auto_lfs_prune: bool = False,
1338
+ ):
1339
+ """
1340
+ Context manager utility to handle committing to a repository. This
1341
+ automatically tracks large files (>10Mb) with git-lfs. Set the
1342
+ `track_large_files` argument to `False` if you wish to ignore that
1343
+ behavior.
1344
+
1345
+ Args:
1346
+ commit_message (`str`):
1347
+ Message to use for the commit.
1348
+ branch (`str`, *optional*):
1349
+ The branch on which the commit will appear. This branch will be
1350
+ checked-out before any operation.
1351
+ track_large_files (`bool`, *optional*, defaults to `True`):
1352
+ Whether to automatically track large files or not. Will do so by
1353
+ default.
1354
+ blocking (`bool`, *optional*, defaults to `True`):
1355
+ Whether the function should return only when the `git push` has
1356
+ finished.
1357
+ auto_lfs_prune (`bool`, defaults to `True`):
1358
+ Whether to automatically prune files once they have been pushed
1359
+ to the remote.
1360
+
1361
+ Examples:
1362
+
1363
+ ```python
1364
+ >>> with Repository(
1365
+ ... "text-files",
1366
+ ... clone_from="<user>/text-files",
1367
+ ... token=True,
1368
+ >>> ).commit("My first file :)"):
1369
+ ... with open("file.txt", "w+") as f:
1370
+ ... f.write(json.dumps({"hey": 8}))
1371
+
1372
+ >>> import torch
1373
+
1374
+ >>> model = torch.nn.Transformer()
1375
+ >>> with Repository(
1376
+ ... "torch-model",
1377
+ ... clone_from="<user>/torch-model",
1378
+ ... token=True,
1379
+ >>> ).commit("My cool model :)"):
1380
+ ... torch.save(model.state_dict(), "model.pt")
1381
+ ```
1382
+
1383
+ """
1384
+
1385
+ files_to_stage = files_to_be_staged(".", folder=self.local_dir)
1386
+
1387
+ if len(files_to_stage):
1388
+ files_in_msg = str(files_to_stage[:5])[:-1] + ", ...]" if len(files_to_stage) > 5 else str(files_to_stage)
1389
+ logger.error(
1390
+ "There exists some updated files in the local repository that are not"
1391
+ f" committed: {files_in_msg}. This may lead to errors if checking out"
1392
+ " a branch. These files and their modifications will be added to the"
1393
+ " current commit."
1394
+ )
1395
+
1396
+ if branch is not None:
1397
+ self.git_checkout(branch, create_branch_ok=True)
1398
+
1399
+ if is_tracked_upstream(self.local_dir):
1400
+ logger.warning("Pulling changes ...")
1401
+ self.git_pull(rebase=True)
1402
+ else:
1403
+ logger.warning(f"The current branch has no upstream branch. Will push to 'origin {self.current_branch}'")
1404
+
1405
+ current_working_directory = os.getcwd()
1406
+ os.chdir(os.path.join(current_working_directory, self.local_dir))
1407
+
1408
+ try:
1409
+ yield self
1410
+ finally:
1411
+ self.git_add(auto_lfs_track=track_large_files)
1412
+
1413
+ try:
1414
+ self.git_commit(commit_message)
1415
+ except OSError as e:
1416
+ # If no changes are detected, there is nothing to commit.
1417
+ if "nothing to commit" not in str(e):
1418
+ raise e
1419
+
1420
+ try:
1421
+ self.git_push(
1422
+ upstream=f"origin {self.current_branch}",
1423
+ blocking=blocking,
1424
+ auto_lfs_prune=auto_lfs_prune,
1425
+ )
1426
+ except OSError as e:
1427
+ # If no changes are detected, there is nothing to commit.
1428
+ if "could not read Username" in str(e):
1429
+ raise OSError("Couldn't authenticate user for push. Did you set `token` to `True`?") from e
1430
+ else:
1431
+ raise e
1432
+
1433
+ os.chdir(current_working_directory)
1434
+
1435
+ def repocard_metadata_load(self) -> Optional[Dict]:
1436
+ filepath = os.path.join(self.local_dir, REPOCARD_NAME)
1437
+ if os.path.isfile(filepath):
1438
+ return metadata_load(filepath)
1439
+ return None
1440
+
1441
+ def repocard_metadata_save(self, data: Dict) -> None:
1442
+ return metadata_save(os.path.join(self.local_dir, REPOCARD_NAME), data)
1443
+
1444
+ @property
1445
+ def commands_failed(self):
1446
+ """
1447
+ Returns the asynchronous commands that failed.
1448
+ """
1449
+ return [c for c in self.command_queue if c.status > 0]
1450
+
1451
+ @property
1452
+ def commands_in_progress(self):
1453
+ """
1454
+ Returns the asynchronous commands that are currently in progress.
1455
+ """
1456
+ return [c for c in self.command_queue if not c.is_done]
1457
+
1458
+ def wait_for_commands(self):
1459
+ """
1460
+ Blocking method: blocks all subsequent execution until all commands have
1461
+ been processed.
1462
+ """
1463
+ index = 0
1464
+ for command_failed in self.commands_failed:
1465
+ logger.error(f"The {command_failed.title} command with PID {command_failed._process.pid} failed.")
1466
+ logger.error(command_failed.stderr)
1467
+
1468
+ while self.commands_in_progress:
1469
+ if index % 10 == 0:
1470
+ logger.warning(
1471
+ f"Waiting for the following commands to finish before shutting down: {self.commands_in_progress}."
1472
+ )
1473
+
1474
+ index += 1
1475
+
1476
+ time.sleep(1)