koichi12 commited on
Commit
5f815af
·
verified ·
1 Parent(s): 9d4bc92

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/transformers/__pycache__/image_processing_utils.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/transformers/__pycache__/image_transforms.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/transformers/__pycache__/modelcard.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/transformers/__pycache__/training_args_tf.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/__init__.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/beam_constraints.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/beam_search.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/transformers/pipelines/__init__.py +1178 -0
  9. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/__init__.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/audio_classification.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/audio_utils.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/automatic_speech_recognition.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/base.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/depth_estimation.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/document_question_answering.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/feature_extraction.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/fill_mask.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_classification.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_feature_extraction.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_segmentation.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_text_to_text.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_to_image.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_to_text.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/mask_generation.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/object_detection.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/pt_utils.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/question_answering.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/table_question_answering.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/text2text_generation.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/text_classification.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/text_generation.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/text_to_audio.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/token_classification.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/video_classification.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/visual_question_answering.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/zero_shot_audio_classification.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/zero_shot_classification.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/zero_shot_image_classification.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/zero_shot_object_detection.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/transformers/pipelines/audio_classification.py +234 -0
  41. .venv/lib/python3.11/site-packages/transformers/pipelines/audio_utils.py +297 -0
  42. .venv/lib/python3.11/site-packages/transformers/pipelines/automatic_speech_recognition.py +766 -0
  43. .venv/lib/python3.11/site-packages/transformers/pipelines/base.py +1484 -0
  44. .venv/lib/python3.11/site-packages/transformers/pipelines/depth_estimation.py +133 -0
  45. .venv/lib/python3.11/site-packages/transformers/pipelines/document_question_answering.py +516 -0
  46. .venv/lib/python3.11/site-packages/transformers/pipelines/feature_extraction.py +86 -0
  47. .venv/lib/python3.11/site-packages/transformers/pipelines/fill_mask.py +273 -0
  48. .venv/lib/python3.11/site-packages/transformers/pipelines/image_classification.py +226 -0
  49. .venv/lib/python3.11/site-packages/transformers/pipelines/image_feature_extraction.py +112 -0
  50. .venv/lib/python3.11/site-packages/transformers/pipelines/image_segmentation.py +220 -0
.venv/lib/python3.11/site-packages/transformers/__pycache__/image_processing_utils.cpython-311.pyc ADDED
Binary file (14.2 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/__pycache__/image_transforms.cpython-311.pyc ADDED
Binary file (40.8 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/__pycache__/modelcard.cpython-311.pyc ADDED
Binary file (44.5 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/__pycache__/training_args_tf.cpython-311.pyc ADDED
Binary file (16.2 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (9.07 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/beam_constraints.cpython-311.pyc ADDED
Binary file (25.4 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/beam_search.cpython-311.pyc ADDED
Binary file (49.1 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__init__.py ADDED
@@ -0,0 +1,1178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import json
16
+ import os
17
+ import warnings
18
+ from pathlib import Path
19
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
20
+
21
+ from huggingface_hub import model_info
22
+
23
+ from ..configuration_utils import PretrainedConfig
24
+ from ..dynamic_module_utils import get_class_from_dynamic_module
25
+ from ..feature_extraction_utils import PreTrainedFeatureExtractor
26
+ from ..image_processing_utils import BaseImageProcessor
27
+ from ..models.auto.configuration_auto import AutoConfig
28
+ from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
29
+ from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING, AutoImageProcessor
30
+ from ..models.auto.modeling_auto import AutoModelForDepthEstimation, AutoModelForImageToImage
31
+ from ..models.auto.processing_auto import PROCESSOR_MAPPING, AutoProcessor
32
+ from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
33
+ from ..processing_utils import ProcessorMixin
34
+ from ..tokenization_utils import PreTrainedTokenizer
35
+ from ..utils import (
36
+ CONFIG_NAME,
37
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
38
+ cached_file,
39
+ extract_commit_hash,
40
+ find_adapter_config_file,
41
+ is_kenlm_available,
42
+ is_offline_mode,
43
+ is_peft_available,
44
+ is_pyctcdecode_available,
45
+ is_tf_available,
46
+ is_torch_available,
47
+ logging,
48
+ )
49
+ from .audio_classification import AudioClassificationPipeline
50
+ from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
51
+ from .base import (
52
+ ArgumentHandler,
53
+ CsvPipelineDataFormat,
54
+ JsonPipelineDataFormat,
55
+ PipedPipelineDataFormat,
56
+ Pipeline,
57
+ PipelineDataFormat,
58
+ PipelineException,
59
+ PipelineRegistry,
60
+ get_default_model_and_revision,
61
+ infer_framework_load_model,
62
+ )
63
+ from .depth_estimation import DepthEstimationPipeline
64
+ from .document_question_answering import DocumentQuestionAnsweringPipeline
65
+ from .feature_extraction import FeatureExtractionPipeline
66
+ from .fill_mask import FillMaskPipeline
67
+ from .image_classification import ImageClassificationPipeline
68
+ from .image_feature_extraction import ImageFeatureExtractionPipeline
69
+ from .image_segmentation import ImageSegmentationPipeline
70
+ from .image_text_to_text import ImageTextToTextPipeline
71
+ from .image_to_image import ImageToImagePipeline
72
+ from .image_to_text import ImageToTextPipeline
73
+ from .mask_generation import MaskGenerationPipeline
74
+ from .object_detection import ObjectDetectionPipeline
75
+ from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
76
+ from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
77
+ from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
78
+ from .text_classification import TextClassificationPipeline
79
+ from .text_generation import TextGenerationPipeline
80
+ from .text_to_audio import TextToAudioPipeline
81
+ from .token_classification import (
82
+ AggregationStrategy,
83
+ NerPipeline,
84
+ TokenClassificationArgumentHandler,
85
+ TokenClassificationPipeline,
86
+ )
87
+ from .video_classification import VideoClassificationPipeline
88
+ from .visual_question_answering import VisualQuestionAnsweringPipeline
89
+ from .zero_shot_audio_classification import ZeroShotAudioClassificationPipeline
90
+ from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
91
+ from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
92
+ from .zero_shot_object_detection import ZeroShotObjectDetectionPipeline
93
+
94
+
95
+ if is_tf_available():
96
+ import tensorflow as tf
97
+
98
+ from ..models.auto.modeling_tf_auto import (
99
+ TFAutoModel,
100
+ TFAutoModelForCausalLM,
101
+ TFAutoModelForImageClassification,
102
+ TFAutoModelForMaskedLM,
103
+ TFAutoModelForQuestionAnswering,
104
+ TFAutoModelForSeq2SeqLM,
105
+ TFAutoModelForSequenceClassification,
106
+ TFAutoModelForTableQuestionAnswering,
107
+ TFAutoModelForTokenClassification,
108
+ TFAutoModelForVision2Seq,
109
+ TFAutoModelForZeroShotImageClassification,
110
+ )
111
+
112
+ if is_torch_available():
113
+ import torch
114
+
115
+ from ..models.auto.modeling_auto import (
116
+ AutoModel,
117
+ AutoModelForAudioClassification,
118
+ AutoModelForCausalLM,
119
+ AutoModelForCTC,
120
+ AutoModelForDocumentQuestionAnswering,
121
+ AutoModelForImageClassification,
122
+ AutoModelForImageSegmentation,
123
+ AutoModelForImageTextToText,
124
+ AutoModelForMaskedLM,
125
+ AutoModelForMaskGeneration,
126
+ AutoModelForObjectDetection,
127
+ AutoModelForQuestionAnswering,
128
+ AutoModelForSemanticSegmentation,
129
+ AutoModelForSeq2SeqLM,
130
+ AutoModelForSequenceClassification,
131
+ AutoModelForSpeechSeq2Seq,
132
+ AutoModelForTableQuestionAnswering,
133
+ AutoModelForTextToSpectrogram,
134
+ AutoModelForTextToWaveform,
135
+ AutoModelForTokenClassification,
136
+ AutoModelForVideoClassification,
137
+ AutoModelForVision2Seq,
138
+ AutoModelForVisualQuestionAnswering,
139
+ AutoModelForZeroShotImageClassification,
140
+ AutoModelForZeroShotObjectDetection,
141
+ )
142
+
143
+
144
+ if TYPE_CHECKING:
145
+ from ..modeling_tf_utils import TFPreTrainedModel
146
+ from ..modeling_utils import PreTrainedModel
147
+ from ..tokenization_utils_fast import PreTrainedTokenizerFast
148
+
149
+
150
+ logger = logging.get_logger(__name__)
151
+
152
+
153
+ # Register all the supported tasks here
154
+ TASK_ALIASES = {
155
+ "sentiment-analysis": "text-classification",
156
+ "ner": "token-classification",
157
+ "vqa": "visual-question-answering",
158
+ "text-to-speech": "text-to-audio",
159
+ }
160
+ SUPPORTED_TASKS = {
161
+ "audio-classification": {
162
+ "impl": AudioClassificationPipeline,
163
+ "tf": (),
164
+ "pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
165
+ "default": {"model": {"pt": ("superb/wav2vec2-base-superb-ks", "372e048")}},
166
+ "type": "audio",
167
+ },
168
+ "automatic-speech-recognition": {
169
+ "impl": AutomaticSpeechRecognitionPipeline,
170
+ "tf": (),
171
+ "pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),
172
+ "default": {"model": {"pt": ("facebook/wav2vec2-base-960h", "22aad52")}},
173
+ "type": "multimodal",
174
+ },
175
+ "text-to-audio": {
176
+ "impl": TextToAudioPipeline,
177
+ "tf": (),
178
+ "pt": (AutoModelForTextToWaveform, AutoModelForTextToSpectrogram) if is_torch_available() else (),
179
+ "default": {"model": {"pt": ("suno/bark-small", "1dbd7a1")}},
180
+ "type": "text",
181
+ },
182
+ "feature-extraction": {
183
+ "impl": FeatureExtractionPipeline,
184
+ "tf": (TFAutoModel,) if is_tf_available() else (),
185
+ "pt": (AutoModel,) if is_torch_available() else (),
186
+ "default": {
187
+ "model": {
188
+ "pt": ("distilbert/distilbert-base-cased", "6ea8117"),
189
+ "tf": ("distilbert/distilbert-base-cased", "6ea8117"),
190
+ }
191
+ },
192
+ "type": "multimodal",
193
+ },
194
+ "text-classification": {
195
+ "impl": TextClassificationPipeline,
196
+ "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
197
+ "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
198
+ "default": {
199
+ "model": {
200
+ "pt": ("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "714eb0f"),
201
+ "tf": ("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "714eb0f"),
202
+ },
203
+ },
204
+ "type": "text",
205
+ },
206
+ "token-classification": {
207
+ "impl": TokenClassificationPipeline,
208
+ "tf": (TFAutoModelForTokenClassification,) if is_tf_available() else (),
209
+ "pt": (AutoModelForTokenClassification,) if is_torch_available() else (),
210
+ "default": {
211
+ "model": {
212
+ "pt": ("dbmdz/bert-large-cased-finetuned-conll03-english", "4c53496"),
213
+ "tf": ("dbmdz/bert-large-cased-finetuned-conll03-english", "4c53496"),
214
+ },
215
+ },
216
+ "type": "text",
217
+ },
218
+ "question-answering": {
219
+ "impl": QuestionAnsweringPipeline,
220
+ "tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (),
221
+ "pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (),
222
+ "default": {
223
+ "model": {
224
+ "pt": ("distilbert/distilbert-base-cased-distilled-squad", "564e9b5"),
225
+ "tf": ("distilbert/distilbert-base-cased-distilled-squad", "564e9b5"),
226
+ },
227
+ },
228
+ "type": "text",
229
+ },
230
+ "table-question-answering": {
231
+ "impl": TableQuestionAnsweringPipeline,
232
+ "pt": (AutoModelForTableQuestionAnswering,) if is_torch_available() else (),
233
+ "tf": (TFAutoModelForTableQuestionAnswering,) if is_tf_available() else (),
234
+ "default": {
235
+ "model": {
236
+ "pt": ("google/tapas-base-finetuned-wtq", "e3dde19"),
237
+ "tf": ("google/tapas-base-finetuned-wtq", "e3dde19"),
238
+ },
239
+ },
240
+ "type": "text",
241
+ },
242
+ "visual-question-answering": {
243
+ "impl": VisualQuestionAnsweringPipeline,
244
+ "pt": (AutoModelForVisualQuestionAnswering,) if is_torch_available() else (),
245
+ "tf": (),
246
+ "default": {
247
+ "model": {"pt": ("dandelin/vilt-b32-finetuned-vqa", "d0a1f6a")},
248
+ },
249
+ "type": "multimodal",
250
+ },
251
+ "document-question-answering": {
252
+ "impl": DocumentQuestionAnsweringPipeline,
253
+ "pt": (AutoModelForDocumentQuestionAnswering,) if is_torch_available() else (),
254
+ "tf": (),
255
+ "default": {
256
+ "model": {"pt": ("impira/layoutlm-document-qa", "beed3c4")},
257
+ },
258
+ "type": "multimodal",
259
+ },
260
+ "fill-mask": {
261
+ "impl": FillMaskPipeline,
262
+ "tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
263
+ "pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
264
+ "default": {
265
+ "model": {
266
+ "pt": ("distilbert/distilroberta-base", "fb53ab8"),
267
+ "tf": ("distilbert/distilroberta-base", "fb53ab8"),
268
+ }
269
+ },
270
+ "type": "text",
271
+ },
272
+ "summarization": {
273
+ "impl": SummarizationPipeline,
274
+ "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
275
+ "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
276
+ "default": {
277
+ "model": {"pt": ("sshleifer/distilbart-cnn-12-6", "a4f8f3e"), "tf": ("google-t5/t5-small", "df1b051")}
278
+ },
279
+ "type": "text",
280
+ },
281
+ # This task is a special case as it's parametrized by SRC, TGT languages.
282
+ "translation": {
283
+ "impl": TranslationPipeline,
284
+ "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
285
+ "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
286
+ "default": {
287
+ ("en", "fr"): {"model": {"pt": ("google-t5/t5-base", "a9723ea"), "tf": ("google-t5/t5-base", "a9723ea")}},
288
+ ("en", "de"): {"model": {"pt": ("google-t5/t5-base", "a9723ea"), "tf": ("google-t5/t5-base", "a9723ea")}},
289
+ ("en", "ro"): {"model": {"pt": ("google-t5/t5-base", "a9723ea"), "tf": ("google-t5/t5-base", "a9723ea")}},
290
+ },
291
+ "type": "text",
292
+ },
293
+ "text2text-generation": {
294
+ "impl": Text2TextGenerationPipeline,
295
+ "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
296
+ "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
297
+ "default": {"model": {"pt": ("google-t5/t5-base", "a9723ea"), "tf": ("google-t5/t5-base", "a9723ea")}},
298
+ "type": "text",
299
+ },
300
+ "text-generation": {
301
+ "impl": TextGenerationPipeline,
302
+ "tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
303
+ "pt": (AutoModelForCausalLM,) if is_torch_available() else (),
304
+ "default": {"model": {"pt": ("openai-community/gpt2", "607a30d"), "tf": ("openai-community/gpt2", "607a30d")}},
305
+ "type": "text",
306
+ },
307
+ "zero-shot-classification": {
308
+ "impl": ZeroShotClassificationPipeline,
309
+ "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
310
+ "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
311
+ "default": {
312
+ "model": {
313
+ "pt": ("facebook/bart-large-mnli", "d7645e1"),
314
+ "tf": ("FacebookAI/roberta-large-mnli", "2a8f12d"),
315
+ },
316
+ "config": {
317
+ "pt": ("facebook/bart-large-mnli", "d7645e1"),
318
+ "tf": ("FacebookAI/roberta-large-mnli", "2a8f12d"),
319
+ },
320
+ },
321
+ "type": "text",
322
+ },
323
+ "zero-shot-image-classification": {
324
+ "impl": ZeroShotImageClassificationPipeline,
325
+ "tf": (TFAutoModelForZeroShotImageClassification,) if is_tf_available() else (),
326
+ "pt": (AutoModelForZeroShotImageClassification,) if is_torch_available() else (),
327
+ "default": {
328
+ "model": {
329
+ "pt": ("openai/clip-vit-base-patch32", "3d74acf"),
330
+ "tf": ("openai/clip-vit-base-patch32", "3d74acf"),
331
+ }
332
+ },
333
+ "type": "multimodal",
334
+ },
335
+ "zero-shot-audio-classification": {
336
+ "impl": ZeroShotAudioClassificationPipeline,
337
+ "tf": (),
338
+ "pt": (AutoModel,) if is_torch_available() else (),
339
+ "default": {
340
+ "model": {
341
+ "pt": ("laion/clap-htsat-fused", "cca9e28"),
342
+ }
343
+ },
344
+ "type": "multimodal",
345
+ },
346
+ "image-classification": {
347
+ "impl": ImageClassificationPipeline,
348
+ "tf": (TFAutoModelForImageClassification,) if is_tf_available() else (),
349
+ "pt": (AutoModelForImageClassification,) if is_torch_available() else (),
350
+ "default": {
351
+ "model": {
352
+ "pt": ("google/vit-base-patch16-224", "3f49326"),
353
+ "tf": ("google/vit-base-patch16-224", "3f49326"),
354
+ }
355
+ },
356
+ "type": "image",
357
+ },
358
+ "image-feature-extraction": {
359
+ "impl": ImageFeatureExtractionPipeline,
360
+ "tf": (TFAutoModel,) if is_tf_available() else (),
361
+ "pt": (AutoModel,) if is_torch_available() else (),
362
+ "default": {
363
+ "model": {
364
+ "pt": ("google/vit-base-patch16-224", "3f49326"),
365
+ "tf": ("google/vit-base-patch16-224", "3f49326"),
366
+ }
367
+ },
368
+ "type": "image",
369
+ },
370
+ "image-segmentation": {
371
+ "impl": ImageSegmentationPipeline,
372
+ "tf": (),
373
+ "pt": (AutoModelForImageSegmentation, AutoModelForSemanticSegmentation) if is_torch_available() else (),
374
+ "default": {"model": {"pt": ("facebook/detr-resnet-50-panoptic", "d53b52a")}},
375
+ "type": "multimodal",
376
+ },
377
+ "image-to-text": {
378
+ "impl": ImageToTextPipeline,
379
+ "tf": (TFAutoModelForVision2Seq,) if is_tf_available() else (),
380
+ "pt": (AutoModelForVision2Seq,) if is_torch_available() else (),
381
+ "default": {
382
+ "model": {
383
+ "pt": ("ydshieh/vit-gpt2-coco-en", "5bebf1e"),
384
+ "tf": ("ydshieh/vit-gpt2-coco-en", "5bebf1e"),
385
+ }
386
+ },
387
+ "type": "multimodal",
388
+ },
389
+ "image-text-to-text": {
390
+ "impl": ImageTextToTextPipeline,
391
+ "tf": (),
392
+ "pt": (AutoModelForImageTextToText,) if is_torch_available() else (),
393
+ "default": {
394
+ "model": {
395
+ "pt": ("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "2c9ba3b"),
396
+ }
397
+ },
398
+ "type": "multimodal",
399
+ },
400
+ "object-detection": {
401
+ "impl": ObjectDetectionPipeline,
402
+ "tf": (),
403
+ "pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
404
+ "default": {"model": {"pt": ("facebook/detr-resnet-50", "1d5f47b")}},
405
+ "type": "multimodal",
406
+ },
407
+ "zero-shot-object-detection": {
408
+ "impl": ZeroShotObjectDetectionPipeline,
409
+ "tf": (),
410
+ "pt": (AutoModelForZeroShotObjectDetection,) if is_torch_available() else (),
411
+ "default": {"model": {"pt": ("google/owlvit-base-patch32", "cbc355f")}},
412
+ "type": "multimodal",
413
+ },
414
+ "depth-estimation": {
415
+ "impl": DepthEstimationPipeline,
416
+ "tf": (),
417
+ "pt": (AutoModelForDepthEstimation,) if is_torch_available() else (),
418
+ "default": {"model": {"pt": ("Intel/dpt-large", "bc15f29")}},
419
+ "type": "image",
420
+ },
421
+ "video-classification": {
422
+ "impl": VideoClassificationPipeline,
423
+ "tf": (),
424
+ "pt": (AutoModelForVideoClassification,) if is_torch_available() else (),
425
+ "default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "488eb9a")}},
426
+ "type": "video",
427
+ },
428
+ "mask-generation": {
429
+ "impl": MaskGenerationPipeline,
430
+ "tf": (),
431
+ "pt": (AutoModelForMaskGeneration,) if is_torch_available() else (),
432
+ "default": {"model": {"pt": ("facebook/sam-vit-huge", "87aecf0")}},
433
+ "type": "multimodal",
434
+ },
435
+ "image-to-image": {
436
+ "impl": ImageToImagePipeline,
437
+ "tf": (),
438
+ "pt": (AutoModelForImageToImage,) if is_torch_available() else (),
439
+ "default": {"model": {"pt": ("caidas/swin2SR-classical-sr-x2-64", "cee1c92")}},
440
+ "type": "image",
441
+ },
442
+ }
443
+
444
+ NO_FEATURE_EXTRACTOR_TASKS = set()
445
+ NO_IMAGE_PROCESSOR_TASKS = set()
446
+ NO_TOKENIZER_TASKS = set()
447
+
448
+ # Those model configs are special, they are generic over their task, meaning
449
+ # any tokenizer/feature_extractor might be use for a given model so we cannot
450
+ # use the statically defined TOKENIZER_MAPPING and FEATURE_EXTRACTOR_MAPPING to
451
+ # see if the model defines such objects or not.
452
+ MULTI_MODEL_AUDIO_CONFIGS = {"SpeechEncoderDecoderConfig"}
453
+ MULTI_MODEL_VISION_CONFIGS = {"VisionEncoderDecoderConfig", "VisionTextDualEncoderConfig"}
454
+ for task, values in SUPPORTED_TASKS.items():
455
+ if values["type"] == "text":
456
+ NO_FEATURE_EXTRACTOR_TASKS.add(task)
457
+ NO_IMAGE_PROCESSOR_TASKS.add(task)
458
+ elif values["type"] in {"image", "video"}:
459
+ NO_TOKENIZER_TASKS.add(task)
460
+ elif values["type"] in {"audio"}:
461
+ NO_TOKENIZER_TASKS.add(task)
462
+ NO_IMAGE_PROCESSOR_TASKS.add(task)
463
+ elif values["type"] != "multimodal":
464
+ raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}")
465
+
466
+ PIPELINE_REGISTRY = PipelineRegistry(supported_tasks=SUPPORTED_TASKS, task_aliases=TASK_ALIASES)
467
+
468
+
469
+ def get_supported_tasks() -> List[str]:
470
+ """
471
+ Returns a list of supported task strings.
472
+ """
473
+ return PIPELINE_REGISTRY.get_supported_tasks()
474
+
475
+
476
+ def get_task(model: str, token: Optional[str] = None, **deprecated_kwargs) -> str:
477
+ use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
478
+ if use_auth_token is not None:
479
+ warnings.warn(
480
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
481
+ FutureWarning,
482
+ )
483
+ if token is not None:
484
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
485
+ token = use_auth_token
486
+
487
+ if is_offline_mode():
488
+ raise RuntimeError("You cannot infer task automatically within `pipeline` when using offline mode")
489
+ try:
490
+ info = model_info(model, token=token)
491
+ except Exception as e:
492
+ raise RuntimeError(f"Instantiating a pipeline without a task set raised an error: {e}")
493
+ if not info.pipeline_tag:
494
+ raise RuntimeError(
495
+ f"The model {model} does not seem to have a correct `pipeline_tag` set to infer the task automatically"
496
+ )
497
+ if getattr(info, "library_name", "transformers") not in {"transformers", "timm"}:
498
+ raise RuntimeError(f"This model is meant to be used with {info.library_name} not with transformers")
499
+ task = info.pipeline_tag
500
+ return task
501
+
502
+
503
+ def check_task(task: str) -> Tuple[str, Dict, Any]:
504
+ """
505
+ Checks an incoming task string, to validate it's correct and return the default Pipeline and Model classes, and
506
+ default models if they exist.
507
+
508
+ Args:
509
+ task (`str`):
510
+ The task defining which pipeline will be returned. Currently accepted tasks are:
511
+
512
+ - `"audio-classification"`
513
+ - `"automatic-speech-recognition"`
514
+ - `"conversational"`
515
+ - `"depth-estimation"`
516
+ - `"document-question-answering"`
517
+ - `"feature-extraction"`
518
+ - `"fill-mask"`
519
+ - `"image-classification"`
520
+ - `"image-feature-extraction"`
521
+ - `"image-segmentation"`
522
+ - `"image-to-text"`
523
+ - `"image-to-image"`
524
+ - `"object-detection"`
525
+ - `"question-answering"`
526
+ - `"summarization"`
527
+ - `"table-question-answering"`
528
+ - `"text2text-generation"`
529
+ - `"text-classification"` (alias `"sentiment-analysis"` available)
530
+ - `"text-generation"`
531
+ - `"text-to-audio"` (alias `"text-to-speech"` available)
532
+ - `"token-classification"` (alias `"ner"` available)
533
+ - `"translation"`
534
+ - `"translation_xx_to_yy"`
535
+ - `"video-classification"`
536
+ - `"visual-question-answering"` (alias `"vqa"` available)
537
+ - `"zero-shot-classification"`
538
+ - `"zero-shot-image-classification"`
539
+ - `"zero-shot-object-detection"`
540
+
541
+ Returns:
542
+ (normalized_task: `str`, task_defaults: `dict`, task_options: (`tuple`, None)) The normalized task name
543
+ (removed alias and options). The actual dictionary required to initialize the pipeline and some extra task
544
+ options for parametrized tasks like "translation_XX_to_YY"
545
+
546
+
547
+ """
548
+ return PIPELINE_REGISTRY.check_task(task)
549
+
550
+
551
+ def clean_custom_task(task_info):
552
+ import transformers
553
+
554
+ if "impl" not in task_info:
555
+ raise RuntimeError("This model introduces a custom pipeline without specifying its implementation.")
556
+ pt_class_names = task_info.get("pt", ())
557
+ if isinstance(pt_class_names, str):
558
+ pt_class_names = [pt_class_names]
559
+ task_info["pt"] = tuple(getattr(transformers, c) for c in pt_class_names)
560
+ tf_class_names = task_info.get("tf", ())
561
+ if isinstance(tf_class_names, str):
562
+ tf_class_names = [tf_class_names]
563
+ task_info["tf"] = tuple(getattr(transformers, c) for c in tf_class_names)
564
+ return task_info, None
565
+
566
+
567
+ def pipeline(
568
+ task: str = None,
569
+ model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None,
570
+ config: Optional[Union[str, PretrainedConfig]] = None,
571
+ tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None,
572
+ feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
573
+ image_processor: Optional[Union[str, BaseImageProcessor]] = None,
574
+ processor: Optional[Union[str, ProcessorMixin]] = None,
575
+ framework: Optional[str] = None,
576
+ revision: Optional[str] = None,
577
+ use_fast: bool = True,
578
+ token: Optional[Union[str, bool]] = None,
579
+ device: Optional[Union[int, str, "torch.device"]] = None,
580
+ device_map=None,
581
+ torch_dtype=None,
582
+ trust_remote_code: Optional[bool] = None,
583
+ model_kwargs: Dict[str, Any] = None,
584
+ pipeline_class: Optional[Any] = None,
585
+ **kwargs,
586
+ ) -> Pipeline:
587
+ """
588
+ Utility factory method to build a [`Pipeline`].
589
+
590
+ A pipeline consists of:
591
+
592
+ - One or more components for pre-processing model inputs, such as a [tokenizer](tokenizer),
593
+ [image_processor](image_processor), [feature_extractor](feature_extractor), or [processor](processors).
594
+ - A [model](model) that generates predictions from the inputs.
595
+ - Optional post-processing steps to refine the model's output, which can also be handled by processors.
596
+
597
+ <Tip>
598
+ While there are such optional arguments as `tokenizer`, `feature_extractor`, `image_processor`, and `processor`,
599
+ they shouldn't be specified all at once. If these components are not provided, `pipeline` will try to load
600
+ required ones automatically. In case you want to provide these components explicitly, please refer to a
601
+ specific pipeline in order to get more details regarding what components are required.
602
+ </Tip>
603
+
604
+ Args:
605
+ task (`str`):
606
+ The task defining which pipeline will be returned. Currently accepted tasks are:
607
+
608
+ - `"audio-classification"`: will return a [`AudioClassificationPipeline`].
609
+ - `"automatic-speech-recognition"`: will return a [`AutomaticSpeechRecognitionPipeline`].
610
+ - `"depth-estimation"`: will return a [`DepthEstimationPipeline`].
611
+ - `"document-question-answering"`: will return a [`DocumentQuestionAnsweringPipeline`].
612
+ - `"feature-extraction"`: will return a [`FeatureExtractionPipeline`].
613
+ - `"fill-mask"`: will return a [`FillMaskPipeline`]:.
614
+ - `"image-classification"`: will return a [`ImageClassificationPipeline`].
615
+ - `"image-feature-extraction"`: will return an [`ImageFeatureExtractionPipeline`].
616
+ - `"image-segmentation"`: will return a [`ImageSegmentationPipeline`].
617
+ - `"image-text-to-text"`: will return a [`ImageTextToTextPipeline`].
618
+ - `"image-to-image"`: will return a [`ImageToImagePipeline`].
619
+ - `"image-to-text"`: will return a [`ImageToTextPipeline`].
620
+ - `"mask-generation"`: will return a [`MaskGenerationPipeline`].
621
+ - `"object-detection"`: will return a [`ObjectDetectionPipeline`].
622
+ - `"question-answering"`: will return a [`QuestionAnsweringPipeline`].
623
+ - `"summarization"`: will return a [`SummarizationPipeline`].
624
+ - `"table-question-answering"`: will return a [`TableQuestionAnsweringPipeline`].
625
+ - `"text2text-generation"`: will return a [`Text2TextGenerationPipeline`].
626
+ - `"text-classification"` (alias `"sentiment-analysis"` available): will return a
627
+ [`TextClassificationPipeline`].
628
+ - `"text-generation"`: will return a [`TextGenerationPipeline`]:.
629
+ - `"text-to-audio"` (alias `"text-to-speech"` available): will return a [`TextToAudioPipeline`]:.
630
+ - `"token-classification"` (alias `"ner"` available): will return a [`TokenClassificationPipeline`].
631
+ - `"translation"`: will return a [`TranslationPipeline`].
632
+ - `"translation_xx_to_yy"`: will return a [`TranslationPipeline`].
633
+ - `"video-classification"`: will return a [`VideoClassificationPipeline`].
634
+ - `"visual-question-answering"`: will return a [`VisualQuestionAnsweringPipeline`].
635
+ - `"zero-shot-classification"`: will return a [`ZeroShotClassificationPipeline`].
636
+ - `"zero-shot-image-classification"`: will return a [`ZeroShotImageClassificationPipeline`].
637
+ - `"zero-shot-audio-classification"`: will return a [`ZeroShotAudioClassificationPipeline`].
638
+ - `"zero-shot-object-detection"`: will return a [`ZeroShotObjectDetectionPipeline`].
639
+
640
+ model (`str` or [`PreTrainedModel`] or [`TFPreTrainedModel`], *optional*):
641
+ The model that will be used by the pipeline to make predictions. This can be a model identifier or an
642
+ actual instance of a pretrained model inheriting from [`PreTrainedModel`] (for PyTorch) or
643
+ [`TFPreTrainedModel`] (for TensorFlow).
644
+
645
+ If not provided, the default for the `task` will be loaded.
646
+ config (`str` or [`PretrainedConfig`], *optional*):
647
+ The configuration that will be used by the pipeline to instantiate the model. This can be a model
648
+ identifier or an actual pretrained model configuration inheriting from [`PretrainedConfig`].
649
+
650
+ If not provided, the default configuration file for the requested model will be used. That means that if
651
+ `model` is given, its default configuration will be used. However, if `model` is not supplied, this
652
+ `task`'s default model's config is used instead.
653
+ tokenizer (`str` or [`PreTrainedTokenizer`], *optional*):
654
+ The tokenizer that will be used by the pipeline to encode data for the model. This can be a model
655
+ identifier or an actual pretrained tokenizer inheriting from [`PreTrainedTokenizer`].
656
+
657
+ If not provided, the default tokenizer for the given `model` will be loaded (if it is a string). If `model`
658
+ is not specified or not a string, then the default tokenizer for `config` is loaded (if it is a string).
659
+ However, if `config` is also not given or not a string, then the default tokenizer for the given `task`
660
+ will be loaded.
661
+ feature_extractor (`str` or [`PreTrainedFeatureExtractor`], *optional*):
662
+ The feature extractor that will be used by the pipeline to encode data for the model. This can be a model
663
+ identifier or an actual pretrained feature extractor inheriting from [`PreTrainedFeatureExtractor`].
664
+
665
+ Feature extractors are used for non-NLP models, such as Speech or Vision models as well as multi-modal
666
+ models. Multi-modal models will also require a tokenizer to be passed.
667
+
668
+ If not provided, the default feature extractor for the given `model` will be loaded (if it is a string). If
669
+ `model` is not specified or not a string, then the default feature extractor for `config` is loaded (if it
670
+ is a string). However, if `config` is also not given or not a string, then the default feature extractor
671
+ for the given `task` will be loaded.
672
+ image_processor (`str` or [`BaseImageProcessor`], *optional*):
673
+ The image processor that will be used by the pipeline to preprocess images for the model. This can be a
674
+ model identifier or an actual image processor inheriting from [`BaseImageProcessor`].
675
+
676
+ Image processors are used for Vision models and multi-modal models that require image inputs. Multi-modal
677
+ models will also require a tokenizer to be passed.
678
+
679
+ If not provided, the default image processor for the given `model` will be loaded (if it is a string). If
680
+ `model` is not specified or not a string, then the default image processor for `config` is loaded (if it is
681
+ a string).
682
+ processor (`str` or [`ProcessorMixin`], *optional*):
683
+ The processor that will be used by the pipeline to preprocess data for the model. This can be a model
684
+ identifier or an actual processor inheriting from [`ProcessorMixin`].
685
+
686
+ Processors are used for multi-modal models that require multi-modal inputs, for example, a model that
687
+ requires both text and image inputs.
688
+
689
+ If not provided, the default processor for the given `model` will be loaded (if it is a string). If `model`
690
+ is not specified or not a string, then the default processor for `config` is loaded (if it is a string).
691
+ framework (`str`, *optional*):
692
+ The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be
693
+ installed.
694
+
695
+ If no framework is specified, will default to the one currently installed. If no framework is specified and
696
+ both frameworks are installed, will default to the framework of the `model`, or to PyTorch if no model is
697
+ provided.
698
+ revision (`str`, *optional*, defaults to `"main"`):
699
+ When passing a task name or a string model identifier: The specific model version to use. It can be a
700
+ branch name, a tag name, or a commit id, since we use a git-based system for storing models and other
701
+ artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
702
+ use_fast (`bool`, *optional*, defaults to `True`):
703
+ Whether or not to use a Fast tokenizer if possible (a [`PreTrainedTokenizerFast`]).
704
+ use_auth_token (`str` or *bool*, *optional*):
705
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
706
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
707
+ device (`int` or `str` or `torch.device`):
708
+ Defines the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank like `1`) on which this
709
+ pipeline will be allocated.
710
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]`, *optional*):
711
+ Sent directly as `model_kwargs` (just a simpler shortcut). When `accelerate` library is present, set
712
+ `device_map="auto"` to compute the most optimized `device_map` automatically (see
713
+ [here](https://huggingface.co/docs/accelerate/main/en/package_reference/big_modeling#accelerate.cpu_offload)
714
+ for more information).
715
+
716
+ <Tip warning={true}>
717
+
718
+ Do not use `device_map` AND `device` at the same time as they will conflict
719
+
720
+ </Tip>
721
+
722
+ torch_dtype (`str` or `torch.dtype`, *optional*):
723
+ Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model
724
+ (`torch.float16`, `torch.bfloat16`, ... or `"auto"`).
725
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
726
+ Whether or not to allow for custom code defined on the Hub in their own modeling, configuration,
727
+ tokenization or even pipeline files. This option should only be set to `True` for repositories you trust
728
+ and in which you have read the code, as it will execute code present on the Hub on your local machine.
729
+ model_kwargs (`Dict[str, Any]`, *optional*):
730
+ Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
731
+ **model_kwargs)` function.
732
+ kwargs (`Dict[str, Any]`, *optional*):
733
+ Additional keyword arguments passed along to the specific pipeline init (see the documentation for the
734
+ corresponding pipeline class for possible values).
735
+
736
+ Returns:
737
+ [`Pipeline`]: A suitable pipeline for the task.
738
+
739
+ Examples:
740
+
741
+ ```python
742
+ >>> from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
743
+
744
+ >>> # Sentiment analysis pipeline
745
+ >>> analyzer = pipeline("sentiment-analysis")
746
+
747
+ >>> # Question answering pipeline, specifying the checkpoint identifier
748
+ >>> oracle = pipeline(
749
+ ... "question-answering", model="distilbert/distilbert-base-cased-distilled-squad", tokenizer="google-bert/bert-base-cased"
750
+ ... )
751
+
752
+ >>> # Named entity recognition pipeline, passing in a specific model and tokenizer
753
+ >>> model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
754
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
755
+ >>> recognizer = pipeline("ner", model=model, tokenizer=tokenizer)
756
+ ```"""
757
+ if model_kwargs is None:
758
+ model_kwargs = {}
759
+ # Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
760
+ # this is to keep BC).
761
+ use_auth_token = model_kwargs.pop("use_auth_token", None)
762
+ if use_auth_token is not None:
763
+ warnings.warn(
764
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
765
+ FutureWarning,
766
+ )
767
+ if token is not None:
768
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
769
+ token = use_auth_token
770
+
771
+ code_revision = kwargs.pop("code_revision", None)
772
+ commit_hash = kwargs.pop("_commit_hash", None)
773
+
774
+ hub_kwargs = {
775
+ "revision": revision,
776
+ "token": token,
777
+ "trust_remote_code": trust_remote_code,
778
+ "_commit_hash": commit_hash,
779
+ }
780
+
781
+ if task is None and model is None:
782
+ raise RuntimeError(
783
+ "Impossible to instantiate a pipeline without either a task or a model "
784
+ "being specified. "
785
+ "Please provide a task class or a model"
786
+ )
787
+
788
+ if model is None and tokenizer is not None:
789
+ raise RuntimeError(
790
+ "Impossible to instantiate a pipeline with tokenizer specified but not the model as the provided tokenizer"
791
+ " may not be compatible with the default model. Please provide a PreTrainedModel class or a"
792
+ " path/identifier to a pretrained model when providing tokenizer."
793
+ )
794
+ if model is None and feature_extractor is not None:
795
+ raise RuntimeError(
796
+ "Impossible to instantiate a pipeline with feature_extractor specified but not the model as the provided"
797
+ " feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class"
798
+ " or a path/identifier to a pretrained model when providing feature_extractor."
799
+ )
800
+ if isinstance(model, Path):
801
+ model = str(model)
802
+
803
+ if commit_hash is None:
804
+ pretrained_model_name_or_path = None
805
+ if isinstance(config, str):
806
+ pretrained_model_name_or_path = config
807
+ elif config is None and isinstance(model, str):
808
+ pretrained_model_name_or_path = model
809
+
810
+ if not isinstance(config, PretrainedConfig) and pretrained_model_name_or_path is not None:
811
+ # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
812
+ resolved_config_file = cached_file(
813
+ pretrained_model_name_or_path,
814
+ CONFIG_NAME,
815
+ _raise_exceptions_for_gated_repo=False,
816
+ _raise_exceptions_for_missing_entries=False,
817
+ _raise_exceptions_for_connection_errors=False,
818
+ cache_dir=model_kwargs.get("cache_dir"),
819
+ **hub_kwargs,
820
+ )
821
+ hub_kwargs["_commit_hash"] = extract_commit_hash(resolved_config_file, commit_hash)
822
+ else:
823
+ hub_kwargs["_commit_hash"] = getattr(config, "_commit_hash", None)
824
+
825
+ # Config is the primordial information item.
826
+ # Instantiate config if needed
827
+ if isinstance(config, str):
828
+ config = AutoConfig.from_pretrained(
829
+ config, _from_pipeline=task, code_revision=code_revision, **hub_kwargs, **model_kwargs
830
+ )
831
+ hub_kwargs["_commit_hash"] = config._commit_hash
832
+ elif config is None and isinstance(model, str):
833
+ # Check for an adapter file in the model path if PEFT is available
834
+ if is_peft_available():
835
+ # `find_adapter_config_file` doesn't accept `trust_remote_code`
836
+ _hub_kwargs = {k: v for k, v in hub_kwargs.items() if k != "trust_remote_code"}
837
+ maybe_adapter_path = find_adapter_config_file(
838
+ model,
839
+ token=hub_kwargs["token"],
840
+ revision=hub_kwargs["revision"],
841
+ _commit_hash=hub_kwargs["_commit_hash"],
842
+ )
843
+
844
+ if maybe_adapter_path is not None:
845
+ with open(maybe_adapter_path, "r", encoding="utf-8") as f:
846
+ adapter_config = json.load(f)
847
+ model = adapter_config["base_model_name_or_path"]
848
+
849
+ config = AutoConfig.from_pretrained(
850
+ model, _from_pipeline=task, code_revision=code_revision, **hub_kwargs, **model_kwargs
851
+ )
852
+ hub_kwargs["_commit_hash"] = config._commit_hash
853
+
854
+ custom_tasks = {}
855
+ if config is not None and len(getattr(config, "custom_pipelines", {})) > 0:
856
+ custom_tasks = config.custom_pipelines
857
+ if task is None and trust_remote_code is not False:
858
+ if len(custom_tasks) == 1:
859
+ task = list(custom_tasks.keys())[0]
860
+ else:
861
+ raise RuntimeError(
862
+ "We can't infer the task automatically for this model as there are multiple tasks available. Pick "
863
+ f"one in {', '.join(custom_tasks.keys())}"
864
+ )
865
+
866
+ if task is None and model is not None:
867
+ if not isinstance(model, str):
868
+ raise RuntimeError(
869
+ "Inferring the task automatically requires to check the hub with a model_id defined as a `str`. "
870
+ f"{model} is not a valid model_id."
871
+ )
872
+ task = get_task(model, token)
873
+
874
+ # Retrieve the task
875
+ if task in custom_tasks:
876
+ normalized_task = task
877
+ targeted_task, task_options = clean_custom_task(custom_tasks[task])
878
+ if pipeline_class is None:
879
+ if not trust_remote_code:
880
+ raise ValueError(
881
+ "Loading this pipeline requires you to execute the code in the pipeline file in that"
882
+ " repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
883
+ " set the option `trust_remote_code=True` to remove this error."
884
+ )
885
+ class_ref = targeted_task["impl"]
886
+ pipeline_class = get_class_from_dynamic_module(
887
+ class_ref,
888
+ model,
889
+ code_revision=code_revision,
890
+ **hub_kwargs,
891
+ )
892
+ else:
893
+ normalized_task, targeted_task, task_options = check_task(task)
894
+ if pipeline_class is None:
895
+ pipeline_class = targeted_task["impl"]
896
+
897
+ # Use default model/config/tokenizer for the task if no model is provided
898
+ if model is None:
899
+ # At that point framework might still be undetermined
900
+ model, default_revision = get_default_model_and_revision(targeted_task, framework, task_options)
901
+ revision = revision if revision is not None else default_revision
902
+ logger.warning(
903
+ f"No model was supplied, defaulted to {model} and revision"
904
+ f" {revision} ({HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n"
905
+ "Using a pipeline without specifying a model name and revision in production is not recommended."
906
+ )
907
+ hub_kwargs["revision"] = revision
908
+ if config is None and isinstance(model, str):
909
+ config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
910
+ hub_kwargs["_commit_hash"] = config._commit_hash
911
+
912
+ if device_map is not None:
913
+ if "device_map" in model_kwargs:
914
+ raise ValueError(
915
+ 'You cannot use both `pipeline(... device_map=..., model_kwargs={"device_map":...})` as those'
916
+ " arguments might conflict, use only one.)"
917
+ )
918
+ if device is not None:
919
+ logger.warning(
920
+ "Both `device` and `device_map` are specified. `device` will override `device_map`. You"
921
+ " will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`."
922
+ )
923
+ model_kwargs["device_map"] = device_map
924
+ if torch_dtype is not None:
925
+ if "torch_dtype" in model_kwargs:
926
+ raise ValueError(
927
+ 'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
928
+ " arguments might conflict, use only one.)"
929
+ )
930
+ if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
931
+ torch_dtype = getattr(torch, torch_dtype)
932
+ model_kwargs["torch_dtype"] = torch_dtype
933
+
934
+ model_name = model if isinstance(model, str) else None
935
+
936
+ # Load the correct model if possible
937
+ # Infer the framework from the model if not already defined
938
+ if isinstance(model, str) or framework is None:
939
+ model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
940
+ framework, model = infer_framework_load_model(
941
+ model,
942
+ model_classes=model_classes,
943
+ config=config,
944
+ framework=framework,
945
+ task=task,
946
+ **hub_kwargs,
947
+ **model_kwargs,
948
+ )
949
+
950
+ model_config = model.config
951
+ hub_kwargs["_commit_hash"] = model.config._commit_hash
952
+
953
+ load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
954
+ load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
955
+ load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None
956
+ load_processor = type(model_config) in PROCESSOR_MAPPING or processor is not None
957
+
958
+ # Check that pipeline class required loading
959
+ load_tokenizer = load_tokenizer and pipeline_class._load_tokenizer
960
+ load_feature_extractor = load_feature_extractor and pipeline_class._load_feature_extractor
961
+ load_image_processor = load_image_processor and pipeline_class._load_image_processor
962
+ load_processor = load_processor and pipeline_class._load_processor
963
+
964
+ # If `model` (instance of `PretrainedModel` instead of `str`) is passed (and/or same for config), while
965
+ # `image_processor` or `feature_extractor` is `None`, the loading will fail. This happens particularly for some
966
+ # vision tasks when calling `pipeline()` with `model` and only one of the `image_processor` and `feature_extractor`.
967
+ # TODO: we need to make `NO_IMAGE_PROCESSOR_TASKS` and `NO_FEATURE_EXTRACTOR_TASKS` more robust to avoid such issue.
968
+ # This block is only temporarily to make CI green.
969
+ if load_image_processor and load_feature_extractor:
970
+ load_feature_extractor = False
971
+
972
+ if (
973
+ tokenizer is None
974
+ and not load_tokenizer
975
+ and normalized_task not in NO_TOKENIZER_TASKS
976
+ # Using class name to avoid importing the real class.
977
+ and (
978
+ model_config.__class__.__name__ in MULTI_MODEL_AUDIO_CONFIGS
979
+ or model_config.__class__.__name__ in MULTI_MODEL_VISION_CONFIGS
980
+ )
981
+ ):
982
+ # This is a special category of models, that are fusions of multiple models
983
+ # so the model_config might not define a tokenizer, but it seems to be
984
+ # necessary for the task, so we're force-trying to load it.
985
+ load_tokenizer = True
986
+ if (
987
+ image_processor is None
988
+ and not load_image_processor
989
+ and normalized_task not in NO_IMAGE_PROCESSOR_TASKS
990
+ # Using class name to avoid importing the real class.
991
+ and model_config.__class__.__name__ in MULTI_MODEL_VISION_CONFIGS
992
+ ):
993
+ # This is a special category of models, that are fusions of multiple models
994
+ # so the model_config might not define a tokenizer, but it seems to be
995
+ # necessary for the task, so we're force-trying to load it.
996
+ load_image_processor = True
997
+ if (
998
+ feature_extractor is None
999
+ and not load_feature_extractor
1000
+ and normalized_task not in NO_FEATURE_EXTRACTOR_TASKS
1001
+ # Using class name to avoid importing the real class.
1002
+ and model_config.__class__.__name__ in MULTI_MODEL_AUDIO_CONFIGS
1003
+ ):
1004
+ # This is a special category of models, that are fusions of multiple models
1005
+ # so the model_config might not define a tokenizer, but it seems to be
1006
+ # necessary for the task, so we're force-trying to load it.
1007
+ load_feature_extractor = True
1008
+
1009
+ if task in NO_TOKENIZER_TASKS:
1010
+ # These will never require a tokenizer.
1011
+ # the model on the other hand might have a tokenizer, but
1012
+ # the files could be missing from the hub, instead of failing
1013
+ # on such repos, we just force to not load it.
1014
+ load_tokenizer = False
1015
+
1016
+ if task in NO_FEATURE_EXTRACTOR_TASKS:
1017
+ load_feature_extractor = False
1018
+ if task in NO_IMAGE_PROCESSOR_TASKS:
1019
+ load_image_processor = False
1020
+
1021
+ if load_tokenizer:
1022
+ # Try to infer tokenizer from model or config name (if provided as str)
1023
+ if tokenizer is None:
1024
+ if isinstance(model_name, str):
1025
+ tokenizer = model_name
1026
+ elif isinstance(config, str):
1027
+ tokenizer = config
1028
+ else:
1029
+ # Impossible to guess what is the right tokenizer here
1030
+ raise Exception(
1031
+ "Impossible to guess which tokenizer to use. "
1032
+ "Please provide a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer."
1033
+ )
1034
+
1035
+ # Instantiate tokenizer if needed
1036
+ if isinstance(tokenizer, (str, tuple)):
1037
+ if isinstance(tokenizer, tuple):
1038
+ # For tuple we have (tokenizer name, {kwargs})
1039
+ use_fast = tokenizer[1].pop("use_fast", use_fast)
1040
+ tokenizer_identifier = tokenizer[0]
1041
+ tokenizer_kwargs = tokenizer[1]
1042
+ else:
1043
+ tokenizer_identifier = tokenizer
1044
+ tokenizer_kwargs = model_kwargs.copy()
1045
+ tokenizer_kwargs.pop("torch_dtype", None)
1046
+
1047
+ tokenizer = AutoTokenizer.from_pretrained(
1048
+ tokenizer_identifier, use_fast=use_fast, _from_pipeline=task, **hub_kwargs, **tokenizer_kwargs
1049
+ )
1050
+
1051
+ if load_image_processor:
1052
+ # Try to infer image processor from model or config name (if provided as str)
1053
+ if image_processor is None:
1054
+ if isinstance(model_name, str):
1055
+ image_processor = model_name
1056
+ elif isinstance(config, str):
1057
+ image_processor = config
1058
+ # Backward compatibility, as `feature_extractor` used to be the name
1059
+ # for `ImageProcessor`.
1060
+ elif feature_extractor is not None and isinstance(feature_extractor, BaseImageProcessor):
1061
+ image_processor = feature_extractor
1062
+ else:
1063
+ # Impossible to guess what is the right image_processor here
1064
+ raise Exception(
1065
+ "Impossible to guess which image processor to use. "
1066
+ "Please provide a PreTrainedImageProcessor class or a path/identifier "
1067
+ "to a pretrained image processor."
1068
+ )
1069
+
1070
+ # Instantiate image_processor if needed
1071
+ if isinstance(image_processor, (str, tuple)):
1072
+ image_processor = AutoImageProcessor.from_pretrained(
1073
+ image_processor, _from_pipeline=task, **hub_kwargs, **model_kwargs
1074
+ )
1075
+
1076
+ if load_feature_extractor:
1077
+ # Try to infer feature extractor from model or config name (if provided as str)
1078
+ if feature_extractor is None:
1079
+ if isinstance(model_name, str):
1080
+ feature_extractor = model_name
1081
+ elif isinstance(config, str):
1082
+ feature_extractor = config
1083
+ else:
1084
+ # Impossible to guess what is the right feature_extractor here
1085
+ raise Exception(
1086
+ "Impossible to guess which feature extractor to use. "
1087
+ "Please provide a PreTrainedFeatureExtractor class or a path/identifier "
1088
+ "to a pretrained feature extractor."
1089
+ )
1090
+
1091
+ # Instantiate feature_extractor if needed
1092
+ if isinstance(feature_extractor, (str, tuple)):
1093
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
1094
+ feature_extractor, _from_pipeline=task, **hub_kwargs, **model_kwargs
1095
+ )
1096
+
1097
+ if (
1098
+ feature_extractor._processor_class
1099
+ and feature_extractor._processor_class.endswith("WithLM")
1100
+ and isinstance(model_name, str)
1101
+ ):
1102
+ try:
1103
+ import kenlm # to trigger `ImportError` if not installed
1104
+ from pyctcdecode import BeamSearchDecoderCTC
1105
+
1106
+ if os.path.isdir(model_name) or os.path.isfile(model_name):
1107
+ decoder = BeamSearchDecoderCTC.load_from_dir(model_name)
1108
+ else:
1109
+ language_model_glob = os.path.join(
1110
+ BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*"
1111
+ )
1112
+ alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
1113
+ allow_patterns = [language_model_glob, alphabet_filename]
1114
+ decoder = BeamSearchDecoderCTC.load_from_hf_hub(model_name, allow_patterns=allow_patterns)
1115
+
1116
+ kwargs["decoder"] = decoder
1117
+ except ImportError as e:
1118
+ logger.warning(f"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Error: {e}")
1119
+ if not is_kenlm_available():
1120
+ logger.warning("Try to install `kenlm`: `pip install kenlm")
1121
+
1122
+ if not is_pyctcdecode_available():
1123
+ logger.warning("Try to install `pyctcdecode`: `pip install pyctcdecode")
1124
+
1125
+ if load_processor:
1126
+ # Try to infer processor from model or config name (if provided as str)
1127
+ if processor is None:
1128
+ if isinstance(model_name, str):
1129
+ processor = model_name
1130
+ elif isinstance(config, str):
1131
+ processor = config
1132
+ else:
1133
+ # Impossible to guess what is the right processor here
1134
+ raise Exception(
1135
+ "Impossible to guess which processor to use. "
1136
+ "Please provide a processor instance or a path/identifier "
1137
+ "to a processor."
1138
+ )
1139
+
1140
+ # Instantiate processor if needed
1141
+ if isinstance(processor, (str, tuple)):
1142
+ processor = AutoProcessor.from_pretrained(processor, _from_pipeline=task, **hub_kwargs, **model_kwargs)
1143
+ if not isinstance(processor, ProcessorMixin):
1144
+ raise TypeError(
1145
+ "Processor was loaded, but it is not an instance of `ProcessorMixin`. "
1146
+ f"Got type `{type(processor)}` instead. Please check that you specified "
1147
+ "correct pipeline task for the model and model has processor implemented and saved."
1148
+ )
1149
+
1150
+ if task == "translation" and model.config.task_specific_params:
1151
+ for key in model.config.task_specific_params:
1152
+ if key.startswith("translation"):
1153
+ task = key
1154
+ warnings.warn(
1155
+ f'"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{task}"',
1156
+ UserWarning,
1157
+ )
1158
+ break
1159
+
1160
+ if tokenizer is not None:
1161
+ kwargs["tokenizer"] = tokenizer
1162
+
1163
+ if feature_extractor is not None:
1164
+ kwargs["feature_extractor"] = feature_extractor
1165
+
1166
+ if torch_dtype is not None:
1167
+ kwargs["torch_dtype"] = torch_dtype
1168
+
1169
+ if image_processor is not None:
1170
+ kwargs["image_processor"] = image_processor
1171
+
1172
+ if device is not None:
1173
+ kwargs["device"] = device
1174
+
1175
+ if processor is not None:
1176
+ kwargs["processor"] = processor
1177
+
1178
+ return pipeline_class(model=model, framework=framework, task=task, **kwargs)
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (48.2 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/audio_classification.cpython-311.pyc ADDED
Binary file (11.8 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/audio_utils.cpython-311.pyc ADDED
Binary file (14.1 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/automatic_speech_recognition.cpython-311.pyc ADDED
Binary file (38.2 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/base.cpython-311.pyc ADDED
Binary file (77.9 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/depth_estimation.cpython-311.pyc ADDED
Binary file (7.7 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/document_question_answering.cpython-311.pyc ADDED
Binary file (26.6 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/feature_extraction.cpython-311.pyc ADDED
Binary file (4.45 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/fill_mask.cpython-311.pyc ADDED
Binary file (14.1 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_classification.cpython-311.pyc ADDED
Binary file (11.8 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_feature_extraction.cpython-311.pyc ADDED
Binary file (5.84 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_segmentation.cpython-311.pyc ADDED
Binary file (11.5 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_text_to_text.cpython-311.pyc ADDED
Binary file (20 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_to_image.cpython-311.pyc ADDED
Binary file (6.68 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_to_text.cpython-311.pyc ADDED
Binary file (9.78 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/mask_generation.cpython-311.pyc ADDED
Binary file (14.8 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/object_detection.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/pt_utils.cpython-311.pyc ADDED
Binary file (15.5 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/question_answering.cpython-311.pyc ADDED
Binary file (33.1 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/table_question_answering.cpython-311.pyc ADDED
Binary file (25.2 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/text2text_generation.cpython-311.pyc ADDED
Binary file (21.3 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/text_classification.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/text_generation.cpython-311.pyc ADDED
Binary file (21.2 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/text_to_audio.cpython-311.pyc ADDED
Binary file (9.03 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/token_classification.cpython-311.pyc ADDED
Binary file (30.6 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/video_classification.cpython-311.pyc ADDED
Binary file (9.97 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/visual_question_answering.cpython-311.pyc ADDED
Binary file (12.3 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/zero_shot_audio_classification.cpython-311.pyc ADDED
Binary file (9.07 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/zero_shot_classification.cpython-311.pyc ADDED
Binary file (16.7 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/zero_shot_image_classification.cpython-311.pyc ADDED
Binary file (10.5 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/zero_shot_object_detection.cpython-311.pyc ADDED
Binary file (12.9 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/pipelines/audio_classification.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import subprocess
15
+ from typing import Union
16
+
17
+ import numpy as np
18
+ import requests
19
+
20
+ from ..utils import add_end_docstrings, is_torch_available, is_torchaudio_available, logging
21
+ from .base import Pipeline, build_pipeline_init_args
22
+
23
+
24
+ if is_torch_available():
25
+ from ..models.auto.modeling_auto import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
31
+ """
32
+ Helper function to read an audio file through ffmpeg.
33
+ """
34
+ ar = f"{sampling_rate}"
35
+ ac = "1"
36
+ format_for_conversion = "f32le"
37
+ ffmpeg_command = [
38
+ "ffmpeg",
39
+ "-i",
40
+ "pipe:0",
41
+ "-ac",
42
+ ac,
43
+ "-ar",
44
+ ar,
45
+ "-f",
46
+ format_for_conversion,
47
+ "-hide_banner",
48
+ "-loglevel",
49
+ "quiet",
50
+ "pipe:1",
51
+ ]
52
+
53
+ try:
54
+ ffmpeg_process = subprocess.Popen(ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
55
+ except FileNotFoundError:
56
+ raise ValueError("ffmpeg was not found but is required to load audio files from filename")
57
+ output_stream = ffmpeg_process.communicate(bpayload)
58
+ out_bytes = output_stream[0]
59
+
60
+ audio = np.frombuffer(out_bytes, np.float32)
61
+ if audio.shape[0] == 0:
62
+ raise ValueError("Malformed soundfile")
63
+ return audio
64
+
65
+
66
+ @add_end_docstrings(build_pipeline_init_args(has_feature_extractor=True))
67
+ class AudioClassificationPipeline(Pipeline):
68
+ """
69
+ Audio classification pipeline using any `AutoModelForAudioClassification`. This pipeline predicts the class of a
70
+ raw waveform or an audio file. In case of an audio file, ffmpeg should be installed to support multiple audio
71
+ formats.
72
+
73
+ Example:
74
+
75
+ ```python
76
+ >>> from transformers import pipeline
77
+
78
+ >>> classifier = pipeline(model="superb/wav2vec2-base-superb-ks")
79
+ >>> classifier("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac")
80
+ [{'score': 0.997, 'label': '_unknown_'}, {'score': 0.002, 'label': 'left'}, {'score': 0.0, 'label': 'yes'}, {'score': 0.0, 'label': 'down'}, {'score': 0.0, 'label': 'stop'}]
81
+ ```
82
+
83
+ Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
84
+
85
+
86
+ This pipeline can currently be loaded from [`pipeline`] using the following task identifier:
87
+ `"audio-classification"`.
88
+
89
+ See the list of available models on
90
+ [huggingface.co/models](https://huggingface.co/models?filter=audio-classification).
91
+ """
92
+
93
+ def __init__(self, *args, **kwargs):
94
+ # Default, might be overriden by the model.config.
95
+ kwargs["top_k"] = kwargs.get("top_k", 5)
96
+ super().__init__(*args, **kwargs)
97
+
98
+ if self.framework != "pt":
99
+ raise ValueError(f"The {self.__class__} is only available in PyTorch.")
100
+
101
+ self.check_model_type(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES)
102
+
103
+ def __call__(
104
+ self,
105
+ inputs: Union[np.ndarray, bytes, str],
106
+ **kwargs,
107
+ ):
108
+ """
109
+ Classify the sequence(s) given as inputs. See the [`AutomaticSpeechRecognitionPipeline`] documentation for more
110
+ information.
111
+
112
+ Args:
113
+ inputs (`np.ndarray` or `bytes` or `str` or `dict`):
114
+ The inputs is either :
115
+ - `str` that is the filename of the audio file, the file will be read at the correct sampling rate
116
+ to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
117
+ - `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
118
+ same way.
119
+ - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
120
+ Raw audio at the correct sampling rate (no further check will be done)
121
+ - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
122
+ pipeline do the resampling. The dict must be either be in the format `{"sampling_rate": int,
123
+ "raw": np.array}`, or `{"sampling_rate": int, "array": np.array}`, where the key `"raw"` or
124
+ `"array"` is used to denote the raw audio waveform.
125
+ top_k (`int`, *optional*, defaults to None):
126
+ The number of top labels that will be returned by the pipeline. If the provided number is `None` or
127
+ higher than the number of labels available in the model configuration, it will default to the number of
128
+ labels.
129
+ function_to_apply(`str`, *optional*, defaults to "softmax"):
130
+ The function to apply to the model output. By default, the pipeline will apply the softmax function to
131
+ the output of the model. Valid options: ["softmax", "sigmoid", "none"]. Note that passing Python's
132
+ built-in `None` will default to "softmax", so you need to pass the string "none" to disable any
133
+ post-processing.
134
+
135
+ Return:
136
+ A list of `dict` with the following keys:
137
+
138
+ - **label** (`str`) -- The label predicted.
139
+ - **score** (`float`) -- The corresponding probability.
140
+ """
141
+ return super().__call__(inputs, **kwargs)
142
+
143
+ def _sanitize_parameters(self, top_k=None, function_to_apply=None, **kwargs):
144
+ # No parameters on this pipeline right now
145
+ postprocess_params = {}
146
+ if top_k is not None:
147
+ if top_k > self.model.config.num_labels:
148
+ top_k = self.model.config.num_labels
149
+ postprocess_params["top_k"] = top_k
150
+ if function_to_apply is not None:
151
+ if function_to_apply not in ["softmax", "sigmoid", "none"]:
152
+ raise ValueError(
153
+ f"Invalid value for `function_to_apply`: {function_to_apply}. "
154
+ "Valid options are ['softmax', 'sigmoid', 'none']"
155
+ )
156
+ postprocess_params["function_to_apply"] = function_to_apply
157
+ else:
158
+ postprocess_params["function_to_apply"] = "softmax"
159
+ return {}, {}, postprocess_params
160
+
161
+ def preprocess(self, inputs):
162
+ if isinstance(inputs, str):
163
+ if inputs.startswith("http://") or inputs.startswith("https://"):
164
+ # We need to actually check for a real protocol, otherwise it's impossible to use a local file
165
+ # like http_huggingface_co.png
166
+ inputs = requests.get(inputs).content
167
+ else:
168
+ with open(inputs, "rb") as f:
169
+ inputs = f.read()
170
+
171
+ if isinstance(inputs, bytes):
172
+ inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
173
+
174
+ if isinstance(inputs, dict):
175
+ # Accepting `"array"` which is the key defined in `datasets` for
176
+ # better integration
177
+ if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
178
+ raise ValueError(
179
+ "When passing a dictionary to AudioClassificationPipeline, the dict needs to contain a "
180
+ '"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
181
+ "containing the sampling_rate associated with that array"
182
+ )
183
+
184
+ _inputs = inputs.pop("raw", None)
185
+ if _inputs is None:
186
+ # Remove path which will not be used from `datasets`.
187
+ inputs.pop("path", None)
188
+ _inputs = inputs.pop("array", None)
189
+ in_sampling_rate = inputs.pop("sampling_rate")
190
+ inputs = _inputs
191
+ if in_sampling_rate != self.feature_extractor.sampling_rate:
192
+ import torch
193
+
194
+ if is_torchaudio_available():
195
+ from torchaudio import functional as F
196
+ else:
197
+ raise ImportError(
198
+ "torchaudio is required to resample audio samples in AudioClassificationPipeline. "
199
+ "The torchaudio package can be installed through: `pip install torchaudio`."
200
+ )
201
+
202
+ inputs = F.resample(
203
+ torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
204
+ ).numpy()
205
+
206
+ if not isinstance(inputs, np.ndarray):
207
+ raise TypeError("We expect a numpy ndarray as input")
208
+ if len(inputs.shape) != 1:
209
+ raise ValueError("We expect a single channel audio input for AudioClassificationPipeline")
210
+
211
+ processed = self.feature_extractor(
212
+ inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
213
+ )
214
+ return processed
215
+
216
+ def _forward(self, model_inputs):
217
+ model_outputs = self.model(**model_inputs)
218
+ return model_outputs
219
+
220
+ def postprocess(self, model_outputs, top_k=5, function_to_apply="softmax"):
221
+ if function_to_apply == "softmax":
222
+ probs = model_outputs.logits[0].softmax(-1)
223
+ elif function_to_apply == "sigmoid":
224
+ probs = model_outputs.logits[0].sigmoid()
225
+ else:
226
+ probs = model_outputs.logits[0]
227
+ scores, ids = probs.topk(top_k)
228
+
229
+ scores = scores.tolist()
230
+ ids = ids.tolist()
231
+
232
+ labels = [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
233
+
234
+ return labels
.venv/lib/python3.11/site-packages/transformers/pipelines/audio_utils.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ import datetime
3
+ import platform
4
+ import subprocess
5
+ from typing import Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+
9
+
10
+ def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
11
+ """
12
+ Helper function to read an audio file through ffmpeg.
13
+ """
14
+ ar = f"{sampling_rate}"
15
+ ac = "1"
16
+ format_for_conversion = "f32le"
17
+ ffmpeg_command = [
18
+ "ffmpeg",
19
+ "-i",
20
+ "pipe:0",
21
+ "-ac",
22
+ ac,
23
+ "-ar",
24
+ ar,
25
+ "-f",
26
+ format_for_conversion,
27
+ "-hide_banner",
28
+ "-loglevel",
29
+ "quiet",
30
+ "pipe:1",
31
+ ]
32
+
33
+ try:
34
+ with subprocess.Popen(ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) as ffmpeg_process:
35
+ output_stream = ffmpeg_process.communicate(bpayload)
36
+ except FileNotFoundError as error:
37
+ raise ValueError("ffmpeg was not found but is required to load audio files from filename") from error
38
+ out_bytes = output_stream[0]
39
+ audio = np.frombuffer(out_bytes, np.float32)
40
+ if audio.shape[0] == 0:
41
+ raise ValueError(
42
+ "Soundfile is either not in the correct format or is malformed. Ensure that the soundfile has "
43
+ "a valid audio file extension (e.g. wav, flac or mp3) and is not corrupted. If reading from a remote "
44
+ "URL, ensure that the URL is the full address to **download** the audio file."
45
+ )
46
+ return audio
47
+
48
+
49
+ def ffmpeg_microphone(
50
+ sampling_rate: int,
51
+ chunk_length_s: float,
52
+ format_for_conversion: str = "f32le",
53
+ ffmpeg_input_device: Optional[str] = None,
54
+ ffmpeg_additional_args: Optional[list[str]] = None,
55
+ ):
56
+ """
57
+ Helper function to read audio from a microphone using ffmpeg. The default input device will be used unless another
58
+ input device is specified using the `ffmpeg_input_device` argument. Uses 'alsa' on Linux, 'avfoundation' on MacOS and
59
+ 'dshow' on Windows.
60
+
61
+ Arguments:
62
+ sampling_rate (`int`):
63
+ The sampling_rate to use when reading the data from the microphone. Try using the model's sampling_rate to
64
+ avoid resampling later.
65
+ chunk_length_s (`float` or `int`):
66
+ The length of the maximum chunk of audio to be sent returned.
67
+ format_for_conversion (`str`, defaults to `f32le`):
68
+ The name of the format of the audio samples to be returned by ffmpeg. The standard is `f32le`, `s16le`
69
+ could also be used.
70
+ ffmpeg_input_device (`str`, *optional*):
71
+ The identifier of the input device to be used by ffmpeg (i.e. ffmpeg's '-i' argument). If unset,
72
+ the default input device will be used. See `https://www.ffmpeg.org/ffmpeg-devices.html#Input-Devices`
73
+ for how to specify and list input devices.
74
+ ffmpeg_additional_args (`list[str]`, *optional*):
75
+ Additional arguments to pass to ffmpeg, can include arguments like -nostdin for running as a background
76
+ process. For example, to pass -nostdin to the ffmpeg process, pass in ["-nostdin"]. If passing in flags
77
+ with multiple arguments, use the following convention (eg ["flag", "arg1", "arg2]).
78
+
79
+ Returns:
80
+ A generator yielding audio chunks of `chunk_length_s` seconds as `bytes` objects of length
81
+ `int(round(sampling_rate * chunk_length_s)) * size_of_sample`.
82
+ """
83
+ ar = f"{sampling_rate}"
84
+ ac = "1"
85
+ if format_for_conversion == "s16le":
86
+ size_of_sample = 2
87
+ elif format_for_conversion == "f32le":
88
+ size_of_sample = 4
89
+ else:
90
+ raise ValueError(f"Unhandled format `{format_for_conversion}`. Please use `s16le` or `f32le`")
91
+
92
+ system = platform.system()
93
+
94
+ if system == "Linux":
95
+ format_ = "alsa"
96
+ input_ = ffmpeg_input_device or "default"
97
+ elif system == "Darwin":
98
+ format_ = "avfoundation"
99
+ input_ = ffmpeg_input_device or ":default"
100
+ elif system == "Windows":
101
+ format_ = "dshow"
102
+ input_ = ffmpeg_input_device or _get_microphone_name()
103
+
104
+ ffmpeg_additional_args = [] if ffmpeg_additional_args is None else ffmpeg_additional_args
105
+
106
+ ffmpeg_command = [
107
+ "ffmpeg",
108
+ "-f",
109
+ format_,
110
+ "-i",
111
+ input_,
112
+ "-ac",
113
+ ac,
114
+ "-ar",
115
+ ar,
116
+ "-f",
117
+ format_for_conversion,
118
+ "-fflags",
119
+ "nobuffer",
120
+ "-hide_banner",
121
+ "-loglevel",
122
+ "quiet",
123
+ "pipe:1",
124
+ ]
125
+
126
+ ffmpeg_command.extend(ffmpeg_additional_args)
127
+
128
+ chunk_len = int(round(sampling_rate * chunk_length_s)) * size_of_sample
129
+ iterator = _ffmpeg_stream(ffmpeg_command, chunk_len)
130
+ for item in iterator:
131
+ yield item
132
+
133
+
134
+ def ffmpeg_microphone_live(
135
+ sampling_rate: int,
136
+ chunk_length_s: float,
137
+ stream_chunk_s: Optional[int] = None,
138
+ stride_length_s: Optional[Union[Tuple[float, float], float]] = None,
139
+ format_for_conversion: str = "f32le",
140
+ ffmpeg_input_device: Optional[str] = None,
141
+ ffmpeg_additional_args: Optional[list[str]] = None,
142
+ ):
143
+ """
144
+ Helper function to read audio from a microphone using ffmpeg. This will output `partial` overlapping chunks starting
145
+ from `stream_chunk_s` (if it is defined) until `chunk_length_s` is reached. It will make use of striding to avoid
146
+ errors on the "sides" of the various chunks. The default input device will be used unless another input device is
147
+ specified using the `ffmpeg_input_device` argument. Uses 'alsa' on Linux, 'avfoundation' on MacOS and 'dshow' on Windows.
148
+
149
+ Arguments:
150
+ sampling_rate (`int`):
151
+ The sampling_rate to use when reading the data from the microphone. Try using the model's sampling_rate to
152
+ avoid resampling later.
153
+ chunk_length_s (`float` or `int`):
154
+ The length of the maximum chunk of audio to be sent returned. This includes the eventual striding.
155
+ stream_chunk_s (`float` or `int`):
156
+ The length of the minimal temporary audio to be returned.
157
+ stride_length_s (`float` or `int` or `(float, float)`, *optional*):
158
+ The length of the striding to be used. Stride is used to provide context to a model on the (left, right) of
159
+ an audio sample but without using that part to actually make the prediction. Setting this does not change
160
+ the length of the chunk.
161
+ format_for_conversion (`str`, *optional*, defaults to `f32le`):
162
+ The name of the format of the audio samples to be returned by ffmpeg. The standard is `f32le`, `s16le`
163
+ could also be used.
164
+ ffmpeg_input_device (`str`, *optional*):
165
+ The identifier of the input device to be used by ffmpeg (i.e. ffmpeg's '-i' argument). If unset,
166
+ the default input device will be used. See `https://www.ffmpeg.org/ffmpeg-devices.html#Input-Devices`
167
+ for how to specify and list input devices.
168
+ ffmpeg_additional_args (`list[str]`, *optional*):
169
+ Additional arguments to pass to ffmpeg, can include arguments like -nostdin for running as a background
170
+ process. For example, to pass -nostdin to the ffmpeg process, pass in ["-nostdin"]. If passing in flags
171
+ with multiple arguments, use the following convention (eg ["flag", "arg1", "arg2]).
172
+
173
+ Return:
174
+ A generator yielding dictionaries of the following form
175
+
176
+ `{"sampling_rate": int, "raw": np.array(), "partial" bool}` With optionally a `"stride" (int, int)` key if
177
+ `stride_length_s` is defined.
178
+
179
+ `stride` and `raw` are all expressed in `samples`, and `partial` is a boolean saying if the current yield item
180
+ is a whole chunk, or a partial temporary result to be later replaced by another larger chunk.
181
+ """
182
+ if stream_chunk_s is not None:
183
+ chunk_s = stream_chunk_s
184
+ else:
185
+ chunk_s = chunk_length_s
186
+
187
+ microphone = ffmpeg_microphone(
188
+ sampling_rate,
189
+ chunk_s,
190
+ format_for_conversion=format_for_conversion,
191
+ ffmpeg_input_device=ffmpeg_input_device,
192
+ ffmpeg_additional_args=[] if ffmpeg_additional_args is None else ffmpeg_additional_args,
193
+ )
194
+
195
+ if format_for_conversion == "s16le":
196
+ dtype = np.int16
197
+ size_of_sample = 2
198
+ elif format_for_conversion == "f32le":
199
+ dtype = np.float32
200
+ size_of_sample = 4
201
+ else:
202
+ raise ValueError(f"Unhandled format `{format_for_conversion}`. Please use `s16le` or `f32le`")
203
+
204
+ if stride_length_s is None:
205
+ stride_length_s = chunk_length_s / 6
206
+ chunk_len = int(round(sampling_rate * chunk_length_s)) * size_of_sample
207
+ if isinstance(stride_length_s, (int, float)):
208
+ stride_length_s = [stride_length_s, stride_length_s]
209
+
210
+ stride_left = int(round(sampling_rate * stride_length_s[0])) * size_of_sample
211
+ stride_right = int(round(sampling_rate * stride_length_s[1])) * size_of_sample
212
+ audio_time = datetime.datetime.now()
213
+ delta = datetime.timedelta(seconds=chunk_s)
214
+ for item in chunk_bytes_iter(microphone, chunk_len, stride=(stride_left, stride_right), stream=True):
215
+ # Put everything back in numpy scale
216
+ item["raw"] = np.frombuffer(item["raw"], dtype=dtype)
217
+ item["stride"] = (
218
+ item["stride"][0] // size_of_sample,
219
+ item["stride"][1] // size_of_sample,
220
+ )
221
+ item["sampling_rate"] = sampling_rate
222
+ audio_time += delta
223
+ if datetime.datetime.now() > audio_time + 10 * delta:
224
+ # We're late !! SKIP
225
+ continue
226
+ yield item
227
+
228
+
229
+ def chunk_bytes_iter(iterator, chunk_len: int, stride: Tuple[int, int], stream: bool = False):
230
+ """
231
+ Reads raw bytes from an iterator and does chunks of length `chunk_len`. Optionally adds `stride` to each chunks to
232
+ get overlaps. `stream` is used to return partial results even if a full `chunk_len` is not yet available.
233
+ """
234
+ acc = b""
235
+ stride_left, stride_right = stride
236
+ if stride_left + stride_right >= chunk_len:
237
+ raise ValueError(
238
+ f"Stride needs to be strictly smaller than chunk_len: ({stride_left}, {stride_right}) vs {chunk_len}"
239
+ )
240
+ _stride_left = 0
241
+ for raw in iterator:
242
+ acc += raw
243
+ if stream and len(acc) < chunk_len:
244
+ stride = (_stride_left, 0)
245
+ yield {"raw": acc[:chunk_len], "stride": stride, "partial": True}
246
+ else:
247
+ while len(acc) >= chunk_len:
248
+ # We are flushing the accumulator
249
+ stride = (_stride_left, stride_right)
250
+ item = {"raw": acc[:chunk_len], "stride": stride}
251
+ if stream:
252
+ item["partial"] = False
253
+ yield item
254
+ _stride_left = stride_left
255
+ acc = acc[chunk_len - stride_left - stride_right :]
256
+ # Last chunk
257
+ if len(acc) > stride_left:
258
+ item = {"raw": acc, "stride": (_stride_left, 0)}
259
+ if stream:
260
+ item["partial"] = False
261
+ yield item
262
+
263
+
264
+ def _ffmpeg_stream(ffmpeg_command, buflen: int):
265
+ """
266
+ Internal function to create the generator of data through ffmpeg
267
+ """
268
+ bufsize = 2**24 # 16Mo
269
+ try:
270
+ with subprocess.Popen(ffmpeg_command, stdout=subprocess.PIPE, bufsize=bufsize) as ffmpeg_process:
271
+ while True:
272
+ raw = ffmpeg_process.stdout.read(buflen)
273
+ if raw == b"":
274
+ break
275
+ yield raw
276
+ except FileNotFoundError as error:
277
+ raise ValueError("ffmpeg was not found but is required to stream audio files from filename") from error
278
+
279
+
280
+ def _get_microphone_name():
281
+ """
282
+ Retrieve the microphone name in Windows .
283
+ """
284
+ command = ["ffmpeg", "-list_devices", "true", "-f", "dshow", "-i", ""]
285
+
286
+ try:
287
+ ffmpeg_devices = subprocess.run(command, text=True, stderr=subprocess.PIPE, encoding="utf-8")
288
+ microphone_lines = [line for line in ffmpeg_devices.stderr.splitlines() if "(audio)" in line]
289
+
290
+ if microphone_lines:
291
+ microphone_name = microphone_lines[0].split('"')[1]
292
+ print(f"Using microphone: {microphone_name}")
293
+ return f"audio={microphone_name}"
294
+ except FileNotFoundError:
295
+ print("ffmpeg was not found. Please install it or make sure it is in your system PATH.")
296
+
297
+ return "default"
.venv/lib/python3.11/site-packages/transformers/pipelines/automatic_speech_recognition.py ADDED
@@ -0,0 +1,766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import warnings
15
+ from collections import defaultdict
16
+ from typing import TYPE_CHECKING, Dict, Optional, Union
17
+
18
+ import numpy as np
19
+ import requests
20
+
21
+ from ..tokenization_utils import PreTrainedTokenizer
22
+ from ..utils import is_torch_available, is_torchaudio_available, logging
23
+ from .audio_utils import ffmpeg_read
24
+ from .base import ChunkPipeline
25
+
26
+
27
+ if TYPE_CHECKING:
28
+ from pyctcdecode import BeamSearchDecoderCTC
29
+
30
+ from ..feature_extraction_sequence_utils import SequenceFeatureExtractor
31
+ from ..modeling_utils import PreTrainedModel
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ if is_torch_available():
36
+ import torch
37
+
38
+ from ..models.auto.modeling_auto import MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
39
+
40
+
41
+ def rescale_stride(stride, ratio):
42
+ """
43
+ Rescales the stride values from audio space to tokens/logits space.
44
+
45
+ (160_000, 16_000, 16_000) -> (2000, 200, 200) for instance.
46
+ """
47
+ # Shape is [B, SEQ] for tokens
48
+ # [B, SEQ, V] for logits
49
+
50
+ new_strides = []
51
+ for input_n, left, right in stride:
52
+ token_n = int(round(input_n * ratio))
53
+ left = int(round(left / input_n * token_n))
54
+ right = int(round(right / input_n * token_n))
55
+ new_stride = (token_n, left, right)
56
+ new_strides.append(new_stride)
57
+
58
+ return new_strides
59
+
60
+
61
+ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, dtype=None):
62
+ inputs_len = inputs.shape[0]
63
+ step = chunk_len - stride_left - stride_right
64
+ for chunk_start_idx in range(0, inputs_len, step):
65
+ chunk_end_idx = chunk_start_idx + chunk_len
66
+ chunk = inputs[chunk_start_idx:chunk_end_idx]
67
+ processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
68
+ if dtype is not None:
69
+ processed = processed.to(dtype=dtype)
70
+ _stride_left = 0 if chunk_start_idx == 0 else stride_left
71
+ is_last = chunk_end_idx >= inputs_len
72
+ _stride_right = 0 if is_last else stride_right
73
+
74
+ chunk_len = chunk.shape[0]
75
+ stride = (chunk_len, _stride_left, _stride_right)
76
+ if chunk.shape[0] > _stride_left:
77
+ yield {"is_last": is_last, "stride": stride, **processed}
78
+ if is_last:
79
+ break
80
+
81
+
82
+ def _fast_find_longest_common_sequence(sequence_left, sequence_right):
83
+ seq_len_left = len(sequence_left)
84
+ seq_len_right = len(sequence_right)
85
+ counter = [[0] * (seq_len_right + 1) for _ in range(seq_len_left + 1)]
86
+ longest = 0
87
+ for i in range(seq_len_left):
88
+ for j in range(seq_len_right):
89
+ if sequence_left[i] == sequence_right[j]:
90
+ previous_counter = counter[i][j] + 1
91
+ counter[i + 1][j + 1] = previous_counter
92
+ if previous_counter > longest:
93
+ longest = previous_counter
94
+
95
+ counter = np.array(counter)
96
+ # we return the idx of the first element of the longest common sequence in the left sequence
97
+ index_left = np.argwhere(counter == longest)[-1][0] - longest if longest != 0 else -1
98
+ index_right = np.argwhere(counter == longest)[-1][1] - longest if longest != 0 else -1
99
+ return index_left, index_right, longest
100
+
101
+
102
+ def _find_longest_common_sequence(sequences, tokenizer):
103
+ # TODO Use a faster algorithm this can probably be done in O(n)
104
+ # using suffix array.
105
+ # It might be tedious to do because of fault tolerance.
106
+ # We actually have a really good property which is that the total sequence
107
+ # MUST be those subsequences in order.
108
+ # Also the algorithm should be more tolerant to errors.
109
+ sequence = [tok_id for tok_id in sequences[0][0].tolist() if tok_id not in tokenizer.all_special_ids]
110
+ for new_seq in sequences[1:]:
111
+ new_sequence = [tok_id for tok_id in new_seq[0].tolist() if tok_id not in tokenizer.all_special_ids]
112
+
113
+ index = 0
114
+ max_ = 0.0
115
+ for i in range(1, len(new_sequence) + 1):
116
+ # epsilon to favor long perfect matches
117
+ eps = i / 10000.0
118
+ matches = np.sum(np.array(sequence[-i:]) == np.array(new_sequence[:i]))
119
+ matching = matches / i + eps
120
+ if matches > 1 and matching > max_:
121
+ index = i
122
+ max_ = matching
123
+ sequence.extend(new_sequence[index:])
124
+ return np.array(sequence)
125
+
126
+
127
+ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
128
+ """
129
+ Pipeline that aims at extracting spoken text contained within some audio.
130
+
131
+ The input can be either a raw waveform or a audio file. In case of the audio file, ffmpeg should be installed for
132
+ to support multiple audio formats
133
+
134
+ Example:
135
+
136
+ ```python
137
+ >>> from transformers import pipeline
138
+
139
+ >>> transcriber = pipeline(model="openai/whisper-base")
140
+ >>> transcriber("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac")
141
+ {'text': ' He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered flour-fatten sauce.'}
142
+ ```
143
+
144
+ Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
145
+
146
+ Arguments:
147
+ model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
148
+ The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
149
+ [`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
150
+ feature_extractor ([`SequenceFeatureExtractor`]):
151
+ The feature extractor that will be used by the pipeline to encode waveform for the model.
152
+ tokenizer ([`PreTrainedTokenizer`]):
153
+ The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
154
+ [`PreTrainedTokenizer`].
155
+ decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*):
156
+ [PyCTCDecode's
157
+ BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180)
158
+ can be passed for language model boosted decoding. See [`Wav2Vec2ProcessorWithLM`] for more information.
159
+ chunk_length_s (`float`, *optional*, defaults to 0):
160
+ The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default).
161
+
162
+ <Tip>
163
+
164
+ For more information on how to effectively use `chunk_length_s`, please have a look at the [ASR chunking
165
+ blog post](https://huggingface.co/blog/asr-chunking).
166
+
167
+ </Tip>
168
+
169
+ stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`):
170
+ The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables
171
+ the model to *see* more context and infer letters better than without this context but the pipeline
172
+ discards the stride bits at the end to make the final reconstitution as perfect as possible.
173
+
174
+ <Tip>
175
+
176
+ For more information on how to effectively use `stride_length_s`, please have a look at the [ASR chunking
177
+ blog post](https://huggingface.co/blog/asr-chunking).
178
+
179
+ </Tip>
180
+
181
+ framework (`str`, *optional*):
182
+ The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be
183
+ installed. If no framework is specified, will default to the one currently installed. If no framework is
184
+ specified and both frameworks are installed, will default to the framework of the `model`, or to PyTorch if
185
+ no model is provided.
186
+ device (Union[`int`, `torch.device`], *optional*):
187
+ Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the
188
+ model on the associated CUDA device id.
189
+ torch_dtype (Union[`int`, `torch.dtype`], *optional*):
190
+ The data-type (dtype) of the computation. Setting this to `None` will use float32 precision. Set to
191
+ `torch.float16` or `torch.bfloat16` to use half-precision in the respective dtypes.
192
+
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ model: "PreTrainedModel",
198
+ feature_extractor: Union["SequenceFeatureExtractor", str] = None,
199
+ tokenizer: Optional[PreTrainedTokenizer] = None,
200
+ decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None,
201
+ device: Union[int, "torch.device"] = None,
202
+ torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
203
+ **kwargs,
204
+ ):
205
+ # set the model type so we can check we have the right pre- and post-processing parameters
206
+ if model.config.model_type == "whisper":
207
+ self.type = "seq2seq_whisper"
208
+ elif model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
209
+ self.type = "seq2seq"
210
+ elif (
211
+ feature_extractor._processor_class
212
+ and feature_extractor._processor_class.endswith("WithLM")
213
+ and decoder is not None
214
+ ):
215
+ self.decoder = decoder
216
+ self.type = "ctc_with_lm"
217
+ else:
218
+ self.type = "ctc"
219
+
220
+ super().__init__(model, tokenizer, feature_extractor, device=device, torch_dtype=torch_dtype, **kwargs)
221
+
222
+ def __call__(
223
+ self,
224
+ inputs: Union[np.ndarray, bytes, str],
225
+ **kwargs,
226
+ ):
227
+ """
228
+ Transcribe the audio sequence(s) given as inputs to text. See the [`AutomaticSpeechRecognitionPipeline`]
229
+ documentation for more information.
230
+
231
+ Args:
232
+ inputs (`np.ndarray` or `bytes` or `str` or `dict`):
233
+ The inputs is either :
234
+ - `str` that is either the filename of a local audio file, or a public URL address to download the
235
+ audio file. The file will be read at the correct sampling rate to get the waveform using
236
+ *ffmpeg*. This requires *ffmpeg* to be installed on the system.
237
+ - `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
238
+ same way.
239
+ - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
240
+ Raw audio at the correct sampling rate (no further check will be done)
241
+ - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
242
+ pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "raw":
243
+ np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
244
+ treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
245
+ inference to provide more context to the model). Only use `stride` with CTC models.
246
+ return_timestamps (*optional*, `str` or `bool`):
247
+ Only available for pure CTC models (Wav2Vec2, HuBERT, etc) and the Whisper model. Not available for
248
+ other sequence-to-sequence models.
249
+
250
+ For CTC models, timestamps can take one of two formats:
251
+ - `"char"`: the pipeline will return timestamps along the text for every character in the text. For
252
+ instance, if you get `[{"text": "h", "timestamp": (0.5, 0.6)}, {"text": "i", "timestamp": (0.7,
253
+ 0.9)}]`, then it means the model predicts that the letter "h" was spoken after `0.5` and before
254
+ `0.6` seconds.
255
+ - `"word"`: the pipeline will return timestamps along the text for every word in the text. For
256
+ instance, if you get `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": "there", "timestamp":
257
+ (1.0, 1.5)}]`, then it means the model predicts that the word "hi" was spoken after `0.5` and
258
+ before `0.9` seconds.
259
+
260
+ For the Whisper model, timestamps can take one of two formats:
261
+ - `"word"`: same as above for word-level CTC timestamps. Word-level timestamps are predicted
262
+ through the *dynamic-time warping (DTW)* algorithm, an approximation to word-level timestamps
263
+ by inspecting the cross-attention weights.
264
+ - `True`: the pipeline will return timestamps along the text for *segments* of words in the text.
265
+ For instance, if you get `[{"text": " Hi there!", "timestamp": (0.5, 1.5)}]`, then it means the
266
+ model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds.
267
+ Note that a segment of text refers to a sequence of one or more words, rather than individual
268
+ words as with word-level timestamps.
269
+ generate_kwargs (`dict`, *optional*):
270
+ The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
271
+ complete overview of generate, check the [following
272
+ guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
273
+
274
+ Return:
275
+ `Dict`: A dictionary with the following keys:
276
+ - **text** (`str`): The recognized text.
277
+ - **chunks** (*optional(, `List[Dict]`)
278
+ When using `return_timestamps`, the `chunks` will become a list containing all the various text
279
+ chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text":
280
+ "there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
281
+ `"".join(chunk["text"] for chunk in output["chunks"])`.
282
+ """
283
+ return super().__call__(inputs, **kwargs)
284
+
285
+ def _sanitize_parameters(
286
+ self,
287
+ chunk_length_s=None,
288
+ stride_length_s=None,
289
+ ignore_warning=None,
290
+ decoder_kwargs=None,
291
+ return_timestamps=None,
292
+ return_language=None,
293
+ generate_kwargs=None,
294
+ max_new_tokens=None,
295
+ ):
296
+ # No parameters on this pipeline right now
297
+ preprocess_params = {}
298
+ if chunk_length_s is not None:
299
+ if self.type == "seq2seq" and not ignore_warning:
300
+ logger.warning(
301
+ "Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
302
+ " be entirely accurate and will have caveats. More information:"
303
+ " https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
304
+ " ignore_warning=True)"
305
+ )
306
+ preprocess_params["chunk_length_s"] = chunk_length_s
307
+ if stride_length_s is not None:
308
+ preprocess_params["stride_length_s"] = stride_length_s
309
+
310
+ forward_params = defaultdict(dict)
311
+ if max_new_tokens is not None:
312
+ warnings.warn(
313
+ "`max_new_tokens` is deprecated and will be removed in version 4.49 of Transformers. To remove this warning, pass `max_new_tokens` as a key inside `generate_kwargs` instead.",
314
+ FutureWarning,
315
+ )
316
+ forward_params["max_new_tokens"] = max_new_tokens
317
+ if generate_kwargs is not None:
318
+ if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
319
+ raise ValueError(
320
+ "`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
321
+ " only 1 version"
322
+ )
323
+ forward_params.update(generate_kwargs)
324
+
325
+ postprocess_params = {}
326
+ if decoder_kwargs is not None:
327
+ postprocess_params["decoder_kwargs"] = decoder_kwargs
328
+ if return_timestamps is not None:
329
+ # Check whether we have a valid setting for return_timestamps and throw an error before we perform a forward pass
330
+ if self.type == "seq2seq" and return_timestamps:
331
+ raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!")
332
+ if self.type == "ctc_with_lm" and return_timestamps != "word":
333
+ raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`")
334
+ if self.type == "ctc" and return_timestamps not in ["char", "word"]:
335
+ raise ValueError(
336
+ "CTC can either predict character level timestamps, or word level timestamps. "
337
+ "Set `return_timestamps='char'` or `return_timestamps='word'` as required."
338
+ )
339
+ if self.type == "seq2seq_whisper" and return_timestamps == "char":
340
+ raise ValueError(
341
+ "Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
342
+ "Use `return_timestamps='word'` or `return_timestamps=True` respectively."
343
+ )
344
+ forward_params["return_timestamps"] = return_timestamps
345
+ postprocess_params["return_timestamps"] = return_timestamps
346
+ if return_language is not None:
347
+ if self.type != "seq2seq_whisper":
348
+ raise ValueError("Only Whisper can return language for now.")
349
+ postprocess_params["return_language"] = return_language
350
+
351
+ if self.assistant_model is not None:
352
+ forward_params["assistant_model"] = self.assistant_model
353
+ if self.assistant_tokenizer is not None:
354
+ forward_params["tokenizer"] = self.tokenizer
355
+ forward_params["assistant_tokenizer"] = self.assistant_tokenizer
356
+
357
+ return preprocess_params, forward_params, postprocess_params
358
+
359
+ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
360
+ if isinstance(inputs, str):
361
+ if inputs.startswith("http://") or inputs.startswith("https://"):
362
+ # We need to actually check for a real protocol, otherwise it's impossible to use a local file
363
+ # like http_huggingface_co.png
364
+ inputs = requests.get(inputs).content
365
+ else:
366
+ with open(inputs, "rb") as f:
367
+ inputs = f.read()
368
+
369
+ if isinstance(inputs, bytes):
370
+ inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
371
+
372
+ stride = None
373
+ extra = {}
374
+ if isinstance(inputs, dict):
375
+ stride = inputs.pop("stride", None)
376
+ # Accepting `"array"` which is the key defined in `datasets` for
377
+ # better integration
378
+ if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
379
+ raise ValueError(
380
+ "When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a "
381
+ '"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
382
+ "containing the sampling_rate associated with that array"
383
+ )
384
+
385
+ _inputs = inputs.pop("raw", None)
386
+ if _inputs is None:
387
+ # Remove path which will not be used from `datasets`.
388
+ inputs.pop("path", None)
389
+ _inputs = inputs.pop("array", None)
390
+ in_sampling_rate = inputs.pop("sampling_rate")
391
+ extra = inputs
392
+ inputs = _inputs
393
+ if in_sampling_rate != self.feature_extractor.sampling_rate:
394
+ if is_torchaudio_available():
395
+ from torchaudio import functional as F
396
+ else:
397
+ raise ImportError(
398
+ "torchaudio is required to resample audio samples in AutomaticSpeechRecognitionPipeline. "
399
+ "The torchaudio package can be installed through: `pip install torchaudio`."
400
+ )
401
+
402
+ inputs = F.resample(
403
+ torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
404
+ ).numpy()
405
+ ratio = self.feature_extractor.sampling_rate / in_sampling_rate
406
+ else:
407
+ ratio = 1
408
+ if stride is not None:
409
+ if stride[0] + stride[1] > inputs.shape[0]:
410
+ raise ValueError("Stride is too large for input")
411
+
412
+ # Stride needs to get the chunk length here, it's going to get
413
+ # swallowed by the `feature_extractor` later, and then batching
414
+ # can add extra data in the inputs, so we need to keep track
415
+ # of the original length in the stride so we can cut properly.
416
+ stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
417
+ if not isinstance(inputs, np.ndarray):
418
+ raise TypeError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
419
+ if len(inputs.shape) != 1:
420
+ raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
421
+
422
+ if chunk_length_s:
423
+ if stride_length_s is None:
424
+ stride_length_s = chunk_length_s / 6
425
+
426
+ if isinstance(stride_length_s, (int, float)):
427
+ stride_length_s = [stride_length_s, stride_length_s]
428
+
429
+ # XXX: Carefuly, this variable will not exist in `seq2seq` setting.
430
+ # Currently chunking is not possible at this level for `seq2seq` so
431
+ # it's ok.
432
+ align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
433
+ chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
434
+ stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
435
+ stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
436
+
437
+ if chunk_len < stride_left + stride_right:
438
+ raise ValueError("Chunk length must be superior to stride length")
439
+
440
+ for item in chunk_iter(
441
+ inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
442
+ ):
443
+ yield {**item, **extra}
444
+ else:
445
+ if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples:
446
+ processed = self.feature_extractor(
447
+ inputs,
448
+ sampling_rate=self.feature_extractor.sampling_rate,
449
+ truncation=False,
450
+ padding="longest",
451
+ return_tensors="pt",
452
+ return_attention_mask=True,
453
+ )
454
+ else:
455
+ if self.type == "seq2seq_whisper" and stride is None:
456
+ processed = self.feature_extractor(
457
+ inputs,
458
+ sampling_rate=self.feature_extractor.sampling_rate,
459
+ return_tensors="pt",
460
+ return_token_timestamps=True,
461
+ return_attention_mask=True,
462
+ )
463
+ extra["num_frames"] = processed.pop("num_frames")
464
+ else:
465
+ processed = self.feature_extractor(
466
+ inputs,
467
+ sampling_rate=self.feature_extractor.sampling_rate,
468
+ return_tensors="pt",
469
+ return_attention_mask=True,
470
+ )
471
+ if self.torch_dtype is not None:
472
+ processed = processed.to(dtype=self.torch_dtype)
473
+ if stride is not None:
474
+ if self.type == "seq2seq":
475
+ raise ValueError("Stride is only usable with CTC models, try removing it !")
476
+
477
+ processed["stride"] = stride
478
+ yield {"is_last": True, **processed, **extra}
479
+
480
+ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
481
+ attention_mask = model_inputs.pop("attention_mask", None)
482
+ stride = model_inputs.pop("stride", None)
483
+ num_frames = model_inputs.pop("num_frames", None)
484
+ is_last = model_inputs.pop("is_last")
485
+
486
+ if stride is not None and num_frames is not None:
487
+ raise ValueError("num_frames must be used only when stride is None")
488
+
489
+ if self.type in {"seq2seq", "seq2seq_whisper"}:
490
+ # Consume values so we can let extra information flow freely through
491
+ # the pipeline (important for `partial` in microphone)
492
+ if "input_features" in model_inputs:
493
+ inputs = model_inputs.pop("input_features")
494
+ elif "input_values" in model_inputs:
495
+ inputs = model_inputs.pop("input_values")
496
+ else:
497
+ raise ValueError(
498
+ "Seq2Seq speech recognition model requires either a "
499
+ f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
500
+ )
501
+
502
+ # custom processing for Whisper timestamps and word-level timestamps
503
+ if return_timestamps and self.type == "seq2seq_whisper":
504
+ generate_kwargs["return_timestamps"] = return_timestamps
505
+ if return_timestamps == "word":
506
+ generate_kwargs["return_token_timestamps"] = True
507
+ generate_kwargs["return_segments"] = True
508
+
509
+ if stride is not None:
510
+ if isinstance(stride, tuple):
511
+ generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
512
+ else:
513
+ generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride]
514
+ else:
515
+ generate_kwargs["num_frames"] = num_frames
516
+
517
+ # User-defined `generation_config` passed to the pipeline call take precedence
518
+ if "generation_config" not in generate_kwargs:
519
+ generate_kwargs["generation_config"] = self.generation_config
520
+
521
+ tokens = self.model.generate(
522
+ inputs=inputs,
523
+ attention_mask=attention_mask,
524
+ **generate_kwargs,
525
+ )
526
+ # whisper longform generation stores timestamps in "segments"
527
+ if return_timestamps == "word" and self.type == "seq2seq_whisper":
528
+ if "segments" not in tokens:
529
+ out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
530
+ else:
531
+ token_timestamps = [
532
+ torch.cat([segment["token_timestamps"] for segment in segment_list])
533
+ for segment_list in tokens["segments"]
534
+ ]
535
+ out = {"tokens": tokens["sequences"], "token_timestamps": token_timestamps}
536
+ else:
537
+ out = {"tokens": tokens}
538
+ if self.type == "seq2seq_whisper":
539
+ if stride is not None:
540
+ out["stride"] = stride
541
+
542
+ else:
543
+ inputs = {
544
+ self.model.main_input_name: model_inputs.pop(self.model.main_input_name),
545
+ "attention_mask": attention_mask,
546
+ }
547
+ outputs = self.model(**inputs)
548
+ logits = outputs.logits
549
+
550
+ if self.type == "ctc_with_lm":
551
+ out = {"logits": logits}
552
+ else:
553
+ out = {"tokens": logits.argmax(dim=-1)}
554
+ if stride is not None:
555
+ # Send stride to `postprocess`.
556
+ # it needs to be handled there where
557
+ # the pieces are to be concatenated.
558
+ ratio = 1 / self.model.config.inputs_to_logits_ratio
559
+ if isinstance(stride, tuple):
560
+ out["stride"] = rescale_stride([stride], ratio)[0]
561
+ else:
562
+ out["stride"] = rescale_stride(stride, ratio)
563
+ # Leftover
564
+ extra = model_inputs
565
+ return {"is_last": is_last, **out, **extra}
566
+
567
+ def postprocess(
568
+ self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None, return_language=None
569
+ ):
570
+ # Optional return types
571
+ optional = {}
572
+
573
+ final_items = []
574
+ key = "logits" if self.type == "ctc_with_lm" else "tokens"
575
+ stride = None
576
+ for outputs in model_outputs:
577
+ if self.framework == "pt" and outputs[key].dtype in (torch.bfloat16, torch.float16):
578
+ items = outputs[key].to(torch.float32).numpy()
579
+ else:
580
+ items = outputs[key].numpy()
581
+ stride = outputs.get("stride", None)
582
+ if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
583
+ total_n, left, right = stride
584
+ # Total_n might be < logits.shape[1]
585
+ # because of padding, that's why
586
+ # we need to reconstruct this information
587
+ # This won't work with left padding (which doesn't exist right now)
588
+ right_n = total_n - right
589
+ items = items[:, left:right_n]
590
+ final_items.append(items)
591
+
592
+ if stride and self.type == "seq2seq":
593
+ items = _find_longest_common_sequence(final_items, self.tokenizer)
594
+ elif self.type == "seq2seq_whisper":
595
+ time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions
596
+ # Send the chunking back to seconds, it's easier to handle in whisper
597
+ sampling_rate = self.feature_extractor.sampling_rate
598
+ for output in model_outputs:
599
+ if "stride" in output:
600
+ chunk_len, stride_left, stride_right = output["stride"]
601
+ # Go back in seconds
602
+ chunk_len /= sampling_rate
603
+ stride_left /= sampling_rate
604
+ stride_right /= sampling_rate
605
+ output["stride"] = chunk_len, stride_left, stride_right
606
+
607
+ text, optional = self.tokenizer._decode_asr(
608
+ model_outputs,
609
+ return_timestamps=return_timestamps,
610
+ return_language=return_language,
611
+ time_precision=time_precision,
612
+ )
613
+ else:
614
+ items = np.concatenate(final_items, axis=1)
615
+ items = items.squeeze(0)
616
+
617
+ if self.type == "ctc_with_lm":
618
+ if decoder_kwargs is None:
619
+ decoder_kwargs = {}
620
+ beams = self.decoder.decode_beams(items, **decoder_kwargs)
621
+ text = beams[0][0]
622
+ if return_timestamps:
623
+ # Simply cast from pyctcdecode format to wav2vec2 format to leverage
624
+ # pre-existing code later
625
+ chunk_offset = beams[0][2]
626
+ offsets = []
627
+ for word, (start_offset, end_offset) in chunk_offset:
628
+ offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
629
+ elif self.type != "seq2seq_whisper":
630
+ skip_special_tokens = self.type != "ctc"
631
+ text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
632
+ if return_timestamps:
633
+ offsets = self.tokenizer.decode(
634
+ items, skip_special_tokens=skip_special_tokens, output_char_offsets=True
635
+ )["char_offsets"]
636
+ if return_timestamps == "word":
637
+ offsets = self.tokenizer._get_word_offsets(offsets, self.tokenizer.replace_word_delimiter_char)
638
+
639
+ if return_timestamps and self.type not in {"seq2seq", "seq2seq_whisper"}:
640
+ chunks = []
641
+ for item in offsets:
642
+ start = item["start_offset"] * self.model.config.inputs_to_logits_ratio
643
+ start /= self.feature_extractor.sampling_rate
644
+
645
+ stop = item["end_offset"] * self.model.config.inputs_to_logits_ratio
646
+ stop /= self.feature_extractor.sampling_rate
647
+
648
+ chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
649
+ optional["chunks"] = chunks
650
+
651
+ extra = defaultdict(list)
652
+ for output in model_outputs:
653
+ output.pop("tokens", None)
654
+ output.pop("logits", None)
655
+ output.pop("is_last", None)
656
+ output.pop("stride", None)
657
+ output.pop("token_timestamps", None)
658
+ for k, v in output.items():
659
+ extra[k].append(v)
660
+ return {"text": text, **optional, **extra}
661
+
662
+
663
+ def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions):
664
+ """
665
+ Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since
666
+ `WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only
667
+ iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is
668
+ processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to
669
+ properly compute the final `offset`.
670
+ """
671
+ # index of the first timestamp token
672
+ timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
673
+ items = []
674
+ # approximation of the token to time ratio : ~0.2seconds
675
+ time_precision = feature_extractor.chunk_length / max_source_positions
676
+ time = 0
677
+ for seq_idx, item in enumerate(sequences):
678
+ sequence, stride = item
679
+ if isinstance(sequence, list):
680
+ sequence = np.array(sequence)
681
+ chunk_len, stride_left, stride_right = stride
682
+ sequence = sequence.squeeze(0)
683
+ # get rid of the `forced_decoder_idx` that are use to parametrize the generation
684
+ begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0
685
+ sequence = sequence[begin_idx:]
686
+
687
+ timestamp_tokens = sequence >= timestamp_begin
688
+ if seq_idx != 0 and sum(timestamp_tokens) > 0:
689
+ consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
690
+ last_timestamp = np.where(timestamp_tokens)[0][-1]
691
+ consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
692
+ time -= stride_left + stride_right
693
+ offset = int((time / feature_extractor.sampling_rate) / time_precision)
694
+ overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
695
+ # relevant timestamps are in the overlapping part
696
+ relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0]
697
+ if relevant_timestamp.shape[0] > 0:
698
+ relevant_timestamp = (
699
+ consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0]
700
+ )
701
+ # if a big stride is used, we need to check some of the previous items for the best overlap
702
+ best_match = 0
703
+ sliced_sequence = []
704
+ for idx, previous_sequence in enumerate(reversed(items)):
705
+ previous_tokens = previous_sequence[1:-1]
706
+ if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0:
707
+ break # the previous sequence is too far in the past
708
+ if len(previous_tokens) > 0:
709
+ # find the longest common sequence between the overlapping parts
710
+ index_left, index_right, match_length = _fast_find_longest_common_sequence(
711
+ sequence[1:relevant_timestamp], previous_tokens
712
+ )
713
+ # don't do anything if only 1 token was matched
714
+ if match_length > 1 and match_length > best_match:
715
+ best_match = match_length
716
+ best_idx = idx
717
+ end_of_curr_sequence_idx = (
718
+ np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1
719
+ )
720
+ end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left
721
+ # if all the tokens are matched, suffix
722
+ if index_left == 0 and match_length == len(previous_tokens):
723
+ sliced_sequence = np.insert(
724
+ sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0]
725
+ )
726
+ sliced_sequence[-1] = previous_sequence[-1]
727
+ # if part of the previous sequence is not taken
728
+ elif index_left >= 0:
729
+ sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx]
730
+ # let's insert the missing part of the previous sequence
731
+ previous_slice = (
732
+ previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]]
733
+ )
734
+ sliced_sequence = np.insert(sliced_sequence, 0, previous_slice)
735
+ sliced_sequence[-1] += offset
736
+
737
+ if len(sliced_sequence) > 0:
738
+ items[len(items) - best_idx - 1] = sliced_sequence
739
+ items = items[: len(items) - best_idx]
740
+ sequence = sequence[end_of_curr_sequence_idx:]
741
+
742
+ # sequence might have changed
743
+ timestamp_tokens = sequence >= timestamp_begin
744
+ consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
745
+ if sum(timestamp_tokens) > 0:
746
+ last_timestamp = np.where(timestamp_tokens)[0][-1]
747
+ consecutive = (
748
+ np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive
749
+ )
750
+
751
+ if len(consecutive) > 0:
752
+ last_slice = 0
753
+ for current_slice in consecutive:
754
+ actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0]
755
+ sliced_tokens = sequence[last_slice:current_slice]
756
+ duration = sliced_tokens[-1] - sliced_tokens[0]
757
+ sliced_tokens[0] = actual_offset
758
+ sliced_tokens[-1] = actual_offset + duration
759
+ items.append(sliced_tokens)
760
+ last_slice = current_slice
761
+
762
+ time += chunk_len
763
+ result = []
764
+ for i in range(len(items)):
765
+ result += items[i].tolist()
766
+ return result
.venv/lib/python3.11/site-packages/transformers/pipelines/base.py ADDED
@@ -0,0 +1,1484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import collections
16
+ import copy
17
+ import csv
18
+ import importlib
19
+ import json
20
+ import os
21
+ import pickle
22
+ import sys
23
+ import traceback
24
+ import types
25
+ import warnings
26
+ from abc import ABC, abstractmethod
27
+ from collections import UserDict
28
+ from contextlib import contextmanager
29
+ from os.path import abspath, exists
30
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
31
+
32
+ from ..dynamic_module_utils import custom_object_save
33
+ from ..feature_extraction_utils import PreTrainedFeatureExtractor
34
+ from ..image_processing_utils import BaseImageProcessor
35
+ from ..modelcard import ModelCard
36
+ from ..models.auto import AutoConfig, AutoTokenizer
37
+ from ..processing_utils import ProcessorMixin
38
+ from ..tokenization_utils import PreTrainedTokenizer
39
+ from ..utils import (
40
+ ModelOutput,
41
+ PushToHubMixin,
42
+ add_end_docstrings,
43
+ copy_func,
44
+ infer_framework,
45
+ is_tf_available,
46
+ is_torch_available,
47
+ is_torch_cuda_available,
48
+ is_torch_mlu_available,
49
+ is_torch_mps_available,
50
+ is_torch_musa_available,
51
+ is_torch_npu_available,
52
+ is_torch_xpu_available,
53
+ logging,
54
+ )
55
+
56
+
57
+ GenericTensor = Union[List["GenericTensor"], "torch.Tensor", "tf.Tensor"]
58
+
59
+ if is_tf_available():
60
+ import tensorflow as tf
61
+
62
+ from ..models.auto.modeling_tf_auto import TFAutoModel
63
+
64
+ if is_torch_available():
65
+ import torch
66
+ from torch.utils.data import DataLoader, Dataset
67
+
68
+ from ..models.auto.modeling_auto import AutoModel
69
+
70
+ # Re-export for backward compatibility
71
+ from .pt_utils import KeyDataset
72
+ else:
73
+ Dataset = None
74
+ KeyDataset = None
75
+
76
+ if TYPE_CHECKING:
77
+ from ..modeling_tf_utils import TFPreTrainedModel
78
+ from ..modeling_utils import PreTrainedModel
79
+
80
+
81
+ logger = logging.get_logger(__name__)
82
+
83
+
84
+ def no_collate_fn(items):
85
+ if len(items) != 1:
86
+ raise ValueError("This collate_fn is meant to be used with batch_size=1")
87
+ return items[0]
88
+
89
+
90
+ def _pad(items, key, padding_value, padding_side):
91
+ batch_size = len(items)
92
+ if isinstance(items[0][key], torch.Tensor):
93
+ # Others include `attention_mask` etc...
94
+ shape = items[0][key].shape
95
+ dim = len(shape)
96
+ if dim == 1:
97
+ # We have a list of 1-dim torch tensors, which can be stacked without padding
98
+ return torch.cat([item[key] for item in items], dim=0)
99
+ if key in ["pixel_values", "image"]:
100
+ # This is probable image so padding shouldn't be necessary
101
+ # B, C, H, W
102
+ return torch.cat([item[key] for item in items], dim=0)
103
+ elif dim == 4 and key == "input_features":
104
+ # this is probably a mel spectrogram batched
105
+ return torch.cat([item[key] for item in items], dim=0)
106
+ max_length = max(item[key].shape[1] for item in items)
107
+ min_length = min(item[key].shape[1] for item in items)
108
+ dtype = items[0][key].dtype
109
+
110
+ if dim == 2:
111
+ if max_length == min_length:
112
+ # Bypass for `ImageGPT` which doesn't provide a padding value, yet
113
+ # we can consistently pad since the size should be matching
114
+ return torch.cat([item[key] for item in items], dim=0)
115
+ tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
116
+ elif dim == 3:
117
+ tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
118
+ elif dim == 4:
119
+ tensor = torch.zeros((batch_size, max_length, shape[-2], shape[-1]), dtype=dtype) + padding_value
120
+
121
+ for i, item in enumerate(items):
122
+ if dim == 2:
123
+ if padding_side == "left":
124
+ tensor[i, -len(item[key][0]) :] = item[key][0].clone()
125
+ else:
126
+ tensor[i, : len(item[key][0])] = item[key][0].clone()
127
+ elif dim == 3:
128
+ if padding_side == "left":
129
+ tensor[i, -len(item[key][0]) :, :] = item[key][0].clone()
130
+ else:
131
+ tensor[i, : len(item[key][0]), :] = item[key][0].clone()
132
+ elif dim == 4:
133
+ if padding_side == "left":
134
+ tensor[i, -len(item[key][0]) :, :, :] = item[key][0].clone()
135
+ else:
136
+ tensor[i, : len(item[key][0]), :, :] = item[key][0].clone()
137
+
138
+ return tensor
139
+ else:
140
+ return [item[key] for item in items]
141
+
142
+
143
+ def pad_collate_fn(tokenizer, feature_extractor):
144
+ # Tokenizer
145
+ t_padding_side = None
146
+ # Feature extractor
147
+ f_padding_side = None
148
+ if tokenizer is None and feature_extractor is None:
149
+ raise ValueError("Pipeline without tokenizer or feature_extractor cannot do batching")
150
+ if tokenizer is not None:
151
+ if tokenizer.pad_token_id is None:
152
+ raise ValueError(
153
+ "Pipeline with tokenizer without pad_token cannot do batching. You can try to set it with "
154
+ "`pipe.tokenizer.pad_token_id = model.config.eos_token_id`."
155
+ )
156
+ else:
157
+ t_padding_value = tokenizer.pad_token_id
158
+ t_padding_side = tokenizer.padding_side
159
+ if feature_extractor is not None:
160
+ # Feature extractor can be images, where no padding is expected
161
+ f_padding_value = getattr(feature_extractor, "padding_value", None)
162
+ f_padding_side = getattr(feature_extractor, "padding_side", None)
163
+
164
+ if t_padding_side is not None and f_padding_side is not None and t_padding_side != f_padding_side:
165
+ raise ValueError(
166
+ f"The feature extractor, and tokenizer don't agree on padding side {t_padding_side} != {f_padding_side}"
167
+ )
168
+ padding_side = "right"
169
+ if t_padding_side is not None:
170
+ padding_side = t_padding_side
171
+ if f_padding_side is not None:
172
+ padding_side = f_padding_side
173
+
174
+ def inner(items):
175
+ keys = set(items[0].keys())
176
+ for item in items:
177
+ if set(item.keys()) != keys:
178
+ raise ValueError(
179
+ f"The elements of the batch contain different keys. Cannot batch them ({set(item.keys())} !="
180
+ f" {keys})"
181
+ )
182
+ # input_values, input_pixels, input_ids, ...
183
+ padded = {}
184
+ for key in keys:
185
+ if key in {"input_ids"}:
186
+ # ImageGPT uses a feature extractor
187
+ if tokenizer is None and feature_extractor is not None:
188
+ _padding_value = f_padding_value
189
+ else:
190
+ _padding_value = t_padding_value
191
+ elif key in {"input_values", "pixel_values", "input_features"}:
192
+ _padding_value = f_padding_value
193
+ elif key in {"p_mask", "special_tokens_mask"}:
194
+ _padding_value = 1
195
+ elif key in {"attention_mask", "token_type_ids"}:
196
+ _padding_value = 0
197
+ else:
198
+ # This is likely another random key maybe even user provided
199
+ _padding_value = 0
200
+ padded[key] = _pad(items, key, _padding_value, padding_side)
201
+ return padded
202
+
203
+ return inner
204
+
205
+
206
+ def infer_framework_load_model(
207
+ model,
208
+ config: AutoConfig,
209
+ model_classes: Optional[Dict[str, Tuple[type]]] = None,
210
+ task: Optional[str] = None,
211
+ framework: Optional[str] = None,
212
+ **model_kwargs,
213
+ ):
214
+ """
215
+ Select framework (TensorFlow or PyTorch) to use from the `model` passed. Returns a tuple (framework, model).
216
+
217
+ If `model` is instantiated, this function will just infer the framework from the model class. Otherwise `model` is
218
+ actually a checkpoint name and this method will try to instantiate it using `model_classes`. Since we don't want to
219
+ instantiate the model twice, this model is returned for use by the pipeline.
220
+
221
+ If both frameworks are installed and available for `model`, PyTorch is selected.
222
+
223
+ Args:
224
+ model (`str`, [`PreTrainedModel`] or [`TFPreTrainedModel]`):
225
+ The model to infer the framework from. If `str`, a checkpoint name. The model to infer the framewrok from.
226
+ config ([`AutoConfig`]):
227
+ The config associated with the model to help using the correct class
228
+ model_classes (dictionary `str` to `type`, *optional*):
229
+ A mapping framework to class.
230
+ task (`str`):
231
+ The task defining which pipeline will be returned.
232
+ model_kwargs:
233
+ Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
234
+ **model_kwargs)` function.
235
+
236
+ Returns:
237
+ `Tuple`: A tuple framework, model.
238
+ """
239
+ if not is_tf_available() and not is_torch_available():
240
+ raise RuntimeError(
241
+ "At least one of TensorFlow 2.0 or PyTorch should be installed. "
242
+ "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
243
+ "To install PyTorch, read the instructions at https://pytorch.org/."
244
+ )
245
+ if isinstance(model, str):
246
+ model_kwargs["_from_pipeline"] = task
247
+ class_tuple = ()
248
+ look_pt = is_torch_available() and framework in {"pt", None}
249
+ look_tf = is_tf_available() and framework in {"tf", None}
250
+ if model_classes:
251
+ if look_pt:
252
+ class_tuple = class_tuple + model_classes.get("pt", (AutoModel,))
253
+ if look_tf:
254
+ class_tuple = class_tuple + model_classes.get("tf", (TFAutoModel,))
255
+ if config.architectures:
256
+ classes = []
257
+ for architecture in config.architectures:
258
+ transformers_module = importlib.import_module("transformers")
259
+ if look_pt:
260
+ _class = getattr(transformers_module, architecture, None)
261
+ if _class is not None:
262
+ classes.append(_class)
263
+ if look_tf:
264
+ _class = getattr(transformers_module, f"TF{architecture}", None)
265
+ if _class is not None:
266
+ classes.append(_class)
267
+ class_tuple = class_tuple + tuple(classes)
268
+
269
+ if len(class_tuple) == 0:
270
+ raise ValueError(f"Pipeline cannot infer suitable model classes from {model}")
271
+
272
+ all_traceback = {}
273
+ for model_class in class_tuple:
274
+ kwargs = model_kwargs.copy()
275
+ if framework == "pt" and model.endswith(".h5"):
276
+ kwargs["from_tf"] = True
277
+ logger.warning(
278
+ "Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
279
+ "Trying to load the model with PyTorch."
280
+ )
281
+ elif framework == "tf" and model.endswith(".bin"):
282
+ kwargs["from_pt"] = True
283
+ logger.warning(
284
+ "Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
285
+ "Trying to load the model with Tensorflow."
286
+ )
287
+
288
+ try:
289
+ model = model_class.from_pretrained(model, **kwargs)
290
+ if hasattr(model, "eval"):
291
+ model = model.eval()
292
+ # Stop loading on the first successful load.
293
+ break
294
+ except (OSError, ValueError):
295
+ all_traceback[model_class.__name__] = traceback.format_exc()
296
+ continue
297
+
298
+ if isinstance(model, str):
299
+ error = ""
300
+ for class_name, trace in all_traceback.items():
301
+ error += f"while loading with {class_name}, an error is thrown:\n{trace}\n"
302
+ raise ValueError(
303
+ f"Could not load model {model} with any of the following classes: {class_tuple}. See the original errors:\n\n{error}\n"
304
+ )
305
+
306
+ if framework is None:
307
+ framework = infer_framework(model.__class__)
308
+ return framework, model
309
+
310
+
311
+ def infer_framework_from_model(
312
+ model,
313
+ model_classes: Optional[Dict[str, Tuple[type]]] = None,
314
+ task: Optional[str] = None,
315
+ framework: Optional[str] = None,
316
+ **model_kwargs,
317
+ ):
318
+ """
319
+ Select framework (TensorFlow or PyTorch) to use from the `model` passed. Returns a tuple (framework, model).
320
+
321
+ If `model` is instantiated, this function will just infer the framework from the model class. Otherwise `model` is
322
+ actually a checkpoint name and this method will try to instantiate it using `model_classes`. Since we don't want to
323
+ instantiate the model twice, this model is returned for use by the pipeline.
324
+
325
+ If both frameworks are installed and available for `model`, PyTorch is selected.
326
+
327
+ Args:
328
+ model (`str`, [`PreTrainedModel`] or [`TFPreTrainedModel]`):
329
+ The model to infer the framework from. If `str`, a checkpoint name. The model to infer the framewrok from.
330
+ model_classes (dictionary `str` to `type`, *optional*):
331
+ A mapping framework to class.
332
+ task (`str`):
333
+ The task defining which pipeline will be returned.
334
+ model_kwargs:
335
+ Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
336
+ **model_kwargs)` function.
337
+
338
+ Returns:
339
+ `Tuple`: A tuple framework, model.
340
+ """
341
+ if isinstance(model, str):
342
+ config = AutoConfig.from_pretrained(model, _from_pipeline=task, **model_kwargs)
343
+ else:
344
+ config = model.config
345
+ return infer_framework_load_model(
346
+ model, config, model_classes=model_classes, _from_pipeline=task, task=task, framework=framework, **model_kwargs
347
+ )
348
+
349
+
350
+ def get_framework(model, revision: Optional[str] = None):
351
+ """
352
+ Select framework (TensorFlow or PyTorch) to use.
353
+
354
+ Args:
355
+ model (`str`, [`PreTrainedModel`] or [`TFPreTrainedModel]`):
356
+ If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
357
+ the model name). If no specific model is provided, defaults to using PyTorch.
358
+ """
359
+ warnings.warn(
360
+ "`get_framework` is deprecated and will be removed in v5, use `infer_framework_from_model` instead.",
361
+ FutureWarning,
362
+ )
363
+ if not is_tf_available() and not is_torch_available():
364
+ raise RuntimeError(
365
+ "At least one of TensorFlow 2.0 or PyTorch should be installed. "
366
+ "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
367
+ "To install PyTorch, read the instructions at https://pytorch.org/."
368
+ )
369
+ if isinstance(model, str):
370
+ if is_torch_available() and not is_tf_available():
371
+ model = AutoModel.from_pretrained(model, revision=revision)
372
+ elif is_tf_available() and not is_torch_available():
373
+ model = TFAutoModel.from_pretrained(model, revision=revision)
374
+ else:
375
+ try:
376
+ model = AutoModel.from_pretrained(model, revision=revision)
377
+ except OSError:
378
+ model = TFAutoModel.from_pretrained(model, revision=revision)
379
+
380
+ framework = infer_framework(model.__class__)
381
+ return framework
382
+
383
+
384
+ def get_default_model_and_revision(
385
+ targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]
386
+ ) -> Union[str, Tuple[str, str]]:
387
+ """
388
+ Select a default model to use for a given task. Defaults to pytorch if ambiguous.
389
+
390
+ Args:
391
+ targeted_task (`Dict`):
392
+ Dictionary representing the given task, that should contain default models
393
+
394
+ framework (`str`, None)
395
+ "pt", "tf" or None, representing a specific framework if it was specified, or None if we don't know yet.
396
+
397
+ task_options (`Any`, None)
398
+ Any further value required by the task to get fully specified, for instance (SRC, TGT) languages for
399
+ translation task.
400
+
401
+ Returns
402
+
403
+ `str` The model string representing the default model for this pipeline
404
+ """
405
+ if is_torch_available() and not is_tf_available():
406
+ framework = "pt"
407
+ elif is_tf_available() and not is_torch_available():
408
+ framework = "tf"
409
+
410
+ defaults = targeted_task["default"]
411
+ if task_options:
412
+ if task_options not in defaults:
413
+ raise ValueError(f"The task does not provide any default models for options {task_options}")
414
+ default_models = defaults[task_options]["model"]
415
+ elif "model" in defaults:
416
+ default_models = targeted_task["default"]["model"]
417
+ else:
418
+ # XXX This error message needs to be updated to be more generic if more tasks are going to become
419
+ # parametrized
420
+ raise ValueError('The task defaults can\'t be correctly selected. You probably meant "translation_XX_to_YY"')
421
+
422
+ if framework is None:
423
+ framework = "pt"
424
+
425
+ return default_models[framework]
426
+
427
+
428
+ def load_assistant_model(
429
+ model: "PreTrainedModel",
430
+ assistant_model: Optional[Union[str, "PreTrainedModel"]],
431
+ assistant_tokenizer: Optional[PreTrainedTokenizer],
432
+ ) -> Tuple[Optional["PreTrainedModel"], Optional[PreTrainedTokenizer]]:
433
+ """
434
+ Prepares the assistant model and the assistant tokenizer for a pipeline whose model that can call `generate`.
435
+
436
+ Args:
437
+ model ([`PreTrainedModel`]):
438
+ The main model that will be used by the pipeline to make predictions.
439
+ assistant_model (`str` or [`PreTrainedModel`], *optional*):
440
+ The assistant model that will be used by the pipeline to make predictions.
441
+ assistant_tokenizer ([`PreTrainedTokenizer`], *optional*):
442
+ The assistant tokenizer that will be used by the pipeline to encode data for the model.
443
+
444
+ Returns:
445
+ Tuple: The loaded assistant model and (optionally) the loaded tokenizer.
446
+ """
447
+ if not model.can_generate() or assistant_model is None:
448
+ return None, None
449
+
450
+ if not isinstance(model, PreTrainedModel):
451
+ raise ValueError(
452
+ "Assisted generation, triggered by the `assistant_model` argument, is only available for "
453
+ "`PreTrainedModel` model instances. For instance, TF or JAX models are not supported."
454
+ )
455
+
456
+ # If the model is passed as a string, load the model and the corresponding tokenizer
457
+ if isinstance(assistant_model, str):
458
+ assistant_config = AutoConfig.from_pretrained(assistant_model)
459
+ _, loaded_assistant_model = infer_framework_load_model(assistant_model, config=assistant_config)
460
+ loaded_assistant_model = loaded_assistant_model.to(device=model.device, dtype=model.dtype)
461
+ loaded_assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_model)
462
+ else:
463
+ loaded_assistant_model = assistant_model
464
+ loaded_assistant_tokenizer = assistant_tokenizer
465
+
466
+ # Finally, let's check the tokenizers: if the two models have different tokenizers, we need to keep the assistant
467
+ # tokenizer
468
+ same_vocab_size = model.config.vocab_size == loaded_assistant_model.config.vocab_size
469
+ same_special_tokens = all(
470
+ getattr(model.config, token) == getattr(loaded_assistant_model.config, token)
471
+ for token in ("eos_token_id", "pad_token_id", "bos_token_id")
472
+ )
473
+ if same_vocab_size and same_special_tokens:
474
+ loaded_assistant_tokenizer = None
475
+ elif loaded_assistant_tokenizer is None:
476
+ raise ValueError(
477
+ "The assistant model has a different tokenizer than the main model. You should pass the assistant "
478
+ "tokenizer."
479
+ )
480
+
481
+ return loaded_assistant_model, loaded_assistant_tokenizer
482
+
483
+
484
+ class PipelineException(Exception):
485
+ """
486
+ Raised by a [`Pipeline`] when handling __call__.
487
+
488
+ Args:
489
+ task (`str`): The task of the pipeline.
490
+ model (`str`): The model used by the pipeline.
491
+ reason (`str`): The error message to display.
492
+ """
493
+
494
+ def __init__(self, task: str, model: str, reason: str):
495
+ super().__init__(reason)
496
+
497
+ self.task = task
498
+ self.model = model
499
+
500
+
501
+ class ArgumentHandler(ABC):
502
+ """
503
+ Base interface for handling arguments for each [`~pipelines.Pipeline`].
504
+ """
505
+
506
+ @abstractmethod
507
+ def __call__(self, *args, **kwargs):
508
+ raise NotImplementedError()
509
+
510
+
511
+ class PipelineDataFormat:
512
+ """
513
+ Base class for all the pipeline supported data format both for reading and writing. Supported data formats
514
+ currently includes:
515
+
516
+ - JSON
517
+ - CSV
518
+ - stdin/stdout (pipe)
519
+
520
+ `PipelineDataFormat` also includes some utilities to work with multi-columns like mapping from datasets columns to
521
+ pipelines keyword arguments through the `dataset_kwarg_1=dataset_column_1` format.
522
+
523
+ Args:
524
+ output_path (`str`): Where to save the outgoing data.
525
+ input_path (`str`): Where to look for the input data.
526
+ column (`str`): The column to read.
527
+ overwrite (`bool`, *optional*, defaults to `False`):
528
+ Whether or not to overwrite the `output_path`.
529
+ """
530
+
531
+ SUPPORTED_FORMATS = ["json", "csv", "pipe"]
532
+
533
+ def __init__(
534
+ self,
535
+ output_path: Optional[str],
536
+ input_path: Optional[str],
537
+ column: Optional[str],
538
+ overwrite: bool = False,
539
+ ):
540
+ self.output_path = output_path
541
+ self.input_path = input_path
542
+ self.column = column.split(",") if column is not None else [""]
543
+ self.is_multi_columns = len(self.column) > 1
544
+
545
+ if self.is_multi_columns:
546
+ self.column = [tuple(c.split("=")) if "=" in c else (c, c) for c in self.column]
547
+
548
+ if output_path is not None and not overwrite:
549
+ if exists(abspath(self.output_path)):
550
+ raise OSError(f"{self.output_path} already exists on disk")
551
+
552
+ if input_path is not None:
553
+ if not exists(abspath(self.input_path)):
554
+ raise OSError(f"{self.input_path} doesnt exist on disk")
555
+
556
+ @abstractmethod
557
+ def __iter__(self):
558
+ raise NotImplementedError()
559
+
560
+ @abstractmethod
561
+ def save(self, data: Union[dict, List[dict]]):
562
+ """
563
+ Save the provided data object with the representation for the current [`~pipelines.PipelineDataFormat`].
564
+
565
+ Args:
566
+ data (`dict` or list of `dict`): The data to store.
567
+ """
568
+ raise NotImplementedError()
569
+
570
+ def save_binary(self, data: Union[dict, List[dict]]) -> str:
571
+ """
572
+ Save the provided data object as a pickle-formatted binary data on the disk.
573
+
574
+ Args:
575
+ data (`dict` or list of `dict`): The data to store.
576
+
577
+ Returns:
578
+ `str`: Path where the data has been saved.
579
+ """
580
+ path, _ = os.path.splitext(self.output_path)
581
+ binary_path = os.path.extsep.join((path, "pickle"))
582
+
583
+ with open(binary_path, "wb+") as f_output:
584
+ pickle.dump(data, f_output)
585
+
586
+ return binary_path
587
+
588
+ @staticmethod
589
+ def from_str(
590
+ format: str,
591
+ output_path: Optional[str],
592
+ input_path: Optional[str],
593
+ column: Optional[str],
594
+ overwrite=False,
595
+ ) -> "PipelineDataFormat":
596
+ """
597
+ Creates an instance of the right subclass of [`~pipelines.PipelineDataFormat`] depending on `format`.
598
+
599
+ Args:
600
+ format (`str`):
601
+ The format of the desired pipeline. Acceptable values are `"json"`, `"csv"` or `"pipe"`.
602
+ output_path (`str`, *optional*):
603
+ Where to save the outgoing data.
604
+ input_path (`str`, *optional*):
605
+ Where to look for the input data.
606
+ column (`str`, *optional*):
607
+ The column to read.
608
+ overwrite (`bool`, *optional*, defaults to `False`):
609
+ Whether or not to overwrite the `output_path`.
610
+
611
+ Returns:
612
+ [`~pipelines.PipelineDataFormat`]: The proper data format.
613
+ """
614
+ if format == "json":
615
+ return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
616
+ elif format == "csv":
617
+ return CsvPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
618
+ elif format == "pipe":
619
+ return PipedPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
620
+ else:
621
+ raise KeyError(f"Unknown reader {format} (Available reader are json/csv/pipe)")
622
+
623
+
624
+ class CsvPipelineDataFormat(PipelineDataFormat):
625
+ """
626
+ Support for pipelines using CSV data format.
627
+
628
+ Args:
629
+ output_path (`str`): Where to save the outgoing data.
630
+ input_path (`str`): Where to look for the input data.
631
+ column (`str`): The column to read.
632
+ overwrite (`bool`, *optional*, defaults to `False`):
633
+ Whether or not to overwrite the `output_path`.
634
+ """
635
+
636
+ def __init__(
637
+ self,
638
+ output_path: Optional[str],
639
+ input_path: Optional[str],
640
+ column: Optional[str],
641
+ overwrite=False,
642
+ ):
643
+ super().__init__(output_path, input_path, column, overwrite=overwrite)
644
+
645
+ def __iter__(self):
646
+ with open(self.input_path, "r") as f:
647
+ reader = csv.DictReader(f)
648
+ for row in reader:
649
+ if self.is_multi_columns:
650
+ yield {k: row[c] for k, c in self.column}
651
+ else:
652
+ yield row[self.column[0]]
653
+
654
+ def save(self, data: List[dict]):
655
+ """
656
+ Save the provided data object with the representation for the current [`~pipelines.PipelineDataFormat`].
657
+
658
+ Args:
659
+ data (`List[dict]`): The data to store.
660
+ """
661
+ with open(self.output_path, "w") as f:
662
+ if len(data) > 0:
663
+ writer = csv.DictWriter(f, list(data[0].keys()))
664
+ writer.writeheader()
665
+ writer.writerows(data)
666
+
667
+
668
+ class JsonPipelineDataFormat(PipelineDataFormat):
669
+ """
670
+ Support for pipelines using JSON file format.
671
+
672
+ Args:
673
+ output_path (`str`): Where to save the outgoing data.
674
+ input_path (`str`): Where to look for the input data.
675
+ column (`str`): The column to read.
676
+ overwrite (`bool`, *optional*, defaults to `False`):
677
+ Whether or not to overwrite the `output_path`.
678
+ """
679
+
680
+ def __init__(
681
+ self,
682
+ output_path: Optional[str],
683
+ input_path: Optional[str],
684
+ column: Optional[str],
685
+ overwrite=False,
686
+ ):
687
+ super().__init__(output_path, input_path, column, overwrite=overwrite)
688
+
689
+ with open(input_path, "r") as f:
690
+ self._entries = json.load(f)
691
+
692
+ def __iter__(self):
693
+ for entry in self._entries:
694
+ if self.is_multi_columns:
695
+ yield {k: entry[c] for k, c in self.column}
696
+ else:
697
+ yield entry[self.column[0]]
698
+
699
+ def save(self, data: dict):
700
+ """
701
+ Save the provided data object in a json file.
702
+
703
+ Args:
704
+ data (`dict`): The data to store.
705
+ """
706
+ with open(self.output_path, "w") as f:
707
+ json.dump(data, f)
708
+
709
+
710
+ class PipedPipelineDataFormat(PipelineDataFormat):
711
+ """
712
+ Read data from piped input to the python process. For multi columns data, columns should separated by \t
713
+
714
+ If columns are provided, then the output will be a dictionary with {column_x: value_x}
715
+
716
+ Args:
717
+ output_path (`str`): Where to save the outgoing data.
718
+ input_path (`str`): Where to look for the input data.
719
+ column (`str`): The column to read.
720
+ overwrite (`bool`, *optional*, defaults to `False`):
721
+ Whether or not to overwrite the `output_path`.
722
+ """
723
+
724
+ def __iter__(self):
725
+ for line in sys.stdin:
726
+ # Split for multi-columns
727
+ if "\t" in line:
728
+ line = line.split("\t")
729
+ if self.column:
730
+ # Dictionary to map arguments
731
+ yield {kwargs: l for (kwargs, _), l in zip(self.column, line)}
732
+ else:
733
+ yield tuple(line)
734
+
735
+ # No dictionary to map arguments
736
+ else:
737
+ yield line
738
+
739
+ def save(self, data: dict):
740
+ """
741
+ Print the data.
742
+
743
+ Args:
744
+ data (`dict`): The data to store.
745
+ """
746
+ print(data)
747
+
748
+ def save_binary(self, data: Union[dict, List[dict]]) -> str:
749
+ if self.output_path is None:
750
+ raise KeyError(
751
+ "When using piped input on pipeline outputting large object requires an output file path. "
752
+ "Please provide such output path through --output argument."
753
+ )
754
+
755
+ return super().save_binary(data)
756
+
757
+
758
+ class _ScikitCompat(ABC):
759
+ """
760
+ Interface layer for the Scikit and Keras compatibility.
761
+ """
762
+
763
+ @abstractmethod
764
+ def transform(self, X):
765
+ raise NotImplementedError()
766
+
767
+ @abstractmethod
768
+ def predict(self, X):
769
+ raise NotImplementedError()
770
+
771
+
772
+ def build_pipeline_init_args(
773
+ has_tokenizer: bool = False,
774
+ has_feature_extractor: bool = False,
775
+ has_image_processor: bool = False,
776
+ has_processor: bool = False,
777
+ supports_binary_output: bool = True,
778
+ ) -> str:
779
+ docstring = r"""
780
+ Arguments:
781
+ model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
782
+ The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
783
+ [`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow."""
784
+ if has_tokenizer:
785
+ docstring += r"""
786
+ tokenizer ([`PreTrainedTokenizer`]):
787
+ The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
788
+ [`PreTrainedTokenizer`]."""
789
+ if has_feature_extractor:
790
+ docstring += r"""
791
+ feature_extractor ([`SequenceFeatureExtractor`]):
792
+ The feature extractor that will be used by the pipeline to encode data for the model. This object inherits from
793
+ [`SequenceFeatureExtractor`]."""
794
+ if has_image_processor:
795
+ docstring += r"""
796
+ image_processor ([`BaseImageProcessor`]):
797
+ The image processor that will be used by the pipeline to encode data for the model. This object inherits from
798
+ [`BaseImageProcessor`]."""
799
+ if has_processor:
800
+ docstring += r"""
801
+ processor ([`ProcessorMixin`]):
802
+ The processor that will be used by the pipeline to encode data for the model. This object inherits from
803
+ [`ProcessorMixin`]. Processor is a composite object that might contain `tokenizer`, `feature_extractor`, and
804
+ `image_processor`."""
805
+ docstring += r"""
806
+ modelcard (`str` or [`ModelCard`], *optional*):
807
+ Model card attributed to the model for this pipeline.
808
+ framework (`str`, *optional*):
809
+ The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be
810
+ installed.
811
+
812
+ If no framework is specified, will default to the one currently installed. If no framework is specified and
813
+ both frameworks are installed, will default to the framework of the `model`, or to PyTorch if no model is
814
+ provided.
815
+ task (`str`, defaults to `""`):
816
+ A task-identifier for the pipeline.
817
+ num_workers (`int`, *optional*, defaults to 8):
818
+ When the pipeline will use *DataLoader* (when passing a dataset, on GPU for a Pytorch model), the number of
819
+ workers to be used.
820
+ batch_size (`int`, *optional*, defaults to 1):
821
+ When the pipeline will use *DataLoader* (when passing a dataset, on GPU for a Pytorch model), the size of
822
+ the batch to use, for inference this is not always beneficial, please read [Batching with
823
+ pipelines](https://huggingface.co/transformers/main_classes/pipelines.html#pipeline-batching) .
824
+ args_parser ([`~pipelines.ArgumentHandler`], *optional*):
825
+ Reference to the object in charge of parsing supplied pipeline parameters.
826
+ device (`int`, *optional*, defaults to -1):
827
+ Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
828
+ the associated CUDA device id. You can pass native `torch.device` or a `str` too
829
+ torch_dtype (`str` or `torch.dtype`, *optional*):
830
+ Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model
831
+ (`torch.float16`, `torch.bfloat16`, ... or `"auto"`)"""
832
+ if supports_binary_output:
833
+ docstring += r"""
834
+ binary_output (`bool`, *optional*, defaults to `False`):
835
+ Flag indicating if the output the pipeline should happen in a serialized format (i.e., pickle) or as
836
+ the raw output data e.g. text."""
837
+ return docstring
838
+
839
+
840
+ PIPELINE_INIT_ARGS = build_pipeline_init_args(
841
+ has_tokenizer=True,
842
+ has_feature_extractor=True,
843
+ has_image_processor=True,
844
+ has_processor=True,
845
+ supports_binary_output=True,
846
+ )
847
+
848
+
849
+ if is_torch_available():
850
+ from transformers.pipelines.pt_utils import (
851
+ PipelineChunkIterator,
852
+ PipelineDataset,
853
+ PipelineIterator,
854
+ PipelinePackIterator,
855
+ )
856
+
857
+
858
+ @add_end_docstrings(
859
+ build_pipeline_init_args(
860
+ has_tokenizer=True, has_feature_extractor=True, has_image_processor=True, has_processor=True
861
+ )
862
+ )
863
+ class Pipeline(_ScikitCompat, PushToHubMixin):
864
+ """
865
+ The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
866
+ different pipelines.
867
+
868
+ Base class implementing pipelined operations. Pipeline workflow is defined as a sequence of the following
869
+ operations:
870
+
871
+ Input -> Tokenization -> Model Inference -> Post-Processing (task dependent) -> Output
872
+
873
+ Pipeline supports running on CPU or GPU through the device argument (see below).
874
+
875
+ Some pipeline, like for instance [`FeatureExtractionPipeline`] (`'feature-extraction'`) output large tensor object
876
+ as nested-lists. In order to avoid dumping such large structure as textual data we provide the `binary_output`
877
+ constructor argument. If set to `True`, the output will be stored in the pickle format.
878
+ """
879
+
880
+ # Historically we have pipelines working with `tokenizer`, `feature_extractor`, and `image_processor`
881
+ # as separate processing components. While we have `processor` class that combines them, some pipelines
882
+ # might still operate with these components separately.
883
+ # With the addition of `processor` to `pipeline`, we want to avoid:
884
+ # - loading `processor` for pipelines that still work with `image_processor` and `tokenizer` separately;
885
+ # - loading `image_processor`/`tokenizer` as a separate component while we operate only with `processor`,
886
+ # because `processor` will load required sub-components by itself.
887
+ # Below flags allow granular control over loading components and set to be backward compatible with current
888
+ # pipelines logic. You may override these flags when creating your pipeline. For example, for
889
+ # `zero-shot-object-detection` pipeline which operates with `processor` you should set `_load_processor=True`
890
+ # and all the rest flags to `False` to avoid unnecessary loading of the components.
891
+ _load_processor = False
892
+ _load_image_processor = True
893
+ _load_feature_extractor = True
894
+ _load_tokenizer = True
895
+
896
+ default_input_names = None
897
+
898
+ def __init__(
899
+ self,
900
+ model: Union["PreTrainedModel", "TFPreTrainedModel"],
901
+ tokenizer: Optional[PreTrainedTokenizer] = None,
902
+ feature_extractor: Optional[PreTrainedFeatureExtractor] = None,
903
+ image_processor: Optional[BaseImageProcessor] = None,
904
+ processor: Optional[ProcessorMixin] = None,
905
+ modelcard: Optional[ModelCard] = None,
906
+ framework: Optional[str] = None,
907
+ task: str = "",
908
+ args_parser: ArgumentHandler = None,
909
+ device: Union[int, "torch.device"] = None,
910
+ torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
911
+ binary_output: bool = False,
912
+ **kwargs,
913
+ ):
914
+ if framework is None:
915
+ framework, model = infer_framework_load_model(model, config=model.config)
916
+
917
+ self.task = task
918
+ self.model = model
919
+ self.tokenizer = tokenizer
920
+ self.feature_extractor = feature_extractor
921
+ self.image_processor = image_processor
922
+ self.processor = processor
923
+ self.modelcard = modelcard
924
+ self.framework = framework
925
+
926
+ # `accelerate` device map
927
+ hf_device_map = getattr(self.model, "hf_device_map", None)
928
+
929
+ if hf_device_map is not None and device is not None:
930
+ raise ValueError(
931
+ "The model has been loaded with `accelerate` and therefore cannot be moved to a specific device. Please "
932
+ "discard the `device` argument when creating your pipeline object."
933
+ )
934
+
935
+ if device is None:
936
+ if hf_device_map is not None:
937
+ # Take the first device used by `accelerate`.
938
+ device = next(iter(hf_device_map.values()))
939
+ else:
940
+ device = 0
941
+
942
+ if is_torch_available() and self.framework == "pt":
943
+ if device == -1 and self.model.device is not None:
944
+ device = self.model.device
945
+ if isinstance(device, torch.device):
946
+ if device.type == "xpu" and not is_torch_xpu_available(check_device=True):
947
+ raise ValueError(f'{device} is not available, you should use device="cpu" instead')
948
+ self.device = device
949
+ elif isinstance(device, str):
950
+ if "xpu" in device and not is_torch_xpu_available(check_device=True):
951
+ raise ValueError(f'{device} is not available, you should use device="cpu" instead')
952
+ self.device = torch.device(device)
953
+ elif device < 0:
954
+ self.device = torch.device("cpu")
955
+ elif is_torch_mlu_available():
956
+ self.device = torch.device(f"mlu:{device}")
957
+ elif is_torch_musa_available():
958
+ self.device = torch.device(f"musa:{device}")
959
+ elif is_torch_cuda_available():
960
+ self.device = torch.device(f"cuda:{device}")
961
+ elif is_torch_npu_available():
962
+ self.device = torch.device(f"npu:{device}")
963
+ elif is_torch_xpu_available(check_device=True):
964
+ self.device = torch.device(f"xpu:{device}")
965
+ elif is_torch_mps_available():
966
+ self.device = torch.device(f"mps:{device}")
967
+ else:
968
+ self.device = torch.device("cpu")
969
+ else:
970
+ self.device = device if device is not None else -1
971
+
972
+ logger.warning(f"Device set to use {self.device}")
973
+
974
+ self.binary_output = binary_output
975
+ # We shouldn't call `model.to()` for models loaded with accelerate as well as the case that model is already on device
976
+ if (
977
+ self.framework == "pt"
978
+ and self.model.device != self.device
979
+ and not (isinstance(self.device, int) and self.device < 0)
980
+ and hf_device_map is None
981
+ ):
982
+ self.model.to(self.device)
983
+
984
+ # If the model can generate:
985
+ # 1 - create a local generation config. This is done to avoid side-effects on the model as we apply local
986
+ # tweaks to the generation config.
987
+ # 2 - load the assistant model if it is passed.
988
+ self.assistant_model, self.assistant_tokenizer = load_assistant_model(
989
+ self.model, kwargs.pop("assistant_model", None), kwargs.pop("assistant_tokenizer", None)
990
+ )
991
+ if self.model.can_generate():
992
+ self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None
993
+ self.generation_config = copy.deepcopy(self.model.generation_config)
994
+ # Update the generation config with task specific params if they exist
995
+ # NOTE: `prefix` is pipeline-specific and doesn't exist in the generation config.
996
+ task_specific_params = self.model.config.task_specific_params
997
+ if task_specific_params is not None and task in task_specific_params:
998
+ this_task_params = task_specific_params.get(task)
999
+ if "prefix" in this_task_params:
1000
+ self.prefix = this_task_params.pop("prefix")
1001
+ self.generation_config.update(**this_task_params)
1002
+ # If the tokenizer has a pad token but the model doesn't, set it so that `generate` is aware of it.
1003
+ if (
1004
+ self.tokenizer is not None
1005
+ and self.tokenizer.pad_token_id is not None
1006
+ and self.generation_config.pad_token_id is None
1007
+ ):
1008
+ self.generation_config.pad_token_id = self.tokenizer.pad_token_id
1009
+
1010
+ self.call_count = 0
1011
+ self._batch_size = kwargs.pop("batch_size", None)
1012
+ self._num_workers = kwargs.pop("num_workers", None)
1013
+ self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
1014
+
1015
+ # In processor only mode, we can get the modality processors from the processor
1016
+ if self.processor is not None and all(
1017
+ [self.tokenizer is None, self.feature_extractor is None, self.image_processor is None]
1018
+ ):
1019
+ self.tokenizer = getattr(self.processor, "tokenizer", None)
1020
+ self.feature_extractor = getattr(self.processor, "feature_extractor", None)
1021
+ self.image_processor = getattr(self.processor, "image_processor", None)
1022
+
1023
+ if self.image_processor is None and self.feature_extractor is not None:
1024
+ if isinstance(self.feature_extractor, BaseImageProcessor):
1025
+ # Backward compatible change, if users called
1026
+ # ImageSegmentationPipeline(.., feature_extractor=MyFeatureExtractor())
1027
+ # then we should keep working
1028
+ self.image_processor = self.feature_extractor
1029
+
1030
+ def save_pretrained(
1031
+ self,
1032
+ save_directory: Union[str, os.PathLike],
1033
+ safe_serialization: bool = True,
1034
+ **kwargs,
1035
+ ):
1036
+ """
1037
+ Save the pipeline's model and tokenizer.
1038
+
1039
+ Args:
1040
+ save_directory (`str` or `os.PathLike`):
1041
+ A path to the directory where to saved. It will be created if it doesn't exist.
1042
+ safe_serialization (`str`):
1043
+ Whether to save the model using `safetensors` or the traditional way for PyTorch or Tensorflow.
1044
+ kwargs (`Dict[str, Any]`, *optional*):
1045
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
1046
+ """
1047
+ use_auth_token = kwargs.pop("use_auth_token", None)
1048
+
1049
+ if use_auth_token is not None:
1050
+ warnings.warn(
1051
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
1052
+ FutureWarning,
1053
+ )
1054
+ if kwargs.get("token", None) is not None:
1055
+ raise ValueError(
1056
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
1057
+ )
1058
+ kwargs["token"] = use_auth_token
1059
+
1060
+ if os.path.isfile(save_directory):
1061
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
1062
+ return
1063
+ os.makedirs(save_directory, exist_ok=True)
1064
+
1065
+ if hasattr(self, "_registered_impl"):
1066
+ # Add info to the config
1067
+ pipeline_info = self._registered_impl.copy()
1068
+ custom_pipelines = {}
1069
+ for task, info in pipeline_info.items():
1070
+ if info["impl"] != self.__class__:
1071
+ continue
1072
+
1073
+ info = info.copy()
1074
+ module_name = info["impl"].__module__
1075
+ last_module = module_name.split(".")[-1]
1076
+ # Change classes into their names/full names
1077
+ info["impl"] = f"{last_module}.{info['impl'].__name__}"
1078
+ info["pt"] = tuple(c.__name__ for c in info["pt"])
1079
+ info["tf"] = tuple(c.__name__ for c in info["tf"])
1080
+
1081
+ custom_pipelines[task] = info
1082
+ self.model.config.custom_pipelines = custom_pipelines
1083
+ # Save the pipeline custom code
1084
+ custom_object_save(self, save_directory)
1085
+
1086
+ kwargs["safe_serialization"] = safe_serialization
1087
+ self.model.save_pretrained(save_directory, **kwargs)
1088
+
1089
+ if self.tokenizer is not None:
1090
+ self.tokenizer.save_pretrained(save_directory, **kwargs)
1091
+
1092
+ if self.feature_extractor is not None:
1093
+ self.feature_extractor.save_pretrained(save_directory, **kwargs)
1094
+
1095
+ if self.image_processor is not None:
1096
+ self.image_processor.save_pretrained(save_directory, **kwargs)
1097
+
1098
+ if self.modelcard is not None:
1099
+ self.modelcard.save_pretrained(save_directory)
1100
+
1101
+ def transform(self, X):
1102
+ """
1103
+ Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
1104
+ """
1105
+ return self(X)
1106
+
1107
+ def predict(self, X):
1108
+ """
1109
+ Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
1110
+ """
1111
+ return self(X)
1112
+
1113
+ @property
1114
+ def torch_dtype(self) -> Optional["torch.dtype"]:
1115
+ """
1116
+ Torch dtype of the model (if it's Pytorch model), `None` otherwise.
1117
+ """
1118
+ return getattr(self.model, "dtype", None)
1119
+
1120
+ @contextmanager
1121
+ def device_placement(self):
1122
+ """
1123
+ Context Manager allowing tensor allocation on the user-specified device in framework agnostic way.
1124
+
1125
+ Returns:
1126
+ Context manager
1127
+
1128
+ Examples:
1129
+
1130
+ ```python
1131
+ # Explicitly ask for tensor allocation on CUDA device :0
1132
+ pipe = pipeline(..., device=0)
1133
+ with pipe.device_placement():
1134
+ # Every framework specific tensor allocation will be done on the request device
1135
+ output = pipe(...)
1136
+ ```"""
1137
+ if self.framework == "tf":
1138
+ with tf.device("/CPU:0" if self.device == -1 else f"/device:GPU:{self.device}"):
1139
+ yield
1140
+ else:
1141
+ if self.device.type == "cuda":
1142
+ with torch.cuda.device(self.device):
1143
+ yield
1144
+ elif self.device.type == "mlu":
1145
+ with torch.mlu.device(self.device):
1146
+ yield
1147
+ elif self.device.type == "musa":
1148
+ with torch.musa.device(self.device):
1149
+ yield
1150
+ else:
1151
+ yield
1152
+
1153
+ def ensure_tensor_on_device(self, **inputs):
1154
+ """
1155
+ Ensure PyTorch tensors are on the specified device.
1156
+
1157
+ Args:
1158
+ inputs (keyword arguments that should be `torch.Tensor`, the rest is ignored):
1159
+ The tensors to place on `self.device`.
1160
+ Recursive on lists **only**.
1161
+
1162
+ Return:
1163
+ `Dict[str, torch.Tensor]`: The same as `inputs` but on the proper device.
1164
+ """
1165
+ return self._ensure_tensor_on_device(inputs, self.device)
1166
+
1167
+ def _ensure_tensor_on_device(self, inputs, device):
1168
+ if isinstance(inputs, ModelOutput):
1169
+ return ModelOutput(
1170
+ {name: self._ensure_tensor_on_device(tensor, device) for name, tensor in inputs.items()}
1171
+ )
1172
+ elif isinstance(inputs, dict):
1173
+ return {name: self._ensure_tensor_on_device(tensor, device) for name, tensor in inputs.items()}
1174
+ elif isinstance(inputs, UserDict):
1175
+ return UserDict({name: self._ensure_tensor_on_device(tensor, device) for name, tensor in inputs.items()})
1176
+ elif isinstance(inputs, list):
1177
+ return [self._ensure_tensor_on_device(item, device) for item in inputs]
1178
+ elif isinstance(inputs, tuple):
1179
+ return tuple([self._ensure_tensor_on_device(item, device) for item in inputs])
1180
+ elif isinstance(inputs, torch.Tensor):
1181
+ return inputs.to(device)
1182
+ else:
1183
+ return inputs
1184
+
1185
+ def check_model_type(self, supported_models: Union[List[str], dict]):
1186
+ """
1187
+ Check if the model class is in supported by the pipeline.
1188
+
1189
+ Args:
1190
+ supported_models (`List[str]` or `dict`):
1191
+ The list of models supported by the pipeline, or a dictionary with model class values.
1192
+ """
1193
+ if not isinstance(supported_models, list): # Create from a model mapping
1194
+ supported_models_names = []
1195
+ for _, model_name in supported_models.items():
1196
+ # Mapping can now contain tuples of models for the same configuration.
1197
+ if isinstance(model_name, tuple):
1198
+ supported_models_names.extend(list(model_name))
1199
+ else:
1200
+ supported_models_names.append(model_name)
1201
+ if hasattr(supported_models, "_model_mapping"):
1202
+ for _, model in supported_models._model_mapping._extra_content.items():
1203
+ if isinstance(model_name, tuple):
1204
+ supported_models_names.extend([m.__name__ for m in model])
1205
+ else:
1206
+ supported_models_names.append(model.__name__)
1207
+ supported_models = supported_models_names
1208
+ if self.model.__class__.__name__ not in supported_models:
1209
+ logger.error(
1210
+ f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are"
1211
+ f" {supported_models}."
1212
+ )
1213
+
1214
+ @abstractmethod
1215
+ def _sanitize_parameters(self, **pipeline_parameters):
1216
+ """
1217
+ _sanitize_parameters will be called with any excessive named arguments from either `__init__` or `__call__`
1218
+ methods. It should return 3 dictionaries of the resolved parameters used by the various `preprocess`,
1219
+ `forward` and `postprocess` methods. Do not fill dictionaries if the caller didn't specify a kwargs. This
1220
+ lets you keep defaults in function signatures, which is more "natural".
1221
+
1222
+ It is not meant to be called directly, it will be automatically called and the final parameters resolved by
1223
+ `__init__` and `__call__`
1224
+ """
1225
+ raise NotImplementedError("_sanitize_parameters not implemented")
1226
+
1227
+ @abstractmethod
1228
+ def preprocess(self, input_: Any, **preprocess_parameters: Dict) -> Dict[str, GenericTensor]:
1229
+ """
1230
+ Preprocess will take the `input_` of a specific pipeline and return a dictionary of everything necessary for
1231
+ `_forward` to run properly. It should contain at least one tensor, but might have arbitrary other items.
1232
+ """
1233
+ raise NotImplementedError("preprocess not implemented")
1234
+
1235
+ @abstractmethod
1236
+ def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput:
1237
+ """
1238
+ _forward will receive the prepared dictionary from `preprocess` and run it on the model. This method might
1239
+ involve the GPU or the CPU and should be agnostic to it. Isolating this function is the reason for `preprocess`
1240
+ and `postprocess` to exist, so that the hot path, this method generally can run as fast as possible.
1241
+
1242
+ It is not meant to be called directly, `forward` is preferred. It is basically the same but contains additional
1243
+ code surrounding `_forward` making sure tensors and models are on the same device, disabling the training part
1244
+ of the code (leading to faster inference).
1245
+ """
1246
+ raise NotImplementedError("_forward not implemented")
1247
+
1248
+ @abstractmethod
1249
+ def postprocess(self, model_outputs: ModelOutput, **postprocess_parameters: Dict) -> Any:
1250
+ """
1251
+ Postprocess will receive the raw outputs of the `_forward` method, generally tensors, and reformat them into
1252
+ something more friendly. Generally it will output a list or a dict or results (containing just strings and
1253
+ numbers).
1254
+ """
1255
+ raise NotImplementedError("postprocess not implemented")
1256
+
1257
+ def get_inference_context(self):
1258
+ return torch.no_grad
1259
+
1260
+ def forward(self, model_inputs, **forward_params):
1261
+ with self.device_placement():
1262
+ if self.framework == "tf":
1263
+ model_inputs["training"] = False
1264
+ model_outputs = self._forward(model_inputs, **forward_params)
1265
+ elif self.framework == "pt":
1266
+ inference_context = self.get_inference_context()
1267
+ with inference_context():
1268
+ model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
1269
+ model_outputs = self._forward(model_inputs, **forward_params)
1270
+ model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
1271
+ else:
1272
+ raise ValueError(f"Framework {self.framework} is not supported")
1273
+ return model_outputs
1274
+
1275
+ def get_iterator(
1276
+ self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
1277
+ ):
1278
+ if isinstance(inputs, collections.abc.Sized):
1279
+ dataset = PipelineDataset(inputs, self.preprocess, preprocess_params)
1280
+ else:
1281
+ if num_workers > 1:
1282
+ logger.warning(
1283
+ "For iterable dataset using num_workers>1 is likely to result"
1284
+ " in errors since everything is iterable, setting `num_workers=1`"
1285
+ " to guarantee correctness."
1286
+ )
1287
+ num_workers = 1
1288
+ dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
1289
+ if "TOKENIZERS_PARALLELISM" not in os.environ:
1290
+ logger.info("Disabling tokenizer parallelism, we're using DataLoader multithreading already")
1291
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
1292
+ # TODO hack by collating feature_extractor and image_processor
1293
+ feature_extractor = self.feature_extractor if self.feature_extractor is not None else self.image_processor
1294
+ collate_fn = no_collate_fn if batch_size == 1 else pad_collate_fn(self.tokenizer, feature_extractor)
1295
+ dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn)
1296
+ model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
1297
+ final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
1298
+ return final_iterator
1299
+
1300
+ def __call__(self, inputs, *args, num_workers=None, batch_size=None, **kwargs):
1301
+ if args:
1302
+ logger.warning(f"Ignoring args : {args}")
1303
+
1304
+ if num_workers is None:
1305
+ if self._num_workers is None:
1306
+ num_workers = 0
1307
+ else:
1308
+ num_workers = self._num_workers
1309
+ if batch_size is None:
1310
+ if self._batch_size is None:
1311
+ batch_size = 1
1312
+ else:
1313
+ batch_size = self._batch_size
1314
+
1315
+ preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(**kwargs)
1316
+
1317
+ # Fuse __init__ params and __call__ params without modifying the __init__ ones.
1318
+ preprocess_params = {**self._preprocess_params, **preprocess_params}
1319
+ forward_params = {**self._forward_params, **forward_params}
1320
+ postprocess_params = {**self._postprocess_params, **postprocess_params}
1321
+
1322
+ self.call_count += 1
1323
+ if self.call_count > 10 and self.framework == "pt" and self.device.type == "cuda":
1324
+ logger.warning_once(
1325
+ "You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a"
1326
+ " dataset",
1327
+ )
1328
+
1329
+ is_dataset = Dataset is not None and isinstance(inputs, Dataset)
1330
+ is_generator = isinstance(inputs, types.GeneratorType)
1331
+ is_list = isinstance(inputs, list)
1332
+
1333
+ is_iterable = is_dataset or is_generator or is_list
1334
+
1335
+ # TODO make the get_iterator work also for `tf` (and `flax`).
1336
+ can_use_iterator = self.framework == "pt" and (is_dataset or is_generator or is_list)
1337
+
1338
+ if is_list:
1339
+ if can_use_iterator:
1340
+ final_iterator = self.get_iterator(
1341
+ inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
1342
+ )
1343
+ outputs = list(final_iterator)
1344
+ return outputs
1345
+ else:
1346
+ return self.run_multi(inputs, preprocess_params, forward_params, postprocess_params)
1347
+ elif can_use_iterator:
1348
+ return self.get_iterator(
1349
+ inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
1350
+ )
1351
+ elif is_iterable:
1352
+ return self.iterate(inputs, preprocess_params, forward_params, postprocess_params)
1353
+ elif self.framework == "pt" and isinstance(self, ChunkPipeline):
1354
+ return next(
1355
+ iter(
1356
+ self.get_iterator(
1357
+ [inputs], num_workers, batch_size, preprocess_params, forward_params, postprocess_params
1358
+ )
1359
+ )
1360
+ )
1361
+ else:
1362
+ return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
1363
+
1364
+ def run_multi(self, inputs, preprocess_params, forward_params, postprocess_params):
1365
+ return [self.run_single(item, preprocess_params, forward_params, postprocess_params) for item in inputs]
1366
+
1367
+ def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
1368
+ model_inputs = self.preprocess(inputs, **preprocess_params)
1369
+ model_outputs = self.forward(model_inputs, **forward_params)
1370
+ outputs = self.postprocess(model_outputs, **postprocess_params)
1371
+ return outputs
1372
+
1373
+ def iterate(self, inputs, preprocess_params, forward_params, postprocess_params):
1374
+ # This function should become `get_iterator` again, this is a temporary
1375
+ # easy solution.
1376
+ for input_ in inputs:
1377
+ yield self.run_single(input_, preprocess_params, forward_params, postprocess_params)
1378
+
1379
+
1380
+ Pipeline.push_to_hub = copy_func(Pipeline.push_to_hub)
1381
+ if Pipeline.push_to_hub.__doc__ is not None:
1382
+ Pipeline.push_to_hub.__doc__ = Pipeline.push_to_hub.__doc__.format(
1383
+ object="pipe", object_class="pipeline", object_files="pipeline file"
1384
+ ).replace(".from_pretrained", "")
1385
+
1386
+
1387
+ class ChunkPipeline(Pipeline):
1388
+ def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
1389
+ all_outputs = []
1390
+ for model_inputs in self.preprocess(inputs, **preprocess_params):
1391
+ model_outputs = self.forward(model_inputs, **forward_params)
1392
+ all_outputs.append(model_outputs)
1393
+ outputs = self.postprocess(all_outputs, **postprocess_params)
1394
+ return outputs
1395
+
1396
+ def get_iterator(
1397
+ self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
1398
+ ):
1399
+ if "TOKENIZERS_PARALLELISM" not in os.environ:
1400
+ logger.info("Disabling tokenizer parallelism, we're using DataLoader multithreading already")
1401
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
1402
+ if num_workers > 1:
1403
+ logger.warning(
1404
+ "For ChunkPipeline using num_workers>0 is likely to result in errors since everything is iterable,"
1405
+ " setting `num_workers=1` to guarantee correctness."
1406
+ )
1407
+ num_workers = 1
1408
+ dataset = PipelineChunkIterator(inputs, self.preprocess, preprocess_params)
1409
+
1410
+ # TODO hack by collating feature_extractor and image_processor
1411
+ feature_extractor = self.feature_extractor if self.feature_extractor is not None else self.image_processor
1412
+ collate_fn = no_collate_fn if batch_size == 1 else pad_collate_fn(self.tokenizer, feature_extractor)
1413
+ dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn)
1414
+ model_iterator = PipelinePackIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
1415
+ final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
1416
+ return final_iterator
1417
+
1418
+
1419
+ class PipelineRegistry:
1420
+ def __init__(self, supported_tasks: Dict[str, Any], task_aliases: Dict[str, str]) -> None:
1421
+ self.supported_tasks = supported_tasks
1422
+ self.task_aliases = task_aliases
1423
+
1424
+ def get_supported_tasks(self) -> List[str]:
1425
+ supported_task = list(self.supported_tasks.keys()) + list(self.task_aliases.keys())
1426
+ supported_task.sort()
1427
+ return supported_task
1428
+
1429
+ def check_task(self, task: str) -> Tuple[str, Dict, Any]:
1430
+ if task in self.task_aliases:
1431
+ task = self.task_aliases[task]
1432
+ if task in self.supported_tasks:
1433
+ targeted_task = self.supported_tasks[task]
1434
+ return task, targeted_task, None
1435
+
1436
+ if task.startswith("translation"):
1437
+ tokens = task.split("_")
1438
+ if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
1439
+ targeted_task = self.supported_tasks["translation"]
1440
+ task = "translation"
1441
+ return task, targeted_task, (tokens[1], tokens[3])
1442
+ raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")
1443
+
1444
+ raise KeyError(
1445
+ f"Unknown task {task}, available tasks are {self.get_supported_tasks() + ['translation_XX_to_YY']}"
1446
+ )
1447
+
1448
+ def register_pipeline(
1449
+ self,
1450
+ task: str,
1451
+ pipeline_class: type,
1452
+ pt_model: Optional[Union[type, Tuple[type]]] = None,
1453
+ tf_model: Optional[Union[type, Tuple[type]]] = None,
1454
+ default: Optional[Dict] = None,
1455
+ type: Optional[str] = None,
1456
+ ) -> None:
1457
+ if task in self.supported_tasks:
1458
+ logger.warning(f"{task} is already registered. Overwriting pipeline for task {task}...")
1459
+
1460
+ if pt_model is None:
1461
+ pt_model = ()
1462
+ elif not isinstance(pt_model, tuple):
1463
+ pt_model = (pt_model,)
1464
+
1465
+ if tf_model is None:
1466
+ tf_model = ()
1467
+ elif not isinstance(tf_model, tuple):
1468
+ tf_model = (tf_model,)
1469
+
1470
+ task_impl = {"impl": pipeline_class, "pt": pt_model, "tf": tf_model}
1471
+
1472
+ if default is not None:
1473
+ if "model" not in default and ("pt" in default or "tf" in default):
1474
+ default = {"model": default}
1475
+ task_impl["default"] = default
1476
+
1477
+ if type is not None:
1478
+ task_impl["type"] = type
1479
+
1480
+ self.supported_tasks[task] = task_impl
1481
+ pipeline_class._registered_impl = {task: task_impl}
1482
+
1483
+ def to_dict(self):
1484
+ return self.supported_tasks
.venv/lib/python3.11/site-packages/transformers/pipelines/depth_estimation.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ from ..utils import (
4
+ add_end_docstrings,
5
+ is_torch_available,
6
+ is_vision_available,
7
+ logging,
8
+ requires_backends,
9
+ )
10
+ from .base import Pipeline, build_pipeline_init_args
11
+
12
+
13
+ if is_vision_available():
14
+ from PIL import Image
15
+
16
+ from ..image_utils import load_image
17
+
18
+ if is_torch_available():
19
+ from ..models.auto.modeling_auto import MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ @add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
25
+ class DepthEstimationPipeline(Pipeline):
26
+ """
27
+ Depth estimation pipeline using any `AutoModelForDepthEstimation`. This pipeline predicts the depth of an image.
28
+
29
+ Example:
30
+
31
+ ```python
32
+ >>> from transformers import pipeline
33
+
34
+ >>> depth_estimator = pipeline(task="depth-estimation", model="LiheYoung/depth-anything-base-hf")
35
+ >>> output = depth_estimator("http://images.cocodataset.org/val2017/000000039769.jpg")
36
+ >>> # This is a tensor with the values being the depth expressed in meters for each pixel
37
+ >>> output["predicted_depth"].shape
38
+ torch.Size([1, 384, 384])
39
+ ```
40
+
41
+ Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
42
+
43
+
44
+ This depth estimation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
45
+ `"depth-estimation"`.
46
+
47
+ See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=depth-estimation).
48
+ """
49
+
50
+ def __init__(self, *args, **kwargs):
51
+ super().__init__(*args, **kwargs)
52
+ requires_backends(self, "vision")
53
+ self.check_model_type(MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
54
+
55
+ def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
56
+ """
57
+ Predict the depth(s) of the image(s) passed as inputs.
58
+
59
+ Args:
60
+ inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
61
+ The pipeline handles three types of images:
62
+
63
+ - A string containing a http link pointing to an image
64
+ - A string containing a local path to an image
65
+ - An image loaded in PIL directly
66
+
67
+ The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
68
+ Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
69
+ images.
70
+ parameters (`Dict`, *optional*):
71
+ A dictionary of argument names to parameter values, to control pipeline behaviour.
72
+ The only parameter available right now is `timeout`, which is the length of time, in seconds,
73
+ that the pipeline should wait before giving up on trying to download an image.
74
+ timeout (`float`, *optional*, defaults to None):
75
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
76
+ the call may block forever.
77
+
78
+ Return:
79
+ A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
80
+ dictionary, if the input is a list of several images, will return a list of dictionaries corresponding to
81
+ the images.
82
+
83
+ The dictionaries contain the following keys:
84
+
85
+ - **predicted_depth** (`torch.Tensor`) -- The predicted depth by the model as a `torch.Tensor`.
86
+ - **depth** (`PIL.Image`) -- The predicted depth by the model as a `PIL.Image`.
87
+ """
88
+ # After deprecation of this is completed, remove the default `None` value for `images`
89
+ if "images" in kwargs:
90
+ inputs = kwargs.pop("images")
91
+ if inputs is None:
92
+ raise ValueError("Cannot call the depth-estimation pipeline without an inputs argument!")
93
+ return super().__call__(inputs, **kwargs)
94
+
95
+ def _sanitize_parameters(self, timeout=None, parameters=None, **kwargs):
96
+ preprocess_params = {}
97
+ if timeout is not None:
98
+ preprocess_params["timeout"] = timeout
99
+ if isinstance(parameters, dict) and "timeout" in parameters:
100
+ preprocess_params["timeout"] = parameters["timeout"]
101
+ return preprocess_params, {}, {}
102
+
103
+ def preprocess(self, image, timeout=None):
104
+ image = load_image(image, timeout)
105
+ model_inputs = self.image_processor(images=image, return_tensors=self.framework)
106
+ if self.framework == "pt":
107
+ model_inputs = model_inputs.to(self.torch_dtype)
108
+ model_inputs["target_size"] = image.size[::-1]
109
+ return model_inputs
110
+
111
+ def _forward(self, model_inputs):
112
+ target_size = model_inputs.pop("target_size")
113
+ model_outputs = self.model(**model_inputs)
114
+ model_outputs["target_size"] = target_size
115
+ return model_outputs
116
+
117
+ def postprocess(self, model_outputs):
118
+ outputs = self.image_processor.post_process_depth_estimation(
119
+ model_outputs,
120
+ # this acts as `source_sizes` for ZoeDepth and as `target_sizes` for the rest of the models so do *not*
121
+ # replace with `target_sizes = [model_outputs["target_size"]]`
122
+ [model_outputs["target_size"]],
123
+ )
124
+
125
+ formatted_outputs = []
126
+ for output in outputs:
127
+ depth = output["predicted_depth"].detach().cpu().numpy()
128
+ depth = (depth - depth.min()) / (depth.max() - depth.min())
129
+ depth = Image.fromarray((depth * 255).astype("uint8"))
130
+
131
+ formatted_outputs.append({"predicted_depth": output["predicted_depth"], "depth": depth})
132
+
133
+ return formatted_outputs[0] if len(outputs) == 1 else formatted_outputs
.venv/lib/python3.11/site-packages/transformers/pipelines/document_question_answering.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The Impira Team and the HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+
20
+ from ..utils import (
21
+ ExplicitEnum,
22
+ add_end_docstrings,
23
+ is_pytesseract_available,
24
+ is_torch_available,
25
+ is_vision_available,
26
+ logging,
27
+ )
28
+ from .base import ChunkPipeline, build_pipeline_init_args
29
+ from .question_answering import select_starts_ends
30
+
31
+
32
+ if is_vision_available():
33
+ from PIL import Image
34
+
35
+ from ..image_utils import load_image
36
+
37
+ if is_torch_available():
38
+ import torch
39
+
40
+ from ..models.auto.modeling_auto import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
41
+
42
+ TESSERACT_LOADED = False
43
+ if is_pytesseract_available():
44
+ TESSERACT_LOADED = True
45
+ import pytesseract
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ # normalize_bbox() and apply_tesseract() are derived from apply_tesseract in models/layoutlmv3/feature_extraction_layoutlmv3.py.
51
+ # However, because the pipeline may evolve from what layoutlmv3 currently does, it's copied (vs. imported) to avoid creating an
52
+ # unnecessary dependency.
53
+ def normalize_box(box, width, height):
54
+ return [
55
+ int(1000 * (box[0] / width)),
56
+ int(1000 * (box[1] / height)),
57
+ int(1000 * (box[2] / width)),
58
+ int(1000 * (box[3] / height)),
59
+ ]
60
+
61
+
62
+ def apply_tesseract(image: "Image.Image", lang: Optional[str], tesseract_config: Optional[str]):
63
+ """Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
64
+ # apply OCR
65
+ data = pytesseract.image_to_data(image, lang=lang, output_type="dict", config=tesseract_config)
66
+ words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
67
+
68
+ # filter empty words and corresponding coordinates
69
+ irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
70
+ words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
71
+ left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
72
+ top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
73
+ width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
74
+ height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
75
+
76
+ # turn coordinates into (left, top, left+width, top+height) format
77
+ actual_boxes = []
78
+ for x, y, w, h in zip(left, top, width, height):
79
+ actual_box = [x, y, x + w, y + h]
80
+ actual_boxes.append(actual_box)
81
+
82
+ image_width, image_height = image.size
83
+
84
+ # finally, normalize the bounding boxes
85
+ normalized_boxes = []
86
+ for box in actual_boxes:
87
+ normalized_boxes.append(normalize_box(box, image_width, image_height))
88
+
89
+ if len(words) != len(normalized_boxes):
90
+ raise ValueError("Not as many words as there are bounding boxes")
91
+
92
+ return words, normalized_boxes
93
+
94
+
95
+ class ModelType(ExplicitEnum):
96
+ LayoutLM = "layoutlm"
97
+ LayoutLMv2andv3 = "layoutlmv2andv3"
98
+ VisionEncoderDecoder = "vision_encoder_decoder"
99
+
100
+
101
+ @add_end_docstrings(build_pipeline_init_args(has_image_processor=True, has_tokenizer=True))
102
+ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
103
+ # TODO: Update task_summary docs to include an example with document QA and then update the first sentence
104
+ """
105
+ Document Question Answering pipeline using any `AutoModelForDocumentQuestionAnswering`. The inputs/outputs are
106
+ similar to the (extractive) question answering pipeline; however, the pipeline takes an image (and optional OCR'd
107
+ words/boxes) as input instead of text context.
108
+
109
+ Example:
110
+
111
+ ```python
112
+ >>> from transformers import pipeline
113
+
114
+ >>> document_qa = pipeline(model="impira/layoutlm-document-qa")
115
+ >>> document_qa(
116
+ ... image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png",
117
+ ... question="What is the invoice number?",
118
+ ... )
119
+ [{'score': 0.425, 'answer': 'us-001', 'start': 16, 'end': 16}]
120
+ ```
121
+
122
+ Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
123
+
124
+ This document question answering pipeline can currently be loaded from [`pipeline`] using the following task
125
+ identifier: `"document-question-answering"`.
126
+
127
+ The models that this pipeline can use are models that have been fine-tuned on a document question answering task.
128
+ See the up-to-date list of available models on
129
+ [huggingface.co/models](https://huggingface.co/models?filter=document-question-answering).
130
+ """
131
+
132
+ def __init__(self, *args, **kwargs):
133
+ super().__init__(*args, **kwargs)
134
+ if self.tokenizer is not None and not self.tokenizer.__class__.__name__.endswith("Fast"):
135
+ raise ValueError(
136
+ "`DocumentQuestionAnsweringPipeline` requires a fast tokenizer, but a slow tokenizer "
137
+ f"(`{self.tokenizer.__class__.__name__}`) is provided."
138
+ )
139
+
140
+ if self.model.config.__class__.__name__ == "VisionEncoderDecoderConfig":
141
+ self.model_type = ModelType.VisionEncoderDecoder
142
+ if self.model.config.encoder.model_type != "donut-swin":
143
+ raise ValueError("Currently, the only supported VisionEncoderDecoder model is Donut")
144
+ else:
145
+ self.check_model_type(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES)
146
+ if self.model.config.__class__.__name__ == "LayoutLMConfig":
147
+ self.model_type = ModelType.LayoutLM
148
+ else:
149
+ self.model_type = ModelType.LayoutLMv2andv3
150
+
151
+ def _sanitize_parameters(
152
+ self,
153
+ padding=None,
154
+ doc_stride=None,
155
+ max_question_len=None,
156
+ lang: Optional[str] = None,
157
+ tesseract_config: Optional[str] = None,
158
+ max_answer_len=None,
159
+ max_seq_len=None,
160
+ top_k=None,
161
+ handle_impossible_answer=None,
162
+ timeout=None,
163
+ **kwargs,
164
+ ):
165
+ preprocess_params, postprocess_params = {}, {}
166
+ if padding is not None:
167
+ preprocess_params["padding"] = padding
168
+ if doc_stride is not None:
169
+ preprocess_params["doc_stride"] = doc_stride
170
+ if max_question_len is not None:
171
+ preprocess_params["max_question_len"] = max_question_len
172
+ if max_seq_len is not None:
173
+ preprocess_params["max_seq_len"] = max_seq_len
174
+ if lang is not None:
175
+ preprocess_params["lang"] = lang
176
+ if tesseract_config is not None:
177
+ preprocess_params["tesseract_config"] = tesseract_config
178
+ if timeout is not None:
179
+ preprocess_params["timeout"] = timeout
180
+
181
+ if top_k is not None:
182
+ if top_k < 1:
183
+ raise ValueError(f"top_k parameter should be >= 1 (got {top_k})")
184
+ postprocess_params["top_k"] = top_k
185
+ if max_answer_len is not None:
186
+ if max_answer_len < 1:
187
+ raise ValueError(f"max_answer_len parameter should be >= 1 (got {max_answer_len}")
188
+ postprocess_params["max_answer_len"] = max_answer_len
189
+ if handle_impossible_answer is not None:
190
+ postprocess_params["handle_impossible_answer"] = handle_impossible_answer
191
+
192
+ forward_params = {}
193
+ if self.assistant_model is not None:
194
+ forward_params["assistant_model"] = self.assistant_model
195
+ if self.assistant_tokenizer is not None:
196
+ forward_params["tokenizer"] = self.tokenizer
197
+ forward_params["assistant_tokenizer"] = self.assistant_tokenizer
198
+
199
+ return preprocess_params, forward_params, postprocess_params
200
+
201
+ def __call__(
202
+ self,
203
+ image: Union["Image.Image", str],
204
+ question: Optional[str] = None,
205
+ word_boxes: Tuple[str, List[float]] = None,
206
+ **kwargs,
207
+ ):
208
+ """
209
+ Answer the question(s) given as inputs by using the document(s). A document is defined as an image and an
210
+ optional list of (word, box) tuples which represent the text in the document. If the `word_boxes` are not
211
+ provided, it will use the Tesseract OCR engine (if available) to extract the words and boxes automatically for
212
+ LayoutLM-like models which require them as input. For Donut, no OCR is run.
213
+
214
+ You can invoke the pipeline several ways:
215
+
216
+ - `pipeline(image=image, question=question)`
217
+ - `pipeline(image=image, question=question, word_boxes=word_boxes)`
218
+ - `pipeline([{"image": image, "question": question}])`
219
+ - `pipeline([{"image": image, "question": question, "word_boxes": word_boxes}])`
220
+
221
+ Args:
222
+ image (`str` or `PIL.Image`):
223
+ The pipeline handles three types of images:
224
+
225
+ - A string containing a http link pointing to an image
226
+ - A string containing a local path to an image
227
+ - An image loaded in PIL directly
228
+
229
+ The pipeline accepts either a single image or a batch of images. If given a single image, it can be
230
+ broadcasted to multiple questions.
231
+ question (`str`):
232
+ A question to ask of the document.
233
+ word_boxes (`List[str, Tuple[float, float, float, float]]`, *optional*):
234
+ A list of words and bounding boxes (normalized 0->1000). If you provide this optional input, then the
235
+ pipeline will use these words and boxes instead of running OCR on the image to derive them for models
236
+ that need them (e.g. LayoutLM). This allows you to reuse OCR'd results across many invocations of the
237
+ pipeline without having to re-run it each time.
238
+ top_k (`int`, *optional*, defaults to 1):
239
+ The number of answers to return (will be chosen by order of likelihood). Note that we return less than
240
+ top_k answers if there are not enough options available within the context.
241
+ doc_stride (`int`, *optional*, defaults to 128):
242
+ If the words in the document are too long to fit with the question for the model, it will be split in
243
+ several chunks with some overlap. This argument controls the size of that overlap.
244
+ max_answer_len (`int`, *optional*, defaults to 15):
245
+ The maximum length of predicted answers (e.g., only answers with a shorter length are considered).
246
+ max_seq_len (`int`, *optional*, defaults to 384):
247
+ The maximum length of the total sentence (context + question) in tokens of each chunk passed to the
248
+ model. The context will be split in several chunks (using `doc_stride` as overlap) if needed.
249
+ max_question_len (`int`, *optional*, defaults to 64):
250
+ The maximum length of the question after tokenization. It will be truncated if needed.
251
+ handle_impossible_answer (`bool`, *optional*, defaults to `False`):
252
+ Whether or not we accept impossible as an answer.
253
+ lang (`str`, *optional*):
254
+ Language to use while running OCR. Defaults to english.
255
+ tesseract_config (`str`, *optional*):
256
+ Additional flags to pass to tesseract while running OCR.
257
+ timeout (`float`, *optional*, defaults to None):
258
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
259
+ the call may block forever.
260
+
261
+ Return:
262
+ A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys:
263
+
264
+ - **score** (`float`) -- The probability associated to the answer.
265
+ - **start** (`int`) -- The start word index of the answer (in the OCR'd version of the input or provided
266
+ `word_boxes`).
267
+ - **end** (`int`) -- The end word index of the answer (in the OCR'd version of the input or provided
268
+ `word_boxes`).
269
+ - **answer** (`str`) -- The answer to the question.
270
+ - **words** (`list[int]`) -- The index of each word/box pair that is in the answer
271
+ """
272
+ if isinstance(question, str):
273
+ inputs = {"question": question, "image": image}
274
+ if word_boxes is not None:
275
+ inputs["word_boxes"] = word_boxes
276
+ else:
277
+ inputs = image
278
+ return super().__call__(inputs, **kwargs)
279
+
280
+ def preprocess(
281
+ self,
282
+ input,
283
+ padding="do_not_pad",
284
+ doc_stride=None,
285
+ max_seq_len=None,
286
+ word_boxes: Tuple[str, List[float]] = None,
287
+ lang=None,
288
+ tesseract_config="",
289
+ timeout=None,
290
+ ):
291
+ # NOTE: This code mirrors the code in question answering and will be implemented in a follow up PR
292
+ # to support documents with enough tokens that overflow the model's window
293
+ if max_seq_len is None:
294
+ max_seq_len = self.tokenizer.model_max_length
295
+
296
+ if doc_stride is None:
297
+ doc_stride = min(max_seq_len // 2, 256)
298
+
299
+ image = None
300
+ image_features = {}
301
+ if input.get("image", None) is not None:
302
+ image = load_image(input["image"], timeout=timeout)
303
+ if self.image_processor is not None:
304
+ image_inputs = self.image_processor(images=image, return_tensors=self.framework)
305
+ if self.framework == "pt":
306
+ image_inputs = image_inputs.to(self.torch_dtype)
307
+ image_features.update(image_inputs)
308
+ elif self.feature_extractor is not None:
309
+ image_features.update(self.feature_extractor(images=image, return_tensors=self.framework))
310
+ elif self.model_type == ModelType.VisionEncoderDecoder:
311
+ raise ValueError("If you are using a VisionEncoderDecoderModel, you must provide a feature extractor")
312
+
313
+ words, boxes = None, None
314
+ if not self.model_type == ModelType.VisionEncoderDecoder:
315
+ if "word_boxes" in input:
316
+ words = [x[0] for x in input["word_boxes"]]
317
+ boxes = [x[1] for x in input["word_boxes"]]
318
+ elif "words" in image_features and "boxes" in image_features:
319
+ words = image_features.pop("words")[0]
320
+ boxes = image_features.pop("boxes")[0]
321
+ elif image is not None:
322
+ if not TESSERACT_LOADED:
323
+ raise ValueError(
324
+ "If you provide an image without word_boxes, then the pipeline will run OCR using Tesseract,"
325
+ " but pytesseract is not available"
326
+ )
327
+ if TESSERACT_LOADED:
328
+ words, boxes = apply_tesseract(image, lang=lang, tesseract_config=tesseract_config)
329
+ else:
330
+ raise ValueError(
331
+ "You must provide an image or word_boxes. If you provide an image, the pipeline will automatically"
332
+ " run OCR to derive words and boxes"
333
+ )
334
+
335
+ if self.tokenizer.padding_side != "right":
336
+ raise ValueError(
337
+ "Document question answering only supports tokenizers whose padding side is 'right', not"
338
+ f" {self.tokenizer.padding_side}"
339
+ )
340
+
341
+ if self.model_type == ModelType.VisionEncoderDecoder:
342
+ task_prompt = f'<s_docvqa><s_question>{input["question"]}</s_question><s_answer>'
343
+ # Adapted from https://huggingface.co/spaces/nielsr/donut-docvqa/blob/main/app.py
344
+ encoding = {
345
+ "inputs": image_features["pixel_values"],
346
+ "decoder_input_ids": self.tokenizer(
347
+ task_prompt, add_special_tokens=False, return_tensors=self.framework
348
+ ).input_ids,
349
+ "return_dict_in_generate": True,
350
+ }
351
+ yield {
352
+ **encoding,
353
+ "p_mask": None,
354
+ "word_ids": None,
355
+ "words": None,
356
+ "output_attentions": True,
357
+ "is_last": True,
358
+ }
359
+ else:
360
+ tokenizer_kwargs = {}
361
+ if self.model_type == ModelType.LayoutLM:
362
+ tokenizer_kwargs["text"] = input["question"].split()
363
+ tokenizer_kwargs["text_pair"] = words
364
+ tokenizer_kwargs["is_split_into_words"] = True
365
+ else:
366
+ tokenizer_kwargs["text"] = [input["question"]]
367
+ tokenizer_kwargs["text_pair"] = [words]
368
+ tokenizer_kwargs["boxes"] = [boxes]
369
+
370
+ encoding = self.tokenizer(
371
+ padding=padding,
372
+ max_length=max_seq_len,
373
+ stride=doc_stride,
374
+ return_token_type_ids=True,
375
+ truncation="only_second",
376
+ return_overflowing_tokens=True,
377
+ **tokenizer_kwargs,
378
+ )
379
+ # TODO: check why slower `LayoutLMTokenizer` and `LayoutLMv2Tokenizer` don't have this key in outputs
380
+ # FIXME: ydshieh and/or Narsil
381
+ encoding.pop("overflow_to_sample_mapping", None) # We do not use this
382
+
383
+ num_spans = len(encoding["input_ids"])
384
+
385
+ # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
386
+ # We put 0 on the tokens from the context and 1 everywhere else (question and special tokens)
387
+ # This logic mirrors the logic in the question_answering pipeline
388
+ p_mask = np.array([[tok != 1 for tok in encoding.sequence_ids(span_id)] for span_id in range(num_spans)])
389
+ for span_idx in range(num_spans):
390
+ if self.framework == "pt":
391
+ span_encoding = {k: torch.tensor(v[span_idx : span_idx + 1]) for (k, v) in encoding.items()}
392
+ if "pixel_values" in image_features:
393
+ span_encoding["image"] = image_features["pixel_values"]
394
+ else:
395
+ raise ValueError("Unsupported: Tensorflow preprocessing for DocumentQuestionAnsweringPipeline")
396
+
397
+ input_ids_span_idx = encoding["input_ids"][span_idx]
398
+ # keep the cls_token unmasked (some models use it to indicate unanswerable questions)
399
+ if self.tokenizer.cls_token_id is not None:
400
+ cls_indices = np.nonzero(np.array(input_ids_span_idx) == self.tokenizer.cls_token_id)[0]
401
+ for cls_index in cls_indices:
402
+ p_mask[span_idx][cls_index] = 0
403
+
404
+ # For each span, place a bounding box [0,0,0,0] for question and CLS tokens, [1000,1000,1000,1000]
405
+ # for SEP tokens, and the word's bounding box for words in the original document.
406
+ if "boxes" not in tokenizer_kwargs:
407
+ bbox = []
408
+ for input_id, sequence_id, word_id in zip(
409
+ encoding.input_ids[span_idx],
410
+ encoding.sequence_ids(span_idx),
411
+ encoding.word_ids(span_idx),
412
+ ):
413
+ if sequence_id == 1:
414
+ bbox.append(boxes[word_id])
415
+ elif input_id == self.tokenizer.sep_token_id:
416
+ bbox.append([1000] * 4)
417
+ else:
418
+ bbox.append([0] * 4)
419
+
420
+ if self.framework == "pt":
421
+ span_encoding["bbox"] = torch.tensor(bbox).unsqueeze(0)
422
+ elif self.framework == "tf":
423
+ raise ValueError("Unsupported: Tensorflow preprocessing for DocumentQuestionAnsweringPipeline")
424
+ yield {
425
+ **span_encoding,
426
+ "p_mask": p_mask[span_idx],
427
+ "word_ids": encoding.word_ids(span_idx),
428
+ "words": words,
429
+ "is_last": span_idx == num_spans - 1,
430
+ }
431
+
432
+ def _forward(self, model_inputs, **generate_kwargs):
433
+ p_mask = model_inputs.pop("p_mask", None)
434
+ word_ids = model_inputs.pop("word_ids", None)
435
+ words = model_inputs.pop("words", None)
436
+ is_last = model_inputs.pop("is_last", False)
437
+
438
+ if self.model_type == ModelType.VisionEncoderDecoder:
439
+ # User-defined `generation_config` passed to the pipeline call take precedence
440
+ if "generation_config" not in generate_kwargs:
441
+ generate_kwargs["generation_config"] = self.generation_config
442
+
443
+ model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
444
+ else:
445
+ model_outputs = self.model(**model_inputs)
446
+
447
+ model_outputs = dict(model_outputs.items())
448
+ model_outputs["p_mask"] = p_mask
449
+ model_outputs["word_ids"] = word_ids
450
+ model_outputs["words"] = words
451
+ model_outputs["attention_mask"] = model_inputs.get("attention_mask", None)
452
+ model_outputs["is_last"] = is_last
453
+ return model_outputs
454
+
455
+ def postprocess(self, model_outputs, top_k=1, **kwargs):
456
+ if self.model_type == ModelType.VisionEncoderDecoder:
457
+ answers = [self.postprocess_encoder_decoder_single(o) for o in model_outputs]
458
+ else:
459
+ answers = self.postprocess_extractive_qa(model_outputs, top_k=top_k, **kwargs)
460
+
461
+ answers = sorted(answers, key=lambda x: x.get("score", 0), reverse=True)[:top_k]
462
+ return answers
463
+
464
+ def postprocess_encoder_decoder_single(self, model_outputs, **kwargs):
465
+ sequence = self.tokenizer.batch_decode(model_outputs["sequences"])[0]
466
+
467
+ # TODO: A lot of this logic is specific to Donut and should probably be handled in the tokenizer
468
+ # (see https://github.com/huggingface/transformers/pull/18414/files#r961747408 for more context).
469
+ sequence = sequence.replace(self.tokenizer.eos_token, "").replace(self.tokenizer.pad_token, "")
470
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
471
+ ret = {
472
+ "answer": None,
473
+ }
474
+
475
+ answer = re.search(r"<s_answer>(.*)</s_answer>", sequence)
476
+ if answer is not None:
477
+ ret["answer"] = answer.group(1).strip()
478
+ return ret
479
+
480
+ def postprocess_extractive_qa(
481
+ self, model_outputs, top_k=1, handle_impossible_answer=False, max_answer_len=15, **kwargs
482
+ ):
483
+ min_null_score = 1000000 # large and positive
484
+ answers = []
485
+ for output in model_outputs:
486
+ words = output["words"]
487
+
488
+ starts, ends, scores, min_null_score = select_starts_ends(
489
+ start=output["start_logits"],
490
+ end=output["end_logits"],
491
+ p_mask=output["p_mask"],
492
+ attention_mask=output["attention_mask"].numpy()
493
+ if output.get("attention_mask", None) is not None
494
+ else None,
495
+ min_null_score=min_null_score,
496
+ top_k=top_k,
497
+ handle_impossible_answer=handle_impossible_answer,
498
+ max_answer_len=max_answer_len,
499
+ )
500
+ word_ids = output["word_ids"]
501
+ for start, end, score in zip(starts, ends, scores):
502
+ word_start, word_end = word_ids[start], word_ids[end]
503
+ if word_start is not None and word_end is not None:
504
+ answers.append(
505
+ {
506
+ "score": float(score),
507
+ "answer": " ".join(words[word_start : word_end + 1]),
508
+ "start": word_start,
509
+ "end": word_end,
510
+ }
511
+ )
512
+
513
+ if handle_impossible_answer:
514
+ answers.append({"score": min_null_score, "answer": "", "start": 0, "end": 0})
515
+
516
+ return answers
.venv/lib/python3.11/site-packages/transformers/pipelines/feature_extraction.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from ..utils import add_end_docstrings
4
+ from .base import GenericTensor, Pipeline, build_pipeline_init_args
5
+
6
+
7
+ @add_end_docstrings(
8
+ build_pipeline_init_args(has_tokenizer=True, supports_binary_output=False),
9
+ r"""
10
+ tokenize_kwargs (`dict`, *optional*):
11
+ Additional dictionary of keyword arguments passed along to the tokenizer.
12
+ return_tensors (`bool`, *optional*):
13
+ If `True`, returns a tensor according to the specified framework, otherwise returns a list.""",
14
+ )
15
+ class FeatureExtractionPipeline(Pipeline):
16
+ """
17
+ Feature extraction pipeline uses no model head. This pipeline extracts the hidden states from the base
18
+ transformer, which can be used as features in downstream tasks.
19
+
20
+ Example:
21
+
22
+ ```python
23
+ >>> from transformers import pipeline
24
+
25
+ >>> extractor = pipeline(model="google-bert/bert-base-uncased", task="feature-extraction")
26
+ >>> result = extractor("This is a simple test.", return_tensors=True)
27
+ >>> result.shape # This is a tensor of shape [1, sequence_length, hidden_dimension] representing the input string.
28
+ torch.Size([1, 8, 768])
29
+ ```
30
+
31
+ Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
32
+
33
+ This feature extraction pipeline can currently be loaded from [`pipeline`] using the task identifier:
34
+ `"feature-extraction"`.
35
+
36
+ All models may be used for this pipeline. See a list of all models, including community-contributed models on
37
+ [huggingface.co/models](https://huggingface.co/models).
38
+ """
39
+
40
+ def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, return_tensors=None, **kwargs):
41
+ if tokenize_kwargs is None:
42
+ tokenize_kwargs = {}
43
+
44
+ if truncation is not None:
45
+ if "truncation" in tokenize_kwargs:
46
+ raise ValueError(
47
+ "truncation parameter defined twice (given as keyword argument as well as in tokenize_kwargs)"
48
+ )
49
+ tokenize_kwargs["truncation"] = truncation
50
+
51
+ preprocess_params = tokenize_kwargs
52
+
53
+ postprocess_params = {}
54
+ if return_tensors is not None:
55
+ postprocess_params["return_tensors"] = return_tensors
56
+
57
+ return preprocess_params, {}, postprocess_params
58
+
59
+ def preprocess(self, inputs, **tokenize_kwargs) -> Dict[str, GenericTensor]:
60
+ model_inputs = self.tokenizer(inputs, return_tensors=self.framework, **tokenize_kwargs)
61
+ return model_inputs
62
+
63
+ def _forward(self, model_inputs):
64
+ model_outputs = self.model(**model_inputs)
65
+ return model_outputs
66
+
67
+ def postprocess(self, model_outputs, return_tensors=False):
68
+ # [0] is the first available tensor, logits or last_hidden_state.
69
+ if return_tensors:
70
+ return model_outputs[0]
71
+ if self.framework == "pt":
72
+ return model_outputs[0].tolist()
73
+ elif self.framework == "tf":
74
+ return model_outputs[0].numpy().tolist()
75
+
76
+ def __call__(self, *args, **kwargs):
77
+ """
78
+ Extract the features of the input(s).
79
+
80
+ Args:
81
+ args (`str` or `List[str]`): One or several texts (or one list of texts) to get the features of.
82
+
83
+ Return:
84
+ A nested list of `float`: The features computed by the model.
85
+ """
86
+ return super().__call__(*args, **kwargs)
.venv/lib/python3.11/site-packages/transformers/pipelines/fill_mask.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+
5
+ from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging
6
+ from .base import GenericTensor, Pipeline, PipelineException, build_pipeline_init_args
7
+
8
+
9
+ if is_tf_available():
10
+ import tensorflow as tf
11
+
12
+ from ..tf_utils import stable_softmax
13
+
14
+
15
+ if is_torch_available():
16
+ import torch
17
+
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ @add_end_docstrings(
23
+ build_pipeline_init_args(has_tokenizer=True),
24
+ r"""
25
+ top_k (`int`, *optional*, defaults to 5):
26
+ The number of predictions to return.
27
+ targets (`str` or `List[str]`, *optional*):
28
+ When passed, the model will limit the scores to the passed targets instead of looking up in the whole
29
+ vocab. If the provided targets are not in the model vocab, they will be tokenized and the first resulting
30
+ token will be used (with a warning, and that might be slower).
31
+ tokenizer_kwargs (`dict`, *optional*):
32
+ Additional dictionary of keyword arguments passed along to the tokenizer.""",
33
+ )
34
+ class FillMaskPipeline(Pipeline):
35
+ """
36
+ Masked language modeling prediction pipeline using any `ModelWithLMHead`. See the [masked language modeling
37
+ examples](../task_summary#masked-language-modeling) for more information.
38
+
39
+ Example:
40
+
41
+ ```python
42
+ >>> from transformers import pipeline
43
+
44
+ >>> fill_masker = pipeline(model="google-bert/bert-base-uncased")
45
+ >>> fill_masker("This is a simple [MASK].")
46
+ [{'score': 0.042, 'token': 3291, 'token_str': 'problem', 'sequence': 'this is a simple problem.'}, {'score': 0.031, 'token': 3160, 'token_str': 'question', 'sequence': 'this is a simple question.'}, {'score': 0.03, 'token': 8522, 'token_str': 'equation', 'sequence': 'this is a simple equation.'}, {'score': 0.027, 'token': 2028, 'token_str': 'one', 'sequence': 'this is a simple one.'}, {'score': 0.024, 'token': 3627, 'token_str': 'rule', 'sequence': 'this is a simple rule.'}]
47
+ ```
48
+
49
+ Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
50
+
51
+ This mask filling pipeline can currently be loaded from [`pipeline`] using the following task identifier:
52
+ `"fill-mask"`.
53
+
54
+ The models that this pipeline can use are models that have been trained with a masked language modeling objective,
55
+ which includes the bi-directional models in the library. See the up-to-date list of available models on
56
+ [huggingface.co/models](https://huggingface.co/models?filter=fill-mask).
57
+
58
+ <Tip>
59
+
60
+ This pipeline only works for inputs with exactly one token masked. Experimental: We added support for multiple
61
+ masks. The returned values are raw model output, and correspond to disjoint probabilities where one might expect
62
+ joint probabilities (See [discussion](https://github.com/huggingface/transformers/pull/10222)).
63
+
64
+ </Tip>
65
+
66
+ <Tip>
67
+
68
+ This pipeline now supports tokenizer_kwargs. For example try:
69
+
70
+ ```python
71
+ >>> from transformers import pipeline
72
+
73
+ >>> fill_masker = pipeline(model="google-bert/bert-base-uncased")
74
+ >>> tokenizer_kwargs = {"truncation": True}
75
+ >>> fill_masker(
76
+ ... "This is a simple [MASK]. " + "...with a large amount of repeated text appended. " * 100,
77
+ ... tokenizer_kwargs=tokenizer_kwargs,
78
+ ... )
79
+ ```
80
+
81
+
82
+ </Tip>
83
+
84
+
85
+ """
86
+
87
+ def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray:
88
+ if self.framework == "tf":
89
+ masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()
90
+ elif self.framework == "pt":
91
+ masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
92
+ else:
93
+ raise ValueError("Unsupported framework")
94
+ return masked_index
95
+
96
+ def _ensure_exactly_one_mask_token(self, input_ids: GenericTensor) -> np.ndarray:
97
+ masked_index = self.get_masked_index(input_ids)
98
+ numel = np.prod(masked_index.shape)
99
+ if numel < 1:
100
+ raise PipelineException(
101
+ "fill-mask",
102
+ self.model.base_model_prefix,
103
+ f"No mask_token ({self.tokenizer.mask_token}) found on the input",
104
+ )
105
+
106
+ def ensure_exactly_one_mask_token(self, model_inputs: GenericTensor):
107
+ if isinstance(model_inputs, list):
108
+ for model_input in model_inputs:
109
+ self._ensure_exactly_one_mask_token(model_input["input_ids"][0])
110
+ else:
111
+ for input_ids in model_inputs["input_ids"]:
112
+ self._ensure_exactly_one_mask_token(input_ids)
113
+
114
+ def preprocess(
115
+ self, inputs, return_tensors=None, tokenizer_kwargs=None, **preprocess_parameters
116
+ ) -> Dict[str, GenericTensor]:
117
+ if return_tensors is None:
118
+ return_tensors = self.framework
119
+ if tokenizer_kwargs is None:
120
+ tokenizer_kwargs = {}
121
+
122
+ model_inputs = self.tokenizer(inputs, return_tensors=return_tensors, **tokenizer_kwargs)
123
+ self.ensure_exactly_one_mask_token(model_inputs)
124
+ return model_inputs
125
+
126
+ def _forward(self, model_inputs):
127
+ model_outputs = self.model(**model_inputs)
128
+ model_outputs["input_ids"] = model_inputs["input_ids"]
129
+ return model_outputs
130
+
131
+ def postprocess(self, model_outputs, top_k=5, target_ids=None):
132
+ # Cap top_k if there are targets
133
+ if target_ids is not None and target_ids.shape[0] < top_k:
134
+ top_k = target_ids.shape[0]
135
+ input_ids = model_outputs["input_ids"][0]
136
+ outputs = model_outputs["logits"]
137
+
138
+ if self.framework == "tf":
139
+ masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()[:, 0]
140
+
141
+ outputs = outputs.numpy()
142
+
143
+ logits = outputs[0, masked_index, :]
144
+ probs = stable_softmax(logits, axis=-1)
145
+ if target_ids is not None:
146
+ probs = tf.gather_nd(tf.squeeze(probs, 0), target_ids.reshape(-1, 1))
147
+ probs = tf.expand_dims(probs, 0)
148
+
149
+ topk = tf.math.top_k(probs, k=top_k)
150
+ values, predictions = topk.values.numpy(), topk.indices.numpy()
151
+ else:
152
+ masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1)
153
+ # Fill mask pipeline supports only one ${mask_token} per sample
154
+
155
+ logits = outputs[0, masked_index, :]
156
+ probs = logits.softmax(dim=-1)
157
+ if target_ids is not None:
158
+ probs = probs[..., target_ids]
159
+
160
+ values, predictions = probs.topk(top_k)
161
+
162
+ result = []
163
+ single_mask = values.shape[0] == 1
164
+ for i, (_values, _predictions) in enumerate(zip(values.tolist(), predictions.tolist())):
165
+ row = []
166
+ for v, p in zip(_values, _predictions):
167
+ # Copy is important since we're going to modify this array in place
168
+ tokens = input_ids.numpy().copy()
169
+ if target_ids is not None:
170
+ p = target_ids[p].tolist()
171
+
172
+ tokens[masked_index[i]] = p
173
+ # Filter padding out:
174
+ tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
175
+ # Originally we skip special tokens to give readable output.
176
+ # For multi masks though, the other [MASK] would be removed otherwise
177
+ # making the output look odd, so we add them back
178
+ sequence = self.tokenizer.decode(tokens, skip_special_tokens=single_mask)
179
+ proposition = {"score": v, "token": p, "token_str": self.tokenizer.decode([p]), "sequence": sequence}
180
+ row.append(proposition)
181
+ result.append(row)
182
+ if single_mask:
183
+ return result[0]
184
+ return result
185
+
186
+ def get_target_ids(self, targets, top_k=None):
187
+ if isinstance(targets, str):
188
+ targets = [targets]
189
+ try:
190
+ vocab = self.tokenizer.get_vocab()
191
+ except Exception:
192
+ vocab = {}
193
+ target_ids = []
194
+ for target in targets:
195
+ id_ = vocab.get(target, None)
196
+ if id_ is None:
197
+ input_ids = self.tokenizer(
198
+ target,
199
+ add_special_tokens=False,
200
+ return_attention_mask=False,
201
+ return_token_type_ids=False,
202
+ max_length=1,
203
+ truncation=True,
204
+ )["input_ids"]
205
+ if len(input_ids) == 0:
206
+ logger.warning(
207
+ f"The specified target token `{target}` does not exist in the model vocabulary. "
208
+ "We cannot replace it with anything meaningful, ignoring it"
209
+ )
210
+ continue
211
+ id_ = input_ids[0]
212
+ # XXX: If users encounter this pass
213
+ # it becomes pretty slow, so let's make sure
214
+ # The warning enables them to fix the input to
215
+ # get faster performance.
216
+ logger.warning(
217
+ f"The specified target token `{target}` does not exist in the model vocabulary. "
218
+ f"Replacing with `{self.tokenizer.convert_ids_to_tokens(id_)}`."
219
+ )
220
+ target_ids.append(id_)
221
+ target_ids = list(set(target_ids))
222
+ if len(target_ids) == 0:
223
+ raise ValueError("At least one target must be provided when passed.")
224
+ target_ids = np.array(target_ids)
225
+ return target_ids
226
+
227
+ def _sanitize_parameters(self, top_k=None, targets=None, tokenizer_kwargs=None):
228
+ preprocess_params = {}
229
+
230
+ if tokenizer_kwargs is not None:
231
+ preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs
232
+
233
+ postprocess_params = {}
234
+
235
+ if targets is not None:
236
+ target_ids = self.get_target_ids(targets, top_k)
237
+ postprocess_params["target_ids"] = target_ids
238
+
239
+ if top_k is not None:
240
+ postprocess_params["top_k"] = top_k
241
+
242
+ if self.tokenizer.mask_token_id is None:
243
+ raise PipelineException(
244
+ "fill-mask", self.model.base_model_prefix, "The tokenizer does not define a `mask_token`."
245
+ )
246
+ return preprocess_params, {}, postprocess_params
247
+
248
+ def __call__(self, inputs, **kwargs):
249
+ """
250
+ Fill the masked token in the text(s) given as inputs.
251
+
252
+ Args:
253
+ inputs (`str` or `List[str]`):
254
+ One or several texts (or one list of prompts) with masked tokens.
255
+ targets (`str` or `List[str]`, *optional*):
256
+ When passed, the model will limit the scores to the passed targets instead of looking up in the whole
257
+ vocab. If the provided targets are not in the model vocab, they will be tokenized and the first
258
+ resulting token will be used (with a warning, and that might be slower).
259
+ top_k (`int`, *optional*):
260
+ When passed, overrides the number of predictions to return.
261
+
262
+ Return:
263
+ A list or a list of list of `dict`: Each result comes as list of dictionaries with the following keys:
264
+
265
+ - **sequence** (`str`) -- The corresponding input with the mask token prediction.
266
+ - **score** (`float`) -- The corresponding probability.
267
+ - **token** (`int`) -- The predicted token id (to replace the masked one).
268
+ - **token_str** (`str`) -- The predicted token (to replace the masked one).
269
+ """
270
+ outputs = super().__call__(inputs, **kwargs)
271
+ if isinstance(inputs, list) and len(inputs) == 1:
272
+ return outputs[0]
273
+ return outputs
.venv/lib/python3.11/site-packages/transformers/pipelines/image_classification.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List, Union
15
+
16
+ import numpy as np
17
+
18
+ from ..utils import (
19
+ ExplicitEnum,
20
+ add_end_docstrings,
21
+ is_tf_available,
22
+ is_torch_available,
23
+ is_vision_available,
24
+ logging,
25
+ requires_backends,
26
+ )
27
+ from .base import Pipeline, build_pipeline_init_args
28
+
29
+
30
+ if is_vision_available():
31
+ from PIL import Image
32
+
33
+ from ..image_utils import load_image
34
+
35
+ if is_tf_available():
36
+ from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
37
+
38
+ if is_torch_available():
39
+ import torch
40
+
41
+ from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ # Copied from transformers.pipelines.text_classification.sigmoid
47
+ def sigmoid(_outputs):
48
+ return 1.0 / (1.0 + np.exp(-_outputs))
49
+
50
+
51
+ # Copied from transformers.pipelines.text_classification.softmax
52
+ def softmax(_outputs):
53
+ maxes = np.max(_outputs, axis=-1, keepdims=True)
54
+ shifted_exp = np.exp(_outputs - maxes)
55
+ return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
56
+
57
+
58
+ # Copied from transformers.pipelines.text_classification.ClassificationFunction
59
+ class ClassificationFunction(ExplicitEnum):
60
+ SIGMOID = "sigmoid"
61
+ SOFTMAX = "softmax"
62
+ NONE = "none"
63
+
64
+
65
+ @add_end_docstrings(
66
+ build_pipeline_init_args(has_image_processor=True),
67
+ r"""
68
+ function_to_apply (`str`, *optional*, defaults to `"default"`):
69
+ The function to apply to the model outputs in order to retrieve the scores. Accepts four different values:
70
+
71
+ - `"default"`: if the model has a single label, will apply the sigmoid function on the output. If the model
72
+ has several labels, will apply the softmax function on the output.
73
+ - `"sigmoid"`: Applies the sigmoid function on the output.
74
+ - `"softmax"`: Applies the softmax function on the output.
75
+ - `"none"`: Does not apply any function on the output.""",
76
+ )
77
+ class ImageClassificationPipeline(Pipeline):
78
+ """
79
+ Image classification pipeline using any `AutoModelForImageClassification`. This pipeline predicts the class of an
80
+ image.
81
+
82
+ Example:
83
+
84
+ ```python
85
+ >>> from transformers import pipeline
86
+
87
+ >>> classifier = pipeline(model="microsoft/beit-base-patch16-224-pt22k-ft22k")
88
+ >>> classifier("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
89
+ [{'score': 0.442, 'label': 'macaw'}, {'score': 0.088, 'label': 'popinjay'}, {'score': 0.075, 'label': 'parrot'}, {'score': 0.073, 'label': 'parodist, lampooner'}, {'score': 0.046, 'label': 'poll, poll_parrot'}]
90
+ ```
91
+
92
+ Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
93
+
94
+ This image classification pipeline can currently be loaded from [`pipeline`] using the following task identifier:
95
+ `"image-classification"`.
96
+
97
+ See the list of available models on
98
+ [huggingface.co/models](https://huggingface.co/models?filter=image-classification).
99
+ """
100
+
101
+ function_to_apply: ClassificationFunction = ClassificationFunction.NONE
102
+
103
+ def __init__(self, *args, **kwargs):
104
+ super().__init__(*args, **kwargs)
105
+ requires_backends(self, "vision")
106
+ self.check_model_type(
107
+ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
108
+ if self.framework == "tf"
109
+ else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
110
+ )
111
+
112
+ def _sanitize_parameters(self, top_k=None, function_to_apply=None, timeout=None):
113
+ preprocess_params = {}
114
+ if timeout is not None:
115
+ preprocess_params["timeout"] = timeout
116
+ postprocess_params = {}
117
+ if top_k is not None:
118
+ postprocess_params["top_k"] = top_k
119
+ if isinstance(function_to_apply, str):
120
+ function_to_apply = ClassificationFunction(function_to_apply.lower())
121
+ if function_to_apply is not None:
122
+ postprocess_params["function_to_apply"] = function_to_apply
123
+ return preprocess_params, {}, postprocess_params
124
+
125
+ def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
126
+ """
127
+ Assign labels to the image(s) passed as inputs.
128
+
129
+ Args:
130
+ inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
131
+ The pipeline handles three types of images:
132
+
133
+ - A string containing a http link pointing to an image
134
+ - A string containing a local path to an image
135
+ - An image loaded in PIL directly
136
+
137
+ The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
138
+ Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
139
+ images.
140
+ function_to_apply (`str`, *optional*, defaults to `"default"`):
141
+ The function to apply to the model outputs in order to retrieve the scores. Accepts four different
142
+ values:
143
+
144
+ If this argument is not specified, then it will apply the following functions according to the number
145
+ of labels:
146
+
147
+ - If the model has a single label, will apply the sigmoid function on the output.
148
+ - If the model has several labels, will apply the softmax function on the output.
149
+
150
+ Possible values are:
151
+
152
+ - `"sigmoid"`: Applies the sigmoid function on the output.
153
+ - `"softmax"`: Applies the softmax function on the output.
154
+ - `"none"`: Does not apply any function on the output.
155
+ top_k (`int`, *optional*, defaults to 5):
156
+ The number of top labels that will be returned by the pipeline. If the provided number is higher than
157
+ the number of labels available in the model configuration, it will default to the number of labels.
158
+ timeout (`float`, *optional*, defaults to None):
159
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
160
+ the call may block forever.
161
+
162
+ Return:
163
+ A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
164
+ dictionary, if the input is a list of several images, will return a list of dictionaries corresponding to
165
+ the images.
166
+
167
+ The dictionaries contain the following keys:
168
+
169
+ - **label** (`str`) -- The label identified by the model.
170
+ - **score** (`int`) -- The score attributed by the model for that label.
171
+ """
172
+ # After deprecation of this is completed, remove the default `None` value for `images`
173
+ if "images" in kwargs:
174
+ inputs = kwargs.pop("images")
175
+ if inputs is None:
176
+ raise ValueError("Cannot call the image-classification pipeline without an inputs argument!")
177
+ return super().__call__(inputs, **kwargs)
178
+
179
+ def preprocess(self, image, timeout=None):
180
+ image = load_image(image, timeout=timeout)
181
+ model_inputs = self.image_processor(images=image, return_tensors=self.framework)
182
+ if self.framework == "pt":
183
+ model_inputs = model_inputs.to(self.torch_dtype)
184
+ return model_inputs
185
+
186
+ def _forward(self, model_inputs):
187
+ model_outputs = self.model(**model_inputs)
188
+ return model_outputs
189
+
190
+ def postprocess(self, model_outputs, function_to_apply=None, top_k=5):
191
+ if function_to_apply is None:
192
+ if self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels == 1:
193
+ function_to_apply = ClassificationFunction.SIGMOID
194
+ elif self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels > 1:
195
+ function_to_apply = ClassificationFunction.SOFTMAX
196
+ elif hasattr(self.model.config, "function_to_apply") and function_to_apply is None:
197
+ function_to_apply = self.model.config.function_to_apply
198
+ else:
199
+ function_to_apply = ClassificationFunction.NONE
200
+
201
+ if top_k > self.model.config.num_labels:
202
+ top_k = self.model.config.num_labels
203
+
204
+ outputs = model_outputs["logits"][0]
205
+ if self.framework == "pt" and outputs.dtype in (torch.bfloat16, torch.float16):
206
+ outputs = outputs.to(torch.float32).numpy()
207
+ else:
208
+ outputs = outputs.numpy()
209
+
210
+ if function_to_apply == ClassificationFunction.SIGMOID:
211
+ scores = sigmoid(outputs)
212
+ elif function_to_apply == ClassificationFunction.SOFTMAX:
213
+ scores = softmax(outputs)
214
+ elif function_to_apply == ClassificationFunction.NONE:
215
+ scores = outputs
216
+ else:
217
+ raise ValueError(f"Unrecognized `function_to_apply` argument: {function_to_apply}")
218
+
219
+ dict_scores = [
220
+ {"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores)
221
+ ]
222
+ dict_scores.sort(key=lambda x: x["score"], reverse=True)
223
+ if top_k is not None:
224
+ dict_scores = dict_scores[:top_k]
225
+
226
+ return dict_scores
.venv/lib/python3.11/site-packages/transformers/pipelines/image_feature_extraction.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from ..utils import add_end_docstrings, is_vision_available
4
+ from .base import GenericTensor, Pipeline, build_pipeline_init_args
5
+
6
+
7
+ if is_vision_available():
8
+ from ..image_utils import load_image
9
+
10
+
11
+ @add_end_docstrings(
12
+ build_pipeline_init_args(has_image_processor=True),
13
+ """
14
+ image_processor_kwargs (`dict`, *optional*):
15
+ Additional dictionary of keyword arguments passed along to the image processor e.g.
16
+ {"size": {"height": 100, "width": 100}}
17
+ pool (`bool`, *optional*, defaults to `False`):
18
+ Whether or not to return the pooled output. If `False`, the model will return the raw hidden states.
19
+ """,
20
+ )
21
+ class ImageFeatureExtractionPipeline(Pipeline):
22
+ """
23
+ Image feature extraction pipeline uses no model head. This pipeline extracts the hidden states from the base
24
+ transformer, which can be used as features in downstream tasks.
25
+
26
+ Example:
27
+
28
+ ```python
29
+ >>> from transformers import pipeline
30
+
31
+ >>> extractor = pipeline(model="google/vit-base-patch16-224", task="image-feature-extraction")
32
+ >>> result = extractor("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", return_tensors=True)
33
+ >>> result.shape # This is a tensor of shape [1, sequence_lenth, hidden_dimension] representing the input image.
34
+ torch.Size([1, 197, 768])
35
+ ```
36
+
37
+ Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
38
+
39
+ This image feature extraction pipeline can currently be loaded from [`pipeline`] using the task identifier:
40
+ `"image-feature-extraction"`.
41
+
42
+ All vision models may be used for this pipeline. See a list of all models, including community-contributed models on
43
+ [huggingface.co/models](https://huggingface.co/models).
44
+ """
45
+
46
+ def _sanitize_parameters(self, image_processor_kwargs=None, return_tensors=None, pool=None, **kwargs):
47
+ preprocess_params = {} if image_processor_kwargs is None else image_processor_kwargs
48
+
49
+ postprocess_params = {}
50
+ if pool is not None:
51
+ postprocess_params["pool"] = pool
52
+ if return_tensors is not None:
53
+ postprocess_params["return_tensors"] = return_tensors
54
+
55
+ if "timeout" in kwargs:
56
+ preprocess_params["timeout"] = kwargs["timeout"]
57
+
58
+ return preprocess_params, {}, postprocess_params
59
+
60
+ def preprocess(self, image, timeout=None, **image_processor_kwargs) -> Dict[str, GenericTensor]:
61
+ image = load_image(image, timeout=timeout)
62
+ model_inputs = self.image_processor(image, return_tensors=self.framework, **image_processor_kwargs)
63
+ if self.framework == "pt":
64
+ model_inputs = model_inputs.to(self.torch_dtype)
65
+ return model_inputs
66
+
67
+ def _forward(self, model_inputs):
68
+ model_outputs = self.model(**model_inputs)
69
+ return model_outputs
70
+
71
+ def postprocess(self, model_outputs, pool=None, return_tensors=False):
72
+ pool = pool if pool is not None else False
73
+
74
+ if pool:
75
+ if "pooler_output" not in model_outputs:
76
+ raise ValueError(
77
+ "No pooled output was returned. Make sure the model has a `pooler` layer when using the `pool` option."
78
+ )
79
+ outputs = model_outputs["pooler_output"]
80
+ else:
81
+ # [0] is the first available tensor, logits or last_hidden_state.
82
+ outputs = model_outputs[0]
83
+
84
+ if return_tensors:
85
+ return outputs
86
+ if self.framework == "pt":
87
+ return outputs.tolist()
88
+ elif self.framework == "tf":
89
+ return outputs.numpy().tolist()
90
+
91
+ def __call__(self, *args, **kwargs):
92
+ """
93
+ Extract the features of the input(s).
94
+
95
+ Args:
96
+ images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
97
+ The pipeline handles three types of images:
98
+
99
+ - A string containing a http link pointing to an image
100
+ - A string containing a local path to an image
101
+ - An image loaded in PIL directly
102
+
103
+ The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
104
+ Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
105
+ images.
106
+ timeout (`float`, *optional*, defaults to None):
107
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is used and
108
+ the call may block forever.
109
+ Return:
110
+ A nested list of `float`: The features computed by the model.
111
+ """
112
+ return super().__call__(*args, **kwargs)
.venv/lib/python3.11/site-packages/transformers/pipelines/image_segmentation.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Union
2
+
3
+ import numpy as np
4
+
5
+ from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
6
+ from .base import Pipeline, build_pipeline_init_args
7
+
8
+
9
+ if is_vision_available():
10
+ from PIL import Image
11
+
12
+ from ..image_utils import load_image
13
+
14
+ if is_torch_available():
15
+ from ..models.auto.modeling_auto import (
16
+ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
17
+ MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES,
18
+ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
19
+ MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES,
20
+ )
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ Prediction = Dict[str, Any]
27
+ Predictions = List[Prediction]
28
+
29
+
30
+ @add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
31
+ class ImageSegmentationPipeline(Pipeline):
32
+ """
33
+ Image segmentation pipeline using any `AutoModelForXXXSegmentation`. This pipeline predicts masks of objects and
34
+ their classes.
35
+
36
+ Example:
37
+
38
+ ```python
39
+ >>> from transformers import pipeline
40
+
41
+ >>> segmenter = pipeline(model="facebook/detr-resnet-50-panoptic")
42
+ >>> segments = segmenter("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
43
+ >>> len(segments)
44
+ 2
45
+
46
+ >>> segments[0]["label"]
47
+ 'bird'
48
+
49
+ >>> segments[1]["label"]
50
+ 'bird'
51
+
52
+ >>> type(segments[0]["mask"]) # This is a black and white mask showing where is the bird on the original image.
53
+ <class 'PIL.Image.Image'>
54
+
55
+ >>> segments[0]["mask"].size
56
+ (768, 512)
57
+ ```
58
+
59
+
60
+ This image segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
61
+ `"image-segmentation"`.
62
+
63
+ See the list of available models on
64
+ [huggingface.co/models](https://huggingface.co/models?filter=image-segmentation).
65
+ """
66
+
67
+ def __init__(self, *args, **kwargs):
68
+ super().__init__(*args, **kwargs)
69
+
70
+ if self.framework == "tf":
71
+ raise ValueError(f"The {self.__class__} is only available in PyTorch.")
72
+
73
+ requires_backends(self, "vision")
74
+ mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES.copy()
75
+ mapping.update(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES)
76
+ mapping.update(MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES)
77
+ mapping.update(MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES)
78
+ self.check_model_type(mapping)
79
+
80
+ def _sanitize_parameters(self, **kwargs):
81
+ preprocess_kwargs = {}
82
+ postprocess_kwargs = {}
83
+ if "subtask" in kwargs:
84
+ postprocess_kwargs["subtask"] = kwargs["subtask"]
85
+ preprocess_kwargs["subtask"] = kwargs["subtask"]
86
+ if "threshold" in kwargs:
87
+ postprocess_kwargs["threshold"] = kwargs["threshold"]
88
+ if "mask_threshold" in kwargs:
89
+ postprocess_kwargs["mask_threshold"] = kwargs["mask_threshold"]
90
+ if "overlap_mask_area_threshold" in kwargs:
91
+ postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
92
+ if "timeout" in kwargs:
93
+ preprocess_kwargs["timeout"] = kwargs["timeout"]
94
+
95
+ return preprocess_kwargs, {}, postprocess_kwargs
96
+
97
+ def __call__(self, inputs=None, **kwargs) -> Union[Predictions, List[Prediction]]:
98
+ """
99
+ Perform segmentation (detect masks & classes) in the image(s) passed as inputs.
100
+
101
+ Args:
102
+ inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
103
+ The pipeline handles three types of images:
104
+
105
+ - A string containing an HTTP(S) link pointing to an image
106
+ - A string containing a local path to an image
107
+ - An image loaded in PIL directly
108
+
109
+ The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
110
+ same format: all as HTTP(S) links, all as local paths, or all as PIL images.
111
+ subtask (`str`, *optional*):
112
+ Segmentation task to be performed, choose [`semantic`, `instance` and `panoptic`] depending on model
113
+ capabilities. If not set, the pipeline will attempt tp resolve in the following order:
114
+ `panoptic`, `instance`, `semantic`.
115
+ threshold (`float`, *optional*, defaults to 0.9):
116
+ Probability threshold to filter out predicted masks.
117
+ mask_threshold (`float`, *optional*, defaults to 0.5):
118
+ Threshold to use when turning the predicted masks into binary values.
119
+ overlap_mask_area_threshold (`float`, *optional*, defaults to 0.5):
120
+ Mask overlap threshold to eliminate small, disconnected segments.
121
+ timeout (`float`, *optional*, defaults to None):
122
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
123
+ the call may block forever.
124
+
125
+ Return:
126
+ A dictionary or a list of dictionaries containing the result. If the input is a single image, will return a
127
+ list of dictionaries, if the input is a list of several images, will return a list of list of dictionaries
128
+ corresponding to each image.
129
+
130
+ The dictionaries contain the mask, label and score (where applicable) of each detected object and contains
131
+ the following keys:
132
+
133
+ - **label** (`str`) -- The class label identified by the model.
134
+ - **mask** (`PIL.Image`) -- A binary mask of the detected object as a Pil Image of shape (width, height) of
135
+ the original image. Returns a mask filled with zeros if no object is found.
136
+ - **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of the
137
+ "object" described by the label and the mask.
138
+ """
139
+ # After deprecation of this is completed, remove the default `None` value for `images`
140
+ if "images" in kwargs:
141
+ inputs = kwargs.pop("images")
142
+ if inputs is None:
143
+ raise ValueError("Cannot call the image-classification pipeline without an inputs argument!")
144
+ return super().__call__(inputs, **kwargs)
145
+
146
+ def preprocess(self, image, subtask=None, timeout=None):
147
+ image = load_image(image, timeout=timeout)
148
+ target_size = [(image.height, image.width)]
149
+ if self.model.config.__class__.__name__ == "OneFormerConfig":
150
+ if subtask is None:
151
+ kwargs = {}
152
+ else:
153
+ kwargs = {"task_inputs": [subtask]}
154
+ inputs = self.image_processor(images=[image], return_tensors="pt", **kwargs)
155
+ if self.framework == "pt":
156
+ inputs = inputs.to(self.torch_dtype)
157
+ inputs["task_inputs"] = self.tokenizer(
158
+ inputs["task_inputs"],
159
+ padding="max_length",
160
+ max_length=self.model.config.task_seq_len,
161
+ return_tensors=self.framework,
162
+ )["input_ids"]
163
+ else:
164
+ inputs = self.image_processor(images=[image], return_tensors="pt")
165
+ if self.framework == "pt":
166
+ inputs = inputs.to(self.torch_dtype)
167
+ inputs["target_size"] = target_size
168
+ return inputs
169
+
170
+ def _forward(self, model_inputs):
171
+ target_size = model_inputs.pop("target_size")
172
+ model_outputs = self.model(**model_inputs)
173
+ model_outputs["target_size"] = target_size
174
+ return model_outputs
175
+
176
+ def postprocess(
177
+ self, model_outputs, subtask=None, threshold=0.9, mask_threshold=0.5, overlap_mask_area_threshold=0.5
178
+ ):
179
+ fn = None
180
+ if subtask in {"panoptic", None} and hasattr(self.image_processor, "post_process_panoptic_segmentation"):
181
+ fn = self.image_processor.post_process_panoptic_segmentation
182
+ elif subtask in {"instance", None} and hasattr(self.image_processor, "post_process_instance_segmentation"):
183
+ fn = self.image_processor.post_process_instance_segmentation
184
+
185
+ if fn is not None:
186
+ outputs = fn(
187
+ model_outputs,
188
+ threshold=threshold,
189
+ mask_threshold=mask_threshold,
190
+ overlap_mask_area_threshold=overlap_mask_area_threshold,
191
+ target_sizes=model_outputs["target_size"],
192
+ )[0]
193
+
194
+ annotation = []
195
+ segmentation = outputs["segmentation"]
196
+
197
+ for segment in outputs["segments_info"]:
198
+ mask = (segmentation == segment["id"]) * 255
199
+ mask = Image.fromarray(mask.numpy().astype(np.uint8), mode="L")
200
+ label = self.model.config.id2label[segment["label_id"]]
201
+ score = segment["score"]
202
+ annotation.append({"score": score, "label": label, "mask": mask})
203
+
204
+ elif subtask in {"semantic", None} and hasattr(self.image_processor, "post_process_semantic_segmentation"):
205
+ outputs = self.image_processor.post_process_semantic_segmentation(
206
+ model_outputs, target_sizes=model_outputs["target_size"]
207
+ )[0]
208
+
209
+ annotation = []
210
+ segmentation = outputs.numpy()
211
+ labels = np.unique(segmentation)
212
+
213
+ for label in labels:
214
+ mask = (segmentation == label) * 255
215
+ mask = Image.fromarray(mask.astype(np.uint8), mode="L")
216
+ label = self.model.config.id2label[label]
217
+ annotation.append({"score": None, "label": label, "mask": mask})
218
+ else:
219
+ raise ValueError(f"Subtask {subtask} is not supported for model {type(self.model)}")
220
+ return annotation