Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/transformers/__pycache__/image_processing_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/__pycache__/image_transforms.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/__pycache__/modelcard.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/__pycache__/training_args_tf.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/beam_constraints.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/generation/__pycache__/beam_search.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__init__.py +1178 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/audio_classification.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/audio_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/automatic_speech_recognition.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/depth_estimation.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/document_question_answering.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/feature_extraction.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/fill_mask.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_classification.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_feature_extraction.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_segmentation.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_text_to_text.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_to_image.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_to_text.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/mask_generation.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/object_detection.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/pt_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/question_answering.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/table_question_answering.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/text2text_generation.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/text_classification.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/text_generation.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/text_to_audio.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/token_classification.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/video_classification.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/visual_question_answering.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/zero_shot_audio_classification.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/zero_shot_classification.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/zero_shot_image_classification.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/zero_shot_object_detection.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/audio_classification.py +234 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/audio_utils.py +297 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/automatic_speech_recognition.py +766 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/base.py +1484 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/depth_estimation.py +133 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/document_question_answering.py +516 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/feature_extraction.py +86 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/fill_mask.py +273 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/image_classification.py +226 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/image_feature_extraction.py +112 -0
- .venv/lib/python3.11/site-packages/transformers/pipelines/image_segmentation.py +220 -0
.venv/lib/python3.11/site-packages/transformers/__pycache__/image_processing_utils.cpython-311.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/__pycache__/image_transforms.cpython-311.pyc
ADDED
|
Binary file (40.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/__pycache__/modelcard.cpython-311.pyc
ADDED
|
Binary file (44.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/__pycache__/training_args_tf.cpython-311.pyc
ADDED
|
Binary file (16.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (9.07 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/beam_constraints.cpython-311.pyc
ADDED
|
Binary file (25.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/generation/__pycache__/beam_search.cpython-311.pyc
ADDED
|
Binary file (49.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__init__.py
ADDED
|
@@ -0,0 +1,1178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import warnings
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
from huggingface_hub import model_info
|
| 22 |
+
|
| 23 |
+
from ..configuration_utils import PretrainedConfig
|
| 24 |
+
from ..dynamic_module_utils import get_class_from_dynamic_module
|
| 25 |
+
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
| 26 |
+
from ..image_processing_utils import BaseImageProcessor
|
| 27 |
+
from ..models.auto.configuration_auto import AutoConfig
|
| 28 |
+
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
|
| 29 |
+
from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING, AutoImageProcessor
|
| 30 |
+
from ..models.auto.modeling_auto import AutoModelForDepthEstimation, AutoModelForImageToImage
|
| 31 |
+
from ..models.auto.processing_auto import PROCESSOR_MAPPING, AutoProcessor
|
| 32 |
+
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
| 33 |
+
from ..processing_utils import ProcessorMixin
|
| 34 |
+
from ..tokenization_utils import PreTrainedTokenizer
|
| 35 |
+
from ..utils import (
|
| 36 |
+
CONFIG_NAME,
|
| 37 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
| 38 |
+
cached_file,
|
| 39 |
+
extract_commit_hash,
|
| 40 |
+
find_adapter_config_file,
|
| 41 |
+
is_kenlm_available,
|
| 42 |
+
is_offline_mode,
|
| 43 |
+
is_peft_available,
|
| 44 |
+
is_pyctcdecode_available,
|
| 45 |
+
is_tf_available,
|
| 46 |
+
is_torch_available,
|
| 47 |
+
logging,
|
| 48 |
+
)
|
| 49 |
+
from .audio_classification import AudioClassificationPipeline
|
| 50 |
+
from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
|
| 51 |
+
from .base import (
|
| 52 |
+
ArgumentHandler,
|
| 53 |
+
CsvPipelineDataFormat,
|
| 54 |
+
JsonPipelineDataFormat,
|
| 55 |
+
PipedPipelineDataFormat,
|
| 56 |
+
Pipeline,
|
| 57 |
+
PipelineDataFormat,
|
| 58 |
+
PipelineException,
|
| 59 |
+
PipelineRegistry,
|
| 60 |
+
get_default_model_and_revision,
|
| 61 |
+
infer_framework_load_model,
|
| 62 |
+
)
|
| 63 |
+
from .depth_estimation import DepthEstimationPipeline
|
| 64 |
+
from .document_question_answering import DocumentQuestionAnsweringPipeline
|
| 65 |
+
from .feature_extraction import FeatureExtractionPipeline
|
| 66 |
+
from .fill_mask import FillMaskPipeline
|
| 67 |
+
from .image_classification import ImageClassificationPipeline
|
| 68 |
+
from .image_feature_extraction import ImageFeatureExtractionPipeline
|
| 69 |
+
from .image_segmentation import ImageSegmentationPipeline
|
| 70 |
+
from .image_text_to_text import ImageTextToTextPipeline
|
| 71 |
+
from .image_to_image import ImageToImagePipeline
|
| 72 |
+
from .image_to_text import ImageToTextPipeline
|
| 73 |
+
from .mask_generation import MaskGenerationPipeline
|
| 74 |
+
from .object_detection import ObjectDetectionPipeline
|
| 75 |
+
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
|
| 76 |
+
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
|
| 77 |
+
from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
|
| 78 |
+
from .text_classification import TextClassificationPipeline
|
| 79 |
+
from .text_generation import TextGenerationPipeline
|
| 80 |
+
from .text_to_audio import TextToAudioPipeline
|
| 81 |
+
from .token_classification import (
|
| 82 |
+
AggregationStrategy,
|
| 83 |
+
NerPipeline,
|
| 84 |
+
TokenClassificationArgumentHandler,
|
| 85 |
+
TokenClassificationPipeline,
|
| 86 |
+
)
|
| 87 |
+
from .video_classification import VideoClassificationPipeline
|
| 88 |
+
from .visual_question_answering import VisualQuestionAnsweringPipeline
|
| 89 |
+
from .zero_shot_audio_classification import ZeroShotAudioClassificationPipeline
|
| 90 |
+
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
|
| 91 |
+
from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
|
| 92 |
+
from .zero_shot_object_detection import ZeroShotObjectDetectionPipeline
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if is_tf_available():
|
| 96 |
+
import tensorflow as tf
|
| 97 |
+
|
| 98 |
+
from ..models.auto.modeling_tf_auto import (
|
| 99 |
+
TFAutoModel,
|
| 100 |
+
TFAutoModelForCausalLM,
|
| 101 |
+
TFAutoModelForImageClassification,
|
| 102 |
+
TFAutoModelForMaskedLM,
|
| 103 |
+
TFAutoModelForQuestionAnswering,
|
| 104 |
+
TFAutoModelForSeq2SeqLM,
|
| 105 |
+
TFAutoModelForSequenceClassification,
|
| 106 |
+
TFAutoModelForTableQuestionAnswering,
|
| 107 |
+
TFAutoModelForTokenClassification,
|
| 108 |
+
TFAutoModelForVision2Seq,
|
| 109 |
+
TFAutoModelForZeroShotImageClassification,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if is_torch_available():
|
| 113 |
+
import torch
|
| 114 |
+
|
| 115 |
+
from ..models.auto.modeling_auto import (
|
| 116 |
+
AutoModel,
|
| 117 |
+
AutoModelForAudioClassification,
|
| 118 |
+
AutoModelForCausalLM,
|
| 119 |
+
AutoModelForCTC,
|
| 120 |
+
AutoModelForDocumentQuestionAnswering,
|
| 121 |
+
AutoModelForImageClassification,
|
| 122 |
+
AutoModelForImageSegmentation,
|
| 123 |
+
AutoModelForImageTextToText,
|
| 124 |
+
AutoModelForMaskedLM,
|
| 125 |
+
AutoModelForMaskGeneration,
|
| 126 |
+
AutoModelForObjectDetection,
|
| 127 |
+
AutoModelForQuestionAnswering,
|
| 128 |
+
AutoModelForSemanticSegmentation,
|
| 129 |
+
AutoModelForSeq2SeqLM,
|
| 130 |
+
AutoModelForSequenceClassification,
|
| 131 |
+
AutoModelForSpeechSeq2Seq,
|
| 132 |
+
AutoModelForTableQuestionAnswering,
|
| 133 |
+
AutoModelForTextToSpectrogram,
|
| 134 |
+
AutoModelForTextToWaveform,
|
| 135 |
+
AutoModelForTokenClassification,
|
| 136 |
+
AutoModelForVideoClassification,
|
| 137 |
+
AutoModelForVision2Seq,
|
| 138 |
+
AutoModelForVisualQuestionAnswering,
|
| 139 |
+
AutoModelForZeroShotImageClassification,
|
| 140 |
+
AutoModelForZeroShotObjectDetection,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
if TYPE_CHECKING:
|
| 145 |
+
from ..modeling_tf_utils import TFPreTrainedModel
|
| 146 |
+
from ..modeling_utils import PreTrainedModel
|
| 147 |
+
from ..tokenization_utils_fast import PreTrainedTokenizerFast
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
logger = logging.get_logger(__name__)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# Register all the supported tasks here
|
| 154 |
+
TASK_ALIASES = {
|
| 155 |
+
"sentiment-analysis": "text-classification",
|
| 156 |
+
"ner": "token-classification",
|
| 157 |
+
"vqa": "visual-question-answering",
|
| 158 |
+
"text-to-speech": "text-to-audio",
|
| 159 |
+
}
|
| 160 |
+
SUPPORTED_TASKS = {
|
| 161 |
+
"audio-classification": {
|
| 162 |
+
"impl": AudioClassificationPipeline,
|
| 163 |
+
"tf": (),
|
| 164 |
+
"pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
|
| 165 |
+
"default": {"model": {"pt": ("superb/wav2vec2-base-superb-ks", "372e048")}},
|
| 166 |
+
"type": "audio",
|
| 167 |
+
},
|
| 168 |
+
"automatic-speech-recognition": {
|
| 169 |
+
"impl": AutomaticSpeechRecognitionPipeline,
|
| 170 |
+
"tf": (),
|
| 171 |
+
"pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),
|
| 172 |
+
"default": {"model": {"pt": ("facebook/wav2vec2-base-960h", "22aad52")}},
|
| 173 |
+
"type": "multimodal",
|
| 174 |
+
},
|
| 175 |
+
"text-to-audio": {
|
| 176 |
+
"impl": TextToAudioPipeline,
|
| 177 |
+
"tf": (),
|
| 178 |
+
"pt": (AutoModelForTextToWaveform, AutoModelForTextToSpectrogram) if is_torch_available() else (),
|
| 179 |
+
"default": {"model": {"pt": ("suno/bark-small", "1dbd7a1")}},
|
| 180 |
+
"type": "text",
|
| 181 |
+
},
|
| 182 |
+
"feature-extraction": {
|
| 183 |
+
"impl": FeatureExtractionPipeline,
|
| 184 |
+
"tf": (TFAutoModel,) if is_tf_available() else (),
|
| 185 |
+
"pt": (AutoModel,) if is_torch_available() else (),
|
| 186 |
+
"default": {
|
| 187 |
+
"model": {
|
| 188 |
+
"pt": ("distilbert/distilbert-base-cased", "6ea8117"),
|
| 189 |
+
"tf": ("distilbert/distilbert-base-cased", "6ea8117"),
|
| 190 |
+
}
|
| 191 |
+
},
|
| 192 |
+
"type": "multimodal",
|
| 193 |
+
},
|
| 194 |
+
"text-classification": {
|
| 195 |
+
"impl": TextClassificationPipeline,
|
| 196 |
+
"tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
|
| 197 |
+
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
|
| 198 |
+
"default": {
|
| 199 |
+
"model": {
|
| 200 |
+
"pt": ("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "714eb0f"),
|
| 201 |
+
"tf": ("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "714eb0f"),
|
| 202 |
+
},
|
| 203 |
+
},
|
| 204 |
+
"type": "text",
|
| 205 |
+
},
|
| 206 |
+
"token-classification": {
|
| 207 |
+
"impl": TokenClassificationPipeline,
|
| 208 |
+
"tf": (TFAutoModelForTokenClassification,) if is_tf_available() else (),
|
| 209 |
+
"pt": (AutoModelForTokenClassification,) if is_torch_available() else (),
|
| 210 |
+
"default": {
|
| 211 |
+
"model": {
|
| 212 |
+
"pt": ("dbmdz/bert-large-cased-finetuned-conll03-english", "4c53496"),
|
| 213 |
+
"tf": ("dbmdz/bert-large-cased-finetuned-conll03-english", "4c53496"),
|
| 214 |
+
},
|
| 215 |
+
},
|
| 216 |
+
"type": "text",
|
| 217 |
+
},
|
| 218 |
+
"question-answering": {
|
| 219 |
+
"impl": QuestionAnsweringPipeline,
|
| 220 |
+
"tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (),
|
| 221 |
+
"pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (),
|
| 222 |
+
"default": {
|
| 223 |
+
"model": {
|
| 224 |
+
"pt": ("distilbert/distilbert-base-cased-distilled-squad", "564e9b5"),
|
| 225 |
+
"tf": ("distilbert/distilbert-base-cased-distilled-squad", "564e9b5"),
|
| 226 |
+
},
|
| 227 |
+
},
|
| 228 |
+
"type": "text",
|
| 229 |
+
},
|
| 230 |
+
"table-question-answering": {
|
| 231 |
+
"impl": TableQuestionAnsweringPipeline,
|
| 232 |
+
"pt": (AutoModelForTableQuestionAnswering,) if is_torch_available() else (),
|
| 233 |
+
"tf": (TFAutoModelForTableQuestionAnswering,) if is_tf_available() else (),
|
| 234 |
+
"default": {
|
| 235 |
+
"model": {
|
| 236 |
+
"pt": ("google/tapas-base-finetuned-wtq", "e3dde19"),
|
| 237 |
+
"tf": ("google/tapas-base-finetuned-wtq", "e3dde19"),
|
| 238 |
+
},
|
| 239 |
+
},
|
| 240 |
+
"type": "text",
|
| 241 |
+
},
|
| 242 |
+
"visual-question-answering": {
|
| 243 |
+
"impl": VisualQuestionAnsweringPipeline,
|
| 244 |
+
"pt": (AutoModelForVisualQuestionAnswering,) if is_torch_available() else (),
|
| 245 |
+
"tf": (),
|
| 246 |
+
"default": {
|
| 247 |
+
"model": {"pt": ("dandelin/vilt-b32-finetuned-vqa", "d0a1f6a")},
|
| 248 |
+
},
|
| 249 |
+
"type": "multimodal",
|
| 250 |
+
},
|
| 251 |
+
"document-question-answering": {
|
| 252 |
+
"impl": DocumentQuestionAnsweringPipeline,
|
| 253 |
+
"pt": (AutoModelForDocumentQuestionAnswering,) if is_torch_available() else (),
|
| 254 |
+
"tf": (),
|
| 255 |
+
"default": {
|
| 256 |
+
"model": {"pt": ("impira/layoutlm-document-qa", "beed3c4")},
|
| 257 |
+
},
|
| 258 |
+
"type": "multimodal",
|
| 259 |
+
},
|
| 260 |
+
"fill-mask": {
|
| 261 |
+
"impl": FillMaskPipeline,
|
| 262 |
+
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
|
| 263 |
+
"pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
|
| 264 |
+
"default": {
|
| 265 |
+
"model": {
|
| 266 |
+
"pt": ("distilbert/distilroberta-base", "fb53ab8"),
|
| 267 |
+
"tf": ("distilbert/distilroberta-base", "fb53ab8"),
|
| 268 |
+
}
|
| 269 |
+
},
|
| 270 |
+
"type": "text",
|
| 271 |
+
},
|
| 272 |
+
"summarization": {
|
| 273 |
+
"impl": SummarizationPipeline,
|
| 274 |
+
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
|
| 275 |
+
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
|
| 276 |
+
"default": {
|
| 277 |
+
"model": {"pt": ("sshleifer/distilbart-cnn-12-6", "a4f8f3e"), "tf": ("google-t5/t5-small", "df1b051")}
|
| 278 |
+
},
|
| 279 |
+
"type": "text",
|
| 280 |
+
},
|
| 281 |
+
# This task is a special case as it's parametrized by SRC, TGT languages.
|
| 282 |
+
"translation": {
|
| 283 |
+
"impl": TranslationPipeline,
|
| 284 |
+
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
|
| 285 |
+
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
|
| 286 |
+
"default": {
|
| 287 |
+
("en", "fr"): {"model": {"pt": ("google-t5/t5-base", "a9723ea"), "tf": ("google-t5/t5-base", "a9723ea")}},
|
| 288 |
+
("en", "de"): {"model": {"pt": ("google-t5/t5-base", "a9723ea"), "tf": ("google-t5/t5-base", "a9723ea")}},
|
| 289 |
+
("en", "ro"): {"model": {"pt": ("google-t5/t5-base", "a9723ea"), "tf": ("google-t5/t5-base", "a9723ea")}},
|
| 290 |
+
},
|
| 291 |
+
"type": "text",
|
| 292 |
+
},
|
| 293 |
+
"text2text-generation": {
|
| 294 |
+
"impl": Text2TextGenerationPipeline,
|
| 295 |
+
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
|
| 296 |
+
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
|
| 297 |
+
"default": {"model": {"pt": ("google-t5/t5-base", "a9723ea"), "tf": ("google-t5/t5-base", "a9723ea")}},
|
| 298 |
+
"type": "text",
|
| 299 |
+
},
|
| 300 |
+
"text-generation": {
|
| 301 |
+
"impl": TextGenerationPipeline,
|
| 302 |
+
"tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
|
| 303 |
+
"pt": (AutoModelForCausalLM,) if is_torch_available() else (),
|
| 304 |
+
"default": {"model": {"pt": ("openai-community/gpt2", "607a30d"), "tf": ("openai-community/gpt2", "607a30d")}},
|
| 305 |
+
"type": "text",
|
| 306 |
+
},
|
| 307 |
+
"zero-shot-classification": {
|
| 308 |
+
"impl": ZeroShotClassificationPipeline,
|
| 309 |
+
"tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
|
| 310 |
+
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
|
| 311 |
+
"default": {
|
| 312 |
+
"model": {
|
| 313 |
+
"pt": ("facebook/bart-large-mnli", "d7645e1"),
|
| 314 |
+
"tf": ("FacebookAI/roberta-large-mnli", "2a8f12d"),
|
| 315 |
+
},
|
| 316 |
+
"config": {
|
| 317 |
+
"pt": ("facebook/bart-large-mnli", "d7645e1"),
|
| 318 |
+
"tf": ("FacebookAI/roberta-large-mnli", "2a8f12d"),
|
| 319 |
+
},
|
| 320 |
+
},
|
| 321 |
+
"type": "text",
|
| 322 |
+
},
|
| 323 |
+
"zero-shot-image-classification": {
|
| 324 |
+
"impl": ZeroShotImageClassificationPipeline,
|
| 325 |
+
"tf": (TFAutoModelForZeroShotImageClassification,) if is_tf_available() else (),
|
| 326 |
+
"pt": (AutoModelForZeroShotImageClassification,) if is_torch_available() else (),
|
| 327 |
+
"default": {
|
| 328 |
+
"model": {
|
| 329 |
+
"pt": ("openai/clip-vit-base-patch32", "3d74acf"),
|
| 330 |
+
"tf": ("openai/clip-vit-base-patch32", "3d74acf"),
|
| 331 |
+
}
|
| 332 |
+
},
|
| 333 |
+
"type": "multimodal",
|
| 334 |
+
},
|
| 335 |
+
"zero-shot-audio-classification": {
|
| 336 |
+
"impl": ZeroShotAudioClassificationPipeline,
|
| 337 |
+
"tf": (),
|
| 338 |
+
"pt": (AutoModel,) if is_torch_available() else (),
|
| 339 |
+
"default": {
|
| 340 |
+
"model": {
|
| 341 |
+
"pt": ("laion/clap-htsat-fused", "cca9e28"),
|
| 342 |
+
}
|
| 343 |
+
},
|
| 344 |
+
"type": "multimodal",
|
| 345 |
+
},
|
| 346 |
+
"image-classification": {
|
| 347 |
+
"impl": ImageClassificationPipeline,
|
| 348 |
+
"tf": (TFAutoModelForImageClassification,) if is_tf_available() else (),
|
| 349 |
+
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
|
| 350 |
+
"default": {
|
| 351 |
+
"model": {
|
| 352 |
+
"pt": ("google/vit-base-patch16-224", "3f49326"),
|
| 353 |
+
"tf": ("google/vit-base-patch16-224", "3f49326"),
|
| 354 |
+
}
|
| 355 |
+
},
|
| 356 |
+
"type": "image",
|
| 357 |
+
},
|
| 358 |
+
"image-feature-extraction": {
|
| 359 |
+
"impl": ImageFeatureExtractionPipeline,
|
| 360 |
+
"tf": (TFAutoModel,) if is_tf_available() else (),
|
| 361 |
+
"pt": (AutoModel,) if is_torch_available() else (),
|
| 362 |
+
"default": {
|
| 363 |
+
"model": {
|
| 364 |
+
"pt": ("google/vit-base-patch16-224", "3f49326"),
|
| 365 |
+
"tf": ("google/vit-base-patch16-224", "3f49326"),
|
| 366 |
+
}
|
| 367 |
+
},
|
| 368 |
+
"type": "image",
|
| 369 |
+
},
|
| 370 |
+
"image-segmentation": {
|
| 371 |
+
"impl": ImageSegmentationPipeline,
|
| 372 |
+
"tf": (),
|
| 373 |
+
"pt": (AutoModelForImageSegmentation, AutoModelForSemanticSegmentation) if is_torch_available() else (),
|
| 374 |
+
"default": {"model": {"pt": ("facebook/detr-resnet-50-panoptic", "d53b52a")}},
|
| 375 |
+
"type": "multimodal",
|
| 376 |
+
},
|
| 377 |
+
"image-to-text": {
|
| 378 |
+
"impl": ImageToTextPipeline,
|
| 379 |
+
"tf": (TFAutoModelForVision2Seq,) if is_tf_available() else (),
|
| 380 |
+
"pt": (AutoModelForVision2Seq,) if is_torch_available() else (),
|
| 381 |
+
"default": {
|
| 382 |
+
"model": {
|
| 383 |
+
"pt": ("ydshieh/vit-gpt2-coco-en", "5bebf1e"),
|
| 384 |
+
"tf": ("ydshieh/vit-gpt2-coco-en", "5bebf1e"),
|
| 385 |
+
}
|
| 386 |
+
},
|
| 387 |
+
"type": "multimodal",
|
| 388 |
+
},
|
| 389 |
+
"image-text-to-text": {
|
| 390 |
+
"impl": ImageTextToTextPipeline,
|
| 391 |
+
"tf": (),
|
| 392 |
+
"pt": (AutoModelForImageTextToText,) if is_torch_available() else (),
|
| 393 |
+
"default": {
|
| 394 |
+
"model": {
|
| 395 |
+
"pt": ("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "2c9ba3b"),
|
| 396 |
+
}
|
| 397 |
+
},
|
| 398 |
+
"type": "multimodal",
|
| 399 |
+
},
|
| 400 |
+
"object-detection": {
|
| 401 |
+
"impl": ObjectDetectionPipeline,
|
| 402 |
+
"tf": (),
|
| 403 |
+
"pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
|
| 404 |
+
"default": {"model": {"pt": ("facebook/detr-resnet-50", "1d5f47b")}},
|
| 405 |
+
"type": "multimodal",
|
| 406 |
+
},
|
| 407 |
+
"zero-shot-object-detection": {
|
| 408 |
+
"impl": ZeroShotObjectDetectionPipeline,
|
| 409 |
+
"tf": (),
|
| 410 |
+
"pt": (AutoModelForZeroShotObjectDetection,) if is_torch_available() else (),
|
| 411 |
+
"default": {"model": {"pt": ("google/owlvit-base-patch32", "cbc355f")}},
|
| 412 |
+
"type": "multimodal",
|
| 413 |
+
},
|
| 414 |
+
"depth-estimation": {
|
| 415 |
+
"impl": DepthEstimationPipeline,
|
| 416 |
+
"tf": (),
|
| 417 |
+
"pt": (AutoModelForDepthEstimation,) if is_torch_available() else (),
|
| 418 |
+
"default": {"model": {"pt": ("Intel/dpt-large", "bc15f29")}},
|
| 419 |
+
"type": "image",
|
| 420 |
+
},
|
| 421 |
+
"video-classification": {
|
| 422 |
+
"impl": VideoClassificationPipeline,
|
| 423 |
+
"tf": (),
|
| 424 |
+
"pt": (AutoModelForVideoClassification,) if is_torch_available() else (),
|
| 425 |
+
"default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "488eb9a")}},
|
| 426 |
+
"type": "video",
|
| 427 |
+
},
|
| 428 |
+
"mask-generation": {
|
| 429 |
+
"impl": MaskGenerationPipeline,
|
| 430 |
+
"tf": (),
|
| 431 |
+
"pt": (AutoModelForMaskGeneration,) if is_torch_available() else (),
|
| 432 |
+
"default": {"model": {"pt": ("facebook/sam-vit-huge", "87aecf0")}},
|
| 433 |
+
"type": "multimodal",
|
| 434 |
+
},
|
| 435 |
+
"image-to-image": {
|
| 436 |
+
"impl": ImageToImagePipeline,
|
| 437 |
+
"tf": (),
|
| 438 |
+
"pt": (AutoModelForImageToImage,) if is_torch_available() else (),
|
| 439 |
+
"default": {"model": {"pt": ("caidas/swin2SR-classical-sr-x2-64", "cee1c92")}},
|
| 440 |
+
"type": "image",
|
| 441 |
+
},
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
NO_FEATURE_EXTRACTOR_TASKS = set()
|
| 445 |
+
NO_IMAGE_PROCESSOR_TASKS = set()
|
| 446 |
+
NO_TOKENIZER_TASKS = set()
|
| 447 |
+
|
| 448 |
+
# Those model configs are special, they are generic over their task, meaning
|
| 449 |
+
# any tokenizer/feature_extractor might be use for a given model so we cannot
|
| 450 |
+
# use the statically defined TOKENIZER_MAPPING and FEATURE_EXTRACTOR_MAPPING to
|
| 451 |
+
# see if the model defines such objects or not.
|
| 452 |
+
MULTI_MODEL_AUDIO_CONFIGS = {"SpeechEncoderDecoderConfig"}
|
| 453 |
+
MULTI_MODEL_VISION_CONFIGS = {"VisionEncoderDecoderConfig", "VisionTextDualEncoderConfig"}
|
| 454 |
+
for task, values in SUPPORTED_TASKS.items():
|
| 455 |
+
if values["type"] == "text":
|
| 456 |
+
NO_FEATURE_EXTRACTOR_TASKS.add(task)
|
| 457 |
+
NO_IMAGE_PROCESSOR_TASKS.add(task)
|
| 458 |
+
elif values["type"] in {"image", "video"}:
|
| 459 |
+
NO_TOKENIZER_TASKS.add(task)
|
| 460 |
+
elif values["type"] in {"audio"}:
|
| 461 |
+
NO_TOKENIZER_TASKS.add(task)
|
| 462 |
+
NO_IMAGE_PROCESSOR_TASKS.add(task)
|
| 463 |
+
elif values["type"] != "multimodal":
|
| 464 |
+
raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}")
|
| 465 |
+
|
| 466 |
+
PIPELINE_REGISTRY = PipelineRegistry(supported_tasks=SUPPORTED_TASKS, task_aliases=TASK_ALIASES)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def get_supported_tasks() -> List[str]:
|
| 470 |
+
"""
|
| 471 |
+
Returns a list of supported task strings.
|
| 472 |
+
"""
|
| 473 |
+
return PIPELINE_REGISTRY.get_supported_tasks()
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def get_task(model: str, token: Optional[str] = None, **deprecated_kwargs) -> str:
|
| 477 |
+
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
|
| 478 |
+
if use_auth_token is not None:
|
| 479 |
+
warnings.warn(
|
| 480 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 481 |
+
FutureWarning,
|
| 482 |
+
)
|
| 483 |
+
if token is not None:
|
| 484 |
+
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
| 485 |
+
token = use_auth_token
|
| 486 |
+
|
| 487 |
+
if is_offline_mode():
|
| 488 |
+
raise RuntimeError("You cannot infer task automatically within `pipeline` when using offline mode")
|
| 489 |
+
try:
|
| 490 |
+
info = model_info(model, token=token)
|
| 491 |
+
except Exception as e:
|
| 492 |
+
raise RuntimeError(f"Instantiating a pipeline without a task set raised an error: {e}")
|
| 493 |
+
if not info.pipeline_tag:
|
| 494 |
+
raise RuntimeError(
|
| 495 |
+
f"The model {model} does not seem to have a correct `pipeline_tag` set to infer the task automatically"
|
| 496 |
+
)
|
| 497 |
+
if getattr(info, "library_name", "transformers") not in {"transformers", "timm"}:
|
| 498 |
+
raise RuntimeError(f"This model is meant to be used with {info.library_name} not with transformers")
|
| 499 |
+
task = info.pipeline_tag
|
| 500 |
+
return task
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def check_task(task: str) -> Tuple[str, Dict, Any]:
|
| 504 |
+
"""
|
| 505 |
+
Checks an incoming task string, to validate it's correct and return the default Pipeline and Model classes, and
|
| 506 |
+
default models if they exist.
|
| 507 |
+
|
| 508 |
+
Args:
|
| 509 |
+
task (`str`):
|
| 510 |
+
The task defining which pipeline will be returned. Currently accepted tasks are:
|
| 511 |
+
|
| 512 |
+
- `"audio-classification"`
|
| 513 |
+
- `"automatic-speech-recognition"`
|
| 514 |
+
- `"conversational"`
|
| 515 |
+
- `"depth-estimation"`
|
| 516 |
+
- `"document-question-answering"`
|
| 517 |
+
- `"feature-extraction"`
|
| 518 |
+
- `"fill-mask"`
|
| 519 |
+
- `"image-classification"`
|
| 520 |
+
- `"image-feature-extraction"`
|
| 521 |
+
- `"image-segmentation"`
|
| 522 |
+
- `"image-to-text"`
|
| 523 |
+
- `"image-to-image"`
|
| 524 |
+
- `"object-detection"`
|
| 525 |
+
- `"question-answering"`
|
| 526 |
+
- `"summarization"`
|
| 527 |
+
- `"table-question-answering"`
|
| 528 |
+
- `"text2text-generation"`
|
| 529 |
+
- `"text-classification"` (alias `"sentiment-analysis"` available)
|
| 530 |
+
- `"text-generation"`
|
| 531 |
+
- `"text-to-audio"` (alias `"text-to-speech"` available)
|
| 532 |
+
- `"token-classification"` (alias `"ner"` available)
|
| 533 |
+
- `"translation"`
|
| 534 |
+
- `"translation_xx_to_yy"`
|
| 535 |
+
- `"video-classification"`
|
| 536 |
+
- `"visual-question-answering"` (alias `"vqa"` available)
|
| 537 |
+
- `"zero-shot-classification"`
|
| 538 |
+
- `"zero-shot-image-classification"`
|
| 539 |
+
- `"zero-shot-object-detection"`
|
| 540 |
+
|
| 541 |
+
Returns:
|
| 542 |
+
(normalized_task: `str`, task_defaults: `dict`, task_options: (`tuple`, None)) The normalized task name
|
| 543 |
+
(removed alias and options). The actual dictionary required to initialize the pipeline and some extra task
|
| 544 |
+
options for parametrized tasks like "translation_XX_to_YY"
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
"""
|
| 548 |
+
return PIPELINE_REGISTRY.check_task(task)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def clean_custom_task(task_info):
|
| 552 |
+
import transformers
|
| 553 |
+
|
| 554 |
+
if "impl" not in task_info:
|
| 555 |
+
raise RuntimeError("This model introduces a custom pipeline without specifying its implementation.")
|
| 556 |
+
pt_class_names = task_info.get("pt", ())
|
| 557 |
+
if isinstance(pt_class_names, str):
|
| 558 |
+
pt_class_names = [pt_class_names]
|
| 559 |
+
task_info["pt"] = tuple(getattr(transformers, c) for c in pt_class_names)
|
| 560 |
+
tf_class_names = task_info.get("tf", ())
|
| 561 |
+
if isinstance(tf_class_names, str):
|
| 562 |
+
tf_class_names = [tf_class_names]
|
| 563 |
+
task_info["tf"] = tuple(getattr(transformers, c) for c in tf_class_names)
|
| 564 |
+
return task_info, None
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def pipeline(
|
| 568 |
+
task: str = None,
|
| 569 |
+
model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None,
|
| 570 |
+
config: Optional[Union[str, PretrainedConfig]] = None,
|
| 571 |
+
tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None,
|
| 572 |
+
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
|
| 573 |
+
image_processor: Optional[Union[str, BaseImageProcessor]] = None,
|
| 574 |
+
processor: Optional[Union[str, ProcessorMixin]] = None,
|
| 575 |
+
framework: Optional[str] = None,
|
| 576 |
+
revision: Optional[str] = None,
|
| 577 |
+
use_fast: bool = True,
|
| 578 |
+
token: Optional[Union[str, bool]] = None,
|
| 579 |
+
device: Optional[Union[int, str, "torch.device"]] = None,
|
| 580 |
+
device_map=None,
|
| 581 |
+
torch_dtype=None,
|
| 582 |
+
trust_remote_code: Optional[bool] = None,
|
| 583 |
+
model_kwargs: Dict[str, Any] = None,
|
| 584 |
+
pipeline_class: Optional[Any] = None,
|
| 585 |
+
**kwargs,
|
| 586 |
+
) -> Pipeline:
|
| 587 |
+
"""
|
| 588 |
+
Utility factory method to build a [`Pipeline`].
|
| 589 |
+
|
| 590 |
+
A pipeline consists of:
|
| 591 |
+
|
| 592 |
+
- One or more components for pre-processing model inputs, such as a [tokenizer](tokenizer),
|
| 593 |
+
[image_processor](image_processor), [feature_extractor](feature_extractor), or [processor](processors).
|
| 594 |
+
- A [model](model) that generates predictions from the inputs.
|
| 595 |
+
- Optional post-processing steps to refine the model's output, which can also be handled by processors.
|
| 596 |
+
|
| 597 |
+
<Tip>
|
| 598 |
+
While there are such optional arguments as `tokenizer`, `feature_extractor`, `image_processor`, and `processor`,
|
| 599 |
+
they shouldn't be specified all at once. If these components are not provided, `pipeline` will try to load
|
| 600 |
+
required ones automatically. In case you want to provide these components explicitly, please refer to a
|
| 601 |
+
specific pipeline in order to get more details regarding what components are required.
|
| 602 |
+
</Tip>
|
| 603 |
+
|
| 604 |
+
Args:
|
| 605 |
+
task (`str`):
|
| 606 |
+
The task defining which pipeline will be returned. Currently accepted tasks are:
|
| 607 |
+
|
| 608 |
+
- `"audio-classification"`: will return a [`AudioClassificationPipeline`].
|
| 609 |
+
- `"automatic-speech-recognition"`: will return a [`AutomaticSpeechRecognitionPipeline`].
|
| 610 |
+
- `"depth-estimation"`: will return a [`DepthEstimationPipeline`].
|
| 611 |
+
- `"document-question-answering"`: will return a [`DocumentQuestionAnsweringPipeline`].
|
| 612 |
+
- `"feature-extraction"`: will return a [`FeatureExtractionPipeline`].
|
| 613 |
+
- `"fill-mask"`: will return a [`FillMaskPipeline`]:.
|
| 614 |
+
- `"image-classification"`: will return a [`ImageClassificationPipeline`].
|
| 615 |
+
- `"image-feature-extraction"`: will return an [`ImageFeatureExtractionPipeline`].
|
| 616 |
+
- `"image-segmentation"`: will return a [`ImageSegmentationPipeline`].
|
| 617 |
+
- `"image-text-to-text"`: will return a [`ImageTextToTextPipeline`].
|
| 618 |
+
- `"image-to-image"`: will return a [`ImageToImagePipeline`].
|
| 619 |
+
- `"image-to-text"`: will return a [`ImageToTextPipeline`].
|
| 620 |
+
- `"mask-generation"`: will return a [`MaskGenerationPipeline`].
|
| 621 |
+
- `"object-detection"`: will return a [`ObjectDetectionPipeline`].
|
| 622 |
+
- `"question-answering"`: will return a [`QuestionAnsweringPipeline`].
|
| 623 |
+
- `"summarization"`: will return a [`SummarizationPipeline`].
|
| 624 |
+
- `"table-question-answering"`: will return a [`TableQuestionAnsweringPipeline`].
|
| 625 |
+
- `"text2text-generation"`: will return a [`Text2TextGenerationPipeline`].
|
| 626 |
+
- `"text-classification"` (alias `"sentiment-analysis"` available): will return a
|
| 627 |
+
[`TextClassificationPipeline`].
|
| 628 |
+
- `"text-generation"`: will return a [`TextGenerationPipeline`]:.
|
| 629 |
+
- `"text-to-audio"` (alias `"text-to-speech"` available): will return a [`TextToAudioPipeline`]:.
|
| 630 |
+
- `"token-classification"` (alias `"ner"` available): will return a [`TokenClassificationPipeline`].
|
| 631 |
+
- `"translation"`: will return a [`TranslationPipeline`].
|
| 632 |
+
- `"translation_xx_to_yy"`: will return a [`TranslationPipeline`].
|
| 633 |
+
- `"video-classification"`: will return a [`VideoClassificationPipeline`].
|
| 634 |
+
- `"visual-question-answering"`: will return a [`VisualQuestionAnsweringPipeline`].
|
| 635 |
+
- `"zero-shot-classification"`: will return a [`ZeroShotClassificationPipeline`].
|
| 636 |
+
- `"zero-shot-image-classification"`: will return a [`ZeroShotImageClassificationPipeline`].
|
| 637 |
+
- `"zero-shot-audio-classification"`: will return a [`ZeroShotAudioClassificationPipeline`].
|
| 638 |
+
- `"zero-shot-object-detection"`: will return a [`ZeroShotObjectDetectionPipeline`].
|
| 639 |
+
|
| 640 |
+
model (`str` or [`PreTrainedModel`] or [`TFPreTrainedModel`], *optional*):
|
| 641 |
+
The model that will be used by the pipeline to make predictions. This can be a model identifier or an
|
| 642 |
+
actual instance of a pretrained model inheriting from [`PreTrainedModel`] (for PyTorch) or
|
| 643 |
+
[`TFPreTrainedModel`] (for TensorFlow).
|
| 644 |
+
|
| 645 |
+
If not provided, the default for the `task` will be loaded.
|
| 646 |
+
config (`str` or [`PretrainedConfig`], *optional*):
|
| 647 |
+
The configuration that will be used by the pipeline to instantiate the model. This can be a model
|
| 648 |
+
identifier or an actual pretrained model configuration inheriting from [`PretrainedConfig`].
|
| 649 |
+
|
| 650 |
+
If not provided, the default configuration file for the requested model will be used. That means that if
|
| 651 |
+
`model` is given, its default configuration will be used. However, if `model` is not supplied, this
|
| 652 |
+
`task`'s default model's config is used instead.
|
| 653 |
+
tokenizer (`str` or [`PreTrainedTokenizer`], *optional*):
|
| 654 |
+
The tokenizer that will be used by the pipeline to encode data for the model. This can be a model
|
| 655 |
+
identifier or an actual pretrained tokenizer inheriting from [`PreTrainedTokenizer`].
|
| 656 |
+
|
| 657 |
+
If not provided, the default tokenizer for the given `model` will be loaded (if it is a string). If `model`
|
| 658 |
+
is not specified or not a string, then the default tokenizer for `config` is loaded (if it is a string).
|
| 659 |
+
However, if `config` is also not given or not a string, then the default tokenizer for the given `task`
|
| 660 |
+
will be loaded.
|
| 661 |
+
feature_extractor (`str` or [`PreTrainedFeatureExtractor`], *optional*):
|
| 662 |
+
The feature extractor that will be used by the pipeline to encode data for the model. This can be a model
|
| 663 |
+
identifier or an actual pretrained feature extractor inheriting from [`PreTrainedFeatureExtractor`].
|
| 664 |
+
|
| 665 |
+
Feature extractors are used for non-NLP models, such as Speech or Vision models as well as multi-modal
|
| 666 |
+
models. Multi-modal models will also require a tokenizer to be passed.
|
| 667 |
+
|
| 668 |
+
If not provided, the default feature extractor for the given `model` will be loaded (if it is a string). If
|
| 669 |
+
`model` is not specified or not a string, then the default feature extractor for `config` is loaded (if it
|
| 670 |
+
is a string). However, if `config` is also not given or not a string, then the default feature extractor
|
| 671 |
+
for the given `task` will be loaded.
|
| 672 |
+
image_processor (`str` or [`BaseImageProcessor`], *optional*):
|
| 673 |
+
The image processor that will be used by the pipeline to preprocess images for the model. This can be a
|
| 674 |
+
model identifier or an actual image processor inheriting from [`BaseImageProcessor`].
|
| 675 |
+
|
| 676 |
+
Image processors are used for Vision models and multi-modal models that require image inputs. Multi-modal
|
| 677 |
+
models will also require a tokenizer to be passed.
|
| 678 |
+
|
| 679 |
+
If not provided, the default image processor for the given `model` will be loaded (if it is a string). If
|
| 680 |
+
`model` is not specified or not a string, then the default image processor for `config` is loaded (if it is
|
| 681 |
+
a string).
|
| 682 |
+
processor (`str` or [`ProcessorMixin`], *optional*):
|
| 683 |
+
The processor that will be used by the pipeline to preprocess data for the model. This can be a model
|
| 684 |
+
identifier or an actual processor inheriting from [`ProcessorMixin`].
|
| 685 |
+
|
| 686 |
+
Processors are used for multi-modal models that require multi-modal inputs, for example, a model that
|
| 687 |
+
requires both text and image inputs.
|
| 688 |
+
|
| 689 |
+
If not provided, the default processor for the given `model` will be loaded (if it is a string). If `model`
|
| 690 |
+
is not specified or not a string, then the default processor for `config` is loaded (if it is a string).
|
| 691 |
+
framework (`str`, *optional*):
|
| 692 |
+
The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be
|
| 693 |
+
installed.
|
| 694 |
+
|
| 695 |
+
If no framework is specified, will default to the one currently installed. If no framework is specified and
|
| 696 |
+
both frameworks are installed, will default to the framework of the `model`, or to PyTorch if no model is
|
| 697 |
+
provided.
|
| 698 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 699 |
+
When passing a task name or a string model identifier: The specific model version to use. It can be a
|
| 700 |
+
branch name, a tag name, or a commit id, since we use a git-based system for storing models and other
|
| 701 |
+
artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
|
| 702 |
+
use_fast (`bool`, *optional*, defaults to `True`):
|
| 703 |
+
Whether or not to use a Fast tokenizer if possible (a [`PreTrainedTokenizerFast`]).
|
| 704 |
+
use_auth_token (`str` or *bool*, *optional*):
|
| 705 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 706 |
+
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
| 707 |
+
device (`int` or `str` or `torch.device`):
|
| 708 |
+
Defines the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank like `1`) on which this
|
| 709 |
+
pipeline will be allocated.
|
| 710 |
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]`, *optional*):
|
| 711 |
+
Sent directly as `model_kwargs` (just a simpler shortcut). When `accelerate` library is present, set
|
| 712 |
+
`device_map="auto"` to compute the most optimized `device_map` automatically (see
|
| 713 |
+
[here](https://huggingface.co/docs/accelerate/main/en/package_reference/big_modeling#accelerate.cpu_offload)
|
| 714 |
+
for more information).
|
| 715 |
+
|
| 716 |
+
<Tip warning={true}>
|
| 717 |
+
|
| 718 |
+
Do not use `device_map` AND `device` at the same time as they will conflict
|
| 719 |
+
|
| 720 |
+
</Tip>
|
| 721 |
+
|
| 722 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
| 723 |
+
Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model
|
| 724 |
+
(`torch.float16`, `torch.bfloat16`, ... or `"auto"`).
|
| 725 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 726 |
+
Whether or not to allow for custom code defined on the Hub in their own modeling, configuration,
|
| 727 |
+
tokenization or even pipeline files. This option should only be set to `True` for repositories you trust
|
| 728 |
+
and in which you have read the code, as it will execute code present on the Hub on your local machine.
|
| 729 |
+
model_kwargs (`Dict[str, Any]`, *optional*):
|
| 730 |
+
Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
|
| 731 |
+
**model_kwargs)` function.
|
| 732 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 733 |
+
Additional keyword arguments passed along to the specific pipeline init (see the documentation for the
|
| 734 |
+
corresponding pipeline class for possible values).
|
| 735 |
+
|
| 736 |
+
Returns:
|
| 737 |
+
[`Pipeline`]: A suitable pipeline for the task.
|
| 738 |
+
|
| 739 |
+
Examples:
|
| 740 |
+
|
| 741 |
+
```python
|
| 742 |
+
>>> from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
|
| 743 |
+
|
| 744 |
+
>>> # Sentiment analysis pipeline
|
| 745 |
+
>>> analyzer = pipeline("sentiment-analysis")
|
| 746 |
+
|
| 747 |
+
>>> # Question answering pipeline, specifying the checkpoint identifier
|
| 748 |
+
>>> oracle = pipeline(
|
| 749 |
+
... "question-answering", model="distilbert/distilbert-base-cased-distilled-squad", tokenizer="google-bert/bert-base-cased"
|
| 750 |
+
... )
|
| 751 |
+
|
| 752 |
+
>>> # Named entity recognition pipeline, passing in a specific model and tokenizer
|
| 753 |
+
>>> model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
|
| 754 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
|
| 755 |
+
>>> recognizer = pipeline("ner", model=model, tokenizer=tokenizer)
|
| 756 |
+
```"""
|
| 757 |
+
if model_kwargs is None:
|
| 758 |
+
model_kwargs = {}
|
| 759 |
+
# Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
|
| 760 |
+
# this is to keep BC).
|
| 761 |
+
use_auth_token = model_kwargs.pop("use_auth_token", None)
|
| 762 |
+
if use_auth_token is not None:
|
| 763 |
+
warnings.warn(
|
| 764 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 765 |
+
FutureWarning,
|
| 766 |
+
)
|
| 767 |
+
if token is not None:
|
| 768 |
+
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
| 769 |
+
token = use_auth_token
|
| 770 |
+
|
| 771 |
+
code_revision = kwargs.pop("code_revision", None)
|
| 772 |
+
commit_hash = kwargs.pop("_commit_hash", None)
|
| 773 |
+
|
| 774 |
+
hub_kwargs = {
|
| 775 |
+
"revision": revision,
|
| 776 |
+
"token": token,
|
| 777 |
+
"trust_remote_code": trust_remote_code,
|
| 778 |
+
"_commit_hash": commit_hash,
|
| 779 |
+
}
|
| 780 |
+
|
| 781 |
+
if task is None and model is None:
|
| 782 |
+
raise RuntimeError(
|
| 783 |
+
"Impossible to instantiate a pipeline without either a task or a model "
|
| 784 |
+
"being specified. "
|
| 785 |
+
"Please provide a task class or a model"
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
if model is None and tokenizer is not None:
|
| 789 |
+
raise RuntimeError(
|
| 790 |
+
"Impossible to instantiate a pipeline with tokenizer specified but not the model as the provided tokenizer"
|
| 791 |
+
" may not be compatible with the default model. Please provide a PreTrainedModel class or a"
|
| 792 |
+
" path/identifier to a pretrained model when providing tokenizer."
|
| 793 |
+
)
|
| 794 |
+
if model is None and feature_extractor is not None:
|
| 795 |
+
raise RuntimeError(
|
| 796 |
+
"Impossible to instantiate a pipeline with feature_extractor specified but not the model as the provided"
|
| 797 |
+
" feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class"
|
| 798 |
+
" or a path/identifier to a pretrained model when providing feature_extractor."
|
| 799 |
+
)
|
| 800 |
+
if isinstance(model, Path):
|
| 801 |
+
model = str(model)
|
| 802 |
+
|
| 803 |
+
if commit_hash is None:
|
| 804 |
+
pretrained_model_name_or_path = None
|
| 805 |
+
if isinstance(config, str):
|
| 806 |
+
pretrained_model_name_or_path = config
|
| 807 |
+
elif config is None and isinstance(model, str):
|
| 808 |
+
pretrained_model_name_or_path = model
|
| 809 |
+
|
| 810 |
+
if not isinstance(config, PretrainedConfig) and pretrained_model_name_or_path is not None:
|
| 811 |
+
# We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
|
| 812 |
+
resolved_config_file = cached_file(
|
| 813 |
+
pretrained_model_name_or_path,
|
| 814 |
+
CONFIG_NAME,
|
| 815 |
+
_raise_exceptions_for_gated_repo=False,
|
| 816 |
+
_raise_exceptions_for_missing_entries=False,
|
| 817 |
+
_raise_exceptions_for_connection_errors=False,
|
| 818 |
+
cache_dir=model_kwargs.get("cache_dir"),
|
| 819 |
+
**hub_kwargs,
|
| 820 |
+
)
|
| 821 |
+
hub_kwargs["_commit_hash"] = extract_commit_hash(resolved_config_file, commit_hash)
|
| 822 |
+
else:
|
| 823 |
+
hub_kwargs["_commit_hash"] = getattr(config, "_commit_hash", None)
|
| 824 |
+
|
| 825 |
+
# Config is the primordial information item.
|
| 826 |
+
# Instantiate config if needed
|
| 827 |
+
if isinstance(config, str):
|
| 828 |
+
config = AutoConfig.from_pretrained(
|
| 829 |
+
config, _from_pipeline=task, code_revision=code_revision, **hub_kwargs, **model_kwargs
|
| 830 |
+
)
|
| 831 |
+
hub_kwargs["_commit_hash"] = config._commit_hash
|
| 832 |
+
elif config is None and isinstance(model, str):
|
| 833 |
+
# Check for an adapter file in the model path if PEFT is available
|
| 834 |
+
if is_peft_available():
|
| 835 |
+
# `find_adapter_config_file` doesn't accept `trust_remote_code`
|
| 836 |
+
_hub_kwargs = {k: v for k, v in hub_kwargs.items() if k != "trust_remote_code"}
|
| 837 |
+
maybe_adapter_path = find_adapter_config_file(
|
| 838 |
+
model,
|
| 839 |
+
token=hub_kwargs["token"],
|
| 840 |
+
revision=hub_kwargs["revision"],
|
| 841 |
+
_commit_hash=hub_kwargs["_commit_hash"],
|
| 842 |
+
)
|
| 843 |
+
|
| 844 |
+
if maybe_adapter_path is not None:
|
| 845 |
+
with open(maybe_adapter_path, "r", encoding="utf-8") as f:
|
| 846 |
+
adapter_config = json.load(f)
|
| 847 |
+
model = adapter_config["base_model_name_or_path"]
|
| 848 |
+
|
| 849 |
+
config = AutoConfig.from_pretrained(
|
| 850 |
+
model, _from_pipeline=task, code_revision=code_revision, **hub_kwargs, **model_kwargs
|
| 851 |
+
)
|
| 852 |
+
hub_kwargs["_commit_hash"] = config._commit_hash
|
| 853 |
+
|
| 854 |
+
custom_tasks = {}
|
| 855 |
+
if config is not None and len(getattr(config, "custom_pipelines", {})) > 0:
|
| 856 |
+
custom_tasks = config.custom_pipelines
|
| 857 |
+
if task is None and trust_remote_code is not False:
|
| 858 |
+
if len(custom_tasks) == 1:
|
| 859 |
+
task = list(custom_tasks.keys())[0]
|
| 860 |
+
else:
|
| 861 |
+
raise RuntimeError(
|
| 862 |
+
"We can't infer the task automatically for this model as there are multiple tasks available. Pick "
|
| 863 |
+
f"one in {', '.join(custom_tasks.keys())}"
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
if task is None and model is not None:
|
| 867 |
+
if not isinstance(model, str):
|
| 868 |
+
raise RuntimeError(
|
| 869 |
+
"Inferring the task automatically requires to check the hub with a model_id defined as a `str`. "
|
| 870 |
+
f"{model} is not a valid model_id."
|
| 871 |
+
)
|
| 872 |
+
task = get_task(model, token)
|
| 873 |
+
|
| 874 |
+
# Retrieve the task
|
| 875 |
+
if task in custom_tasks:
|
| 876 |
+
normalized_task = task
|
| 877 |
+
targeted_task, task_options = clean_custom_task(custom_tasks[task])
|
| 878 |
+
if pipeline_class is None:
|
| 879 |
+
if not trust_remote_code:
|
| 880 |
+
raise ValueError(
|
| 881 |
+
"Loading this pipeline requires you to execute the code in the pipeline file in that"
|
| 882 |
+
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
| 883 |
+
" set the option `trust_remote_code=True` to remove this error."
|
| 884 |
+
)
|
| 885 |
+
class_ref = targeted_task["impl"]
|
| 886 |
+
pipeline_class = get_class_from_dynamic_module(
|
| 887 |
+
class_ref,
|
| 888 |
+
model,
|
| 889 |
+
code_revision=code_revision,
|
| 890 |
+
**hub_kwargs,
|
| 891 |
+
)
|
| 892 |
+
else:
|
| 893 |
+
normalized_task, targeted_task, task_options = check_task(task)
|
| 894 |
+
if pipeline_class is None:
|
| 895 |
+
pipeline_class = targeted_task["impl"]
|
| 896 |
+
|
| 897 |
+
# Use default model/config/tokenizer for the task if no model is provided
|
| 898 |
+
if model is None:
|
| 899 |
+
# At that point framework might still be undetermined
|
| 900 |
+
model, default_revision = get_default_model_and_revision(targeted_task, framework, task_options)
|
| 901 |
+
revision = revision if revision is not None else default_revision
|
| 902 |
+
logger.warning(
|
| 903 |
+
f"No model was supplied, defaulted to {model} and revision"
|
| 904 |
+
f" {revision} ({HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n"
|
| 905 |
+
"Using a pipeline without specifying a model name and revision in production is not recommended."
|
| 906 |
+
)
|
| 907 |
+
hub_kwargs["revision"] = revision
|
| 908 |
+
if config is None and isinstance(model, str):
|
| 909 |
+
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
|
| 910 |
+
hub_kwargs["_commit_hash"] = config._commit_hash
|
| 911 |
+
|
| 912 |
+
if device_map is not None:
|
| 913 |
+
if "device_map" in model_kwargs:
|
| 914 |
+
raise ValueError(
|
| 915 |
+
'You cannot use both `pipeline(... device_map=..., model_kwargs={"device_map":...})` as those'
|
| 916 |
+
" arguments might conflict, use only one.)"
|
| 917 |
+
)
|
| 918 |
+
if device is not None:
|
| 919 |
+
logger.warning(
|
| 920 |
+
"Both `device` and `device_map` are specified. `device` will override `device_map`. You"
|
| 921 |
+
" will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`."
|
| 922 |
+
)
|
| 923 |
+
model_kwargs["device_map"] = device_map
|
| 924 |
+
if torch_dtype is not None:
|
| 925 |
+
if "torch_dtype" in model_kwargs:
|
| 926 |
+
raise ValueError(
|
| 927 |
+
'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
|
| 928 |
+
" arguments might conflict, use only one.)"
|
| 929 |
+
)
|
| 930 |
+
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
|
| 931 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 932 |
+
model_kwargs["torch_dtype"] = torch_dtype
|
| 933 |
+
|
| 934 |
+
model_name = model if isinstance(model, str) else None
|
| 935 |
+
|
| 936 |
+
# Load the correct model if possible
|
| 937 |
+
# Infer the framework from the model if not already defined
|
| 938 |
+
if isinstance(model, str) or framework is None:
|
| 939 |
+
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
|
| 940 |
+
framework, model = infer_framework_load_model(
|
| 941 |
+
model,
|
| 942 |
+
model_classes=model_classes,
|
| 943 |
+
config=config,
|
| 944 |
+
framework=framework,
|
| 945 |
+
task=task,
|
| 946 |
+
**hub_kwargs,
|
| 947 |
+
**model_kwargs,
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
model_config = model.config
|
| 951 |
+
hub_kwargs["_commit_hash"] = model.config._commit_hash
|
| 952 |
+
|
| 953 |
+
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
|
| 954 |
+
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
|
| 955 |
+
load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None
|
| 956 |
+
load_processor = type(model_config) in PROCESSOR_MAPPING or processor is not None
|
| 957 |
+
|
| 958 |
+
# Check that pipeline class required loading
|
| 959 |
+
load_tokenizer = load_tokenizer and pipeline_class._load_tokenizer
|
| 960 |
+
load_feature_extractor = load_feature_extractor and pipeline_class._load_feature_extractor
|
| 961 |
+
load_image_processor = load_image_processor and pipeline_class._load_image_processor
|
| 962 |
+
load_processor = load_processor and pipeline_class._load_processor
|
| 963 |
+
|
| 964 |
+
# If `model` (instance of `PretrainedModel` instead of `str`) is passed (and/or same for config), while
|
| 965 |
+
# `image_processor` or `feature_extractor` is `None`, the loading will fail. This happens particularly for some
|
| 966 |
+
# vision tasks when calling `pipeline()` with `model` and only one of the `image_processor` and `feature_extractor`.
|
| 967 |
+
# TODO: we need to make `NO_IMAGE_PROCESSOR_TASKS` and `NO_FEATURE_EXTRACTOR_TASKS` more robust to avoid such issue.
|
| 968 |
+
# This block is only temporarily to make CI green.
|
| 969 |
+
if load_image_processor and load_feature_extractor:
|
| 970 |
+
load_feature_extractor = False
|
| 971 |
+
|
| 972 |
+
if (
|
| 973 |
+
tokenizer is None
|
| 974 |
+
and not load_tokenizer
|
| 975 |
+
and normalized_task not in NO_TOKENIZER_TASKS
|
| 976 |
+
# Using class name to avoid importing the real class.
|
| 977 |
+
and (
|
| 978 |
+
model_config.__class__.__name__ in MULTI_MODEL_AUDIO_CONFIGS
|
| 979 |
+
or model_config.__class__.__name__ in MULTI_MODEL_VISION_CONFIGS
|
| 980 |
+
)
|
| 981 |
+
):
|
| 982 |
+
# This is a special category of models, that are fusions of multiple models
|
| 983 |
+
# so the model_config might not define a tokenizer, but it seems to be
|
| 984 |
+
# necessary for the task, so we're force-trying to load it.
|
| 985 |
+
load_tokenizer = True
|
| 986 |
+
if (
|
| 987 |
+
image_processor is None
|
| 988 |
+
and not load_image_processor
|
| 989 |
+
and normalized_task not in NO_IMAGE_PROCESSOR_TASKS
|
| 990 |
+
# Using class name to avoid importing the real class.
|
| 991 |
+
and model_config.__class__.__name__ in MULTI_MODEL_VISION_CONFIGS
|
| 992 |
+
):
|
| 993 |
+
# This is a special category of models, that are fusions of multiple models
|
| 994 |
+
# so the model_config might not define a tokenizer, but it seems to be
|
| 995 |
+
# necessary for the task, so we're force-trying to load it.
|
| 996 |
+
load_image_processor = True
|
| 997 |
+
if (
|
| 998 |
+
feature_extractor is None
|
| 999 |
+
and not load_feature_extractor
|
| 1000 |
+
and normalized_task not in NO_FEATURE_EXTRACTOR_TASKS
|
| 1001 |
+
# Using class name to avoid importing the real class.
|
| 1002 |
+
and model_config.__class__.__name__ in MULTI_MODEL_AUDIO_CONFIGS
|
| 1003 |
+
):
|
| 1004 |
+
# This is a special category of models, that are fusions of multiple models
|
| 1005 |
+
# so the model_config might not define a tokenizer, but it seems to be
|
| 1006 |
+
# necessary for the task, so we're force-trying to load it.
|
| 1007 |
+
load_feature_extractor = True
|
| 1008 |
+
|
| 1009 |
+
if task in NO_TOKENIZER_TASKS:
|
| 1010 |
+
# These will never require a tokenizer.
|
| 1011 |
+
# the model on the other hand might have a tokenizer, but
|
| 1012 |
+
# the files could be missing from the hub, instead of failing
|
| 1013 |
+
# on such repos, we just force to not load it.
|
| 1014 |
+
load_tokenizer = False
|
| 1015 |
+
|
| 1016 |
+
if task in NO_FEATURE_EXTRACTOR_TASKS:
|
| 1017 |
+
load_feature_extractor = False
|
| 1018 |
+
if task in NO_IMAGE_PROCESSOR_TASKS:
|
| 1019 |
+
load_image_processor = False
|
| 1020 |
+
|
| 1021 |
+
if load_tokenizer:
|
| 1022 |
+
# Try to infer tokenizer from model or config name (if provided as str)
|
| 1023 |
+
if tokenizer is None:
|
| 1024 |
+
if isinstance(model_name, str):
|
| 1025 |
+
tokenizer = model_name
|
| 1026 |
+
elif isinstance(config, str):
|
| 1027 |
+
tokenizer = config
|
| 1028 |
+
else:
|
| 1029 |
+
# Impossible to guess what is the right tokenizer here
|
| 1030 |
+
raise Exception(
|
| 1031 |
+
"Impossible to guess which tokenizer to use. "
|
| 1032 |
+
"Please provide a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer."
|
| 1033 |
+
)
|
| 1034 |
+
|
| 1035 |
+
# Instantiate tokenizer if needed
|
| 1036 |
+
if isinstance(tokenizer, (str, tuple)):
|
| 1037 |
+
if isinstance(tokenizer, tuple):
|
| 1038 |
+
# For tuple we have (tokenizer name, {kwargs})
|
| 1039 |
+
use_fast = tokenizer[1].pop("use_fast", use_fast)
|
| 1040 |
+
tokenizer_identifier = tokenizer[0]
|
| 1041 |
+
tokenizer_kwargs = tokenizer[1]
|
| 1042 |
+
else:
|
| 1043 |
+
tokenizer_identifier = tokenizer
|
| 1044 |
+
tokenizer_kwargs = model_kwargs.copy()
|
| 1045 |
+
tokenizer_kwargs.pop("torch_dtype", None)
|
| 1046 |
+
|
| 1047 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 1048 |
+
tokenizer_identifier, use_fast=use_fast, _from_pipeline=task, **hub_kwargs, **tokenizer_kwargs
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
if load_image_processor:
|
| 1052 |
+
# Try to infer image processor from model or config name (if provided as str)
|
| 1053 |
+
if image_processor is None:
|
| 1054 |
+
if isinstance(model_name, str):
|
| 1055 |
+
image_processor = model_name
|
| 1056 |
+
elif isinstance(config, str):
|
| 1057 |
+
image_processor = config
|
| 1058 |
+
# Backward compatibility, as `feature_extractor` used to be the name
|
| 1059 |
+
# for `ImageProcessor`.
|
| 1060 |
+
elif feature_extractor is not None and isinstance(feature_extractor, BaseImageProcessor):
|
| 1061 |
+
image_processor = feature_extractor
|
| 1062 |
+
else:
|
| 1063 |
+
# Impossible to guess what is the right image_processor here
|
| 1064 |
+
raise Exception(
|
| 1065 |
+
"Impossible to guess which image processor to use. "
|
| 1066 |
+
"Please provide a PreTrainedImageProcessor class or a path/identifier "
|
| 1067 |
+
"to a pretrained image processor."
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
# Instantiate image_processor if needed
|
| 1071 |
+
if isinstance(image_processor, (str, tuple)):
|
| 1072 |
+
image_processor = AutoImageProcessor.from_pretrained(
|
| 1073 |
+
image_processor, _from_pipeline=task, **hub_kwargs, **model_kwargs
|
| 1074 |
+
)
|
| 1075 |
+
|
| 1076 |
+
if load_feature_extractor:
|
| 1077 |
+
# Try to infer feature extractor from model or config name (if provided as str)
|
| 1078 |
+
if feature_extractor is None:
|
| 1079 |
+
if isinstance(model_name, str):
|
| 1080 |
+
feature_extractor = model_name
|
| 1081 |
+
elif isinstance(config, str):
|
| 1082 |
+
feature_extractor = config
|
| 1083 |
+
else:
|
| 1084 |
+
# Impossible to guess what is the right feature_extractor here
|
| 1085 |
+
raise Exception(
|
| 1086 |
+
"Impossible to guess which feature extractor to use. "
|
| 1087 |
+
"Please provide a PreTrainedFeatureExtractor class or a path/identifier "
|
| 1088 |
+
"to a pretrained feature extractor."
|
| 1089 |
+
)
|
| 1090 |
+
|
| 1091 |
+
# Instantiate feature_extractor if needed
|
| 1092 |
+
if isinstance(feature_extractor, (str, tuple)):
|
| 1093 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
| 1094 |
+
feature_extractor, _from_pipeline=task, **hub_kwargs, **model_kwargs
|
| 1095 |
+
)
|
| 1096 |
+
|
| 1097 |
+
if (
|
| 1098 |
+
feature_extractor._processor_class
|
| 1099 |
+
and feature_extractor._processor_class.endswith("WithLM")
|
| 1100 |
+
and isinstance(model_name, str)
|
| 1101 |
+
):
|
| 1102 |
+
try:
|
| 1103 |
+
import kenlm # to trigger `ImportError` if not installed
|
| 1104 |
+
from pyctcdecode import BeamSearchDecoderCTC
|
| 1105 |
+
|
| 1106 |
+
if os.path.isdir(model_name) or os.path.isfile(model_name):
|
| 1107 |
+
decoder = BeamSearchDecoderCTC.load_from_dir(model_name)
|
| 1108 |
+
else:
|
| 1109 |
+
language_model_glob = os.path.join(
|
| 1110 |
+
BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*"
|
| 1111 |
+
)
|
| 1112 |
+
alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
|
| 1113 |
+
allow_patterns = [language_model_glob, alphabet_filename]
|
| 1114 |
+
decoder = BeamSearchDecoderCTC.load_from_hf_hub(model_name, allow_patterns=allow_patterns)
|
| 1115 |
+
|
| 1116 |
+
kwargs["decoder"] = decoder
|
| 1117 |
+
except ImportError as e:
|
| 1118 |
+
logger.warning(f"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Error: {e}")
|
| 1119 |
+
if not is_kenlm_available():
|
| 1120 |
+
logger.warning("Try to install `kenlm`: `pip install kenlm")
|
| 1121 |
+
|
| 1122 |
+
if not is_pyctcdecode_available():
|
| 1123 |
+
logger.warning("Try to install `pyctcdecode`: `pip install pyctcdecode")
|
| 1124 |
+
|
| 1125 |
+
if load_processor:
|
| 1126 |
+
# Try to infer processor from model or config name (if provided as str)
|
| 1127 |
+
if processor is None:
|
| 1128 |
+
if isinstance(model_name, str):
|
| 1129 |
+
processor = model_name
|
| 1130 |
+
elif isinstance(config, str):
|
| 1131 |
+
processor = config
|
| 1132 |
+
else:
|
| 1133 |
+
# Impossible to guess what is the right processor here
|
| 1134 |
+
raise Exception(
|
| 1135 |
+
"Impossible to guess which processor to use. "
|
| 1136 |
+
"Please provide a processor instance or a path/identifier "
|
| 1137 |
+
"to a processor."
|
| 1138 |
+
)
|
| 1139 |
+
|
| 1140 |
+
# Instantiate processor if needed
|
| 1141 |
+
if isinstance(processor, (str, tuple)):
|
| 1142 |
+
processor = AutoProcessor.from_pretrained(processor, _from_pipeline=task, **hub_kwargs, **model_kwargs)
|
| 1143 |
+
if not isinstance(processor, ProcessorMixin):
|
| 1144 |
+
raise TypeError(
|
| 1145 |
+
"Processor was loaded, but it is not an instance of `ProcessorMixin`. "
|
| 1146 |
+
f"Got type `{type(processor)}` instead. Please check that you specified "
|
| 1147 |
+
"correct pipeline task for the model and model has processor implemented and saved."
|
| 1148 |
+
)
|
| 1149 |
+
|
| 1150 |
+
if task == "translation" and model.config.task_specific_params:
|
| 1151 |
+
for key in model.config.task_specific_params:
|
| 1152 |
+
if key.startswith("translation"):
|
| 1153 |
+
task = key
|
| 1154 |
+
warnings.warn(
|
| 1155 |
+
f'"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{task}"',
|
| 1156 |
+
UserWarning,
|
| 1157 |
+
)
|
| 1158 |
+
break
|
| 1159 |
+
|
| 1160 |
+
if tokenizer is not None:
|
| 1161 |
+
kwargs["tokenizer"] = tokenizer
|
| 1162 |
+
|
| 1163 |
+
if feature_extractor is not None:
|
| 1164 |
+
kwargs["feature_extractor"] = feature_extractor
|
| 1165 |
+
|
| 1166 |
+
if torch_dtype is not None:
|
| 1167 |
+
kwargs["torch_dtype"] = torch_dtype
|
| 1168 |
+
|
| 1169 |
+
if image_processor is not None:
|
| 1170 |
+
kwargs["image_processor"] = image_processor
|
| 1171 |
+
|
| 1172 |
+
if device is not None:
|
| 1173 |
+
kwargs["device"] = device
|
| 1174 |
+
|
| 1175 |
+
if processor is not None:
|
| 1176 |
+
kwargs["processor"] = processor
|
| 1177 |
+
|
| 1178 |
+
return pipeline_class(model=model, framework=framework, task=task, **kwargs)
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (48.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/audio_classification.cpython-311.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/audio_utils.cpython-311.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/automatic_speech_recognition.cpython-311.pyc
ADDED
|
Binary file (38.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (77.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/depth_estimation.cpython-311.pyc
ADDED
|
Binary file (7.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/document_question_answering.cpython-311.pyc
ADDED
|
Binary file (26.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/feature_extraction.cpython-311.pyc
ADDED
|
Binary file (4.45 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/fill_mask.cpython-311.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_classification.cpython-311.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_feature_extraction.cpython-311.pyc
ADDED
|
Binary file (5.84 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_segmentation.cpython-311.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_text_to_text.cpython-311.pyc
ADDED
|
Binary file (20 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_to_image.cpython-311.pyc
ADDED
|
Binary file (6.68 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/image_to_text.cpython-311.pyc
ADDED
|
Binary file (9.78 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/mask_generation.cpython-311.pyc
ADDED
|
Binary file (14.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/object_detection.cpython-311.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/pt_utils.cpython-311.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/question_answering.cpython-311.pyc
ADDED
|
Binary file (33.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/table_question_answering.cpython-311.pyc
ADDED
|
Binary file (25.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/text2text_generation.cpython-311.pyc
ADDED
|
Binary file (21.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/text_classification.cpython-311.pyc
ADDED
|
Binary file (13.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/text_generation.cpython-311.pyc
ADDED
|
Binary file (21.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/text_to_audio.cpython-311.pyc
ADDED
|
Binary file (9.03 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/token_classification.cpython-311.pyc
ADDED
|
Binary file (30.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/video_classification.cpython-311.pyc
ADDED
|
Binary file (9.97 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/visual_question_answering.cpython-311.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/zero_shot_audio_classification.cpython-311.pyc
ADDED
|
Binary file (9.07 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/zero_shot_classification.cpython-311.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/zero_shot_image_classification.cpython-311.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/__pycache__/zero_shot_object_detection.cpython-311.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/pipelines/audio_classification.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import subprocess
|
| 15 |
+
from typing import Union
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import requests
|
| 19 |
+
|
| 20 |
+
from ..utils import add_end_docstrings, is_torch_available, is_torchaudio_available, logging
|
| 21 |
+
from .base import Pipeline, build_pipeline_init_args
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if is_torch_available():
|
| 25 |
+
from ..models.auto.modeling_auto import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
|
| 31 |
+
"""
|
| 32 |
+
Helper function to read an audio file through ffmpeg.
|
| 33 |
+
"""
|
| 34 |
+
ar = f"{sampling_rate}"
|
| 35 |
+
ac = "1"
|
| 36 |
+
format_for_conversion = "f32le"
|
| 37 |
+
ffmpeg_command = [
|
| 38 |
+
"ffmpeg",
|
| 39 |
+
"-i",
|
| 40 |
+
"pipe:0",
|
| 41 |
+
"-ac",
|
| 42 |
+
ac,
|
| 43 |
+
"-ar",
|
| 44 |
+
ar,
|
| 45 |
+
"-f",
|
| 46 |
+
format_for_conversion,
|
| 47 |
+
"-hide_banner",
|
| 48 |
+
"-loglevel",
|
| 49 |
+
"quiet",
|
| 50 |
+
"pipe:1",
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
ffmpeg_process = subprocess.Popen(ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
|
| 55 |
+
except FileNotFoundError:
|
| 56 |
+
raise ValueError("ffmpeg was not found but is required to load audio files from filename")
|
| 57 |
+
output_stream = ffmpeg_process.communicate(bpayload)
|
| 58 |
+
out_bytes = output_stream[0]
|
| 59 |
+
|
| 60 |
+
audio = np.frombuffer(out_bytes, np.float32)
|
| 61 |
+
if audio.shape[0] == 0:
|
| 62 |
+
raise ValueError("Malformed soundfile")
|
| 63 |
+
return audio
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@add_end_docstrings(build_pipeline_init_args(has_feature_extractor=True))
|
| 67 |
+
class AudioClassificationPipeline(Pipeline):
|
| 68 |
+
"""
|
| 69 |
+
Audio classification pipeline using any `AutoModelForAudioClassification`. This pipeline predicts the class of a
|
| 70 |
+
raw waveform or an audio file. In case of an audio file, ffmpeg should be installed to support multiple audio
|
| 71 |
+
formats.
|
| 72 |
+
|
| 73 |
+
Example:
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
>>> from transformers import pipeline
|
| 77 |
+
|
| 78 |
+
>>> classifier = pipeline(model="superb/wav2vec2-base-superb-ks")
|
| 79 |
+
>>> classifier("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac")
|
| 80 |
+
[{'score': 0.997, 'label': '_unknown_'}, {'score': 0.002, 'label': 'left'}, {'score': 0.0, 'label': 'yes'}, {'score': 0.0, 'label': 'down'}, {'score': 0.0, 'label': 'stop'}]
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
This pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
| 87 |
+
`"audio-classification"`.
|
| 88 |
+
|
| 89 |
+
See the list of available models on
|
| 90 |
+
[huggingface.co/models](https://huggingface.co/models?filter=audio-classification).
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(self, *args, **kwargs):
|
| 94 |
+
# Default, might be overriden by the model.config.
|
| 95 |
+
kwargs["top_k"] = kwargs.get("top_k", 5)
|
| 96 |
+
super().__init__(*args, **kwargs)
|
| 97 |
+
|
| 98 |
+
if self.framework != "pt":
|
| 99 |
+
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
| 100 |
+
|
| 101 |
+
self.check_model_type(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES)
|
| 102 |
+
|
| 103 |
+
def __call__(
|
| 104 |
+
self,
|
| 105 |
+
inputs: Union[np.ndarray, bytes, str],
|
| 106 |
+
**kwargs,
|
| 107 |
+
):
|
| 108 |
+
"""
|
| 109 |
+
Classify the sequence(s) given as inputs. See the [`AutomaticSpeechRecognitionPipeline`] documentation for more
|
| 110 |
+
information.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
|
| 114 |
+
The inputs is either :
|
| 115 |
+
- `str` that is the filename of the audio file, the file will be read at the correct sampling rate
|
| 116 |
+
to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
|
| 117 |
+
- `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
|
| 118 |
+
same way.
|
| 119 |
+
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
|
| 120 |
+
Raw audio at the correct sampling rate (no further check will be done)
|
| 121 |
+
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
|
| 122 |
+
pipeline do the resampling. The dict must be either be in the format `{"sampling_rate": int,
|
| 123 |
+
"raw": np.array}`, or `{"sampling_rate": int, "array": np.array}`, where the key `"raw"` or
|
| 124 |
+
`"array"` is used to denote the raw audio waveform.
|
| 125 |
+
top_k (`int`, *optional*, defaults to None):
|
| 126 |
+
The number of top labels that will be returned by the pipeline. If the provided number is `None` or
|
| 127 |
+
higher than the number of labels available in the model configuration, it will default to the number of
|
| 128 |
+
labels.
|
| 129 |
+
function_to_apply(`str`, *optional*, defaults to "softmax"):
|
| 130 |
+
The function to apply to the model output. By default, the pipeline will apply the softmax function to
|
| 131 |
+
the output of the model. Valid options: ["softmax", "sigmoid", "none"]. Note that passing Python's
|
| 132 |
+
built-in `None` will default to "softmax", so you need to pass the string "none" to disable any
|
| 133 |
+
post-processing.
|
| 134 |
+
|
| 135 |
+
Return:
|
| 136 |
+
A list of `dict` with the following keys:
|
| 137 |
+
|
| 138 |
+
- **label** (`str`) -- The label predicted.
|
| 139 |
+
- **score** (`float`) -- The corresponding probability.
|
| 140 |
+
"""
|
| 141 |
+
return super().__call__(inputs, **kwargs)
|
| 142 |
+
|
| 143 |
+
def _sanitize_parameters(self, top_k=None, function_to_apply=None, **kwargs):
|
| 144 |
+
# No parameters on this pipeline right now
|
| 145 |
+
postprocess_params = {}
|
| 146 |
+
if top_k is not None:
|
| 147 |
+
if top_k > self.model.config.num_labels:
|
| 148 |
+
top_k = self.model.config.num_labels
|
| 149 |
+
postprocess_params["top_k"] = top_k
|
| 150 |
+
if function_to_apply is not None:
|
| 151 |
+
if function_to_apply not in ["softmax", "sigmoid", "none"]:
|
| 152 |
+
raise ValueError(
|
| 153 |
+
f"Invalid value for `function_to_apply`: {function_to_apply}. "
|
| 154 |
+
"Valid options are ['softmax', 'sigmoid', 'none']"
|
| 155 |
+
)
|
| 156 |
+
postprocess_params["function_to_apply"] = function_to_apply
|
| 157 |
+
else:
|
| 158 |
+
postprocess_params["function_to_apply"] = "softmax"
|
| 159 |
+
return {}, {}, postprocess_params
|
| 160 |
+
|
| 161 |
+
def preprocess(self, inputs):
|
| 162 |
+
if isinstance(inputs, str):
|
| 163 |
+
if inputs.startswith("http://") or inputs.startswith("https://"):
|
| 164 |
+
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
| 165 |
+
# like http_huggingface_co.png
|
| 166 |
+
inputs = requests.get(inputs).content
|
| 167 |
+
else:
|
| 168 |
+
with open(inputs, "rb") as f:
|
| 169 |
+
inputs = f.read()
|
| 170 |
+
|
| 171 |
+
if isinstance(inputs, bytes):
|
| 172 |
+
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
|
| 173 |
+
|
| 174 |
+
if isinstance(inputs, dict):
|
| 175 |
+
# Accepting `"array"` which is the key defined in `datasets` for
|
| 176 |
+
# better integration
|
| 177 |
+
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
|
| 178 |
+
raise ValueError(
|
| 179 |
+
"When passing a dictionary to AudioClassificationPipeline, the dict needs to contain a "
|
| 180 |
+
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
|
| 181 |
+
"containing the sampling_rate associated with that array"
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
_inputs = inputs.pop("raw", None)
|
| 185 |
+
if _inputs is None:
|
| 186 |
+
# Remove path which will not be used from `datasets`.
|
| 187 |
+
inputs.pop("path", None)
|
| 188 |
+
_inputs = inputs.pop("array", None)
|
| 189 |
+
in_sampling_rate = inputs.pop("sampling_rate")
|
| 190 |
+
inputs = _inputs
|
| 191 |
+
if in_sampling_rate != self.feature_extractor.sampling_rate:
|
| 192 |
+
import torch
|
| 193 |
+
|
| 194 |
+
if is_torchaudio_available():
|
| 195 |
+
from torchaudio import functional as F
|
| 196 |
+
else:
|
| 197 |
+
raise ImportError(
|
| 198 |
+
"torchaudio is required to resample audio samples in AudioClassificationPipeline. "
|
| 199 |
+
"The torchaudio package can be installed through: `pip install torchaudio`."
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
inputs = F.resample(
|
| 203 |
+
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
|
| 204 |
+
).numpy()
|
| 205 |
+
|
| 206 |
+
if not isinstance(inputs, np.ndarray):
|
| 207 |
+
raise TypeError("We expect a numpy ndarray as input")
|
| 208 |
+
if len(inputs.shape) != 1:
|
| 209 |
+
raise ValueError("We expect a single channel audio input for AudioClassificationPipeline")
|
| 210 |
+
|
| 211 |
+
processed = self.feature_extractor(
|
| 212 |
+
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
| 213 |
+
)
|
| 214 |
+
return processed
|
| 215 |
+
|
| 216 |
+
def _forward(self, model_inputs):
|
| 217 |
+
model_outputs = self.model(**model_inputs)
|
| 218 |
+
return model_outputs
|
| 219 |
+
|
| 220 |
+
def postprocess(self, model_outputs, top_k=5, function_to_apply="softmax"):
|
| 221 |
+
if function_to_apply == "softmax":
|
| 222 |
+
probs = model_outputs.logits[0].softmax(-1)
|
| 223 |
+
elif function_to_apply == "sigmoid":
|
| 224 |
+
probs = model_outputs.logits[0].sigmoid()
|
| 225 |
+
else:
|
| 226 |
+
probs = model_outputs.logits[0]
|
| 227 |
+
scores, ids = probs.topk(top_k)
|
| 228 |
+
|
| 229 |
+
scores = scores.tolist()
|
| 230 |
+
ids = ids.tolist()
|
| 231 |
+
|
| 232 |
+
labels = [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
|
| 233 |
+
|
| 234 |
+
return labels
|
.venv/lib/python3.11/site-packages/transformers/pipelines/audio_utils.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
import datetime
|
| 3 |
+
import platform
|
| 4 |
+
import subprocess
|
| 5 |
+
from typing import Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
|
| 11 |
+
"""
|
| 12 |
+
Helper function to read an audio file through ffmpeg.
|
| 13 |
+
"""
|
| 14 |
+
ar = f"{sampling_rate}"
|
| 15 |
+
ac = "1"
|
| 16 |
+
format_for_conversion = "f32le"
|
| 17 |
+
ffmpeg_command = [
|
| 18 |
+
"ffmpeg",
|
| 19 |
+
"-i",
|
| 20 |
+
"pipe:0",
|
| 21 |
+
"-ac",
|
| 22 |
+
ac,
|
| 23 |
+
"-ar",
|
| 24 |
+
ar,
|
| 25 |
+
"-f",
|
| 26 |
+
format_for_conversion,
|
| 27 |
+
"-hide_banner",
|
| 28 |
+
"-loglevel",
|
| 29 |
+
"quiet",
|
| 30 |
+
"pipe:1",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
with subprocess.Popen(ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) as ffmpeg_process:
|
| 35 |
+
output_stream = ffmpeg_process.communicate(bpayload)
|
| 36 |
+
except FileNotFoundError as error:
|
| 37 |
+
raise ValueError("ffmpeg was not found but is required to load audio files from filename") from error
|
| 38 |
+
out_bytes = output_stream[0]
|
| 39 |
+
audio = np.frombuffer(out_bytes, np.float32)
|
| 40 |
+
if audio.shape[0] == 0:
|
| 41 |
+
raise ValueError(
|
| 42 |
+
"Soundfile is either not in the correct format or is malformed. Ensure that the soundfile has "
|
| 43 |
+
"a valid audio file extension (e.g. wav, flac or mp3) and is not corrupted. If reading from a remote "
|
| 44 |
+
"URL, ensure that the URL is the full address to **download** the audio file."
|
| 45 |
+
)
|
| 46 |
+
return audio
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def ffmpeg_microphone(
|
| 50 |
+
sampling_rate: int,
|
| 51 |
+
chunk_length_s: float,
|
| 52 |
+
format_for_conversion: str = "f32le",
|
| 53 |
+
ffmpeg_input_device: Optional[str] = None,
|
| 54 |
+
ffmpeg_additional_args: Optional[list[str]] = None,
|
| 55 |
+
):
|
| 56 |
+
"""
|
| 57 |
+
Helper function to read audio from a microphone using ffmpeg. The default input device will be used unless another
|
| 58 |
+
input device is specified using the `ffmpeg_input_device` argument. Uses 'alsa' on Linux, 'avfoundation' on MacOS and
|
| 59 |
+
'dshow' on Windows.
|
| 60 |
+
|
| 61 |
+
Arguments:
|
| 62 |
+
sampling_rate (`int`):
|
| 63 |
+
The sampling_rate to use when reading the data from the microphone. Try using the model's sampling_rate to
|
| 64 |
+
avoid resampling later.
|
| 65 |
+
chunk_length_s (`float` or `int`):
|
| 66 |
+
The length of the maximum chunk of audio to be sent returned.
|
| 67 |
+
format_for_conversion (`str`, defaults to `f32le`):
|
| 68 |
+
The name of the format of the audio samples to be returned by ffmpeg. The standard is `f32le`, `s16le`
|
| 69 |
+
could also be used.
|
| 70 |
+
ffmpeg_input_device (`str`, *optional*):
|
| 71 |
+
The identifier of the input device to be used by ffmpeg (i.e. ffmpeg's '-i' argument). If unset,
|
| 72 |
+
the default input device will be used. See `https://www.ffmpeg.org/ffmpeg-devices.html#Input-Devices`
|
| 73 |
+
for how to specify and list input devices.
|
| 74 |
+
ffmpeg_additional_args (`list[str]`, *optional*):
|
| 75 |
+
Additional arguments to pass to ffmpeg, can include arguments like -nostdin for running as a background
|
| 76 |
+
process. For example, to pass -nostdin to the ffmpeg process, pass in ["-nostdin"]. If passing in flags
|
| 77 |
+
with multiple arguments, use the following convention (eg ["flag", "arg1", "arg2]).
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
A generator yielding audio chunks of `chunk_length_s` seconds as `bytes` objects of length
|
| 81 |
+
`int(round(sampling_rate * chunk_length_s)) * size_of_sample`.
|
| 82 |
+
"""
|
| 83 |
+
ar = f"{sampling_rate}"
|
| 84 |
+
ac = "1"
|
| 85 |
+
if format_for_conversion == "s16le":
|
| 86 |
+
size_of_sample = 2
|
| 87 |
+
elif format_for_conversion == "f32le":
|
| 88 |
+
size_of_sample = 4
|
| 89 |
+
else:
|
| 90 |
+
raise ValueError(f"Unhandled format `{format_for_conversion}`. Please use `s16le` or `f32le`")
|
| 91 |
+
|
| 92 |
+
system = platform.system()
|
| 93 |
+
|
| 94 |
+
if system == "Linux":
|
| 95 |
+
format_ = "alsa"
|
| 96 |
+
input_ = ffmpeg_input_device or "default"
|
| 97 |
+
elif system == "Darwin":
|
| 98 |
+
format_ = "avfoundation"
|
| 99 |
+
input_ = ffmpeg_input_device or ":default"
|
| 100 |
+
elif system == "Windows":
|
| 101 |
+
format_ = "dshow"
|
| 102 |
+
input_ = ffmpeg_input_device or _get_microphone_name()
|
| 103 |
+
|
| 104 |
+
ffmpeg_additional_args = [] if ffmpeg_additional_args is None else ffmpeg_additional_args
|
| 105 |
+
|
| 106 |
+
ffmpeg_command = [
|
| 107 |
+
"ffmpeg",
|
| 108 |
+
"-f",
|
| 109 |
+
format_,
|
| 110 |
+
"-i",
|
| 111 |
+
input_,
|
| 112 |
+
"-ac",
|
| 113 |
+
ac,
|
| 114 |
+
"-ar",
|
| 115 |
+
ar,
|
| 116 |
+
"-f",
|
| 117 |
+
format_for_conversion,
|
| 118 |
+
"-fflags",
|
| 119 |
+
"nobuffer",
|
| 120 |
+
"-hide_banner",
|
| 121 |
+
"-loglevel",
|
| 122 |
+
"quiet",
|
| 123 |
+
"pipe:1",
|
| 124 |
+
]
|
| 125 |
+
|
| 126 |
+
ffmpeg_command.extend(ffmpeg_additional_args)
|
| 127 |
+
|
| 128 |
+
chunk_len = int(round(sampling_rate * chunk_length_s)) * size_of_sample
|
| 129 |
+
iterator = _ffmpeg_stream(ffmpeg_command, chunk_len)
|
| 130 |
+
for item in iterator:
|
| 131 |
+
yield item
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def ffmpeg_microphone_live(
|
| 135 |
+
sampling_rate: int,
|
| 136 |
+
chunk_length_s: float,
|
| 137 |
+
stream_chunk_s: Optional[int] = None,
|
| 138 |
+
stride_length_s: Optional[Union[Tuple[float, float], float]] = None,
|
| 139 |
+
format_for_conversion: str = "f32le",
|
| 140 |
+
ffmpeg_input_device: Optional[str] = None,
|
| 141 |
+
ffmpeg_additional_args: Optional[list[str]] = None,
|
| 142 |
+
):
|
| 143 |
+
"""
|
| 144 |
+
Helper function to read audio from a microphone using ffmpeg. This will output `partial` overlapping chunks starting
|
| 145 |
+
from `stream_chunk_s` (if it is defined) until `chunk_length_s` is reached. It will make use of striding to avoid
|
| 146 |
+
errors on the "sides" of the various chunks. The default input device will be used unless another input device is
|
| 147 |
+
specified using the `ffmpeg_input_device` argument. Uses 'alsa' on Linux, 'avfoundation' on MacOS and 'dshow' on Windows.
|
| 148 |
+
|
| 149 |
+
Arguments:
|
| 150 |
+
sampling_rate (`int`):
|
| 151 |
+
The sampling_rate to use when reading the data from the microphone. Try using the model's sampling_rate to
|
| 152 |
+
avoid resampling later.
|
| 153 |
+
chunk_length_s (`float` or `int`):
|
| 154 |
+
The length of the maximum chunk of audio to be sent returned. This includes the eventual striding.
|
| 155 |
+
stream_chunk_s (`float` or `int`):
|
| 156 |
+
The length of the minimal temporary audio to be returned.
|
| 157 |
+
stride_length_s (`float` or `int` or `(float, float)`, *optional*):
|
| 158 |
+
The length of the striding to be used. Stride is used to provide context to a model on the (left, right) of
|
| 159 |
+
an audio sample but without using that part to actually make the prediction. Setting this does not change
|
| 160 |
+
the length of the chunk.
|
| 161 |
+
format_for_conversion (`str`, *optional*, defaults to `f32le`):
|
| 162 |
+
The name of the format of the audio samples to be returned by ffmpeg. The standard is `f32le`, `s16le`
|
| 163 |
+
could also be used.
|
| 164 |
+
ffmpeg_input_device (`str`, *optional*):
|
| 165 |
+
The identifier of the input device to be used by ffmpeg (i.e. ffmpeg's '-i' argument). If unset,
|
| 166 |
+
the default input device will be used. See `https://www.ffmpeg.org/ffmpeg-devices.html#Input-Devices`
|
| 167 |
+
for how to specify and list input devices.
|
| 168 |
+
ffmpeg_additional_args (`list[str]`, *optional*):
|
| 169 |
+
Additional arguments to pass to ffmpeg, can include arguments like -nostdin for running as a background
|
| 170 |
+
process. For example, to pass -nostdin to the ffmpeg process, pass in ["-nostdin"]. If passing in flags
|
| 171 |
+
with multiple arguments, use the following convention (eg ["flag", "arg1", "arg2]).
|
| 172 |
+
|
| 173 |
+
Return:
|
| 174 |
+
A generator yielding dictionaries of the following form
|
| 175 |
+
|
| 176 |
+
`{"sampling_rate": int, "raw": np.array(), "partial" bool}` With optionally a `"stride" (int, int)` key if
|
| 177 |
+
`stride_length_s` is defined.
|
| 178 |
+
|
| 179 |
+
`stride` and `raw` are all expressed in `samples`, and `partial` is a boolean saying if the current yield item
|
| 180 |
+
is a whole chunk, or a partial temporary result to be later replaced by another larger chunk.
|
| 181 |
+
"""
|
| 182 |
+
if stream_chunk_s is not None:
|
| 183 |
+
chunk_s = stream_chunk_s
|
| 184 |
+
else:
|
| 185 |
+
chunk_s = chunk_length_s
|
| 186 |
+
|
| 187 |
+
microphone = ffmpeg_microphone(
|
| 188 |
+
sampling_rate,
|
| 189 |
+
chunk_s,
|
| 190 |
+
format_for_conversion=format_for_conversion,
|
| 191 |
+
ffmpeg_input_device=ffmpeg_input_device,
|
| 192 |
+
ffmpeg_additional_args=[] if ffmpeg_additional_args is None else ffmpeg_additional_args,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if format_for_conversion == "s16le":
|
| 196 |
+
dtype = np.int16
|
| 197 |
+
size_of_sample = 2
|
| 198 |
+
elif format_for_conversion == "f32le":
|
| 199 |
+
dtype = np.float32
|
| 200 |
+
size_of_sample = 4
|
| 201 |
+
else:
|
| 202 |
+
raise ValueError(f"Unhandled format `{format_for_conversion}`. Please use `s16le` or `f32le`")
|
| 203 |
+
|
| 204 |
+
if stride_length_s is None:
|
| 205 |
+
stride_length_s = chunk_length_s / 6
|
| 206 |
+
chunk_len = int(round(sampling_rate * chunk_length_s)) * size_of_sample
|
| 207 |
+
if isinstance(stride_length_s, (int, float)):
|
| 208 |
+
stride_length_s = [stride_length_s, stride_length_s]
|
| 209 |
+
|
| 210 |
+
stride_left = int(round(sampling_rate * stride_length_s[0])) * size_of_sample
|
| 211 |
+
stride_right = int(round(sampling_rate * stride_length_s[1])) * size_of_sample
|
| 212 |
+
audio_time = datetime.datetime.now()
|
| 213 |
+
delta = datetime.timedelta(seconds=chunk_s)
|
| 214 |
+
for item in chunk_bytes_iter(microphone, chunk_len, stride=(stride_left, stride_right), stream=True):
|
| 215 |
+
# Put everything back in numpy scale
|
| 216 |
+
item["raw"] = np.frombuffer(item["raw"], dtype=dtype)
|
| 217 |
+
item["stride"] = (
|
| 218 |
+
item["stride"][0] // size_of_sample,
|
| 219 |
+
item["stride"][1] // size_of_sample,
|
| 220 |
+
)
|
| 221 |
+
item["sampling_rate"] = sampling_rate
|
| 222 |
+
audio_time += delta
|
| 223 |
+
if datetime.datetime.now() > audio_time + 10 * delta:
|
| 224 |
+
# We're late !! SKIP
|
| 225 |
+
continue
|
| 226 |
+
yield item
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def chunk_bytes_iter(iterator, chunk_len: int, stride: Tuple[int, int], stream: bool = False):
|
| 230 |
+
"""
|
| 231 |
+
Reads raw bytes from an iterator and does chunks of length `chunk_len`. Optionally adds `stride` to each chunks to
|
| 232 |
+
get overlaps. `stream` is used to return partial results even if a full `chunk_len` is not yet available.
|
| 233 |
+
"""
|
| 234 |
+
acc = b""
|
| 235 |
+
stride_left, stride_right = stride
|
| 236 |
+
if stride_left + stride_right >= chunk_len:
|
| 237 |
+
raise ValueError(
|
| 238 |
+
f"Stride needs to be strictly smaller than chunk_len: ({stride_left}, {stride_right}) vs {chunk_len}"
|
| 239 |
+
)
|
| 240 |
+
_stride_left = 0
|
| 241 |
+
for raw in iterator:
|
| 242 |
+
acc += raw
|
| 243 |
+
if stream and len(acc) < chunk_len:
|
| 244 |
+
stride = (_stride_left, 0)
|
| 245 |
+
yield {"raw": acc[:chunk_len], "stride": stride, "partial": True}
|
| 246 |
+
else:
|
| 247 |
+
while len(acc) >= chunk_len:
|
| 248 |
+
# We are flushing the accumulator
|
| 249 |
+
stride = (_stride_left, stride_right)
|
| 250 |
+
item = {"raw": acc[:chunk_len], "stride": stride}
|
| 251 |
+
if stream:
|
| 252 |
+
item["partial"] = False
|
| 253 |
+
yield item
|
| 254 |
+
_stride_left = stride_left
|
| 255 |
+
acc = acc[chunk_len - stride_left - stride_right :]
|
| 256 |
+
# Last chunk
|
| 257 |
+
if len(acc) > stride_left:
|
| 258 |
+
item = {"raw": acc, "stride": (_stride_left, 0)}
|
| 259 |
+
if stream:
|
| 260 |
+
item["partial"] = False
|
| 261 |
+
yield item
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def _ffmpeg_stream(ffmpeg_command, buflen: int):
|
| 265 |
+
"""
|
| 266 |
+
Internal function to create the generator of data through ffmpeg
|
| 267 |
+
"""
|
| 268 |
+
bufsize = 2**24 # 16Mo
|
| 269 |
+
try:
|
| 270 |
+
with subprocess.Popen(ffmpeg_command, stdout=subprocess.PIPE, bufsize=bufsize) as ffmpeg_process:
|
| 271 |
+
while True:
|
| 272 |
+
raw = ffmpeg_process.stdout.read(buflen)
|
| 273 |
+
if raw == b"":
|
| 274 |
+
break
|
| 275 |
+
yield raw
|
| 276 |
+
except FileNotFoundError as error:
|
| 277 |
+
raise ValueError("ffmpeg was not found but is required to stream audio files from filename") from error
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def _get_microphone_name():
|
| 281 |
+
"""
|
| 282 |
+
Retrieve the microphone name in Windows .
|
| 283 |
+
"""
|
| 284 |
+
command = ["ffmpeg", "-list_devices", "true", "-f", "dshow", "-i", ""]
|
| 285 |
+
|
| 286 |
+
try:
|
| 287 |
+
ffmpeg_devices = subprocess.run(command, text=True, stderr=subprocess.PIPE, encoding="utf-8")
|
| 288 |
+
microphone_lines = [line for line in ffmpeg_devices.stderr.splitlines() if "(audio)" in line]
|
| 289 |
+
|
| 290 |
+
if microphone_lines:
|
| 291 |
+
microphone_name = microphone_lines[0].split('"')[1]
|
| 292 |
+
print(f"Using microphone: {microphone_name}")
|
| 293 |
+
return f"audio={microphone_name}"
|
| 294 |
+
except FileNotFoundError:
|
| 295 |
+
print("ffmpeg was not found. Please install it or make sure it is in your system PATH.")
|
| 296 |
+
|
| 297 |
+
return "default"
|
.venv/lib/python3.11/site-packages/transformers/pipelines/automatic_speech_recognition.py
ADDED
|
@@ -0,0 +1,766 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import warnings
|
| 15 |
+
from collections import defaultdict
|
| 16 |
+
from typing import TYPE_CHECKING, Dict, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import requests
|
| 20 |
+
|
| 21 |
+
from ..tokenization_utils import PreTrainedTokenizer
|
| 22 |
+
from ..utils import is_torch_available, is_torchaudio_available, logging
|
| 23 |
+
from .audio_utils import ffmpeg_read
|
| 24 |
+
from .base import ChunkPipeline
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING:
|
| 28 |
+
from pyctcdecode import BeamSearchDecoderCTC
|
| 29 |
+
|
| 30 |
+
from ..feature_extraction_sequence_utils import SequenceFeatureExtractor
|
| 31 |
+
from ..modeling_utils import PreTrainedModel
|
| 32 |
+
|
| 33 |
+
logger = logging.get_logger(__name__)
|
| 34 |
+
|
| 35 |
+
if is_torch_available():
|
| 36 |
+
import torch
|
| 37 |
+
|
| 38 |
+
from ..models.auto.modeling_auto import MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def rescale_stride(stride, ratio):
|
| 42 |
+
"""
|
| 43 |
+
Rescales the stride values from audio space to tokens/logits space.
|
| 44 |
+
|
| 45 |
+
(160_000, 16_000, 16_000) -> (2000, 200, 200) for instance.
|
| 46 |
+
"""
|
| 47 |
+
# Shape is [B, SEQ] for tokens
|
| 48 |
+
# [B, SEQ, V] for logits
|
| 49 |
+
|
| 50 |
+
new_strides = []
|
| 51 |
+
for input_n, left, right in stride:
|
| 52 |
+
token_n = int(round(input_n * ratio))
|
| 53 |
+
left = int(round(left / input_n * token_n))
|
| 54 |
+
right = int(round(right / input_n * token_n))
|
| 55 |
+
new_stride = (token_n, left, right)
|
| 56 |
+
new_strides.append(new_stride)
|
| 57 |
+
|
| 58 |
+
return new_strides
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, dtype=None):
|
| 62 |
+
inputs_len = inputs.shape[0]
|
| 63 |
+
step = chunk_len - stride_left - stride_right
|
| 64 |
+
for chunk_start_idx in range(0, inputs_len, step):
|
| 65 |
+
chunk_end_idx = chunk_start_idx + chunk_len
|
| 66 |
+
chunk = inputs[chunk_start_idx:chunk_end_idx]
|
| 67 |
+
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
|
| 68 |
+
if dtype is not None:
|
| 69 |
+
processed = processed.to(dtype=dtype)
|
| 70 |
+
_stride_left = 0 if chunk_start_idx == 0 else stride_left
|
| 71 |
+
is_last = chunk_end_idx >= inputs_len
|
| 72 |
+
_stride_right = 0 if is_last else stride_right
|
| 73 |
+
|
| 74 |
+
chunk_len = chunk.shape[0]
|
| 75 |
+
stride = (chunk_len, _stride_left, _stride_right)
|
| 76 |
+
if chunk.shape[0] > _stride_left:
|
| 77 |
+
yield {"is_last": is_last, "stride": stride, **processed}
|
| 78 |
+
if is_last:
|
| 79 |
+
break
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _fast_find_longest_common_sequence(sequence_left, sequence_right):
|
| 83 |
+
seq_len_left = len(sequence_left)
|
| 84 |
+
seq_len_right = len(sequence_right)
|
| 85 |
+
counter = [[0] * (seq_len_right + 1) for _ in range(seq_len_left + 1)]
|
| 86 |
+
longest = 0
|
| 87 |
+
for i in range(seq_len_left):
|
| 88 |
+
for j in range(seq_len_right):
|
| 89 |
+
if sequence_left[i] == sequence_right[j]:
|
| 90 |
+
previous_counter = counter[i][j] + 1
|
| 91 |
+
counter[i + 1][j + 1] = previous_counter
|
| 92 |
+
if previous_counter > longest:
|
| 93 |
+
longest = previous_counter
|
| 94 |
+
|
| 95 |
+
counter = np.array(counter)
|
| 96 |
+
# we return the idx of the first element of the longest common sequence in the left sequence
|
| 97 |
+
index_left = np.argwhere(counter == longest)[-1][0] - longest if longest != 0 else -1
|
| 98 |
+
index_right = np.argwhere(counter == longest)[-1][1] - longest if longest != 0 else -1
|
| 99 |
+
return index_left, index_right, longest
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _find_longest_common_sequence(sequences, tokenizer):
|
| 103 |
+
# TODO Use a faster algorithm this can probably be done in O(n)
|
| 104 |
+
# using suffix array.
|
| 105 |
+
# It might be tedious to do because of fault tolerance.
|
| 106 |
+
# We actually have a really good property which is that the total sequence
|
| 107 |
+
# MUST be those subsequences in order.
|
| 108 |
+
# Also the algorithm should be more tolerant to errors.
|
| 109 |
+
sequence = [tok_id for tok_id in sequences[0][0].tolist() if tok_id not in tokenizer.all_special_ids]
|
| 110 |
+
for new_seq in sequences[1:]:
|
| 111 |
+
new_sequence = [tok_id for tok_id in new_seq[0].tolist() if tok_id not in tokenizer.all_special_ids]
|
| 112 |
+
|
| 113 |
+
index = 0
|
| 114 |
+
max_ = 0.0
|
| 115 |
+
for i in range(1, len(new_sequence) + 1):
|
| 116 |
+
# epsilon to favor long perfect matches
|
| 117 |
+
eps = i / 10000.0
|
| 118 |
+
matches = np.sum(np.array(sequence[-i:]) == np.array(new_sequence[:i]))
|
| 119 |
+
matching = matches / i + eps
|
| 120 |
+
if matches > 1 and matching > max_:
|
| 121 |
+
index = i
|
| 122 |
+
max_ = matching
|
| 123 |
+
sequence.extend(new_sequence[index:])
|
| 124 |
+
return np.array(sequence)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
| 128 |
+
"""
|
| 129 |
+
Pipeline that aims at extracting spoken text contained within some audio.
|
| 130 |
+
|
| 131 |
+
The input can be either a raw waveform or a audio file. In case of the audio file, ffmpeg should be installed for
|
| 132 |
+
to support multiple audio formats
|
| 133 |
+
|
| 134 |
+
Example:
|
| 135 |
+
|
| 136 |
+
```python
|
| 137 |
+
>>> from transformers import pipeline
|
| 138 |
+
|
| 139 |
+
>>> transcriber = pipeline(model="openai/whisper-base")
|
| 140 |
+
>>> transcriber("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac")
|
| 141 |
+
{'text': ' He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered flour-fatten sauce.'}
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
| 145 |
+
|
| 146 |
+
Arguments:
|
| 147 |
+
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
|
| 148 |
+
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
|
| 149 |
+
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
|
| 150 |
+
feature_extractor ([`SequenceFeatureExtractor`]):
|
| 151 |
+
The feature extractor that will be used by the pipeline to encode waveform for the model.
|
| 152 |
+
tokenizer ([`PreTrainedTokenizer`]):
|
| 153 |
+
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
|
| 154 |
+
[`PreTrainedTokenizer`].
|
| 155 |
+
decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*):
|
| 156 |
+
[PyCTCDecode's
|
| 157 |
+
BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180)
|
| 158 |
+
can be passed for language model boosted decoding. See [`Wav2Vec2ProcessorWithLM`] for more information.
|
| 159 |
+
chunk_length_s (`float`, *optional*, defaults to 0):
|
| 160 |
+
The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default).
|
| 161 |
+
|
| 162 |
+
<Tip>
|
| 163 |
+
|
| 164 |
+
For more information on how to effectively use `chunk_length_s`, please have a look at the [ASR chunking
|
| 165 |
+
blog post](https://huggingface.co/blog/asr-chunking).
|
| 166 |
+
|
| 167 |
+
</Tip>
|
| 168 |
+
|
| 169 |
+
stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`):
|
| 170 |
+
The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables
|
| 171 |
+
the model to *see* more context and infer letters better than without this context but the pipeline
|
| 172 |
+
discards the stride bits at the end to make the final reconstitution as perfect as possible.
|
| 173 |
+
|
| 174 |
+
<Tip>
|
| 175 |
+
|
| 176 |
+
For more information on how to effectively use `stride_length_s`, please have a look at the [ASR chunking
|
| 177 |
+
blog post](https://huggingface.co/blog/asr-chunking).
|
| 178 |
+
|
| 179 |
+
</Tip>
|
| 180 |
+
|
| 181 |
+
framework (`str`, *optional*):
|
| 182 |
+
The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be
|
| 183 |
+
installed. If no framework is specified, will default to the one currently installed. If no framework is
|
| 184 |
+
specified and both frameworks are installed, will default to the framework of the `model`, or to PyTorch if
|
| 185 |
+
no model is provided.
|
| 186 |
+
device (Union[`int`, `torch.device`], *optional*):
|
| 187 |
+
Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the
|
| 188 |
+
model on the associated CUDA device id.
|
| 189 |
+
torch_dtype (Union[`int`, `torch.dtype`], *optional*):
|
| 190 |
+
The data-type (dtype) of the computation. Setting this to `None` will use float32 precision. Set to
|
| 191 |
+
`torch.float16` or `torch.bfloat16` to use half-precision in the respective dtypes.
|
| 192 |
+
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
def __init__(
|
| 196 |
+
self,
|
| 197 |
+
model: "PreTrainedModel",
|
| 198 |
+
feature_extractor: Union["SequenceFeatureExtractor", str] = None,
|
| 199 |
+
tokenizer: Optional[PreTrainedTokenizer] = None,
|
| 200 |
+
decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None,
|
| 201 |
+
device: Union[int, "torch.device"] = None,
|
| 202 |
+
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
|
| 203 |
+
**kwargs,
|
| 204 |
+
):
|
| 205 |
+
# set the model type so we can check we have the right pre- and post-processing parameters
|
| 206 |
+
if model.config.model_type == "whisper":
|
| 207 |
+
self.type = "seq2seq_whisper"
|
| 208 |
+
elif model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
|
| 209 |
+
self.type = "seq2seq"
|
| 210 |
+
elif (
|
| 211 |
+
feature_extractor._processor_class
|
| 212 |
+
and feature_extractor._processor_class.endswith("WithLM")
|
| 213 |
+
and decoder is not None
|
| 214 |
+
):
|
| 215 |
+
self.decoder = decoder
|
| 216 |
+
self.type = "ctc_with_lm"
|
| 217 |
+
else:
|
| 218 |
+
self.type = "ctc"
|
| 219 |
+
|
| 220 |
+
super().__init__(model, tokenizer, feature_extractor, device=device, torch_dtype=torch_dtype, **kwargs)
|
| 221 |
+
|
| 222 |
+
def __call__(
|
| 223 |
+
self,
|
| 224 |
+
inputs: Union[np.ndarray, bytes, str],
|
| 225 |
+
**kwargs,
|
| 226 |
+
):
|
| 227 |
+
"""
|
| 228 |
+
Transcribe the audio sequence(s) given as inputs to text. See the [`AutomaticSpeechRecognitionPipeline`]
|
| 229 |
+
documentation for more information.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
|
| 233 |
+
The inputs is either :
|
| 234 |
+
- `str` that is either the filename of a local audio file, or a public URL address to download the
|
| 235 |
+
audio file. The file will be read at the correct sampling rate to get the waveform using
|
| 236 |
+
*ffmpeg*. This requires *ffmpeg* to be installed on the system.
|
| 237 |
+
- `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
|
| 238 |
+
same way.
|
| 239 |
+
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
|
| 240 |
+
Raw audio at the correct sampling rate (no further check will be done)
|
| 241 |
+
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
|
| 242 |
+
pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "raw":
|
| 243 |
+
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
|
| 244 |
+
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
|
| 245 |
+
inference to provide more context to the model). Only use `stride` with CTC models.
|
| 246 |
+
return_timestamps (*optional*, `str` or `bool`):
|
| 247 |
+
Only available for pure CTC models (Wav2Vec2, HuBERT, etc) and the Whisper model. Not available for
|
| 248 |
+
other sequence-to-sequence models.
|
| 249 |
+
|
| 250 |
+
For CTC models, timestamps can take one of two formats:
|
| 251 |
+
- `"char"`: the pipeline will return timestamps along the text for every character in the text. For
|
| 252 |
+
instance, if you get `[{"text": "h", "timestamp": (0.5, 0.6)}, {"text": "i", "timestamp": (0.7,
|
| 253 |
+
0.9)}]`, then it means the model predicts that the letter "h" was spoken after `0.5` and before
|
| 254 |
+
`0.6` seconds.
|
| 255 |
+
- `"word"`: the pipeline will return timestamps along the text for every word in the text. For
|
| 256 |
+
instance, if you get `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": "there", "timestamp":
|
| 257 |
+
(1.0, 1.5)}]`, then it means the model predicts that the word "hi" was spoken after `0.5` and
|
| 258 |
+
before `0.9` seconds.
|
| 259 |
+
|
| 260 |
+
For the Whisper model, timestamps can take one of two formats:
|
| 261 |
+
- `"word"`: same as above for word-level CTC timestamps. Word-level timestamps are predicted
|
| 262 |
+
through the *dynamic-time warping (DTW)* algorithm, an approximation to word-level timestamps
|
| 263 |
+
by inspecting the cross-attention weights.
|
| 264 |
+
- `True`: the pipeline will return timestamps along the text for *segments* of words in the text.
|
| 265 |
+
For instance, if you get `[{"text": " Hi there!", "timestamp": (0.5, 1.5)}]`, then it means the
|
| 266 |
+
model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds.
|
| 267 |
+
Note that a segment of text refers to a sequence of one or more words, rather than individual
|
| 268 |
+
words as with word-level timestamps.
|
| 269 |
+
generate_kwargs (`dict`, *optional*):
|
| 270 |
+
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
|
| 271 |
+
complete overview of generate, check the [following
|
| 272 |
+
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
|
| 273 |
+
|
| 274 |
+
Return:
|
| 275 |
+
`Dict`: A dictionary with the following keys:
|
| 276 |
+
- **text** (`str`): The recognized text.
|
| 277 |
+
- **chunks** (*optional(, `List[Dict]`)
|
| 278 |
+
When using `return_timestamps`, the `chunks` will become a list containing all the various text
|
| 279 |
+
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text":
|
| 280 |
+
"there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
|
| 281 |
+
`"".join(chunk["text"] for chunk in output["chunks"])`.
|
| 282 |
+
"""
|
| 283 |
+
return super().__call__(inputs, **kwargs)
|
| 284 |
+
|
| 285 |
+
def _sanitize_parameters(
|
| 286 |
+
self,
|
| 287 |
+
chunk_length_s=None,
|
| 288 |
+
stride_length_s=None,
|
| 289 |
+
ignore_warning=None,
|
| 290 |
+
decoder_kwargs=None,
|
| 291 |
+
return_timestamps=None,
|
| 292 |
+
return_language=None,
|
| 293 |
+
generate_kwargs=None,
|
| 294 |
+
max_new_tokens=None,
|
| 295 |
+
):
|
| 296 |
+
# No parameters on this pipeline right now
|
| 297 |
+
preprocess_params = {}
|
| 298 |
+
if chunk_length_s is not None:
|
| 299 |
+
if self.type == "seq2seq" and not ignore_warning:
|
| 300 |
+
logger.warning(
|
| 301 |
+
"Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
|
| 302 |
+
" be entirely accurate and will have caveats. More information:"
|
| 303 |
+
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
|
| 304 |
+
" ignore_warning=True)"
|
| 305 |
+
)
|
| 306 |
+
preprocess_params["chunk_length_s"] = chunk_length_s
|
| 307 |
+
if stride_length_s is not None:
|
| 308 |
+
preprocess_params["stride_length_s"] = stride_length_s
|
| 309 |
+
|
| 310 |
+
forward_params = defaultdict(dict)
|
| 311 |
+
if max_new_tokens is not None:
|
| 312 |
+
warnings.warn(
|
| 313 |
+
"`max_new_tokens` is deprecated and will be removed in version 4.49 of Transformers. To remove this warning, pass `max_new_tokens` as a key inside `generate_kwargs` instead.",
|
| 314 |
+
FutureWarning,
|
| 315 |
+
)
|
| 316 |
+
forward_params["max_new_tokens"] = max_new_tokens
|
| 317 |
+
if generate_kwargs is not None:
|
| 318 |
+
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
|
| 319 |
+
raise ValueError(
|
| 320 |
+
"`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
|
| 321 |
+
" only 1 version"
|
| 322 |
+
)
|
| 323 |
+
forward_params.update(generate_kwargs)
|
| 324 |
+
|
| 325 |
+
postprocess_params = {}
|
| 326 |
+
if decoder_kwargs is not None:
|
| 327 |
+
postprocess_params["decoder_kwargs"] = decoder_kwargs
|
| 328 |
+
if return_timestamps is not None:
|
| 329 |
+
# Check whether we have a valid setting for return_timestamps and throw an error before we perform a forward pass
|
| 330 |
+
if self.type == "seq2seq" and return_timestamps:
|
| 331 |
+
raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!")
|
| 332 |
+
if self.type == "ctc_with_lm" and return_timestamps != "word":
|
| 333 |
+
raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`")
|
| 334 |
+
if self.type == "ctc" and return_timestamps not in ["char", "word"]:
|
| 335 |
+
raise ValueError(
|
| 336 |
+
"CTC can either predict character level timestamps, or word level timestamps. "
|
| 337 |
+
"Set `return_timestamps='char'` or `return_timestamps='word'` as required."
|
| 338 |
+
)
|
| 339 |
+
if self.type == "seq2seq_whisper" and return_timestamps == "char":
|
| 340 |
+
raise ValueError(
|
| 341 |
+
"Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
|
| 342 |
+
"Use `return_timestamps='word'` or `return_timestamps=True` respectively."
|
| 343 |
+
)
|
| 344 |
+
forward_params["return_timestamps"] = return_timestamps
|
| 345 |
+
postprocess_params["return_timestamps"] = return_timestamps
|
| 346 |
+
if return_language is not None:
|
| 347 |
+
if self.type != "seq2seq_whisper":
|
| 348 |
+
raise ValueError("Only Whisper can return language for now.")
|
| 349 |
+
postprocess_params["return_language"] = return_language
|
| 350 |
+
|
| 351 |
+
if self.assistant_model is not None:
|
| 352 |
+
forward_params["assistant_model"] = self.assistant_model
|
| 353 |
+
if self.assistant_tokenizer is not None:
|
| 354 |
+
forward_params["tokenizer"] = self.tokenizer
|
| 355 |
+
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
|
| 356 |
+
|
| 357 |
+
return preprocess_params, forward_params, postprocess_params
|
| 358 |
+
|
| 359 |
+
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
|
| 360 |
+
if isinstance(inputs, str):
|
| 361 |
+
if inputs.startswith("http://") or inputs.startswith("https://"):
|
| 362 |
+
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
| 363 |
+
# like http_huggingface_co.png
|
| 364 |
+
inputs = requests.get(inputs).content
|
| 365 |
+
else:
|
| 366 |
+
with open(inputs, "rb") as f:
|
| 367 |
+
inputs = f.read()
|
| 368 |
+
|
| 369 |
+
if isinstance(inputs, bytes):
|
| 370 |
+
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
|
| 371 |
+
|
| 372 |
+
stride = None
|
| 373 |
+
extra = {}
|
| 374 |
+
if isinstance(inputs, dict):
|
| 375 |
+
stride = inputs.pop("stride", None)
|
| 376 |
+
# Accepting `"array"` which is the key defined in `datasets` for
|
| 377 |
+
# better integration
|
| 378 |
+
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
|
| 379 |
+
raise ValueError(
|
| 380 |
+
"When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a "
|
| 381 |
+
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
|
| 382 |
+
"containing the sampling_rate associated with that array"
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
_inputs = inputs.pop("raw", None)
|
| 386 |
+
if _inputs is None:
|
| 387 |
+
# Remove path which will not be used from `datasets`.
|
| 388 |
+
inputs.pop("path", None)
|
| 389 |
+
_inputs = inputs.pop("array", None)
|
| 390 |
+
in_sampling_rate = inputs.pop("sampling_rate")
|
| 391 |
+
extra = inputs
|
| 392 |
+
inputs = _inputs
|
| 393 |
+
if in_sampling_rate != self.feature_extractor.sampling_rate:
|
| 394 |
+
if is_torchaudio_available():
|
| 395 |
+
from torchaudio import functional as F
|
| 396 |
+
else:
|
| 397 |
+
raise ImportError(
|
| 398 |
+
"torchaudio is required to resample audio samples in AutomaticSpeechRecognitionPipeline. "
|
| 399 |
+
"The torchaudio package can be installed through: `pip install torchaudio`."
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
inputs = F.resample(
|
| 403 |
+
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
|
| 404 |
+
).numpy()
|
| 405 |
+
ratio = self.feature_extractor.sampling_rate / in_sampling_rate
|
| 406 |
+
else:
|
| 407 |
+
ratio = 1
|
| 408 |
+
if stride is not None:
|
| 409 |
+
if stride[0] + stride[1] > inputs.shape[0]:
|
| 410 |
+
raise ValueError("Stride is too large for input")
|
| 411 |
+
|
| 412 |
+
# Stride needs to get the chunk length here, it's going to get
|
| 413 |
+
# swallowed by the `feature_extractor` later, and then batching
|
| 414 |
+
# can add extra data in the inputs, so we need to keep track
|
| 415 |
+
# of the original length in the stride so we can cut properly.
|
| 416 |
+
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
|
| 417 |
+
if not isinstance(inputs, np.ndarray):
|
| 418 |
+
raise TypeError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
|
| 419 |
+
if len(inputs.shape) != 1:
|
| 420 |
+
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
|
| 421 |
+
|
| 422 |
+
if chunk_length_s:
|
| 423 |
+
if stride_length_s is None:
|
| 424 |
+
stride_length_s = chunk_length_s / 6
|
| 425 |
+
|
| 426 |
+
if isinstance(stride_length_s, (int, float)):
|
| 427 |
+
stride_length_s = [stride_length_s, stride_length_s]
|
| 428 |
+
|
| 429 |
+
# XXX: Carefuly, this variable will not exist in `seq2seq` setting.
|
| 430 |
+
# Currently chunking is not possible at this level for `seq2seq` so
|
| 431 |
+
# it's ok.
|
| 432 |
+
align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
|
| 433 |
+
chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
|
| 434 |
+
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
|
| 435 |
+
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
|
| 436 |
+
|
| 437 |
+
if chunk_len < stride_left + stride_right:
|
| 438 |
+
raise ValueError("Chunk length must be superior to stride length")
|
| 439 |
+
|
| 440 |
+
for item in chunk_iter(
|
| 441 |
+
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
|
| 442 |
+
):
|
| 443 |
+
yield {**item, **extra}
|
| 444 |
+
else:
|
| 445 |
+
if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples:
|
| 446 |
+
processed = self.feature_extractor(
|
| 447 |
+
inputs,
|
| 448 |
+
sampling_rate=self.feature_extractor.sampling_rate,
|
| 449 |
+
truncation=False,
|
| 450 |
+
padding="longest",
|
| 451 |
+
return_tensors="pt",
|
| 452 |
+
return_attention_mask=True,
|
| 453 |
+
)
|
| 454 |
+
else:
|
| 455 |
+
if self.type == "seq2seq_whisper" and stride is None:
|
| 456 |
+
processed = self.feature_extractor(
|
| 457 |
+
inputs,
|
| 458 |
+
sampling_rate=self.feature_extractor.sampling_rate,
|
| 459 |
+
return_tensors="pt",
|
| 460 |
+
return_token_timestamps=True,
|
| 461 |
+
return_attention_mask=True,
|
| 462 |
+
)
|
| 463 |
+
extra["num_frames"] = processed.pop("num_frames")
|
| 464 |
+
else:
|
| 465 |
+
processed = self.feature_extractor(
|
| 466 |
+
inputs,
|
| 467 |
+
sampling_rate=self.feature_extractor.sampling_rate,
|
| 468 |
+
return_tensors="pt",
|
| 469 |
+
return_attention_mask=True,
|
| 470 |
+
)
|
| 471 |
+
if self.torch_dtype is not None:
|
| 472 |
+
processed = processed.to(dtype=self.torch_dtype)
|
| 473 |
+
if stride is not None:
|
| 474 |
+
if self.type == "seq2seq":
|
| 475 |
+
raise ValueError("Stride is only usable with CTC models, try removing it !")
|
| 476 |
+
|
| 477 |
+
processed["stride"] = stride
|
| 478 |
+
yield {"is_last": True, **processed, **extra}
|
| 479 |
+
|
| 480 |
+
def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
|
| 481 |
+
attention_mask = model_inputs.pop("attention_mask", None)
|
| 482 |
+
stride = model_inputs.pop("stride", None)
|
| 483 |
+
num_frames = model_inputs.pop("num_frames", None)
|
| 484 |
+
is_last = model_inputs.pop("is_last")
|
| 485 |
+
|
| 486 |
+
if stride is not None and num_frames is not None:
|
| 487 |
+
raise ValueError("num_frames must be used only when stride is None")
|
| 488 |
+
|
| 489 |
+
if self.type in {"seq2seq", "seq2seq_whisper"}:
|
| 490 |
+
# Consume values so we can let extra information flow freely through
|
| 491 |
+
# the pipeline (important for `partial` in microphone)
|
| 492 |
+
if "input_features" in model_inputs:
|
| 493 |
+
inputs = model_inputs.pop("input_features")
|
| 494 |
+
elif "input_values" in model_inputs:
|
| 495 |
+
inputs = model_inputs.pop("input_values")
|
| 496 |
+
else:
|
| 497 |
+
raise ValueError(
|
| 498 |
+
"Seq2Seq speech recognition model requires either a "
|
| 499 |
+
f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
# custom processing for Whisper timestamps and word-level timestamps
|
| 503 |
+
if return_timestamps and self.type == "seq2seq_whisper":
|
| 504 |
+
generate_kwargs["return_timestamps"] = return_timestamps
|
| 505 |
+
if return_timestamps == "word":
|
| 506 |
+
generate_kwargs["return_token_timestamps"] = True
|
| 507 |
+
generate_kwargs["return_segments"] = True
|
| 508 |
+
|
| 509 |
+
if stride is not None:
|
| 510 |
+
if isinstance(stride, tuple):
|
| 511 |
+
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
|
| 512 |
+
else:
|
| 513 |
+
generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride]
|
| 514 |
+
else:
|
| 515 |
+
generate_kwargs["num_frames"] = num_frames
|
| 516 |
+
|
| 517 |
+
# User-defined `generation_config` passed to the pipeline call take precedence
|
| 518 |
+
if "generation_config" not in generate_kwargs:
|
| 519 |
+
generate_kwargs["generation_config"] = self.generation_config
|
| 520 |
+
|
| 521 |
+
tokens = self.model.generate(
|
| 522 |
+
inputs=inputs,
|
| 523 |
+
attention_mask=attention_mask,
|
| 524 |
+
**generate_kwargs,
|
| 525 |
+
)
|
| 526 |
+
# whisper longform generation stores timestamps in "segments"
|
| 527 |
+
if return_timestamps == "word" and self.type == "seq2seq_whisper":
|
| 528 |
+
if "segments" not in tokens:
|
| 529 |
+
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
|
| 530 |
+
else:
|
| 531 |
+
token_timestamps = [
|
| 532 |
+
torch.cat([segment["token_timestamps"] for segment in segment_list])
|
| 533 |
+
for segment_list in tokens["segments"]
|
| 534 |
+
]
|
| 535 |
+
out = {"tokens": tokens["sequences"], "token_timestamps": token_timestamps}
|
| 536 |
+
else:
|
| 537 |
+
out = {"tokens": tokens}
|
| 538 |
+
if self.type == "seq2seq_whisper":
|
| 539 |
+
if stride is not None:
|
| 540 |
+
out["stride"] = stride
|
| 541 |
+
|
| 542 |
+
else:
|
| 543 |
+
inputs = {
|
| 544 |
+
self.model.main_input_name: model_inputs.pop(self.model.main_input_name),
|
| 545 |
+
"attention_mask": attention_mask,
|
| 546 |
+
}
|
| 547 |
+
outputs = self.model(**inputs)
|
| 548 |
+
logits = outputs.logits
|
| 549 |
+
|
| 550 |
+
if self.type == "ctc_with_lm":
|
| 551 |
+
out = {"logits": logits}
|
| 552 |
+
else:
|
| 553 |
+
out = {"tokens": logits.argmax(dim=-1)}
|
| 554 |
+
if stride is not None:
|
| 555 |
+
# Send stride to `postprocess`.
|
| 556 |
+
# it needs to be handled there where
|
| 557 |
+
# the pieces are to be concatenated.
|
| 558 |
+
ratio = 1 / self.model.config.inputs_to_logits_ratio
|
| 559 |
+
if isinstance(stride, tuple):
|
| 560 |
+
out["stride"] = rescale_stride([stride], ratio)[0]
|
| 561 |
+
else:
|
| 562 |
+
out["stride"] = rescale_stride(stride, ratio)
|
| 563 |
+
# Leftover
|
| 564 |
+
extra = model_inputs
|
| 565 |
+
return {"is_last": is_last, **out, **extra}
|
| 566 |
+
|
| 567 |
+
def postprocess(
|
| 568 |
+
self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None, return_language=None
|
| 569 |
+
):
|
| 570 |
+
# Optional return types
|
| 571 |
+
optional = {}
|
| 572 |
+
|
| 573 |
+
final_items = []
|
| 574 |
+
key = "logits" if self.type == "ctc_with_lm" else "tokens"
|
| 575 |
+
stride = None
|
| 576 |
+
for outputs in model_outputs:
|
| 577 |
+
if self.framework == "pt" and outputs[key].dtype in (torch.bfloat16, torch.float16):
|
| 578 |
+
items = outputs[key].to(torch.float32).numpy()
|
| 579 |
+
else:
|
| 580 |
+
items = outputs[key].numpy()
|
| 581 |
+
stride = outputs.get("stride", None)
|
| 582 |
+
if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
|
| 583 |
+
total_n, left, right = stride
|
| 584 |
+
# Total_n might be < logits.shape[1]
|
| 585 |
+
# because of padding, that's why
|
| 586 |
+
# we need to reconstruct this information
|
| 587 |
+
# This won't work with left padding (which doesn't exist right now)
|
| 588 |
+
right_n = total_n - right
|
| 589 |
+
items = items[:, left:right_n]
|
| 590 |
+
final_items.append(items)
|
| 591 |
+
|
| 592 |
+
if stride and self.type == "seq2seq":
|
| 593 |
+
items = _find_longest_common_sequence(final_items, self.tokenizer)
|
| 594 |
+
elif self.type == "seq2seq_whisper":
|
| 595 |
+
time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions
|
| 596 |
+
# Send the chunking back to seconds, it's easier to handle in whisper
|
| 597 |
+
sampling_rate = self.feature_extractor.sampling_rate
|
| 598 |
+
for output in model_outputs:
|
| 599 |
+
if "stride" in output:
|
| 600 |
+
chunk_len, stride_left, stride_right = output["stride"]
|
| 601 |
+
# Go back in seconds
|
| 602 |
+
chunk_len /= sampling_rate
|
| 603 |
+
stride_left /= sampling_rate
|
| 604 |
+
stride_right /= sampling_rate
|
| 605 |
+
output["stride"] = chunk_len, stride_left, stride_right
|
| 606 |
+
|
| 607 |
+
text, optional = self.tokenizer._decode_asr(
|
| 608 |
+
model_outputs,
|
| 609 |
+
return_timestamps=return_timestamps,
|
| 610 |
+
return_language=return_language,
|
| 611 |
+
time_precision=time_precision,
|
| 612 |
+
)
|
| 613 |
+
else:
|
| 614 |
+
items = np.concatenate(final_items, axis=1)
|
| 615 |
+
items = items.squeeze(0)
|
| 616 |
+
|
| 617 |
+
if self.type == "ctc_with_lm":
|
| 618 |
+
if decoder_kwargs is None:
|
| 619 |
+
decoder_kwargs = {}
|
| 620 |
+
beams = self.decoder.decode_beams(items, **decoder_kwargs)
|
| 621 |
+
text = beams[0][0]
|
| 622 |
+
if return_timestamps:
|
| 623 |
+
# Simply cast from pyctcdecode format to wav2vec2 format to leverage
|
| 624 |
+
# pre-existing code later
|
| 625 |
+
chunk_offset = beams[0][2]
|
| 626 |
+
offsets = []
|
| 627 |
+
for word, (start_offset, end_offset) in chunk_offset:
|
| 628 |
+
offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
|
| 629 |
+
elif self.type != "seq2seq_whisper":
|
| 630 |
+
skip_special_tokens = self.type != "ctc"
|
| 631 |
+
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
|
| 632 |
+
if return_timestamps:
|
| 633 |
+
offsets = self.tokenizer.decode(
|
| 634 |
+
items, skip_special_tokens=skip_special_tokens, output_char_offsets=True
|
| 635 |
+
)["char_offsets"]
|
| 636 |
+
if return_timestamps == "word":
|
| 637 |
+
offsets = self.tokenizer._get_word_offsets(offsets, self.tokenizer.replace_word_delimiter_char)
|
| 638 |
+
|
| 639 |
+
if return_timestamps and self.type not in {"seq2seq", "seq2seq_whisper"}:
|
| 640 |
+
chunks = []
|
| 641 |
+
for item in offsets:
|
| 642 |
+
start = item["start_offset"] * self.model.config.inputs_to_logits_ratio
|
| 643 |
+
start /= self.feature_extractor.sampling_rate
|
| 644 |
+
|
| 645 |
+
stop = item["end_offset"] * self.model.config.inputs_to_logits_ratio
|
| 646 |
+
stop /= self.feature_extractor.sampling_rate
|
| 647 |
+
|
| 648 |
+
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
|
| 649 |
+
optional["chunks"] = chunks
|
| 650 |
+
|
| 651 |
+
extra = defaultdict(list)
|
| 652 |
+
for output in model_outputs:
|
| 653 |
+
output.pop("tokens", None)
|
| 654 |
+
output.pop("logits", None)
|
| 655 |
+
output.pop("is_last", None)
|
| 656 |
+
output.pop("stride", None)
|
| 657 |
+
output.pop("token_timestamps", None)
|
| 658 |
+
for k, v in output.items():
|
| 659 |
+
extra[k].append(v)
|
| 660 |
+
return {"text": text, **optional, **extra}
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions):
|
| 664 |
+
"""
|
| 665 |
+
Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since
|
| 666 |
+
`WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only
|
| 667 |
+
iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is
|
| 668 |
+
processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to
|
| 669 |
+
properly compute the final `offset`.
|
| 670 |
+
"""
|
| 671 |
+
# index of the first timestamp token
|
| 672 |
+
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
|
| 673 |
+
items = []
|
| 674 |
+
# approximation of the token to time ratio : ~0.2seconds
|
| 675 |
+
time_precision = feature_extractor.chunk_length / max_source_positions
|
| 676 |
+
time = 0
|
| 677 |
+
for seq_idx, item in enumerate(sequences):
|
| 678 |
+
sequence, stride = item
|
| 679 |
+
if isinstance(sequence, list):
|
| 680 |
+
sequence = np.array(sequence)
|
| 681 |
+
chunk_len, stride_left, stride_right = stride
|
| 682 |
+
sequence = sequence.squeeze(0)
|
| 683 |
+
# get rid of the `forced_decoder_idx` that are use to parametrize the generation
|
| 684 |
+
begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0
|
| 685 |
+
sequence = sequence[begin_idx:]
|
| 686 |
+
|
| 687 |
+
timestamp_tokens = sequence >= timestamp_begin
|
| 688 |
+
if seq_idx != 0 and sum(timestamp_tokens) > 0:
|
| 689 |
+
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
| 690 |
+
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
| 691 |
+
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
|
| 692 |
+
time -= stride_left + stride_right
|
| 693 |
+
offset = int((time / feature_extractor.sampling_rate) / time_precision)
|
| 694 |
+
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
|
| 695 |
+
# relevant timestamps are in the overlapping part
|
| 696 |
+
relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0]
|
| 697 |
+
if relevant_timestamp.shape[0] > 0:
|
| 698 |
+
relevant_timestamp = (
|
| 699 |
+
consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0]
|
| 700 |
+
)
|
| 701 |
+
# if a big stride is used, we need to check some of the previous items for the best overlap
|
| 702 |
+
best_match = 0
|
| 703 |
+
sliced_sequence = []
|
| 704 |
+
for idx, previous_sequence in enumerate(reversed(items)):
|
| 705 |
+
previous_tokens = previous_sequence[1:-1]
|
| 706 |
+
if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0:
|
| 707 |
+
break # the previous sequence is too far in the past
|
| 708 |
+
if len(previous_tokens) > 0:
|
| 709 |
+
# find the longest common sequence between the overlapping parts
|
| 710 |
+
index_left, index_right, match_length = _fast_find_longest_common_sequence(
|
| 711 |
+
sequence[1:relevant_timestamp], previous_tokens
|
| 712 |
+
)
|
| 713 |
+
# don't do anything if only 1 token was matched
|
| 714 |
+
if match_length > 1 and match_length > best_match:
|
| 715 |
+
best_match = match_length
|
| 716 |
+
best_idx = idx
|
| 717 |
+
end_of_curr_sequence_idx = (
|
| 718 |
+
np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1
|
| 719 |
+
)
|
| 720 |
+
end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left
|
| 721 |
+
# if all the tokens are matched, suffix
|
| 722 |
+
if index_left == 0 and match_length == len(previous_tokens):
|
| 723 |
+
sliced_sequence = np.insert(
|
| 724 |
+
sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0]
|
| 725 |
+
)
|
| 726 |
+
sliced_sequence[-1] = previous_sequence[-1]
|
| 727 |
+
# if part of the previous sequence is not taken
|
| 728 |
+
elif index_left >= 0:
|
| 729 |
+
sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx]
|
| 730 |
+
# let's insert the missing part of the previous sequence
|
| 731 |
+
previous_slice = (
|
| 732 |
+
previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]]
|
| 733 |
+
)
|
| 734 |
+
sliced_sequence = np.insert(sliced_sequence, 0, previous_slice)
|
| 735 |
+
sliced_sequence[-1] += offset
|
| 736 |
+
|
| 737 |
+
if len(sliced_sequence) > 0:
|
| 738 |
+
items[len(items) - best_idx - 1] = sliced_sequence
|
| 739 |
+
items = items[: len(items) - best_idx]
|
| 740 |
+
sequence = sequence[end_of_curr_sequence_idx:]
|
| 741 |
+
|
| 742 |
+
# sequence might have changed
|
| 743 |
+
timestamp_tokens = sequence >= timestamp_begin
|
| 744 |
+
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
| 745 |
+
if sum(timestamp_tokens) > 0:
|
| 746 |
+
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
| 747 |
+
consecutive = (
|
| 748 |
+
np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
if len(consecutive) > 0:
|
| 752 |
+
last_slice = 0
|
| 753 |
+
for current_slice in consecutive:
|
| 754 |
+
actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0]
|
| 755 |
+
sliced_tokens = sequence[last_slice:current_slice]
|
| 756 |
+
duration = sliced_tokens[-1] - sliced_tokens[0]
|
| 757 |
+
sliced_tokens[0] = actual_offset
|
| 758 |
+
sliced_tokens[-1] = actual_offset + duration
|
| 759 |
+
items.append(sliced_tokens)
|
| 760 |
+
last_slice = current_slice
|
| 761 |
+
|
| 762 |
+
time += chunk_len
|
| 763 |
+
result = []
|
| 764 |
+
for i in range(len(items)):
|
| 765 |
+
result += items[i].tolist()
|
| 766 |
+
return result
|
.venv/lib/python3.11/site-packages/transformers/pipelines/base.py
ADDED
|
@@ -0,0 +1,1484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import collections
|
| 16 |
+
import copy
|
| 17 |
+
import csv
|
| 18 |
+
import importlib
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import pickle
|
| 22 |
+
import sys
|
| 23 |
+
import traceback
|
| 24 |
+
import types
|
| 25 |
+
import warnings
|
| 26 |
+
from abc import ABC, abstractmethod
|
| 27 |
+
from collections import UserDict
|
| 28 |
+
from contextlib import contextmanager
|
| 29 |
+
from os.path import abspath, exists
|
| 30 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
| 31 |
+
|
| 32 |
+
from ..dynamic_module_utils import custom_object_save
|
| 33 |
+
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
| 34 |
+
from ..image_processing_utils import BaseImageProcessor
|
| 35 |
+
from ..modelcard import ModelCard
|
| 36 |
+
from ..models.auto import AutoConfig, AutoTokenizer
|
| 37 |
+
from ..processing_utils import ProcessorMixin
|
| 38 |
+
from ..tokenization_utils import PreTrainedTokenizer
|
| 39 |
+
from ..utils import (
|
| 40 |
+
ModelOutput,
|
| 41 |
+
PushToHubMixin,
|
| 42 |
+
add_end_docstrings,
|
| 43 |
+
copy_func,
|
| 44 |
+
infer_framework,
|
| 45 |
+
is_tf_available,
|
| 46 |
+
is_torch_available,
|
| 47 |
+
is_torch_cuda_available,
|
| 48 |
+
is_torch_mlu_available,
|
| 49 |
+
is_torch_mps_available,
|
| 50 |
+
is_torch_musa_available,
|
| 51 |
+
is_torch_npu_available,
|
| 52 |
+
is_torch_xpu_available,
|
| 53 |
+
logging,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
GenericTensor = Union[List["GenericTensor"], "torch.Tensor", "tf.Tensor"]
|
| 58 |
+
|
| 59 |
+
if is_tf_available():
|
| 60 |
+
import tensorflow as tf
|
| 61 |
+
|
| 62 |
+
from ..models.auto.modeling_tf_auto import TFAutoModel
|
| 63 |
+
|
| 64 |
+
if is_torch_available():
|
| 65 |
+
import torch
|
| 66 |
+
from torch.utils.data import DataLoader, Dataset
|
| 67 |
+
|
| 68 |
+
from ..models.auto.modeling_auto import AutoModel
|
| 69 |
+
|
| 70 |
+
# Re-export for backward compatibility
|
| 71 |
+
from .pt_utils import KeyDataset
|
| 72 |
+
else:
|
| 73 |
+
Dataset = None
|
| 74 |
+
KeyDataset = None
|
| 75 |
+
|
| 76 |
+
if TYPE_CHECKING:
|
| 77 |
+
from ..modeling_tf_utils import TFPreTrainedModel
|
| 78 |
+
from ..modeling_utils import PreTrainedModel
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
logger = logging.get_logger(__name__)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def no_collate_fn(items):
|
| 85 |
+
if len(items) != 1:
|
| 86 |
+
raise ValueError("This collate_fn is meant to be used with batch_size=1")
|
| 87 |
+
return items[0]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _pad(items, key, padding_value, padding_side):
|
| 91 |
+
batch_size = len(items)
|
| 92 |
+
if isinstance(items[0][key], torch.Tensor):
|
| 93 |
+
# Others include `attention_mask` etc...
|
| 94 |
+
shape = items[0][key].shape
|
| 95 |
+
dim = len(shape)
|
| 96 |
+
if dim == 1:
|
| 97 |
+
# We have a list of 1-dim torch tensors, which can be stacked without padding
|
| 98 |
+
return torch.cat([item[key] for item in items], dim=0)
|
| 99 |
+
if key in ["pixel_values", "image"]:
|
| 100 |
+
# This is probable image so padding shouldn't be necessary
|
| 101 |
+
# B, C, H, W
|
| 102 |
+
return torch.cat([item[key] for item in items], dim=0)
|
| 103 |
+
elif dim == 4 and key == "input_features":
|
| 104 |
+
# this is probably a mel spectrogram batched
|
| 105 |
+
return torch.cat([item[key] for item in items], dim=0)
|
| 106 |
+
max_length = max(item[key].shape[1] for item in items)
|
| 107 |
+
min_length = min(item[key].shape[1] for item in items)
|
| 108 |
+
dtype = items[0][key].dtype
|
| 109 |
+
|
| 110 |
+
if dim == 2:
|
| 111 |
+
if max_length == min_length:
|
| 112 |
+
# Bypass for `ImageGPT` which doesn't provide a padding value, yet
|
| 113 |
+
# we can consistently pad since the size should be matching
|
| 114 |
+
return torch.cat([item[key] for item in items], dim=0)
|
| 115 |
+
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
|
| 116 |
+
elif dim == 3:
|
| 117 |
+
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
|
| 118 |
+
elif dim == 4:
|
| 119 |
+
tensor = torch.zeros((batch_size, max_length, shape[-2], shape[-1]), dtype=dtype) + padding_value
|
| 120 |
+
|
| 121 |
+
for i, item in enumerate(items):
|
| 122 |
+
if dim == 2:
|
| 123 |
+
if padding_side == "left":
|
| 124 |
+
tensor[i, -len(item[key][0]) :] = item[key][0].clone()
|
| 125 |
+
else:
|
| 126 |
+
tensor[i, : len(item[key][0])] = item[key][0].clone()
|
| 127 |
+
elif dim == 3:
|
| 128 |
+
if padding_side == "left":
|
| 129 |
+
tensor[i, -len(item[key][0]) :, :] = item[key][0].clone()
|
| 130 |
+
else:
|
| 131 |
+
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
|
| 132 |
+
elif dim == 4:
|
| 133 |
+
if padding_side == "left":
|
| 134 |
+
tensor[i, -len(item[key][0]) :, :, :] = item[key][0].clone()
|
| 135 |
+
else:
|
| 136 |
+
tensor[i, : len(item[key][0]), :, :] = item[key][0].clone()
|
| 137 |
+
|
| 138 |
+
return tensor
|
| 139 |
+
else:
|
| 140 |
+
return [item[key] for item in items]
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def pad_collate_fn(tokenizer, feature_extractor):
|
| 144 |
+
# Tokenizer
|
| 145 |
+
t_padding_side = None
|
| 146 |
+
# Feature extractor
|
| 147 |
+
f_padding_side = None
|
| 148 |
+
if tokenizer is None and feature_extractor is None:
|
| 149 |
+
raise ValueError("Pipeline without tokenizer or feature_extractor cannot do batching")
|
| 150 |
+
if tokenizer is not None:
|
| 151 |
+
if tokenizer.pad_token_id is None:
|
| 152 |
+
raise ValueError(
|
| 153 |
+
"Pipeline with tokenizer without pad_token cannot do batching. You can try to set it with "
|
| 154 |
+
"`pipe.tokenizer.pad_token_id = model.config.eos_token_id`."
|
| 155 |
+
)
|
| 156 |
+
else:
|
| 157 |
+
t_padding_value = tokenizer.pad_token_id
|
| 158 |
+
t_padding_side = tokenizer.padding_side
|
| 159 |
+
if feature_extractor is not None:
|
| 160 |
+
# Feature extractor can be images, where no padding is expected
|
| 161 |
+
f_padding_value = getattr(feature_extractor, "padding_value", None)
|
| 162 |
+
f_padding_side = getattr(feature_extractor, "padding_side", None)
|
| 163 |
+
|
| 164 |
+
if t_padding_side is not None and f_padding_side is not None and t_padding_side != f_padding_side:
|
| 165 |
+
raise ValueError(
|
| 166 |
+
f"The feature extractor, and tokenizer don't agree on padding side {t_padding_side} != {f_padding_side}"
|
| 167 |
+
)
|
| 168 |
+
padding_side = "right"
|
| 169 |
+
if t_padding_side is not None:
|
| 170 |
+
padding_side = t_padding_side
|
| 171 |
+
if f_padding_side is not None:
|
| 172 |
+
padding_side = f_padding_side
|
| 173 |
+
|
| 174 |
+
def inner(items):
|
| 175 |
+
keys = set(items[0].keys())
|
| 176 |
+
for item in items:
|
| 177 |
+
if set(item.keys()) != keys:
|
| 178 |
+
raise ValueError(
|
| 179 |
+
f"The elements of the batch contain different keys. Cannot batch them ({set(item.keys())} !="
|
| 180 |
+
f" {keys})"
|
| 181 |
+
)
|
| 182 |
+
# input_values, input_pixels, input_ids, ...
|
| 183 |
+
padded = {}
|
| 184 |
+
for key in keys:
|
| 185 |
+
if key in {"input_ids"}:
|
| 186 |
+
# ImageGPT uses a feature extractor
|
| 187 |
+
if tokenizer is None and feature_extractor is not None:
|
| 188 |
+
_padding_value = f_padding_value
|
| 189 |
+
else:
|
| 190 |
+
_padding_value = t_padding_value
|
| 191 |
+
elif key in {"input_values", "pixel_values", "input_features"}:
|
| 192 |
+
_padding_value = f_padding_value
|
| 193 |
+
elif key in {"p_mask", "special_tokens_mask"}:
|
| 194 |
+
_padding_value = 1
|
| 195 |
+
elif key in {"attention_mask", "token_type_ids"}:
|
| 196 |
+
_padding_value = 0
|
| 197 |
+
else:
|
| 198 |
+
# This is likely another random key maybe even user provided
|
| 199 |
+
_padding_value = 0
|
| 200 |
+
padded[key] = _pad(items, key, _padding_value, padding_side)
|
| 201 |
+
return padded
|
| 202 |
+
|
| 203 |
+
return inner
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def infer_framework_load_model(
|
| 207 |
+
model,
|
| 208 |
+
config: AutoConfig,
|
| 209 |
+
model_classes: Optional[Dict[str, Tuple[type]]] = None,
|
| 210 |
+
task: Optional[str] = None,
|
| 211 |
+
framework: Optional[str] = None,
|
| 212 |
+
**model_kwargs,
|
| 213 |
+
):
|
| 214 |
+
"""
|
| 215 |
+
Select framework (TensorFlow or PyTorch) to use from the `model` passed. Returns a tuple (framework, model).
|
| 216 |
+
|
| 217 |
+
If `model` is instantiated, this function will just infer the framework from the model class. Otherwise `model` is
|
| 218 |
+
actually a checkpoint name and this method will try to instantiate it using `model_classes`. Since we don't want to
|
| 219 |
+
instantiate the model twice, this model is returned for use by the pipeline.
|
| 220 |
+
|
| 221 |
+
If both frameworks are installed and available for `model`, PyTorch is selected.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
model (`str`, [`PreTrainedModel`] or [`TFPreTrainedModel]`):
|
| 225 |
+
The model to infer the framework from. If `str`, a checkpoint name. The model to infer the framewrok from.
|
| 226 |
+
config ([`AutoConfig`]):
|
| 227 |
+
The config associated with the model to help using the correct class
|
| 228 |
+
model_classes (dictionary `str` to `type`, *optional*):
|
| 229 |
+
A mapping framework to class.
|
| 230 |
+
task (`str`):
|
| 231 |
+
The task defining which pipeline will be returned.
|
| 232 |
+
model_kwargs:
|
| 233 |
+
Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
|
| 234 |
+
**model_kwargs)` function.
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
`Tuple`: A tuple framework, model.
|
| 238 |
+
"""
|
| 239 |
+
if not is_tf_available() and not is_torch_available():
|
| 240 |
+
raise RuntimeError(
|
| 241 |
+
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
| 242 |
+
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
|
| 243 |
+
"To install PyTorch, read the instructions at https://pytorch.org/."
|
| 244 |
+
)
|
| 245 |
+
if isinstance(model, str):
|
| 246 |
+
model_kwargs["_from_pipeline"] = task
|
| 247 |
+
class_tuple = ()
|
| 248 |
+
look_pt = is_torch_available() and framework in {"pt", None}
|
| 249 |
+
look_tf = is_tf_available() and framework in {"tf", None}
|
| 250 |
+
if model_classes:
|
| 251 |
+
if look_pt:
|
| 252 |
+
class_tuple = class_tuple + model_classes.get("pt", (AutoModel,))
|
| 253 |
+
if look_tf:
|
| 254 |
+
class_tuple = class_tuple + model_classes.get("tf", (TFAutoModel,))
|
| 255 |
+
if config.architectures:
|
| 256 |
+
classes = []
|
| 257 |
+
for architecture in config.architectures:
|
| 258 |
+
transformers_module = importlib.import_module("transformers")
|
| 259 |
+
if look_pt:
|
| 260 |
+
_class = getattr(transformers_module, architecture, None)
|
| 261 |
+
if _class is not None:
|
| 262 |
+
classes.append(_class)
|
| 263 |
+
if look_tf:
|
| 264 |
+
_class = getattr(transformers_module, f"TF{architecture}", None)
|
| 265 |
+
if _class is not None:
|
| 266 |
+
classes.append(_class)
|
| 267 |
+
class_tuple = class_tuple + tuple(classes)
|
| 268 |
+
|
| 269 |
+
if len(class_tuple) == 0:
|
| 270 |
+
raise ValueError(f"Pipeline cannot infer suitable model classes from {model}")
|
| 271 |
+
|
| 272 |
+
all_traceback = {}
|
| 273 |
+
for model_class in class_tuple:
|
| 274 |
+
kwargs = model_kwargs.copy()
|
| 275 |
+
if framework == "pt" and model.endswith(".h5"):
|
| 276 |
+
kwargs["from_tf"] = True
|
| 277 |
+
logger.warning(
|
| 278 |
+
"Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
|
| 279 |
+
"Trying to load the model with PyTorch."
|
| 280 |
+
)
|
| 281 |
+
elif framework == "tf" and model.endswith(".bin"):
|
| 282 |
+
kwargs["from_pt"] = True
|
| 283 |
+
logger.warning(
|
| 284 |
+
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
|
| 285 |
+
"Trying to load the model with Tensorflow."
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
try:
|
| 289 |
+
model = model_class.from_pretrained(model, **kwargs)
|
| 290 |
+
if hasattr(model, "eval"):
|
| 291 |
+
model = model.eval()
|
| 292 |
+
# Stop loading on the first successful load.
|
| 293 |
+
break
|
| 294 |
+
except (OSError, ValueError):
|
| 295 |
+
all_traceback[model_class.__name__] = traceback.format_exc()
|
| 296 |
+
continue
|
| 297 |
+
|
| 298 |
+
if isinstance(model, str):
|
| 299 |
+
error = ""
|
| 300 |
+
for class_name, trace in all_traceback.items():
|
| 301 |
+
error += f"while loading with {class_name}, an error is thrown:\n{trace}\n"
|
| 302 |
+
raise ValueError(
|
| 303 |
+
f"Could not load model {model} with any of the following classes: {class_tuple}. See the original errors:\n\n{error}\n"
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
if framework is None:
|
| 307 |
+
framework = infer_framework(model.__class__)
|
| 308 |
+
return framework, model
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def infer_framework_from_model(
|
| 312 |
+
model,
|
| 313 |
+
model_classes: Optional[Dict[str, Tuple[type]]] = None,
|
| 314 |
+
task: Optional[str] = None,
|
| 315 |
+
framework: Optional[str] = None,
|
| 316 |
+
**model_kwargs,
|
| 317 |
+
):
|
| 318 |
+
"""
|
| 319 |
+
Select framework (TensorFlow or PyTorch) to use from the `model` passed. Returns a tuple (framework, model).
|
| 320 |
+
|
| 321 |
+
If `model` is instantiated, this function will just infer the framework from the model class. Otherwise `model` is
|
| 322 |
+
actually a checkpoint name and this method will try to instantiate it using `model_classes`. Since we don't want to
|
| 323 |
+
instantiate the model twice, this model is returned for use by the pipeline.
|
| 324 |
+
|
| 325 |
+
If both frameworks are installed and available for `model`, PyTorch is selected.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
model (`str`, [`PreTrainedModel`] or [`TFPreTrainedModel]`):
|
| 329 |
+
The model to infer the framework from. If `str`, a checkpoint name. The model to infer the framewrok from.
|
| 330 |
+
model_classes (dictionary `str` to `type`, *optional*):
|
| 331 |
+
A mapping framework to class.
|
| 332 |
+
task (`str`):
|
| 333 |
+
The task defining which pipeline will be returned.
|
| 334 |
+
model_kwargs:
|
| 335 |
+
Additional dictionary of keyword arguments passed along to the model's `from_pretrained(...,
|
| 336 |
+
**model_kwargs)` function.
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
`Tuple`: A tuple framework, model.
|
| 340 |
+
"""
|
| 341 |
+
if isinstance(model, str):
|
| 342 |
+
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **model_kwargs)
|
| 343 |
+
else:
|
| 344 |
+
config = model.config
|
| 345 |
+
return infer_framework_load_model(
|
| 346 |
+
model, config, model_classes=model_classes, _from_pipeline=task, task=task, framework=framework, **model_kwargs
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def get_framework(model, revision: Optional[str] = None):
|
| 351 |
+
"""
|
| 352 |
+
Select framework (TensorFlow or PyTorch) to use.
|
| 353 |
+
|
| 354 |
+
Args:
|
| 355 |
+
model (`str`, [`PreTrainedModel`] or [`TFPreTrainedModel]`):
|
| 356 |
+
If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
|
| 357 |
+
the model name). If no specific model is provided, defaults to using PyTorch.
|
| 358 |
+
"""
|
| 359 |
+
warnings.warn(
|
| 360 |
+
"`get_framework` is deprecated and will be removed in v5, use `infer_framework_from_model` instead.",
|
| 361 |
+
FutureWarning,
|
| 362 |
+
)
|
| 363 |
+
if not is_tf_available() and not is_torch_available():
|
| 364 |
+
raise RuntimeError(
|
| 365 |
+
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
| 366 |
+
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
|
| 367 |
+
"To install PyTorch, read the instructions at https://pytorch.org/."
|
| 368 |
+
)
|
| 369 |
+
if isinstance(model, str):
|
| 370 |
+
if is_torch_available() and not is_tf_available():
|
| 371 |
+
model = AutoModel.from_pretrained(model, revision=revision)
|
| 372 |
+
elif is_tf_available() and not is_torch_available():
|
| 373 |
+
model = TFAutoModel.from_pretrained(model, revision=revision)
|
| 374 |
+
else:
|
| 375 |
+
try:
|
| 376 |
+
model = AutoModel.from_pretrained(model, revision=revision)
|
| 377 |
+
except OSError:
|
| 378 |
+
model = TFAutoModel.from_pretrained(model, revision=revision)
|
| 379 |
+
|
| 380 |
+
framework = infer_framework(model.__class__)
|
| 381 |
+
return framework
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def get_default_model_and_revision(
|
| 385 |
+
targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]
|
| 386 |
+
) -> Union[str, Tuple[str, str]]:
|
| 387 |
+
"""
|
| 388 |
+
Select a default model to use for a given task. Defaults to pytorch if ambiguous.
|
| 389 |
+
|
| 390 |
+
Args:
|
| 391 |
+
targeted_task (`Dict`):
|
| 392 |
+
Dictionary representing the given task, that should contain default models
|
| 393 |
+
|
| 394 |
+
framework (`str`, None)
|
| 395 |
+
"pt", "tf" or None, representing a specific framework if it was specified, or None if we don't know yet.
|
| 396 |
+
|
| 397 |
+
task_options (`Any`, None)
|
| 398 |
+
Any further value required by the task to get fully specified, for instance (SRC, TGT) languages for
|
| 399 |
+
translation task.
|
| 400 |
+
|
| 401 |
+
Returns
|
| 402 |
+
|
| 403 |
+
`str` The model string representing the default model for this pipeline
|
| 404 |
+
"""
|
| 405 |
+
if is_torch_available() and not is_tf_available():
|
| 406 |
+
framework = "pt"
|
| 407 |
+
elif is_tf_available() and not is_torch_available():
|
| 408 |
+
framework = "tf"
|
| 409 |
+
|
| 410 |
+
defaults = targeted_task["default"]
|
| 411 |
+
if task_options:
|
| 412 |
+
if task_options not in defaults:
|
| 413 |
+
raise ValueError(f"The task does not provide any default models for options {task_options}")
|
| 414 |
+
default_models = defaults[task_options]["model"]
|
| 415 |
+
elif "model" in defaults:
|
| 416 |
+
default_models = targeted_task["default"]["model"]
|
| 417 |
+
else:
|
| 418 |
+
# XXX This error message needs to be updated to be more generic if more tasks are going to become
|
| 419 |
+
# parametrized
|
| 420 |
+
raise ValueError('The task defaults can\'t be correctly selected. You probably meant "translation_XX_to_YY"')
|
| 421 |
+
|
| 422 |
+
if framework is None:
|
| 423 |
+
framework = "pt"
|
| 424 |
+
|
| 425 |
+
return default_models[framework]
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def load_assistant_model(
|
| 429 |
+
model: "PreTrainedModel",
|
| 430 |
+
assistant_model: Optional[Union[str, "PreTrainedModel"]],
|
| 431 |
+
assistant_tokenizer: Optional[PreTrainedTokenizer],
|
| 432 |
+
) -> Tuple[Optional["PreTrainedModel"], Optional[PreTrainedTokenizer]]:
|
| 433 |
+
"""
|
| 434 |
+
Prepares the assistant model and the assistant tokenizer for a pipeline whose model that can call `generate`.
|
| 435 |
+
|
| 436 |
+
Args:
|
| 437 |
+
model ([`PreTrainedModel`]):
|
| 438 |
+
The main model that will be used by the pipeline to make predictions.
|
| 439 |
+
assistant_model (`str` or [`PreTrainedModel`], *optional*):
|
| 440 |
+
The assistant model that will be used by the pipeline to make predictions.
|
| 441 |
+
assistant_tokenizer ([`PreTrainedTokenizer`], *optional*):
|
| 442 |
+
The assistant tokenizer that will be used by the pipeline to encode data for the model.
|
| 443 |
+
|
| 444 |
+
Returns:
|
| 445 |
+
Tuple: The loaded assistant model and (optionally) the loaded tokenizer.
|
| 446 |
+
"""
|
| 447 |
+
if not model.can_generate() or assistant_model is None:
|
| 448 |
+
return None, None
|
| 449 |
+
|
| 450 |
+
if not isinstance(model, PreTrainedModel):
|
| 451 |
+
raise ValueError(
|
| 452 |
+
"Assisted generation, triggered by the `assistant_model` argument, is only available for "
|
| 453 |
+
"`PreTrainedModel` model instances. For instance, TF or JAX models are not supported."
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
# If the model is passed as a string, load the model and the corresponding tokenizer
|
| 457 |
+
if isinstance(assistant_model, str):
|
| 458 |
+
assistant_config = AutoConfig.from_pretrained(assistant_model)
|
| 459 |
+
_, loaded_assistant_model = infer_framework_load_model(assistant_model, config=assistant_config)
|
| 460 |
+
loaded_assistant_model = loaded_assistant_model.to(device=model.device, dtype=model.dtype)
|
| 461 |
+
loaded_assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_model)
|
| 462 |
+
else:
|
| 463 |
+
loaded_assistant_model = assistant_model
|
| 464 |
+
loaded_assistant_tokenizer = assistant_tokenizer
|
| 465 |
+
|
| 466 |
+
# Finally, let's check the tokenizers: if the two models have different tokenizers, we need to keep the assistant
|
| 467 |
+
# tokenizer
|
| 468 |
+
same_vocab_size = model.config.vocab_size == loaded_assistant_model.config.vocab_size
|
| 469 |
+
same_special_tokens = all(
|
| 470 |
+
getattr(model.config, token) == getattr(loaded_assistant_model.config, token)
|
| 471 |
+
for token in ("eos_token_id", "pad_token_id", "bos_token_id")
|
| 472 |
+
)
|
| 473 |
+
if same_vocab_size and same_special_tokens:
|
| 474 |
+
loaded_assistant_tokenizer = None
|
| 475 |
+
elif loaded_assistant_tokenizer is None:
|
| 476 |
+
raise ValueError(
|
| 477 |
+
"The assistant model has a different tokenizer than the main model. You should pass the assistant "
|
| 478 |
+
"tokenizer."
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
return loaded_assistant_model, loaded_assistant_tokenizer
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
class PipelineException(Exception):
|
| 485 |
+
"""
|
| 486 |
+
Raised by a [`Pipeline`] when handling __call__.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
task (`str`): The task of the pipeline.
|
| 490 |
+
model (`str`): The model used by the pipeline.
|
| 491 |
+
reason (`str`): The error message to display.
|
| 492 |
+
"""
|
| 493 |
+
|
| 494 |
+
def __init__(self, task: str, model: str, reason: str):
|
| 495 |
+
super().__init__(reason)
|
| 496 |
+
|
| 497 |
+
self.task = task
|
| 498 |
+
self.model = model
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class ArgumentHandler(ABC):
|
| 502 |
+
"""
|
| 503 |
+
Base interface for handling arguments for each [`~pipelines.Pipeline`].
|
| 504 |
+
"""
|
| 505 |
+
|
| 506 |
+
@abstractmethod
|
| 507 |
+
def __call__(self, *args, **kwargs):
|
| 508 |
+
raise NotImplementedError()
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
class PipelineDataFormat:
|
| 512 |
+
"""
|
| 513 |
+
Base class for all the pipeline supported data format both for reading and writing. Supported data formats
|
| 514 |
+
currently includes:
|
| 515 |
+
|
| 516 |
+
- JSON
|
| 517 |
+
- CSV
|
| 518 |
+
- stdin/stdout (pipe)
|
| 519 |
+
|
| 520 |
+
`PipelineDataFormat` also includes some utilities to work with multi-columns like mapping from datasets columns to
|
| 521 |
+
pipelines keyword arguments through the `dataset_kwarg_1=dataset_column_1` format.
|
| 522 |
+
|
| 523 |
+
Args:
|
| 524 |
+
output_path (`str`): Where to save the outgoing data.
|
| 525 |
+
input_path (`str`): Where to look for the input data.
|
| 526 |
+
column (`str`): The column to read.
|
| 527 |
+
overwrite (`bool`, *optional*, defaults to `False`):
|
| 528 |
+
Whether or not to overwrite the `output_path`.
|
| 529 |
+
"""
|
| 530 |
+
|
| 531 |
+
SUPPORTED_FORMATS = ["json", "csv", "pipe"]
|
| 532 |
+
|
| 533 |
+
def __init__(
|
| 534 |
+
self,
|
| 535 |
+
output_path: Optional[str],
|
| 536 |
+
input_path: Optional[str],
|
| 537 |
+
column: Optional[str],
|
| 538 |
+
overwrite: bool = False,
|
| 539 |
+
):
|
| 540 |
+
self.output_path = output_path
|
| 541 |
+
self.input_path = input_path
|
| 542 |
+
self.column = column.split(",") if column is not None else [""]
|
| 543 |
+
self.is_multi_columns = len(self.column) > 1
|
| 544 |
+
|
| 545 |
+
if self.is_multi_columns:
|
| 546 |
+
self.column = [tuple(c.split("=")) if "=" in c else (c, c) for c in self.column]
|
| 547 |
+
|
| 548 |
+
if output_path is not None and not overwrite:
|
| 549 |
+
if exists(abspath(self.output_path)):
|
| 550 |
+
raise OSError(f"{self.output_path} already exists on disk")
|
| 551 |
+
|
| 552 |
+
if input_path is not None:
|
| 553 |
+
if not exists(abspath(self.input_path)):
|
| 554 |
+
raise OSError(f"{self.input_path} doesnt exist on disk")
|
| 555 |
+
|
| 556 |
+
@abstractmethod
|
| 557 |
+
def __iter__(self):
|
| 558 |
+
raise NotImplementedError()
|
| 559 |
+
|
| 560 |
+
@abstractmethod
|
| 561 |
+
def save(self, data: Union[dict, List[dict]]):
|
| 562 |
+
"""
|
| 563 |
+
Save the provided data object with the representation for the current [`~pipelines.PipelineDataFormat`].
|
| 564 |
+
|
| 565 |
+
Args:
|
| 566 |
+
data (`dict` or list of `dict`): The data to store.
|
| 567 |
+
"""
|
| 568 |
+
raise NotImplementedError()
|
| 569 |
+
|
| 570 |
+
def save_binary(self, data: Union[dict, List[dict]]) -> str:
|
| 571 |
+
"""
|
| 572 |
+
Save the provided data object as a pickle-formatted binary data on the disk.
|
| 573 |
+
|
| 574 |
+
Args:
|
| 575 |
+
data (`dict` or list of `dict`): The data to store.
|
| 576 |
+
|
| 577 |
+
Returns:
|
| 578 |
+
`str`: Path where the data has been saved.
|
| 579 |
+
"""
|
| 580 |
+
path, _ = os.path.splitext(self.output_path)
|
| 581 |
+
binary_path = os.path.extsep.join((path, "pickle"))
|
| 582 |
+
|
| 583 |
+
with open(binary_path, "wb+") as f_output:
|
| 584 |
+
pickle.dump(data, f_output)
|
| 585 |
+
|
| 586 |
+
return binary_path
|
| 587 |
+
|
| 588 |
+
@staticmethod
|
| 589 |
+
def from_str(
|
| 590 |
+
format: str,
|
| 591 |
+
output_path: Optional[str],
|
| 592 |
+
input_path: Optional[str],
|
| 593 |
+
column: Optional[str],
|
| 594 |
+
overwrite=False,
|
| 595 |
+
) -> "PipelineDataFormat":
|
| 596 |
+
"""
|
| 597 |
+
Creates an instance of the right subclass of [`~pipelines.PipelineDataFormat`] depending on `format`.
|
| 598 |
+
|
| 599 |
+
Args:
|
| 600 |
+
format (`str`):
|
| 601 |
+
The format of the desired pipeline. Acceptable values are `"json"`, `"csv"` or `"pipe"`.
|
| 602 |
+
output_path (`str`, *optional*):
|
| 603 |
+
Where to save the outgoing data.
|
| 604 |
+
input_path (`str`, *optional*):
|
| 605 |
+
Where to look for the input data.
|
| 606 |
+
column (`str`, *optional*):
|
| 607 |
+
The column to read.
|
| 608 |
+
overwrite (`bool`, *optional*, defaults to `False`):
|
| 609 |
+
Whether or not to overwrite the `output_path`.
|
| 610 |
+
|
| 611 |
+
Returns:
|
| 612 |
+
[`~pipelines.PipelineDataFormat`]: The proper data format.
|
| 613 |
+
"""
|
| 614 |
+
if format == "json":
|
| 615 |
+
return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
|
| 616 |
+
elif format == "csv":
|
| 617 |
+
return CsvPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
|
| 618 |
+
elif format == "pipe":
|
| 619 |
+
return PipedPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
|
| 620 |
+
else:
|
| 621 |
+
raise KeyError(f"Unknown reader {format} (Available reader are json/csv/pipe)")
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
class CsvPipelineDataFormat(PipelineDataFormat):
|
| 625 |
+
"""
|
| 626 |
+
Support for pipelines using CSV data format.
|
| 627 |
+
|
| 628 |
+
Args:
|
| 629 |
+
output_path (`str`): Where to save the outgoing data.
|
| 630 |
+
input_path (`str`): Where to look for the input data.
|
| 631 |
+
column (`str`): The column to read.
|
| 632 |
+
overwrite (`bool`, *optional*, defaults to `False`):
|
| 633 |
+
Whether or not to overwrite the `output_path`.
|
| 634 |
+
"""
|
| 635 |
+
|
| 636 |
+
def __init__(
|
| 637 |
+
self,
|
| 638 |
+
output_path: Optional[str],
|
| 639 |
+
input_path: Optional[str],
|
| 640 |
+
column: Optional[str],
|
| 641 |
+
overwrite=False,
|
| 642 |
+
):
|
| 643 |
+
super().__init__(output_path, input_path, column, overwrite=overwrite)
|
| 644 |
+
|
| 645 |
+
def __iter__(self):
|
| 646 |
+
with open(self.input_path, "r") as f:
|
| 647 |
+
reader = csv.DictReader(f)
|
| 648 |
+
for row in reader:
|
| 649 |
+
if self.is_multi_columns:
|
| 650 |
+
yield {k: row[c] for k, c in self.column}
|
| 651 |
+
else:
|
| 652 |
+
yield row[self.column[0]]
|
| 653 |
+
|
| 654 |
+
def save(self, data: List[dict]):
|
| 655 |
+
"""
|
| 656 |
+
Save the provided data object with the representation for the current [`~pipelines.PipelineDataFormat`].
|
| 657 |
+
|
| 658 |
+
Args:
|
| 659 |
+
data (`List[dict]`): The data to store.
|
| 660 |
+
"""
|
| 661 |
+
with open(self.output_path, "w") as f:
|
| 662 |
+
if len(data) > 0:
|
| 663 |
+
writer = csv.DictWriter(f, list(data[0].keys()))
|
| 664 |
+
writer.writeheader()
|
| 665 |
+
writer.writerows(data)
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
class JsonPipelineDataFormat(PipelineDataFormat):
|
| 669 |
+
"""
|
| 670 |
+
Support for pipelines using JSON file format.
|
| 671 |
+
|
| 672 |
+
Args:
|
| 673 |
+
output_path (`str`): Where to save the outgoing data.
|
| 674 |
+
input_path (`str`): Where to look for the input data.
|
| 675 |
+
column (`str`): The column to read.
|
| 676 |
+
overwrite (`bool`, *optional*, defaults to `False`):
|
| 677 |
+
Whether or not to overwrite the `output_path`.
|
| 678 |
+
"""
|
| 679 |
+
|
| 680 |
+
def __init__(
|
| 681 |
+
self,
|
| 682 |
+
output_path: Optional[str],
|
| 683 |
+
input_path: Optional[str],
|
| 684 |
+
column: Optional[str],
|
| 685 |
+
overwrite=False,
|
| 686 |
+
):
|
| 687 |
+
super().__init__(output_path, input_path, column, overwrite=overwrite)
|
| 688 |
+
|
| 689 |
+
with open(input_path, "r") as f:
|
| 690 |
+
self._entries = json.load(f)
|
| 691 |
+
|
| 692 |
+
def __iter__(self):
|
| 693 |
+
for entry in self._entries:
|
| 694 |
+
if self.is_multi_columns:
|
| 695 |
+
yield {k: entry[c] for k, c in self.column}
|
| 696 |
+
else:
|
| 697 |
+
yield entry[self.column[0]]
|
| 698 |
+
|
| 699 |
+
def save(self, data: dict):
|
| 700 |
+
"""
|
| 701 |
+
Save the provided data object in a json file.
|
| 702 |
+
|
| 703 |
+
Args:
|
| 704 |
+
data (`dict`): The data to store.
|
| 705 |
+
"""
|
| 706 |
+
with open(self.output_path, "w") as f:
|
| 707 |
+
json.dump(data, f)
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
class PipedPipelineDataFormat(PipelineDataFormat):
|
| 711 |
+
"""
|
| 712 |
+
Read data from piped input to the python process. For multi columns data, columns should separated by \t
|
| 713 |
+
|
| 714 |
+
If columns are provided, then the output will be a dictionary with {column_x: value_x}
|
| 715 |
+
|
| 716 |
+
Args:
|
| 717 |
+
output_path (`str`): Where to save the outgoing data.
|
| 718 |
+
input_path (`str`): Where to look for the input data.
|
| 719 |
+
column (`str`): The column to read.
|
| 720 |
+
overwrite (`bool`, *optional*, defaults to `False`):
|
| 721 |
+
Whether or not to overwrite the `output_path`.
|
| 722 |
+
"""
|
| 723 |
+
|
| 724 |
+
def __iter__(self):
|
| 725 |
+
for line in sys.stdin:
|
| 726 |
+
# Split for multi-columns
|
| 727 |
+
if "\t" in line:
|
| 728 |
+
line = line.split("\t")
|
| 729 |
+
if self.column:
|
| 730 |
+
# Dictionary to map arguments
|
| 731 |
+
yield {kwargs: l for (kwargs, _), l in zip(self.column, line)}
|
| 732 |
+
else:
|
| 733 |
+
yield tuple(line)
|
| 734 |
+
|
| 735 |
+
# No dictionary to map arguments
|
| 736 |
+
else:
|
| 737 |
+
yield line
|
| 738 |
+
|
| 739 |
+
def save(self, data: dict):
|
| 740 |
+
"""
|
| 741 |
+
Print the data.
|
| 742 |
+
|
| 743 |
+
Args:
|
| 744 |
+
data (`dict`): The data to store.
|
| 745 |
+
"""
|
| 746 |
+
print(data)
|
| 747 |
+
|
| 748 |
+
def save_binary(self, data: Union[dict, List[dict]]) -> str:
|
| 749 |
+
if self.output_path is None:
|
| 750 |
+
raise KeyError(
|
| 751 |
+
"When using piped input on pipeline outputting large object requires an output file path. "
|
| 752 |
+
"Please provide such output path through --output argument."
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
return super().save_binary(data)
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
class _ScikitCompat(ABC):
|
| 759 |
+
"""
|
| 760 |
+
Interface layer for the Scikit and Keras compatibility.
|
| 761 |
+
"""
|
| 762 |
+
|
| 763 |
+
@abstractmethod
|
| 764 |
+
def transform(self, X):
|
| 765 |
+
raise NotImplementedError()
|
| 766 |
+
|
| 767 |
+
@abstractmethod
|
| 768 |
+
def predict(self, X):
|
| 769 |
+
raise NotImplementedError()
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
def build_pipeline_init_args(
|
| 773 |
+
has_tokenizer: bool = False,
|
| 774 |
+
has_feature_extractor: bool = False,
|
| 775 |
+
has_image_processor: bool = False,
|
| 776 |
+
has_processor: bool = False,
|
| 777 |
+
supports_binary_output: bool = True,
|
| 778 |
+
) -> str:
|
| 779 |
+
docstring = r"""
|
| 780 |
+
Arguments:
|
| 781 |
+
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
|
| 782 |
+
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
|
| 783 |
+
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow."""
|
| 784 |
+
if has_tokenizer:
|
| 785 |
+
docstring += r"""
|
| 786 |
+
tokenizer ([`PreTrainedTokenizer`]):
|
| 787 |
+
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
|
| 788 |
+
[`PreTrainedTokenizer`]."""
|
| 789 |
+
if has_feature_extractor:
|
| 790 |
+
docstring += r"""
|
| 791 |
+
feature_extractor ([`SequenceFeatureExtractor`]):
|
| 792 |
+
The feature extractor that will be used by the pipeline to encode data for the model. This object inherits from
|
| 793 |
+
[`SequenceFeatureExtractor`]."""
|
| 794 |
+
if has_image_processor:
|
| 795 |
+
docstring += r"""
|
| 796 |
+
image_processor ([`BaseImageProcessor`]):
|
| 797 |
+
The image processor that will be used by the pipeline to encode data for the model. This object inherits from
|
| 798 |
+
[`BaseImageProcessor`]."""
|
| 799 |
+
if has_processor:
|
| 800 |
+
docstring += r"""
|
| 801 |
+
processor ([`ProcessorMixin`]):
|
| 802 |
+
The processor that will be used by the pipeline to encode data for the model. This object inherits from
|
| 803 |
+
[`ProcessorMixin`]. Processor is a composite object that might contain `tokenizer`, `feature_extractor`, and
|
| 804 |
+
`image_processor`."""
|
| 805 |
+
docstring += r"""
|
| 806 |
+
modelcard (`str` or [`ModelCard`], *optional*):
|
| 807 |
+
Model card attributed to the model for this pipeline.
|
| 808 |
+
framework (`str`, *optional*):
|
| 809 |
+
The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be
|
| 810 |
+
installed.
|
| 811 |
+
|
| 812 |
+
If no framework is specified, will default to the one currently installed. If no framework is specified and
|
| 813 |
+
both frameworks are installed, will default to the framework of the `model`, or to PyTorch if no model is
|
| 814 |
+
provided.
|
| 815 |
+
task (`str`, defaults to `""`):
|
| 816 |
+
A task-identifier for the pipeline.
|
| 817 |
+
num_workers (`int`, *optional*, defaults to 8):
|
| 818 |
+
When the pipeline will use *DataLoader* (when passing a dataset, on GPU for a Pytorch model), the number of
|
| 819 |
+
workers to be used.
|
| 820 |
+
batch_size (`int`, *optional*, defaults to 1):
|
| 821 |
+
When the pipeline will use *DataLoader* (when passing a dataset, on GPU for a Pytorch model), the size of
|
| 822 |
+
the batch to use, for inference this is not always beneficial, please read [Batching with
|
| 823 |
+
pipelines](https://huggingface.co/transformers/main_classes/pipelines.html#pipeline-batching) .
|
| 824 |
+
args_parser ([`~pipelines.ArgumentHandler`], *optional*):
|
| 825 |
+
Reference to the object in charge of parsing supplied pipeline parameters.
|
| 826 |
+
device (`int`, *optional*, defaults to -1):
|
| 827 |
+
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
|
| 828 |
+
the associated CUDA device id. You can pass native `torch.device` or a `str` too
|
| 829 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
| 830 |
+
Sent directly as `model_kwargs` (just a simpler shortcut) to use the available precision for this model
|
| 831 |
+
(`torch.float16`, `torch.bfloat16`, ... or `"auto"`)"""
|
| 832 |
+
if supports_binary_output:
|
| 833 |
+
docstring += r"""
|
| 834 |
+
binary_output (`bool`, *optional*, defaults to `False`):
|
| 835 |
+
Flag indicating if the output the pipeline should happen in a serialized format (i.e., pickle) or as
|
| 836 |
+
the raw output data e.g. text."""
|
| 837 |
+
return docstring
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
PIPELINE_INIT_ARGS = build_pipeline_init_args(
|
| 841 |
+
has_tokenizer=True,
|
| 842 |
+
has_feature_extractor=True,
|
| 843 |
+
has_image_processor=True,
|
| 844 |
+
has_processor=True,
|
| 845 |
+
supports_binary_output=True,
|
| 846 |
+
)
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
if is_torch_available():
|
| 850 |
+
from transformers.pipelines.pt_utils import (
|
| 851 |
+
PipelineChunkIterator,
|
| 852 |
+
PipelineDataset,
|
| 853 |
+
PipelineIterator,
|
| 854 |
+
PipelinePackIterator,
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
@add_end_docstrings(
|
| 859 |
+
build_pipeline_init_args(
|
| 860 |
+
has_tokenizer=True, has_feature_extractor=True, has_image_processor=True, has_processor=True
|
| 861 |
+
)
|
| 862 |
+
)
|
| 863 |
+
class Pipeline(_ScikitCompat, PushToHubMixin):
|
| 864 |
+
"""
|
| 865 |
+
The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
|
| 866 |
+
different pipelines.
|
| 867 |
+
|
| 868 |
+
Base class implementing pipelined operations. Pipeline workflow is defined as a sequence of the following
|
| 869 |
+
operations:
|
| 870 |
+
|
| 871 |
+
Input -> Tokenization -> Model Inference -> Post-Processing (task dependent) -> Output
|
| 872 |
+
|
| 873 |
+
Pipeline supports running on CPU or GPU through the device argument (see below).
|
| 874 |
+
|
| 875 |
+
Some pipeline, like for instance [`FeatureExtractionPipeline`] (`'feature-extraction'`) output large tensor object
|
| 876 |
+
as nested-lists. In order to avoid dumping such large structure as textual data we provide the `binary_output`
|
| 877 |
+
constructor argument. If set to `True`, the output will be stored in the pickle format.
|
| 878 |
+
"""
|
| 879 |
+
|
| 880 |
+
# Historically we have pipelines working with `tokenizer`, `feature_extractor`, and `image_processor`
|
| 881 |
+
# as separate processing components. While we have `processor` class that combines them, some pipelines
|
| 882 |
+
# might still operate with these components separately.
|
| 883 |
+
# With the addition of `processor` to `pipeline`, we want to avoid:
|
| 884 |
+
# - loading `processor` for pipelines that still work with `image_processor` and `tokenizer` separately;
|
| 885 |
+
# - loading `image_processor`/`tokenizer` as a separate component while we operate only with `processor`,
|
| 886 |
+
# because `processor` will load required sub-components by itself.
|
| 887 |
+
# Below flags allow granular control over loading components and set to be backward compatible with current
|
| 888 |
+
# pipelines logic. You may override these flags when creating your pipeline. For example, for
|
| 889 |
+
# `zero-shot-object-detection` pipeline which operates with `processor` you should set `_load_processor=True`
|
| 890 |
+
# and all the rest flags to `False` to avoid unnecessary loading of the components.
|
| 891 |
+
_load_processor = False
|
| 892 |
+
_load_image_processor = True
|
| 893 |
+
_load_feature_extractor = True
|
| 894 |
+
_load_tokenizer = True
|
| 895 |
+
|
| 896 |
+
default_input_names = None
|
| 897 |
+
|
| 898 |
+
def __init__(
|
| 899 |
+
self,
|
| 900 |
+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
|
| 901 |
+
tokenizer: Optional[PreTrainedTokenizer] = None,
|
| 902 |
+
feature_extractor: Optional[PreTrainedFeatureExtractor] = None,
|
| 903 |
+
image_processor: Optional[BaseImageProcessor] = None,
|
| 904 |
+
processor: Optional[ProcessorMixin] = None,
|
| 905 |
+
modelcard: Optional[ModelCard] = None,
|
| 906 |
+
framework: Optional[str] = None,
|
| 907 |
+
task: str = "",
|
| 908 |
+
args_parser: ArgumentHandler = None,
|
| 909 |
+
device: Union[int, "torch.device"] = None,
|
| 910 |
+
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
|
| 911 |
+
binary_output: bool = False,
|
| 912 |
+
**kwargs,
|
| 913 |
+
):
|
| 914 |
+
if framework is None:
|
| 915 |
+
framework, model = infer_framework_load_model(model, config=model.config)
|
| 916 |
+
|
| 917 |
+
self.task = task
|
| 918 |
+
self.model = model
|
| 919 |
+
self.tokenizer = tokenizer
|
| 920 |
+
self.feature_extractor = feature_extractor
|
| 921 |
+
self.image_processor = image_processor
|
| 922 |
+
self.processor = processor
|
| 923 |
+
self.modelcard = modelcard
|
| 924 |
+
self.framework = framework
|
| 925 |
+
|
| 926 |
+
# `accelerate` device map
|
| 927 |
+
hf_device_map = getattr(self.model, "hf_device_map", None)
|
| 928 |
+
|
| 929 |
+
if hf_device_map is not None and device is not None:
|
| 930 |
+
raise ValueError(
|
| 931 |
+
"The model has been loaded with `accelerate` and therefore cannot be moved to a specific device. Please "
|
| 932 |
+
"discard the `device` argument when creating your pipeline object."
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
if device is None:
|
| 936 |
+
if hf_device_map is not None:
|
| 937 |
+
# Take the first device used by `accelerate`.
|
| 938 |
+
device = next(iter(hf_device_map.values()))
|
| 939 |
+
else:
|
| 940 |
+
device = 0
|
| 941 |
+
|
| 942 |
+
if is_torch_available() and self.framework == "pt":
|
| 943 |
+
if device == -1 and self.model.device is not None:
|
| 944 |
+
device = self.model.device
|
| 945 |
+
if isinstance(device, torch.device):
|
| 946 |
+
if device.type == "xpu" and not is_torch_xpu_available(check_device=True):
|
| 947 |
+
raise ValueError(f'{device} is not available, you should use device="cpu" instead')
|
| 948 |
+
self.device = device
|
| 949 |
+
elif isinstance(device, str):
|
| 950 |
+
if "xpu" in device and not is_torch_xpu_available(check_device=True):
|
| 951 |
+
raise ValueError(f'{device} is not available, you should use device="cpu" instead')
|
| 952 |
+
self.device = torch.device(device)
|
| 953 |
+
elif device < 0:
|
| 954 |
+
self.device = torch.device("cpu")
|
| 955 |
+
elif is_torch_mlu_available():
|
| 956 |
+
self.device = torch.device(f"mlu:{device}")
|
| 957 |
+
elif is_torch_musa_available():
|
| 958 |
+
self.device = torch.device(f"musa:{device}")
|
| 959 |
+
elif is_torch_cuda_available():
|
| 960 |
+
self.device = torch.device(f"cuda:{device}")
|
| 961 |
+
elif is_torch_npu_available():
|
| 962 |
+
self.device = torch.device(f"npu:{device}")
|
| 963 |
+
elif is_torch_xpu_available(check_device=True):
|
| 964 |
+
self.device = torch.device(f"xpu:{device}")
|
| 965 |
+
elif is_torch_mps_available():
|
| 966 |
+
self.device = torch.device(f"mps:{device}")
|
| 967 |
+
else:
|
| 968 |
+
self.device = torch.device("cpu")
|
| 969 |
+
else:
|
| 970 |
+
self.device = device if device is not None else -1
|
| 971 |
+
|
| 972 |
+
logger.warning(f"Device set to use {self.device}")
|
| 973 |
+
|
| 974 |
+
self.binary_output = binary_output
|
| 975 |
+
# We shouldn't call `model.to()` for models loaded with accelerate as well as the case that model is already on device
|
| 976 |
+
if (
|
| 977 |
+
self.framework == "pt"
|
| 978 |
+
and self.model.device != self.device
|
| 979 |
+
and not (isinstance(self.device, int) and self.device < 0)
|
| 980 |
+
and hf_device_map is None
|
| 981 |
+
):
|
| 982 |
+
self.model.to(self.device)
|
| 983 |
+
|
| 984 |
+
# If the model can generate:
|
| 985 |
+
# 1 - create a local generation config. This is done to avoid side-effects on the model as we apply local
|
| 986 |
+
# tweaks to the generation config.
|
| 987 |
+
# 2 - load the assistant model if it is passed.
|
| 988 |
+
self.assistant_model, self.assistant_tokenizer = load_assistant_model(
|
| 989 |
+
self.model, kwargs.pop("assistant_model", None), kwargs.pop("assistant_tokenizer", None)
|
| 990 |
+
)
|
| 991 |
+
if self.model.can_generate():
|
| 992 |
+
self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None
|
| 993 |
+
self.generation_config = copy.deepcopy(self.model.generation_config)
|
| 994 |
+
# Update the generation config with task specific params if they exist
|
| 995 |
+
# NOTE: `prefix` is pipeline-specific and doesn't exist in the generation config.
|
| 996 |
+
task_specific_params = self.model.config.task_specific_params
|
| 997 |
+
if task_specific_params is not None and task in task_specific_params:
|
| 998 |
+
this_task_params = task_specific_params.get(task)
|
| 999 |
+
if "prefix" in this_task_params:
|
| 1000 |
+
self.prefix = this_task_params.pop("prefix")
|
| 1001 |
+
self.generation_config.update(**this_task_params)
|
| 1002 |
+
# If the tokenizer has a pad token but the model doesn't, set it so that `generate` is aware of it.
|
| 1003 |
+
if (
|
| 1004 |
+
self.tokenizer is not None
|
| 1005 |
+
and self.tokenizer.pad_token_id is not None
|
| 1006 |
+
and self.generation_config.pad_token_id is None
|
| 1007 |
+
):
|
| 1008 |
+
self.generation_config.pad_token_id = self.tokenizer.pad_token_id
|
| 1009 |
+
|
| 1010 |
+
self.call_count = 0
|
| 1011 |
+
self._batch_size = kwargs.pop("batch_size", None)
|
| 1012 |
+
self._num_workers = kwargs.pop("num_workers", None)
|
| 1013 |
+
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
|
| 1014 |
+
|
| 1015 |
+
# In processor only mode, we can get the modality processors from the processor
|
| 1016 |
+
if self.processor is not None and all(
|
| 1017 |
+
[self.tokenizer is None, self.feature_extractor is None, self.image_processor is None]
|
| 1018 |
+
):
|
| 1019 |
+
self.tokenizer = getattr(self.processor, "tokenizer", None)
|
| 1020 |
+
self.feature_extractor = getattr(self.processor, "feature_extractor", None)
|
| 1021 |
+
self.image_processor = getattr(self.processor, "image_processor", None)
|
| 1022 |
+
|
| 1023 |
+
if self.image_processor is None and self.feature_extractor is not None:
|
| 1024 |
+
if isinstance(self.feature_extractor, BaseImageProcessor):
|
| 1025 |
+
# Backward compatible change, if users called
|
| 1026 |
+
# ImageSegmentationPipeline(.., feature_extractor=MyFeatureExtractor())
|
| 1027 |
+
# then we should keep working
|
| 1028 |
+
self.image_processor = self.feature_extractor
|
| 1029 |
+
|
| 1030 |
+
def save_pretrained(
|
| 1031 |
+
self,
|
| 1032 |
+
save_directory: Union[str, os.PathLike],
|
| 1033 |
+
safe_serialization: bool = True,
|
| 1034 |
+
**kwargs,
|
| 1035 |
+
):
|
| 1036 |
+
"""
|
| 1037 |
+
Save the pipeline's model and tokenizer.
|
| 1038 |
+
|
| 1039 |
+
Args:
|
| 1040 |
+
save_directory (`str` or `os.PathLike`):
|
| 1041 |
+
A path to the directory where to saved. It will be created if it doesn't exist.
|
| 1042 |
+
safe_serialization (`str`):
|
| 1043 |
+
Whether to save the model using `safetensors` or the traditional way for PyTorch or Tensorflow.
|
| 1044 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 1045 |
+
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
| 1046 |
+
"""
|
| 1047 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 1048 |
+
|
| 1049 |
+
if use_auth_token is not None:
|
| 1050 |
+
warnings.warn(
|
| 1051 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 1052 |
+
FutureWarning,
|
| 1053 |
+
)
|
| 1054 |
+
if kwargs.get("token", None) is not None:
|
| 1055 |
+
raise ValueError(
|
| 1056 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 1057 |
+
)
|
| 1058 |
+
kwargs["token"] = use_auth_token
|
| 1059 |
+
|
| 1060 |
+
if os.path.isfile(save_directory):
|
| 1061 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
| 1062 |
+
return
|
| 1063 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 1064 |
+
|
| 1065 |
+
if hasattr(self, "_registered_impl"):
|
| 1066 |
+
# Add info to the config
|
| 1067 |
+
pipeline_info = self._registered_impl.copy()
|
| 1068 |
+
custom_pipelines = {}
|
| 1069 |
+
for task, info in pipeline_info.items():
|
| 1070 |
+
if info["impl"] != self.__class__:
|
| 1071 |
+
continue
|
| 1072 |
+
|
| 1073 |
+
info = info.copy()
|
| 1074 |
+
module_name = info["impl"].__module__
|
| 1075 |
+
last_module = module_name.split(".")[-1]
|
| 1076 |
+
# Change classes into their names/full names
|
| 1077 |
+
info["impl"] = f"{last_module}.{info['impl'].__name__}"
|
| 1078 |
+
info["pt"] = tuple(c.__name__ for c in info["pt"])
|
| 1079 |
+
info["tf"] = tuple(c.__name__ for c in info["tf"])
|
| 1080 |
+
|
| 1081 |
+
custom_pipelines[task] = info
|
| 1082 |
+
self.model.config.custom_pipelines = custom_pipelines
|
| 1083 |
+
# Save the pipeline custom code
|
| 1084 |
+
custom_object_save(self, save_directory)
|
| 1085 |
+
|
| 1086 |
+
kwargs["safe_serialization"] = safe_serialization
|
| 1087 |
+
self.model.save_pretrained(save_directory, **kwargs)
|
| 1088 |
+
|
| 1089 |
+
if self.tokenizer is not None:
|
| 1090 |
+
self.tokenizer.save_pretrained(save_directory, **kwargs)
|
| 1091 |
+
|
| 1092 |
+
if self.feature_extractor is not None:
|
| 1093 |
+
self.feature_extractor.save_pretrained(save_directory, **kwargs)
|
| 1094 |
+
|
| 1095 |
+
if self.image_processor is not None:
|
| 1096 |
+
self.image_processor.save_pretrained(save_directory, **kwargs)
|
| 1097 |
+
|
| 1098 |
+
if self.modelcard is not None:
|
| 1099 |
+
self.modelcard.save_pretrained(save_directory)
|
| 1100 |
+
|
| 1101 |
+
def transform(self, X):
|
| 1102 |
+
"""
|
| 1103 |
+
Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
|
| 1104 |
+
"""
|
| 1105 |
+
return self(X)
|
| 1106 |
+
|
| 1107 |
+
def predict(self, X):
|
| 1108 |
+
"""
|
| 1109 |
+
Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
|
| 1110 |
+
"""
|
| 1111 |
+
return self(X)
|
| 1112 |
+
|
| 1113 |
+
@property
|
| 1114 |
+
def torch_dtype(self) -> Optional["torch.dtype"]:
|
| 1115 |
+
"""
|
| 1116 |
+
Torch dtype of the model (if it's Pytorch model), `None` otherwise.
|
| 1117 |
+
"""
|
| 1118 |
+
return getattr(self.model, "dtype", None)
|
| 1119 |
+
|
| 1120 |
+
@contextmanager
|
| 1121 |
+
def device_placement(self):
|
| 1122 |
+
"""
|
| 1123 |
+
Context Manager allowing tensor allocation on the user-specified device in framework agnostic way.
|
| 1124 |
+
|
| 1125 |
+
Returns:
|
| 1126 |
+
Context manager
|
| 1127 |
+
|
| 1128 |
+
Examples:
|
| 1129 |
+
|
| 1130 |
+
```python
|
| 1131 |
+
# Explicitly ask for tensor allocation on CUDA device :0
|
| 1132 |
+
pipe = pipeline(..., device=0)
|
| 1133 |
+
with pipe.device_placement():
|
| 1134 |
+
# Every framework specific tensor allocation will be done on the request device
|
| 1135 |
+
output = pipe(...)
|
| 1136 |
+
```"""
|
| 1137 |
+
if self.framework == "tf":
|
| 1138 |
+
with tf.device("/CPU:0" if self.device == -1 else f"/device:GPU:{self.device}"):
|
| 1139 |
+
yield
|
| 1140 |
+
else:
|
| 1141 |
+
if self.device.type == "cuda":
|
| 1142 |
+
with torch.cuda.device(self.device):
|
| 1143 |
+
yield
|
| 1144 |
+
elif self.device.type == "mlu":
|
| 1145 |
+
with torch.mlu.device(self.device):
|
| 1146 |
+
yield
|
| 1147 |
+
elif self.device.type == "musa":
|
| 1148 |
+
with torch.musa.device(self.device):
|
| 1149 |
+
yield
|
| 1150 |
+
else:
|
| 1151 |
+
yield
|
| 1152 |
+
|
| 1153 |
+
def ensure_tensor_on_device(self, **inputs):
|
| 1154 |
+
"""
|
| 1155 |
+
Ensure PyTorch tensors are on the specified device.
|
| 1156 |
+
|
| 1157 |
+
Args:
|
| 1158 |
+
inputs (keyword arguments that should be `torch.Tensor`, the rest is ignored):
|
| 1159 |
+
The tensors to place on `self.device`.
|
| 1160 |
+
Recursive on lists **only**.
|
| 1161 |
+
|
| 1162 |
+
Return:
|
| 1163 |
+
`Dict[str, torch.Tensor]`: The same as `inputs` but on the proper device.
|
| 1164 |
+
"""
|
| 1165 |
+
return self._ensure_tensor_on_device(inputs, self.device)
|
| 1166 |
+
|
| 1167 |
+
def _ensure_tensor_on_device(self, inputs, device):
|
| 1168 |
+
if isinstance(inputs, ModelOutput):
|
| 1169 |
+
return ModelOutput(
|
| 1170 |
+
{name: self._ensure_tensor_on_device(tensor, device) for name, tensor in inputs.items()}
|
| 1171 |
+
)
|
| 1172 |
+
elif isinstance(inputs, dict):
|
| 1173 |
+
return {name: self._ensure_tensor_on_device(tensor, device) for name, tensor in inputs.items()}
|
| 1174 |
+
elif isinstance(inputs, UserDict):
|
| 1175 |
+
return UserDict({name: self._ensure_tensor_on_device(tensor, device) for name, tensor in inputs.items()})
|
| 1176 |
+
elif isinstance(inputs, list):
|
| 1177 |
+
return [self._ensure_tensor_on_device(item, device) for item in inputs]
|
| 1178 |
+
elif isinstance(inputs, tuple):
|
| 1179 |
+
return tuple([self._ensure_tensor_on_device(item, device) for item in inputs])
|
| 1180 |
+
elif isinstance(inputs, torch.Tensor):
|
| 1181 |
+
return inputs.to(device)
|
| 1182 |
+
else:
|
| 1183 |
+
return inputs
|
| 1184 |
+
|
| 1185 |
+
def check_model_type(self, supported_models: Union[List[str], dict]):
|
| 1186 |
+
"""
|
| 1187 |
+
Check if the model class is in supported by the pipeline.
|
| 1188 |
+
|
| 1189 |
+
Args:
|
| 1190 |
+
supported_models (`List[str]` or `dict`):
|
| 1191 |
+
The list of models supported by the pipeline, or a dictionary with model class values.
|
| 1192 |
+
"""
|
| 1193 |
+
if not isinstance(supported_models, list): # Create from a model mapping
|
| 1194 |
+
supported_models_names = []
|
| 1195 |
+
for _, model_name in supported_models.items():
|
| 1196 |
+
# Mapping can now contain tuples of models for the same configuration.
|
| 1197 |
+
if isinstance(model_name, tuple):
|
| 1198 |
+
supported_models_names.extend(list(model_name))
|
| 1199 |
+
else:
|
| 1200 |
+
supported_models_names.append(model_name)
|
| 1201 |
+
if hasattr(supported_models, "_model_mapping"):
|
| 1202 |
+
for _, model in supported_models._model_mapping._extra_content.items():
|
| 1203 |
+
if isinstance(model_name, tuple):
|
| 1204 |
+
supported_models_names.extend([m.__name__ for m in model])
|
| 1205 |
+
else:
|
| 1206 |
+
supported_models_names.append(model.__name__)
|
| 1207 |
+
supported_models = supported_models_names
|
| 1208 |
+
if self.model.__class__.__name__ not in supported_models:
|
| 1209 |
+
logger.error(
|
| 1210 |
+
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are"
|
| 1211 |
+
f" {supported_models}."
|
| 1212 |
+
)
|
| 1213 |
+
|
| 1214 |
+
@abstractmethod
|
| 1215 |
+
def _sanitize_parameters(self, **pipeline_parameters):
|
| 1216 |
+
"""
|
| 1217 |
+
_sanitize_parameters will be called with any excessive named arguments from either `__init__` or `__call__`
|
| 1218 |
+
methods. It should return 3 dictionaries of the resolved parameters used by the various `preprocess`,
|
| 1219 |
+
`forward` and `postprocess` methods. Do not fill dictionaries if the caller didn't specify a kwargs. This
|
| 1220 |
+
lets you keep defaults in function signatures, which is more "natural".
|
| 1221 |
+
|
| 1222 |
+
It is not meant to be called directly, it will be automatically called and the final parameters resolved by
|
| 1223 |
+
`__init__` and `__call__`
|
| 1224 |
+
"""
|
| 1225 |
+
raise NotImplementedError("_sanitize_parameters not implemented")
|
| 1226 |
+
|
| 1227 |
+
@abstractmethod
|
| 1228 |
+
def preprocess(self, input_: Any, **preprocess_parameters: Dict) -> Dict[str, GenericTensor]:
|
| 1229 |
+
"""
|
| 1230 |
+
Preprocess will take the `input_` of a specific pipeline and return a dictionary of everything necessary for
|
| 1231 |
+
`_forward` to run properly. It should contain at least one tensor, but might have arbitrary other items.
|
| 1232 |
+
"""
|
| 1233 |
+
raise NotImplementedError("preprocess not implemented")
|
| 1234 |
+
|
| 1235 |
+
@abstractmethod
|
| 1236 |
+
def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput:
|
| 1237 |
+
"""
|
| 1238 |
+
_forward will receive the prepared dictionary from `preprocess` and run it on the model. This method might
|
| 1239 |
+
involve the GPU or the CPU and should be agnostic to it. Isolating this function is the reason for `preprocess`
|
| 1240 |
+
and `postprocess` to exist, so that the hot path, this method generally can run as fast as possible.
|
| 1241 |
+
|
| 1242 |
+
It is not meant to be called directly, `forward` is preferred. It is basically the same but contains additional
|
| 1243 |
+
code surrounding `_forward` making sure tensors and models are on the same device, disabling the training part
|
| 1244 |
+
of the code (leading to faster inference).
|
| 1245 |
+
"""
|
| 1246 |
+
raise NotImplementedError("_forward not implemented")
|
| 1247 |
+
|
| 1248 |
+
@abstractmethod
|
| 1249 |
+
def postprocess(self, model_outputs: ModelOutput, **postprocess_parameters: Dict) -> Any:
|
| 1250 |
+
"""
|
| 1251 |
+
Postprocess will receive the raw outputs of the `_forward` method, generally tensors, and reformat them into
|
| 1252 |
+
something more friendly. Generally it will output a list or a dict or results (containing just strings and
|
| 1253 |
+
numbers).
|
| 1254 |
+
"""
|
| 1255 |
+
raise NotImplementedError("postprocess not implemented")
|
| 1256 |
+
|
| 1257 |
+
def get_inference_context(self):
|
| 1258 |
+
return torch.no_grad
|
| 1259 |
+
|
| 1260 |
+
def forward(self, model_inputs, **forward_params):
|
| 1261 |
+
with self.device_placement():
|
| 1262 |
+
if self.framework == "tf":
|
| 1263 |
+
model_inputs["training"] = False
|
| 1264 |
+
model_outputs = self._forward(model_inputs, **forward_params)
|
| 1265 |
+
elif self.framework == "pt":
|
| 1266 |
+
inference_context = self.get_inference_context()
|
| 1267 |
+
with inference_context():
|
| 1268 |
+
model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
|
| 1269 |
+
model_outputs = self._forward(model_inputs, **forward_params)
|
| 1270 |
+
model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
|
| 1271 |
+
else:
|
| 1272 |
+
raise ValueError(f"Framework {self.framework} is not supported")
|
| 1273 |
+
return model_outputs
|
| 1274 |
+
|
| 1275 |
+
def get_iterator(
|
| 1276 |
+
self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
|
| 1277 |
+
):
|
| 1278 |
+
if isinstance(inputs, collections.abc.Sized):
|
| 1279 |
+
dataset = PipelineDataset(inputs, self.preprocess, preprocess_params)
|
| 1280 |
+
else:
|
| 1281 |
+
if num_workers > 1:
|
| 1282 |
+
logger.warning(
|
| 1283 |
+
"For iterable dataset using num_workers>1 is likely to result"
|
| 1284 |
+
" in errors since everything is iterable, setting `num_workers=1`"
|
| 1285 |
+
" to guarantee correctness."
|
| 1286 |
+
)
|
| 1287 |
+
num_workers = 1
|
| 1288 |
+
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
|
| 1289 |
+
if "TOKENIZERS_PARALLELISM" not in os.environ:
|
| 1290 |
+
logger.info("Disabling tokenizer parallelism, we're using DataLoader multithreading already")
|
| 1291 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 1292 |
+
# TODO hack by collating feature_extractor and image_processor
|
| 1293 |
+
feature_extractor = self.feature_extractor if self.feature_extractor is not None else self.image_processor
|
| 1294 |
+
collate_fn = no_collate_fn if batch_size == 1 else pad_collate_fn(self.tokenizer, feature_extractor)
|
| 1295 |
+
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn)
|
| 1296 |
+
model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
|
| 1297 |
+
final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
|
| 1298 |
+
return final_iterator
|
| 1299 |
+
|
| 1300 |
+
def __call__(self, inputs, *args, num_workers=None, batch_size=None, **kwargs):
|
| 1301 |
+
if args:
|
| 1302 |
+
logger.warning(f"Ignoring args : {args}")
|
| 1303 |
+
|
| 1304 |
+
if num_workers is None:
|
| 1305 |
+
if self._num_workers is None:
|
| 1306 |
+
num_workers = 0
|
| 1307 |
+
else:
|
| 1308 |
+
num_workers = self._num_workers
|
| 1309 |
+
if batch_size is None:
|
| 1310 |
+
if self._batch_size is None:
|
| 1311 |
+
batch_size = 1
|
| 1312 |
+
else:
|
| 1313 |
+
batch_size = self._batch_size
|
| 1314 |
+
|
| 1315 |
+
preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(**kwargs)
|
| 1316 |
+
|
| 1317 |
+
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
|
| 1318 |
+
preprocess_params = {**self._preprocess_params, **preprocess_params}
|
| 1319 |
+
forward_params = {**self._forward_params, **forward_params}
|
| 1320 |
+
postprocess_params = {**self._postprocess_params, **postprocess_params}
|
| 1321 |
+
|
| 1322 |
+
self.call_count += 1
|
| 1323 |
+
if self.call_count > 10 and self.framework == "pt" and self.device.type == "cuda":
|
| 1324 |
+
logger.warning_once(
|
| 1325 |
+
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a"
|
| 1326 |
+
" dataset",
|
| 1327 |
+
)
|
| 1328 |
+
|
| 1329 |
+
is_dataset = Dataset is not None and isinstance(inputs, Dataset)
|
| 1330 |
+
is_generator = isinstance(inputs, types.GeneratorType)
|
| 1331 |
+
is_list = isinstance(inputs, list)
|
| 1332 |
+
|
| 1333 |
+
is_iterable = is_dataset or is_generator or is_list
|
| 1334 |
+
|
| 1335 |
+
# TODO make the get_iterator work also for `tf` (and `flax`).
|
| 1336 |
+
can_use_iterator = self.framework == "pt" and (is_dataset or is_generator or is_list)
|
| 1337 |
+
|
| 1338 |
+
if is_list:
|
| 1339 |
+
if can_use_iterator:
|
| 1340 |
+
final_iterator = self.get_iterator(
|
| 1341 |
+
inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
|
| 1342 |
+
)
|
| 1343 |
+
outputs = list(final_iterator)
|
| 1344 |
+
return outputs
|
| 1345 |
+
else:
|
| 1346 |
+
return self.run_multi(inputs, preprocess_params, forward_params, postprocess_params)
|
| 1347 |
+
elif can_use_iterator:
|
| 1348 |
+
return self.get_iterator(
|
| 1349 |
+
inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
|
| 1350 |
+
)
|
| 1351 |
+
elif is_iterable:
|
| 1352 |
+
return self.iterate(inputs, preprocess_params, forward_params, postprocess_params)
|
| 1353 |
+
elif self.framework == "pt" and isinstance(self, ChunkPipeline):
|
| 1354 |
+
return next(
|
| 1355 |
+
iter(
|
| 1356 |
+
self.get_iterator(
|
| 1357 |
+
[inputs], num_workers, batch_size, preprocess_params, forward_params, postprocess_params
|
| 1358 |
+
)
|
| 1359 |
+
)
|
| 1360 |
+
)
|
| 1361 |
+
else:
|
| 1362 |
+
return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
|
| 1363 |
+
|
| 1364 |
+
def run_multi(self, inputs, preprocess_params, forward_params, postprocess_params):
|
| 1365 |
+
return [self.run_single(item, preprocess_params, forward_params, postprocess_params) for item in inputs]
|
| 1366 |
+
|
| 1367 |
+
def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
|
| 1368 |
+
model_inputs = self.preprocess(inputs, **preprocess_params)
|
| 1369 |
+
model_outputs = self.forward(model_inputs, **forward_params)
|
| 1370 |
+
outputs = self.postprocess(model_outputs, **postprocess_params)
|
| 1371 |
+
return outputs
|
| 1372 |
+
|
| 1373 |
+
def iterate(self, inputs, preprocess_params, forward_params, postprocess_params):
|
| 1374 |
+
# This function should become `get_iterator` again, this is a temporary
|
| 1375 |
+
# easy solution.
|
| 1376 |
+
for input_ in inputs:
|
| 1377 |
+
yield self.run_single(input_, preprocess_params, forward_params, postprocess_params)
|
| 1378 |
+
|
| 1379 |
+
|
| 1380 |
+
Pipeline.push_to_hub = copy_func(Pipeline.push_to_hub)
|
| 1381 |
+
if Pipeline.push_to_hub.__doc__ is not None:
|
| 1382 |
+
Pipeline.push_to_hub.__doc__ = Pipeline.push_to_hub.__doc__.format(
|
| 1383 |
+
object="pipe", object_class="pipeline", object_files="pipeline file"
|
| 1384 |
+
).replace(".from_pretrained", "")
|
| 1385 |
+
|
| 1386 |
+
|
| 1387 |
+
class ChunkPipeline(Pipeline):
|
| 1388 |
+
def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
|
| 1389 |
+
all_outputs = []
|
| 1390 |
+
for model_inputs in self.preprocess(inputs, **preprocess_params):
|
| 1391 |
+
model_outputs = self.forward(model_inputs, **forward_params)
|
| 1392 |
+
all_outputs.append(model_outputs)
|
| 1393 |
+
outputs = self.postprocess(all_outputs, **postprocess_params)
|
| 1394 |
+
return outputs
|
| 1395 |
+
|
| 1396 |
+
def get_iterator(
|
| 1397 |
+
self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
|
| 1398 |
+
):
|
| 1399 |
+
if "TOKENIZERS_PARALLELISM" not in os.environ:
|
| 1400 |
+
logger.info("Disabling tokenizer parallelism, we're using DataLoader multithreading already")
|
| 1401 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 1402 |
+
if num_workers > 1:
|
| 1403 |
+
logger.warning(
|
| 1404 |
+
"For ChunkPipeline using num_workers>0 is likely to result in errors since everything is iterable,"
|
| 1405 |
+
" setting `num_workers=1` to guarantee correctness."
|
| 1406 |
+
)
|
| 1407 |
+
num_workers = 1
|
| 1408 |
+
dataset = PipelineChunkIterator(inputs, self.preprocess, preprocess_params)
|
| 1409 |
+
|
| 1410 |
+
# TODO hack by collating feature_extractor and image_processor
|
| 1411 |
+
feature_extractor = self.feature_extractor if self.feature_extractor is not None else self.image_processor
|
| 1412 |
+
collate_fn = no_collate_fn if batch_size == 1 else pad_collate_fn(self.tokenizer, feature_extractor)
|
| 1413 |
+
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn)
|
| 1414 |
+
model_iterator = PipelinePackIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
|
| 1415 |
+
final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
|
| 1416 |
+
return final_iterator
|
| 1417 |
+
|
| 1418 |
+
|
| 1419 |
+
class PipelineRegistry:
|
| 1420 |
+
def __init__(self, supported_tasks: Dict[str, Any], task_aliases: Dict[str, str]) -> None:
|
| 1421 |
+
self.supported_tasks = supported_tasks
|
| 1422 |
+
self.task_aliases = task_aliases
|
| 1423 |
+
|
| 1424 |
+
def get_supported_tasks(self) -> List[str]:
|
| 1425 |
+
supported_task = list(self.supported_tasks.keys()) + list(self.task_aliases.keys())
|
| 1426 |
+
supported_task.sort()
|
| 1427 |
+
return supported_task
|
| 1428 |
+
|
| 1429 |
+
def check_task(self, task: str) -> Tuple[str, Dict, Any]:
|
| 1430 |
+
if task in self.task_aliases:
|
| 1431 |
+
task = self.task_aliases[task]
|
| 1432 |
+
if task in self.supported_tasks:
|
| 1433 |
+
targeted_task = self.supported_tasks[task]
|
| 1434 |
+
return task, targeted_task, None
|
| 1435 |
+
|
| 1436 |
+
if task.startswith("translation"):
|
| 1437 |
+
tokens = task.split("_")
|
| 1438 |
+
if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
|
| 1439 |
+
targeted_task = self.supported_tasks["translation"]
|
| 1440 |
+
task = "translation"
|
| 1441 |
+
return task, targeted_task, (tokens[1], tokens[3])
|
| 1442 |
+
raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")
|
| 1443 |
+
|
| 1444 |
+
raise KeyError(
|
| 1445 |
+
f"Unknown task {task}, available tasks are {self.get_supported_tasks() + ['translation_XX_to_YY']}"
|
| 1446 |
+
)
|
| 1447 |
+
|
| 1448 |
+
def register_pipeline(
|
| 1449 |
+
self,
|
| 1450 |
+
task: str,
|
| 1451 |
+
pipeline_class: type,
|
| 1452 |
+
pt_model: Optional[Union[type, Tuple[type]]] = None,
|
| 1453 |
+
tf_model: Optional[Union[type, Tuple[type]]] = None,
|
| 1454 |
+
default: Optional[Dict] = None,
|
| 1455 |
+
type: Optional[str] = None,
|
| 1456 |
+
) -> None:
|
| 1457 |
+
if task in self.supported_tasks:
|
| 1458 |
+
logger.warning(f"{task} is already registered. Overwriting pipeline for task {task}...")
|
| 1459 |
+
|
| 1460 |
+
if pt_model is None:
|
| 1461 |
+
pt_model = ()
|
| 1462 |
+
elif not isinstance(pt_model, tuple):
|
| 1463 |
+
pt_model = (pt_model,)
|
| 1464 |
+
|
| 1465 |
+
if tf_model is None:
|
| 1466 |
+
tf_model = ()
|
| 1467 |
+
elif not isinstance(tf_model, tuple):
|
| 1468 |
+
tf_model = (tf_model,)
|
| 1469 |
+
|
| 1470 |
+
task_impl = {"impl": pipeline_class, "pt": pt_model, "tf": tf_model}
|
| 1471 |
+
|
| 1472 |
+
if default is not None:
|
| 1473 |
+
if "model" not in default and ("pt" in default or "tf" in default):
|
| 1474 |
+
default = {"model": default}
|
| 1475 |
+
task_impl["default"] = default
|
| 1476 |
+
|
| 1477 |
+
if type is not None:
|
| 1478 |
+
task_impl["type"] = type
|
| 1479 |
+
|
| 1480 |
+
self.supported_tasks[task] = task_impl
|
| 1481 |
+
pipeline_class._registered_impl = {task: task_impl}
|
| 1482 |
+
|
| 1483 |
+
def to_dict(self):
|
| 1484 |
+
return self.supported_tasks
|
.venv/lib/python3.11/site-packages/transformers/pipelines/depth_estimation.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union
|
| 2 |
+
|
| 3 |
+
from ..utils import (
|
| 4 |
+
add_end_docstrings,
|
| 5 |
+
is_torch_available,
|
| 6 |
+
is_vision_available,
|
| 7 |
+
logging,
|
| 8 |
+
requires_backends,
|
| 9 |
+
)
|
| 10 |
+
from .base import Pipeline, build_pipeline_init_args
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
if is_vision_available():
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
from ..image_utils import load_image
|
| 17 |
+
|
| 18 |
+
if is_torch_available():
|
| 19 |
+
from ..models.auto.modeling_auto import MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
|
| 25 |
+
class DepthEstimationPipeline(Pipeline):
|
| 26 |
+
"""
|
| 27 |
+
Depth estimation pipeline using any `AutoModelForDepthEstimation`. This pipeline predicts the depth of an image.
|
| 28 |
+
|
| 29 |
+
Example:
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
>>> from transformers import pipeline
|
| 33 |
+
|
| 34 |
+
>>> depth_estimator = pipeline(task="depth-estimation", model="LiheYoung/depth-anything-base-hf")
|
| 35 |
+
>>> output = depth_estimator("http://images.cocodataset.org/val2017/000000039769.jpg")
|
| 36 |
+
>>> # This is a tensor with the values being the depth expressed in meters for each pixel
|
| 37 |
+
>>> output["predicted_depth"].shape
|
| 38 |
+
torch.Size([1, 384, 384])
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
This depth estimation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
| 45 |
+
`"depth-estimation"`.
|
| 46 |
+
|
| 47 |
+
See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=depth-estimation).
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, *args, **kwargs):
|
| 51 |
+
super().__init__(*args, **kwargs)
|
| 52 |
+
requires_backends(self, "vision")
|
| 53 |
+
self.check_model_type(MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
|
| 54 |
+
|
| 55 |
+
def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
|
| 56 |
+
"""
|
| 57 |
+
Predict the depth(s) of the image(s) passed as inputs.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
| 61 |
+
The pipeline handles three types of images:
|
| 62 |
+
|
| 63 |
+
- A string containing a http link pointing to an image
|
| 64 |
+
- A string containing a local path to an image
|
| 65 |
+
- An image loaded in PIL directly
|
| 66 |
+
|
| 67 |
+
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
|
| 68 |
+
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
|
| 69 |
+
images.
|
| 70 |
+
parameters (`Dict`, *optional*):
|
| 71 |
+
A dictionary of argument names to parameter values, to control pipeline behaviour.
|
| 72 |
+
The only parameter available right now is `timeout`, which is the length of time, in seconds,
|
| 73 |
+
that the pipeline should wait before giving up on trying to download an image.
|
| 74 |
+
timeout (`float`, *optional*, defaults to None):
|
| 75 |
+
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
| 76 |
+
the call may block forever.
|
| 77 |
+
|
| 78 |
+
Return:
|
| 79 |
+
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
|
| 80 |
+
dictionary, if the input is a list of several images, will return a list of dictionaries corresponding to
|
| 81 |
+
the images.
|
| 82 |
+
|
| 83 |
+
The dictionaries contain the following keys:
|
| 84 |
+
|
| 85 |
+
- **predicted_depth** (`torch.Tensor`) -- The predicted depth by the model as a `torch.Tensor`.
|
| 86 |
+
- **depth** (`PIL.Image`) -- The predicted depth by the model as a `PIL.Image`.
|
| 87 |
+
"""
|
| 88 |
+
# After deprecation of this is completed, remove the default `None` value for `images`
|
| 89 |
+
if "images" in kwargs:
|
| 90 |
+
inputs = kwargs.pop("images")
|
| 91 |
+
if inputs is None:
|
| 92 |
+
raise ValueError("Cannot call the depth-estimation pipeline without an inputs argument!")
|
| 93 |
+
return super().__call__(inputs, **kwargs)
|
| 94 |
+
|
| 95 |
+
def _sanitize_parameters(self, timeout=None, parameters=None, **kwargs):
|
| 96 |
+
preprocess_params = {}
|
| 97 |
+
if timeout is not None:
|
| 98 |
+
preprocess_params["timeout"] = timeout
|
| 99 |
+
if isinstance(parameters, dict) and "timeout" in parameters:
|
| 100 |
+
preprocess_params["timeout"] = parameters["timeout"]
|
| 101 |
+
return preprocess_params, {}, {}
|
| 102 |
+
|
| 103 |
+
def preprocess(self, image, timeout=None):
|
| 104 |
+
image = load_image(image, timeout)
|
| 105 |
+
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
|
| 106 |
+
if self.framework == "pt":
|
| 107 |
+
model_inputs = model_inputs.to(self.torch_dtype)
|
| 108 |
+
model_inputs["target_size"] = image.size[::-1]
|
| 109 |
+
return model_inputs
|
| 110 |
+
|
| 111 |
+
def _forward(self, model_inputs):
|
| 112 |
+
target_size = model_inputs.pop("target_size")
|
| 113 |
+
model_outputs = self.model(**model_inputs)
|
| 114 |
+
model_outputs["target_size"] = target_size
|
| 115 |
+
return model_outputs
|
| 116 |
+
|
| 117 |
+
def postprocess(self, model_outputs):
|
| 118 |
+
outputs = self.image_processor.post_process_depth_estimation(
|
| 119 |
+
model_outputs,
|
| 120 |
+
# this acts as `source_sizes` for ZoeDepth and as `target_sizes` for the rest of the models so do *not*
|
| 121 |
+
# replace with `target_sizes = [model_outputs["target_size"]]`
|
| 122 |
+
[model_outputs["target_size"]],
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
formatted_outputs = []
|
| 126 |
+
for output in outputs:
|
| 127 |
+
depth = output["predicted_depth"].detach().cpu().numpy()
|
| 128 |
+
depth = (depth - depth.min()) / (depth.max() - depth.min())
|
| 129 |
+
depth = Image.fromarray((depth * 255).astype("uint8"))
|
| 130 |
+
|
| 131 |
+
formatted_outputs.append({"predicted_depth": output["predicted_depth"], "depth": depth})
|
| 132 |
+
|
| 133 |
+
return formatted_outputs[0] if len(outputs) == 1 else formatted_outputs
|
.venv/lib/python3.11/site-packages/transformers/pipelines/document_question_answering.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The Impira Team and the HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import re
|
| 16 |
+
from typing import List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
from ..utils import (
|
| 21 |
+
ExplicitEnum,
|
| 22 |
+
add_end_docstrings,
|
| 23 |
+
is_pytesseract_available,
|
| 24 |
+
is_torch_available,
|
| 25 |
+
is_vision_available,
|
| 26 |
+
logging,
|
| 27 |
+
)
|
| 28 |
+
from .base import ChunkPipeline, build_pipeline_init_args
|
| 29 |
+
from .question_answering import select_starts_ends
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if is_vision_available():
|
| 33 |
+
from PIL import Image
|
| 34 |
+
|
| 35 |
+
from ..image_utils import load_image
|
| 36 |
+
|
| 37 |
+
if is_torch_available():
|
| 38 |
+
import torch
|
| 39 |
+
|
| 40 |
+
from ..models.auto.modeling_auto import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
|
| 41 |
+
|
| 42 |
+
TESSERACT_LOADED = False
|
| 43 |
+
if is_pytesseract_available():
|
| 44 |
+
TESSERACT_LOADED = True
|
| 45 |
+
import pytesseract
|
| 46 |
+
|
| 47 |
+
logger = logging.get_logger(__name__)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# normalize_bbox() and apply_tesseract() are derived from apply_tesseract in models/layoutlmv3/feature_extraction_layoutlmv3.py.
|
| 51 |
+
# However, because the pipeline may evolve from what layoutlmv3 currently does, it's copied (vs. imported) to avoid creating an
|
| 52 |
+
# unnecessary dependency.
|
| 53 |
+
def normalize_box(box, width, height):
|
| 54 |
+
return [
|
| 55 |
+
int(1000 * (box[0] / width)),
|
| 56 |
+
int(1000 * (box[1] / height)),
|
| 57 |
+
int(1000 * (box[2] / width)),
|
| 58 |
+
int(1000 * (box[3] / height)),
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def apply_tesseract(image: "Image.Image", lang: Optional[str], tesseract_config: Optional[str]):
|
| 63 |
+
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
|
| 64 |
+
# apply OCR
|
| 65 |
+
data = pytesseract.image_to_data(image, lang=lang, output_type="dict", config=tesseract_config)
|
| 66 |
+
words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
|
| 67 |
+
|
| 68 |
+
# filter empty words and corresponding coordinates
|
| 69 |
+
irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
|
| 70 |
+
words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
|
| 71 |
+
left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
|
| 72 |
+
top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
|
| 73 |
+
width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
|
| 74 |
+
height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
|
| 75 |
+
|
| 76 |
+
# turn coordinates into (left, top, left+width, top+height) format
|
| 77 |
+
actual_boxes = []
|
| 78 |
+
for x, y, w, h in zip(left, top, width, height):
|
| 79 |
+
actual_box = [x, y, x + w, y + h]
|
| 80 |
+
actual_boxes.append(actual_box)
|
| 81 |
+
|
| 82 |
+
image_width, image_height = image.size
|
| 83 |
+
|
| 84 |
+
# finally, normalize the bounding boxes
|
| 85 |
+
normalized_boxes = []
|
| 86 |
+
for box in actual_boxes:
|
| 87 |
+
normalized_boxes.append(normalize_box(box, image_width, image_height))
|
| 88 |
+
|
| 89 |
+
if len(words) != len(normalized_boxes):
|
| 90 |
+
raise ValueError("Not as many words as there are bounding boxes")
|
| 91 |
+
|
| 92 |
+
return words, normalized_boxes
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class ModelType(ExplicitEnum):
|
| 96 |
+
LayoutLM = "layoutlm"
|
| 97 |
+
LayoutLMv2andv3 = "layoutlmv2andv3"
|
| 98 |
+
VisionEncoderDecoder = "vision_encoder_decoder"
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True, has_tokenizer=True))
|
| 102 |
+
class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
| 103 |
+
# TODO: Update task_summary docs to include an example with document QA and then update the first sentence
|
| 104 |
+
"""
|
| 105 |
+
Document Question Answering pipeline using any `AutoModelForDocumentQuestionAnswering`. The inputs/outputs are
|
| 106 |
+
similar to the (extractive) question answering pipeline; however, the pipeline takes an image (and optional OCR'd
|
| 107 |
+
words/boxes) as input instead of text context.
|
| 108 |
+
|
| 109 |
+
Example:
|
| 110 |
+
|
| 111 |
+
```python
|
| 112 |
+
>>> from transformers import pipeline
|
| 113 |
+
|
| 114 |
+
>>> document_qa = pipeline(model="impira/layoutlm-document-qa")
|
| 115 |
+
>>> document_qa(
|
| 116 |
+
... image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png",
|
| 117 |
+
... question="What is the invoice number?",
|
| 118 |
+
... )
|
| 119 |
+
[{'score': 0.425, 'answer': 'us-001', 'start': 16, 'end': 16}]
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
| 123 |
+
|
| 124 |
+
This document question answering pipeline can currently be loaded from [`pipeline`] using the following task
|
| 125 |
+
identifier: `"document-question-answering"`.
|
| 126 |
+
|
| 127 |
+
The models that this pipeline can use are models that have been fine-tuned on a document question answering task.
|
| 128 |
+
See the up-to-date list of available models on
|
| 129 |
+
[huggingface.co/models](https://huggingface.co/models?filter=document-question-answering).
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(self, *args, **kwargs):
|
| 133 |
+
super().__init__(*args, **kwargs)
|
| 134 |
+
if self.tokenizer is not None and not self.tokenizer.__class__.__name__.endswith("Fast"):
|
| 135 |
+
raise ValueError(
|
| 136 |
+
"`DocumentQuestionAnsweringPipeline` requires a fast tokenizer, but a slow tokenizer "
|
| 137 |
+
f"(`{self.tokenizer.__class__.__name__}`) is provided."
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
if self.model.config.__class__.__name__ == "VisionEncoderDecoderConfig":
|
| 141 |
+
self.model_type = ModelType.VisionEncoderDecoder
|
| 142 |
+
if self.model.config.encoder.model_type != "donut-swin":
|
| 143 |
+
raise ValueError("Currently, the only supported VisionEncoderDecoder model is Donut")
|
| 144 |
+
else:
|
| 145 |
+
self.check_model_type(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES)
|
| 146 |
+
if self.model.config.__class__.__name__ == "LayoutLMConfig":
|
| 147 |
+
self.model_type = ModelType.LayoutLM
|
| 148 |
+
else:
|
| 149 |
+
self.model_type = ModelType.LayoutLMv2andv3
|
| 150 |
+
|
| 151 |
+
def _sanitize_parameters(
|
| 152 |
+
self,
|
| 153 |
+
padding=None,
|
| 154 |
+
doc_stride=None,
|
| 155 |
+
max_question_len=None,
|
| 156 |
+
lang: Optional[str] = None,
|
| 157 |
+
tesseract_config: Optional[str] = None,
|
| 158 |
+
max_answer_len=None,
|
| 159 |
+
max_seq_len=None,
|
| 160 |
+
top_k=None,
|
| 161 |
+
handle_impossible_answer=None,
|
| 162 |
+
timeout=None,
|
| 163 |
+
**kwargs,
|
| 164 |
+
):
|
| 165 |
+
preprocess_params, postprocess_params = {}, {}
|
| 166 |
+
if padding is not None:
|
| 167 |
+
preprocess_params["padding"] = padding
|
| 168 |
+
if doc_stride is not None:
|
| 169 |
+
preprocess_params["doc_stride"] = doc_stride
|
| 170 |
+
if max_question_len is not None:
|
| 171 |
+
preprocess_params["max_question_len"] = max_question_len
|
| 172 |
+
if max_seq_len is not None:
|
| 173 |
+
preprocess_params["max_seq_len"] = max_seq_len
|
| 174 |
+
if lang is not None:
|
| 175 |
+
preprocess_params["lang"] = lang
|
| 176 |
+
if tesseract_config is not None:
|
| 177 |
+
preprocess_params["tesseract_config"] = tesseract_config
|
| 178 |
+
if timeout is not None:
|
| 179 |
+
preprocess_params["timeout"] = timeout
|
| 180 |
+
|
| 181 |
+
if top_k is not None:
|
| 182 |
+
if top_k < 1:
|
| 183 |
+
raise ValueError(f"top_k parameter should be >= 1 (got {top_k})")
|
| 184 |
+
postprocess_params["top_k"] = top_k
|
| 185 |
+
if max_answer_len is not None:
|
| 186 |
+
if max_answer_len < 1:
|
| 187 |
+
raise ValueError(f"max_answer_len parameter should be >= 1 (got {max_answer_len}")
|
| 188 |
+
postprocess_params["max_answer_len"] = max_answer_len
|
| 189 |
+
if handle_impossible_answer is not None:
|
| 190 |
+
postprocess_params["handle_impossible_answer"] = handle_impossible_answer
|
| 191 |
+
|
| 192 |
+
forward_params = {}
|
| 193 |
+
if self.assistant_model is not None:
|
| 194 |
+
forward_params["assistant_model"] = self.assistant_model
|
| 195 |
+
if self.assistant_tokenizer is not None:
|
| 196 |
+
forward_params["tokenizer"] = self.tokenizer
|
| 197 |
+
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
|
| 198 |
+
|
| 199 |
+
return preprocess_params, forward_params, postprocess_params
|
| 200 |
+
|
| 201 |
+
def __call__(
|
| 202 |
+
self,
|
| 203 |
+
image: Union["Image.Image", str],
|
| 204 |
+
question: Optional[str] = None,
|
| 205 |
+
word_boxes: Tuple[str, List[float]] = None,
|
| 206 |
+
**kwargs,
|
| 207 |
+
):
|
| 208 |
+
"""
|
| 209 |
+
Answer the question(s) given as inputs by using the document(s). A document is defined as an image and an
|
| 210 |
+
optional list of (word, box) tuples which represent the text in the document. If the `word_boxes` are not
|
| 211 |
+
provided, it will use the Tesseract OCR engine (if available) to extract the words and boxes automatically for
|
| 212 |
+
LayoutLM-like models which require them as input. For Donut, no OCR is run.
|
| 213 |
+
|
| 214 |
+
You can invoke the pipeline several ways:
|
| 215 |
+
|
| 216 |
+
- `pipeline(image=image, question=question)`
|
| 217 |
+
- `pipeline(image=image, question=question, word_boxes=word_boxes)`
|
| 218 |
+
- `pipeline([{"image": image, "question": question}])`
|
| 219 |
+
- `pipeline([{"image": image, "question": question, "word_boxes": word_boxes}])`
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
image (`str` or `PIL.Image`):
|
| 223 |
+
The pipeline handles three types of images:
|
| 224 |
+
|
| 225 |
+
- A string containing a http link pointing to an image
|
| 226 |
+
- A string containing a local path to an image
|
| 227 |
+
- An image loaded in PIL directly
|
| 228 |
+
|
| 229 |
+
The pipeline accepts either a single image or a batch of images. If given a single image, it can be
|
| 230 |
+
broadcasted to multiple questions.
|
| 231 |
+
question (`str`):
|
| 232 |
+
A question to ask of the document.
|
| 233 |
+
word_boxes (`List[str, Tuple[float, float, float, float]]`, *optional*):
|
| 234 |
+
A list of words and bounding boxes (normalized 0->1000). If you provide this optional input, then the
|
| 235 |
+
pipeline will use these words and boxes instead of running OCR on the image to derive them for models
|
| 236 |
+
that need them (e.g. LayoutLM). This allows you to reuse OCR'd results across many invocations of the
|
| 237 |
+
pipeline without having to re-run it each time.
|
| 238 |
+
top_k (`int`, *optional*, defaults to 1):
|
| 239 |
+
The number of answers to return (will be chosen by order of likelihood). Note that we return less than
|
| 240 |
+
top_k answers if there are not enough options available within the context.
|
| 241 |
+
doc_stride (`int`, *optional*, defaults to 128):
|
| 242 |
+
If the words in the document are too long to fit with the question for the model, it will be split in
|
| 243 |
+
several chunks with some overlap. This argument controls the size of that overlap.
|
| 244 |
+
max_answer_len (`int`, *optional*, defaults to 15):
|
| 245 |
+
The maximum length of predicted answers (e.g., only answers with a shorter length are considered).
|
| 246 |
+
max_seq_len (`int`, *optional*, defaults to 384):
|
| 247 |
+
The maximum length of the total sentence (context + question) in tokens of each chunk passed to the
|
| 248 |
+
model. The context will be split in several chunks (using `doc_stride` as overlap) if needed.
|
| 249 |
+
max_question_len (`int`, *optional*, defaults to 64):
|
| 250 |
+
The maximum length of the question after tokenization. It will be truncated if needed.
|
| 251 |
+
handle_impossible_answer (`bool`, *optional*, defaults to `False`):
|
| 252 |
+
Whether or not we accept impossible as an answer.
|
| 253 |
+
lang (`str`, *optional*):
|
| 254 |
+
Language to use while running OCR. Defaults to english.
|
| 255 |
+
tesseract_config (`str`, *optional*):
|
| 256 |
+
Additional flags to pass to tesseract while running OCR.
|
| 257 |
+
timeout (`float`, *optional*, defaults to None):
|
| 258 |
+
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
| 259 |
+
the call may block forever.
|
| 260 |
+
|
| 261 |
+
Return:
|
| 262 |
+
A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys:
|
| 263 |
+
|
| 264 |
+
- **score** (`float`) -- The probability associated to the answer.
|
| 265 |
+
- **start** (`int`) -- The start word index of the answer (in the OCR'd version of the input or provided
|
| 266 |
+
`word_boxes`).
|
| 267 |
+
- **end** (`int`) -- The end word index of the answer (in the OCR'd version of the input or provided
|
| 268 |
+
`word_boxes`).
|
| 269 |
+
- **answer** (`str`) -- The answer to the question.
|
| 270 |
+
- **words** (`list[int]`) -- The index of each word/box pair that is in the answer
|
| 271 |
+
"""
|
| 272 |
+
if isinstance(question, str):
|
| 273 |
+
inputs = {"question": question, "image": image}
|
| 274 |
+
if word_boxes is not None:
|
| 275 |
+
inputs["word_boxes"] = word_boxes
|
| 276 |
+
else:
|
| 277 |
+
inputs = image
|
| 278 |
+
return super().__call__(inputs, **kwargs)
|
| 279 |
+
|
| 280 |
+
def preprocess(
|
| 281 |
+
self,
|
| 282 |
+
input,
|
| 283 |
+
padding="do_not_pad",
|
| 284 |
+
doc_stride=None,
|
| 285 |
+
max_seq_len=None,
|
| 286 |
+
word_boxes: Tuple[str, List[float]] = None,
|
| 287 |
+
lang=None,
|
| 288 |
+
tesseract_config="",
|
| 289 |
+
timeout=None,
|
| 290 |
+
):
|
| 291 |
+
# NOTE: This code mirrors the code in question answering and will be implemented in a follow up PR
|
| 292 |
+
# to support documents with enough tokens that overflow the model's window
|
| 293 |
+
if max_seq_len is None:
|
| 294 |
+
max_seq_len = self.tokenizer.model_max_length
|
| 295 |
+
|
| 296 |
+
if doc_stride is None:
|
| 297 |
+
doc_stride = min(max_seq_len // 2, 256)
|
| 298 |
+
|
| 299 |
+
image = None
|
| 300 |
+
image_features = {}
|
| 301 |
+
if input.get("image", None) is not None:
|
| 302 |
+
image = load_image(input["image"], timeout=timeout)
|
| 303 |
+
if self.image_processor is not None:
|
| 304 |
+
image_inputs = self.image_processor(images=image, return_tensors=self.framework)
|
| 305 |
+
if self.framework == "pt":
|
| 306 |
+
image_inputs = image_inputs.to(self.torch_dtype)
|
| 307 |
+
image_features.update(image_inputs)
|
| 308 |
+
elif self.feature_extractor is not None:
|
| 309 |
+
image_features.update(self.feature_extractor(images=image, return_tensors=self.framework))
|
| 310 |
+
elif self.model_type == ModelType.VisionEncoderDecoder:
|
| 311 |
+
raise ValueError("If you are using a VisionEncoderDecoderModel, you must provide a feature extractor")
|
| 312 |
+
|
| 313 |
+
words, boxes = None, None
|
| 314 |
+
if not self.model_type == ModelType.VisionEncoderDecoder:
|
| 315 |
+
if "word_boxes" in input:
|
| 316 |
+
words = [x[0] for x in input["word_boxes"]]
|
| 317 |
+
boxes = [x[1] for x in input["word_boxes"]]
|
| 318 |
+
elif "words" in image_features and "boxes" in image_features:
|
| 319 |
+
words = image_features.pop("words")[0]
|
| 320 |
+
boxes = image_features.pop("boxes")[0]
|
| 321 |
+
elif image is not None:
|
| 322 |
+
if not TESSERACT_LOADED:
|
| 323 |
+
raise ValueError(
|
| 324 |
+
"If you provide an image without word_boxes, then the pipeline will run OCR using Tesseract,"
|
| 325 |
+
" but pytesseract is not available"
|
| 326 |
+
)
|
| 327 |
+
if TESSERACT_LOADED:
|
| 328 |
+
words, boxes = apply_tesseract(image, lang=lang, tesseract_config=tesseract_config)
|
| 329 |
+
else:
|
| 330 |
+
raise ValueError(
|
| 331 |
+
"You must provide an image or word_boxes. If you provide an image, the pipeline will automatically"
|
| 332 |
+
" run OCR to derive words and boxes"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if self.tokenizer.padding_side != "right":
|
| 336 |
+
raise ValueError(
|
| 337 |
+
"Document question answering only supports tokenizers whose padding side is 'right', not"
|
| 338 |
+
f" {self.tokenizer.padding_side}"
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
if self.model_type == ModelType.VisionEncoderDecoder:
|
| 342 |
+
task_prompt = f'<s_docvqa><s_question>{input["question"]}</s_question><s_answer>'
|
| 343 |
+
# Adapted from https://huggingface.co/spaces/nielsr/donut-docvqa/blob/main/app.py
|
| 344 |
+
encoding = {
|
| 345 |
+
"inputs": image_features["pixel_values"],
|
| 346 |
+
"decoder_input_ids": self.tokenizer(
|
| 347 |
+
task_prompt, add_special_tokens=False, return_tensors=self.framework
|
| 348 |
+
).input_ids,
|
| 349 |
+
"return_dict_in_generate": True,
|
| 350 |
+
}
|
| 351 |
+
yield {
|
| 352 |
+
**encoding,
|
| 353 |
+
"p_mask": None,
|
| 354 |
+
"word_ids": None,
|
| 355 |
+
"words": None,
|
| 356 |
+
"output_attentions": True,
|
| 357 |
+
"is_last": True,
|
| 358 |
+
}
|
| 359 |
+
else:
|
| 360 |
+
tokenizer_kwargs = {}
|
| 361 |
+
if self.model_type == ModelType.LayoutLM:
|
| 362 |
+
tokenizer_kwargs["text"] = input["question"].split()
|
| 363 |
+
tokenizer_kwargs["text_pair"] = words
|
| 364 |
+
tokenizer_kwargs["is_split_into_words"] = True
|
| 365 |
+
else:
|
| 366 |
+
tokenizer_kwargs["text"] = [input["question"]]
|
| 367 |
+
tokenizer_kwargs["text_pair"] = [words]
|
| 368 |
+
tokenizer_kwargs["boxes"] = [boxes]
|
| 369 |
+
|
| 370 |
+
encoding = self.tokenizer(
|
| 371 |
+
padding=padding,
|
| 372 |
+
max_length=max_seq_len,
|
| 373 |
+
stride=doc_stride,
|
| 374 |
+
return_token_type_ids=True,
|
| 375 |
+
truncation="only_second",
|
| 376 |
+
return_overflowing_tokens=True,
|
| 377 |
+
**tokenizer_kwargs,
|
| 378 |
+
)
|
| 379 |
+
# TODO: check why slower `LayoutLMTokenizer` and `LayoutLMv2Tokenizer` don't have this key in outputs
|
| 380 |
+
# FIXME: ydshieh and/or Narsil
|
| 381 |
+
encoding.pop("overflow_to_sample_mapping", None) # We do not use this
|
| 382 |
+
|
| 383 |
+
num_spans = len(encoding["input_ids"])
|
| 384 |
+
|
| 385 |
+
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
|
| 386 |
+
# We put 0 on the tokens from the context and 1 everywhere else (question and special tokens)
|
| 387 |
+
# This logic mirrors the logic in the question_answering pipeline
|
| 388 |
+
p_mask = np.array([[tok != 1 for tok in encoding.sequence_ids(span_id)] for span_id in range(num_spans)])
|
| 389 |
+
for span_idx in range(num_spans):
|
| 390 |
+
if self.framework == "pt":
|
| 391 |
+
span_encoding = {k: torch.tensor(v[span_idx : span_idx + 1]) for (k, v) in encoding.items()}
|
| 392 |
+
if "pixel_values" in image_features:
|
| 393 |
+
span_encoding["image"] = image_features["pixel_values"]
|
| 394 |
+
else:
|
| 395 |
+
raise ValueError("Unsupported: Tensorflow preprocessing for DocumentQuestionAnsweringPipeline")
|
| 396 |
+
|
| 397 |
+
input_ids_span_idx = encoding["input_ids"][span_idx]
|
| 398 |
+
# keep the cls_token unmasked (some models use it to indicate unanswerable questions)
|
| 399 |
+
if self.tokenizer.cls_token_id is not None:
|
| 400 |
+
cls_indices = np.nonzero(np.array(input_ids_span_idx) == self.tokenizer.cls_token_id)[0]
|
| 401 |
+
for cls_index in cls_indices:
|
| 402 |
+
p_mask[span_idx][cls_index] = 0
|
| 403 |
+
|
| 404 |
+
# For each span, place a bounding box [0,0,0,0] for question and CLS tokens, [1000,1000,1000,1000]
|
| 405 |
+
# for SEP tokens, and the word's bounding box for words in the original document.
|
| 406 |
+
if "boxes" not in tokenizer_kwargs:
|
| 407 |
+
bbox = []
|
| 408 |
+
for input_id, sequence_id, word_id in zip(
|
| 409 |
+
encoding.input_ids[span_idx],
|
| 410 |
+
encoding.sequence_ids(span_idx),
|
| 411 |
+
encoding.word_ids(span_idx),
|
| 412 |
+
):
|
| 413 |
+
if sequence_id == 1:
|
| 414 |
+
bbox.append(boxes[word_id])
|
| 415 |
+
elif input_id == self.tokenizer.sep_token_id:
|
| 416 |
+
bbox.append([1000] * 4)
|
| 417 |
+
else:
|
| 418 |
+
bbox.append([0] * 4)
|
| 419 |
+
|
| 420 |
+
if self.framework == "pt":
|
| 421 |
+
span_encoding["bbox"] = torch.tensor(bbox).unsqueeze(0)
|
| 422 |
+
elif self.framework == "tf":
|
| 423 |
+
raise ValueError("Unsupported: Tensorflow preprocessing for DocumentQuestionAnsweringPipeline")
|
| 424 |
+
yield {
|
| 425 |
+
**span_encoding,
|
| 426 |
+
"p_mask": p_mask[span_idx],
|
| 427 |
+
"word_ids": encoding.word_ids(span_idx),
|
| 428 |
+
"words": words,
|
| 429 |
+
"is_last": span_idx == num_spans - 1,
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
def _forward(self, model_inputs, **generate_kwargs):
|
| 433 |
+
p_mask = model_inputs.pop("p_mask", None)
|
| 434 |
+
word_ids = model_inputs.pop("word_ids", None)
|
| 435 |
+
words = model_inputs.pop("words", None)
|
| 436 |
+
is_last = model_inputs.pop("is_last", False)
|
| 437 |
+
|
| 438 |
+
if self.model_type == ModelType.VisionEncoderDecoder:
|
| 439 |
+
# User-defined `generation_config` passed to the pipeline call take precedence
|
| 440 |
+
if "generation_config" not in generate_kwargs:
|
| 441 |
+
generate_kwargs["generation_config"] = self.generation_config
|
| 442 |
+
|
| 443 |
+
model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
|
| 444 |
+
else:
|
| 445 |
+
model_outputs = self.model(**model_inputs)
|
| 446 |
+
|
| 447 |
+
model_outputs = dict(model_outputs.items())
|
| 448 |
+
model_outputs["p_mask"] = p_mask
|
| 449 |
+
model_outputs["word_ids"] = word_ids
|
| 450 |
+
model_outputs["words"] = words
|
| 451 |
+
model_outputs["attention_mask"] = model_inputs.get("attention_mask", None)
|
| 452 |
+
model_outputs["is_last"] = is_last
|
| 453 |
+
return model_outputs
|
| 454 |
+
|
| 455 |
+
def postprocess(self, model_outputs, top_k=1, **kwargs):
|
| 456 |
+
if self.model_type == ModelType.VisionEncoderDecoder:
|
| 457 |
+
answers = [self.postprocess_encoder_decoder_single(o) for o in model_outputs]
|
| 458 |
+
else:
|
| 459 |
+
answers = self.postprocess_extractive_qa(model_outputs, top_k=top_k, **kwargs)
|
| 460 |
+
|
| 461 |
+
answers = sorted(answers, key=lambda x: x.get("score", 0), reverse=True)[:top_k]
|
| 462 |
+
return answers
|
| 463 |
+
|
| 464 |
+
def postprocess_encoder_decoder_single(self, model_outputs, **kwargs):
|
| 465 |
+
sequence = self.tokenizer.batch_decode(model_outputs["sequences"])[0]
|
| 466 |
+
|
| 467 |
+
# TODO: A lot of this logic is specific to Donut and should probably be handled in the tokenizer
|
| 468 |
+
# (see https://github.com/huggingface/transformers/pull/18414/files#r961747408 for more context).
|
| 469 |
+
sequence = sequence.replace(self.tokenizer.eos_token, "").replace(self.tokenizer.pad_token, "")
|
| 470 |
+
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
| 471 |
+
ret = {
|
| 472 |
+
"answer": None,
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
answer = re.search(r"<s_answer>(.*)</s_answer>", sequence)
|
| 476 |
+
if answer is not None:
|
| 477 |
+
ret["answer"] = answer.group(1).strip()
|
| 478 |
+
return ret
|
| 479 |
+
|
| 480 |
+
def postprocess_extractive_qa(
|
| 481 |
+
self, model_outputs, top_k=1, handle_impossible_answer=False, max_answer_len=15, **kwargs
|
| 482 |
+
):
|
| 483 |
+
min_null_score = 1000000 # large and positive
|
| 484 |
+
answers = []
|
| 485 |
+
for output in model_outputs:
|
| 486 |
+
words = output["words"]
|
| 487 |
+
|
| 488 |
+
starts, ends, scores, min_null_score = select_starts_ends(
|
| 489 |
+
start=output["start_logits"],
|
| 490 |
+
end=output["end_logits"],
|
| 491 |
+
p_mask=output["p_mask"],
|
| 492 |
+
attention_mask=output["attention_mask"].numpy()
|
| 493 |
+
if output.get("attention_mask", None) is not None
|
| 494 |
+
else None,
|
| 495 |
+
min_null_score=min_null_score,
|
| 496 |
+
top_k=top_k,
|
| 497 |
+
handle_impossible_answer=handle_impossible_answer,
|
| 498 |
+
max_answer_len=max_answer_len,
|
| 499 |
+
)
|
| 500 |
+
word_ids = output["word_ids"]
|
| 501 |
+
for start, end, score in zip(starts, ends, scores):
|
| 502 |
+
word_start, word_end = word_ids[start], word_ids[end]
|
| 503 |
+
if word_start is not None and word_end is not None:
|
| 504 |
+
answers.append(
|
| 505 |
+
{
|
| 506 |
+
"score": float(score),
|
| 507 |
+
"answer": " ".join(words[word_start : word_end + 1]),
|
| 508 |
+
"start": word_start,
|
| 509 |
+
"end": word_end,
|
| 510 |
+
}
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
if handle_impossible_answer:
|
| 514 |
+
answers.append({"score": min_null_score, "answer": "", "start": 0, "end": 0})
|
| 515 |
+
|
| 516 |
+
return answers
|
.venv/lib/python3.11/site-packages/transformers/pipelines/feature_extraction.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
from ..utils import add_end_docstrings
|
| 4 |
+
from .base import GenericTensor, Pipeline, build_pipeline_init_args
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@add_end_docstrings(
|
| 8 |
+
build_pipeline_init_args(has_tokenizer=True, supports_binary_output=False),
|
| 9 |
+
r"""
|
| 10 |
+
tokenize_kwargs (`dict`, *optional*):
|
| 11 |
+
Additional dictionary of keyword arguments passed along to the tokenizer.
|
| 12 |
+
return_tensors (`bool`, *optional*):
|
| 13 |
+
If `True`, returns a tensor according to the specified framework, otherwise returns a list.""",
|
| 14 |
+
)
|
| 15 |
+
class FeatureExtractionPipeline(Pipeline):
|
| 16 |
+
"""
|
| 17 |
+
Feature extraction pipeline uses no model head. This pipeline extracts the hidden states from the base
|
| 18 |
+
transformer, which can be used as features in downstream tasks.
|
| 19 |
+
|
| 20 |
+
Example:
|
| 21 |
+
|
| 22 |
+
```python
|
| 23 |
+
>>> from transformers import pipeline
|
| 24 |
+
|
| 25 |
+
>>> extractor = pipeline(model="google-bert/bert-base-uncased", task="feature-extraction")
|
| 26 |
+
>>> result = extractor("This is a simple test.", return_tensors=True)
|
| 27 |
+
>>> result.shape # This is a tensor of shape [1, sequence_length, hidden_dimension] representing the input string.
|
| 28 |
+
torch.Size([1, 8, 768])
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
| 32 |
+
|
| 33 |
+
This feature extraction pipeline can currently be loaded from [`pipeline`] using the task identifier:
|
| 34 |
+
`"feature-extraction"`.
|
| 35 |
+
|
| 36 |
+
All models may be used for this pipeline. See a list of all models, including community-contributed models on
|
| 37 |
+
[huggingface.co/models](https://huggingface.co/models).
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, return_tensors=None, **kwargs):
|
| 41 |
+
if tokenize_kwargs is None:
|
| 42 |
+
tokenize_kwargs = {}
|
| 43 |
+
|
| 44 |
+
if truncation is not None:
|
| 45 |
+
if "truncation" in tokenize_kwargs:
|
| 46 |
+
raise ValueError(
|
| 47 |
+
"truncation parameter defined twice (given as keyword argument as well as in tokenize_kwargs)"
|
| 48 |
+
)
|
| 49 |
+
tokenize_kwargs["truncation"] = truncation
|
| 50 |
+
|
| 51 |
+
preprocess_params = tokenize_kwargs
|
| 52 |
+
|
| 53 |
+
postprocess_params = {}
|
| 54 |
+
if return_tensors is not None:
|
| 55 |
+
postprocess_params["return_tensors"] = return_tensors
|
| 56 |
+
|
| 57 |
+
return preprocess_params, {}, postprocess_params
|
| 58 |
+
|
| 59 |
+
def preprocess(self, inputs, **tokenize_kwargs) -> Dict[str, GenericTensor]:
|
| 60 |
+
model_inputs = self.tokenizer(inputs, return_tensors=self.framework, **tokenize_kwargs)
|
| 61 |
+
return model_inputs
|
| 62 |
+
|
| 63 |
+
def _forward(self, model_inputs):
|
| 64 |
+
model_outputs = self.model(**model_inputs)
|
| 65 |
+
return model_outputs
|
| 66 |
+
|
| 67 |
+
def postprocess(self, model_outputs, return_tensors=False):
|
| 68 |
+
# [0] is the first available tensor, logits or last_hidden_state.
|
| 69 |
+
if return_tensors:
|
| 70 |
+
return model_outputs[0]
|
| 71 |
+
if self.framework == "pt":
|
| 72 |
+
return model_outputs[0].tolist()
|
| 73 |
+
elif self.framework == "tf":
|
| 74 |
+
return model_outputs[0].numpy().tolist()
|
| 75 |
+
|
| 76 |
+
def __call__(self, *args, **kwargs):
|
| 77 |
+
"""
|
| 78 |
+
Extract the features of the input(s).
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
args (`str` or `List[str]`): One or several texts (or one list of texts) to get the features of.
|
| 82 |
+
|
| 83 |
+
Return:
|
| 84 |
+
A nested list of `float`: The features computed by the model.
|
| 85 |
+
"""
|
| 86 |
+
return super().__call__(*args, **kwargs)
|
.venv/lib/python3.11/site-packages/transformers/pipelines/fill_mask.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging
|
| 6 |
+
from .base import GenericTensor, Pipeline, PipelineException, build_pipeline_init_args
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
if is_tf_available():
|
| 10 |
+
import tensorflow as tf
|
| 11 |
+
|
| 12 |
+
from ..tf_utils import stable_softmax
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
if is_torch_available():
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
logger = logging.get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@add_end_docstrings(
|
| 23 |
+
build_pipeline_init_args(has_tokenizer=True),
|
| 24 |
+
r"""
|
| 25 |
+
top_k (`int`, *optional*, defaults to 5):
|
| 26 |
+
The number of predictions to return.
|
| 27 |
+
targets (`str` or `List[str]`, *optional*):
|
| 28 |
+
When passed, the model will limit the scores to the passed targets instead of looking up in the whole
|
| 29 |
+
vocab. If the provided targets are not in the model vocab, they will be tokenized and the first resulting
|
| 30 |
+
token will be used (with a warning, and that might be slower).
|
| 31 |
+
tokenizer_kwargs (`dict`, *optional*):
|
| 32 |
+
Additional dictionary of keyword arguments passed along to the tokenizer.""",
|
| 33 |
+
)
|
| 34 |
+
class FillMaskPipeline(Pipeline):
|
| 35 |
+
"""
|
| 36 |
+
Masked language modeling prediction pipeline using any `ModelWithLMHead`. See the [masked language modeling
|
| 37 |
+
examples](../task_summary#masked-language-modeling) for more information.
|
| 38 |
+
|
| 39 |
+
Example:
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
>>> from transformers import pipeline
|
| 43 |
+
|
| 44 |
+
>>> fill_masker = pipeline(model="google-bert/bert-base-uncased")
|
| 45 |
+
>>> fill_masker("This is a simple [MASK].")
|
| 46 |
+
[{'score': 0.042, 'token': 3291, 'token_str': 'problem', 'sequence': 'this is a simple problem.'}, {'score': 0.031, 'token': 3160, 'token_str': 'question', 'sequence': 'this is a simple question.'}, {'score': 0.03, 'token': 8522, 'token_str': 'equation', 'sequence': 'this is a simple equation.'}, {'score': 0.027, 'token': 2028, 'token_str': 'one', 'sequence': 'this is a simple one.'}, {'score': 0.024, 'token': 3627, 'token_str': 'rule', 'sequence': 'this is a simple rule.'}]
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
| 50 |
+
|
| 51 |
+
This mask filling pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
| 52 |
+
`"fill-mask"`.
|
| 53 |
+
|
| 54 |
+
The models that this pipeline can use are models that have been trained with a masked language modeling objective,
|
| 55 |
+
which includes the bi-directional models in the library. See the up-to-date list of available models on
|
| 56 |
+
[huggingface.co/models](https://huggingface.co/models?filter=fill-mask).
|
| 57 |
+
|
| 58 |
+
<Tip>
|
| 59 |
+
|
| 60 |
+
This pipeline only works for inputs with exactly one token masked. Experimental: We added support for multiple
|
| 61 |
+
masks. The returned values are raw model output, and correspond to disjoint probabilities where one might expect
|
| 62 |
+
joint probabilities (See [discussion](https://github.com/huggingface/transformers/pull/10222)).
|
| 63 |
+
|
| 64 |
+
</Tip>
|
| 65 |
+
|
| 66 |
+
<Tip>
|
| 67 |
+
|
| 68 |
+
This pipeline now supports tokenizer_kwargs. For example try:
|
| 69 |
+
|
| 70 |
+
```python
|
| 71 |
+
>>> from transformers import pipeline
|
| 72 |
+
|
| 73 |
+
>>> fill_masker = pipeline(model="google-bert/bert-base-uncased")
|
| 74 |
+
>>> tokenizer_kwargs = {"truncation": True}
|
| 75 |
+
>>> fill_masker(
|
| 76 |
+
... "This is a simple [MASK]. " + "...with a large amount of repeated text appended. " * 100,
|
| 77 |
+
... tokenizer_kwargs=tokenizer_kwargs,
|
| 78 |
+
... )
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
</Tip>
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray:
|
| 88 |
+
if self.framework == "tf":
|
| 89 |
+
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()
|
| 90 |
+
elif self.framework == "pt":
|
| 91 |
+
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError("Unsupported framework")
|
| 94 |
+
return masked_index
|
| 95 |
+
|
| 96 |
+
def _ensure_exactly_one_mask_token(self, input_ids: GenericTensor) -> np.ndarray:
|
| 97 |
+
masked_index = self.get_masked_index(input_ids)
|
| 98 |
+
numel = np.prod(masked_index.shape)
|
| 99 |
+
if numel < 1:
|
| 100 |
+
raise PipelineException(
|
| 101 |
+
"fill-mask",
|
| 102 |
+
self.model.base_model_prefix,
|
| 103 |
+
f"No mask_token ({self.tokenizer.mask_token}) found on the input",
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def ensure_exactly_one_mask_token(self, model_inputs: GenericTensor):
|
| 107 |
+
if isinstance(model_inputs, list):
|
| 108 |
+
for model_input in model_inputs:
|
| 109 |
+
self._ensure_exactly_one_mask_token(model_input["input_ids"][0])
|
| 110 |
+
else:
|
| 111 |
+
for input_ids in model_inputs["input_ids"]:
|
| 112 |
+
self._ensure_exactly_one_mask_token(input_ids)
|
| 113 |
+
|
| 114 |
+
def preprocess(
|
| 115 |
+
self, inputs, return_tensors=None, tokenizer_kwargs=None, **preprocess_parameters
|
| 116 |
+
) -> Dict[str, GenericTensor]:
|
| 117 |
+
if return_tensors is None:
|
| 118 |
+
return_tensors = self.framework
|
| 119 |
+
if tokenizer_kwargs is None:
|
| 120 |
+
tokenizer_kwargs = {}
|
| 121 |
+
|
| 122 |
+
model_inputs = self.tokenizer(inputs, return_tensors=return_tensors, **tokenizer_kwargs)
|
| 123 |
+
self.ensure_exactly_one_mask_token(model_inputs)
|
| 124 |
+
return model_inputs
|
| 125 |
+
|
| 126 |
+
def _forward(self, model_inputs):
|
| 127 |
+
model_outputs = self.model(**model_inputs)
|
| 128 |
+
model_outputs["input_ids"] = model_inputs["input_ids"]
|
| 129 |
+
return model_outputs
|
| 130 |
+
|
| 131 |
+
def postprocess(self, model_outputs, top_k=5, target_ids=None):
|
| 132 |
+
# Cap top_k if there are targets
|
| 133 |
+
if target_ids is not None and target_ids.shape[0] < top_k:
|
| 134 |
+
top_k = target_ids.shape[0]
|
| 135 |
+
input_ids = model_outputs["input_ids"][0]
|
| 136 |
+
outputs = model_outputs["logits"]
|
| 137 |
+
|
| 138 |
+
if self.framework == "tf":
|
| 139 |
+
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()[:, 0]
|
| 140 |
+
|
| 141 |
+
outputs = outputs.numpy()
|
| 142 |
+
|
| 143 |
+
logits = outputs[0, masked_index, :]
|
| 144 |
+
probs = stable_softmax(logits, axis=-1)
|
| 145 |
+
if target_ids is not None:
|
| 146 |
+
probs = tf.gather_nd(tf.squeeze(probs, 0), target_ids.reshape(-1, 1))
|
| 147 |
+
probs = tf.expand_dims(probs, 0)
|
| 148 |
+
|
| 149 |
+
topk = tf.math.top_k(probs, k=top_k)
|
| 150 |
+
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
| 151 |
+
else:
|
| 152 |
+
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1)
|
| 153 |
+
# Fill mask pipeline supports only one ${mask_token} per sample
|
| 154 |
+
|
| 155 |
+
logits = outputs[0, masked_index, :]
|
| 156 |
+
probs = logits.softmax(dim=-1)
|
| 157 |
+
if target_ids is not None:
|
| 158 |
+
probs = probs[..., target_ids]
|
| 159 |
+
|
| 160 |
+
values, predictions = probs.topk(top_k)
|
| 161 |
+
|
| 162 |
+
result = []
|
| 163 |
+
single_mask = values.shape[0] == 1
|
| 164 |
+
for i, (_values, _predictions) in enumerate(zip(values.tolist(), predictions.tolist())):
|
| 165 |
+
row = []
|
| 166 |
+
for v, p in zip(_values, _predictions):
|
| 167 |
+
# Copy is important since we're going to modify this array in place
|
| 168 |
+
tokens = input_ids.numpy().copy()
|
| 169 |
+
if target_ids is not None:
|
| 170 |
+
p = target_ids[p].tolist()
|
| 171 |
+
|
| 172 |
+
tokens[masked_index[i]] = p
|
| 173 |
+
# Filter padding out:
|
| 174 |
+
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
|
| 175 |
+
# Originally we skip special tokens to give readable output.
|
| 176 |
+
# For multi masks though, the other [MASK] would be removed otherwise
|
| 177 |
+
# making the output look odd, so we add them back
|
| 178 |
+
sequence = self.tokenizer.decode(tokens, skip_special_tokens=single_mask)
|
| 179 |
+
proposition = {"score": v, "token": p, "token_str": self.tokenizer.decode([p]), "sequence": sequence}
|
| 180 |
+
row.append(proposition)
|
| 181 |
+
result.append(row)
|
| 182 |
+
if single_mask:
|
| 183 |
+
return result[0]
|
| 184 |
+
return result
|
| 185 |
+
|
| 186 |
+
def get_target_ids(self, targets, top_k=None):
|
| 187 |
+
if isinstance(targets, str):
|
| 188 |
+
targets = [targets]
|
| 189 |
+
try:
|
| 190 |
+
vocab = self.tokenizer.get_vocab()
|
| 191 |
+
except Exception:
|
| 192 |
+
vocab = {}
|
| 193 |
+
target_ids = []
|
| 194 |
+
for target in targets:
|
| 195 |
+
id_ = vocab.get(target, None)
|
| 196 |
+
if id_ is None:
|
| 197 |
+
input_ids = self.tokenizer(
|
| 198 |
+
target,
|
| 199 |
+
add_special_tokens=False,
|
| 200 |
+
return_attention_mask=False,
|
| 201 |
+
return_token_type_ids=False,
|
| 202 |
+
max_length=1,
|
| 203 |
+
truncation=True,
|
| 204 |
+
)["input_ids"]
|
| 205 |
+
if len(input_ids) == 0:
|
| 206 |
+
logger.warning(
|
| 207 |
+
f"The specified target token `{target}` does not exist in the model vocabulary. "
|
| 208 |
+
"We cannot replace it with anything meaningful, ignoring it"
|
| 209 |
+
)
|
| 210 |
+
continue
|
| 211 |
+
id_ = input_ids[0]
|
| 212 |
+
# XXX: If users encounter this pass
|
| 213 |
+
# it becomes pretty slow, so let's make sure
|
| 214 |
+
# The warning enables them to fix the input to
|
| 215 |
+
# get faster performance.
|
| 216 |
+
logger.warning(
|
| 217 |
+
f"The specified target token `{target}` does not exist in the model vocabulary. "
|
| 218 |
+
f"Replacing with `{self.tokenizer.convert_ids_to_tokens(id_)}`."
|
| 219 |
+
)
|
| 220 |
+
target_ids.append(id_)
|
| 221 |
+
target_ids = list(set(target_ids))
|
| 222 |
+
if len(target_ids) == 0:
|
| 223 |
+
raise ValueError("At least one target must be provided when passed.")
|
| 224 |
+
target_ids = np.array(target_ids)
|
| 225 |
+
return target_ids
|
| 226 |
+
|
| 227 |
+
def _sanitize_parameters(self, top_k=None, targets=None, tokenizer_kwargs=None):
|
| 228 |
+
preprocess_params = {}
|
| 229 |
+
|
| 230 |
+
if tokenizer_kwargs is not None:
|
| 231 |
+
preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs
|
| 232 |
+
|
| 233 |
+
postprocess_params = {}
|
| 234 |
+
|
| 235 |
+
if targets is not None:
|
| 236 |
+
target_ids = self.get_target_ids(targets, top_k)
|
| 237 |
+
postprocess_params["target_ids"] = target_ids
|
| 238 |
+
|
| 239 |
+
if top_k is not None:
|
| 240 |
+
postprocess_params["top_k"] = top_k
|
| 241 |
+
|
| 242 |
+
if self.tokenizer.mask_token_id is None:
|
| 243 |
+
raise PipelineException(
|
| 244 |
+
"fill-mask", self.model.base_model_prefix, "The tokenizer does not define a `mask_token`."
|
| 245 |
+
)
|
| 246 |
+
return preprocess_params, {}, postprocess_params
|
| 247 |
+
|
| 248 |
+
def __call__(self, inputs, **kwargs):
|
| 249 |
+
"""
|
| 250 |
+
Fill the masked token in the text(s) given as inputs.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
inputs (`str` or `List[str]`):
|
| 254 |
+
One or several texts (or one list of prompts) with masked tokens.
|
| 255 |
+
targets (`str` or `List[str]`, *optional*):
|
| 256 |
+
When passed, the model will limit the scores to the passed targets instead of looking up in the whole
|
| 257 |
+
vocab. If the provided targets are not in the model vocab, they will be tokenized and the first
|
| 258 |
+
resulting token will be used (with a warning, and that might be slower).
|
| 259 |
+
top_k (`int`, *optional*):
|
| 260 |
+
When passed, overrides the number of predictions to return.
|
| 261 |
+
|
| 262 |
+
Return:
|
| 263 |
+
A list or a list of list of `dict`: Each result comes as list of dictionaries with the following keys:
|
| 264 |
+
|
| 265 |
+
- **sequence** (`str`) -- The corresponding input with the mask token prediction.
|
| 266 |
+
- **score** (`float`) -- The corresponding probability.
|
| 267 |
+
- **token** (`int`) -- The predicted token id (to replace the masked one).
|
| 268 |
+
- **token_str** (`str`) -- The predicted token (to replace the masked one).
|
| 269 |
+
"""
|
| 270 |
+
outputs = super().__call__(inputs, **kwargs)
|
| 271 |
+
if isinstance(inputs, list) and len(inputs) == 1:
|
| 272 |
+
return outputs[0]
|
| 273 |
+
return outputs
|
.venv/lib/python3.11/site-packages/transformers/pipelines/image_classification.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import List, Union
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
from ..utils import (
|
| 19 |
+
ExplicitEnum,
|
| 20 |
+
add_end_docstrings,
|
| 21 |
+
is_tf_available,
|
| 22 |
+
is_torch_available,
|
| 23 |
+
is_vision_available,
|
| 24 |
+
logging,
|
| 25 |
+
requires_backends,
|
| 26 |
+
)
|
| 27 |
+
from .base import Pipeline, build_pipeline_init_args
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if is_vision_available():
|
| 31 |
+
from PIL import Image
|
| 32 |
+
|
| 33 |
+
from ..image_utils import load_image
|
| 34 |
+
|
| 35 |
+
if is_tf_available():
|
| 36 |
+
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
| 37 |
+
|
| 38 |
+
if is_torch_available():
|
| 39 |
+
import torch
|
| 40 |
+
|
| 41 |
+
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Copied from transformers.pipelines.text_classification.sigmoid
|
| 47 |
+
def sigmoid(_outputs):
|
| 48 |
+
return 1.0 / (1.0 + np.exp(-_outputs))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Copied from transformers.pipelines.text_classification.softmax
|
| 52 |
+
def softmax(_outputs):
|
| 53 |
+
maxes = np.max(_outputs, axis=-1, keepdims=True)
|
| 54 |
+
shifted_exp = np.exp(_outputs - maxes)
|
| 55 |
+
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Copied from transformers.pipelines.text_classification.ClassificationFunction
|
| 59 |
+
class ClassificationFunction(ExplicitEnum):
|
| 60 |
+
SIGMOID = "sigmoid"
|
| 61 |
+
SOFTMAX = "softmax"
|
| 62 |
+
NONE = "none"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@add_end_docstrings(
|
| 66 |
+
build_pipeline_init_args(has_image_processor=True),
|
| 67 |
+
r"""
|
| 68 |
+
function_to_apply (`str`, *optional*, defaults to `"default"`):
|
| 69 |
+
The function to apply to the model outputs in order to retrieve the scores. Accepts four different values:
|
| 70 |
+
|
| 71 |
+
- `"default"`: if the model has a single label, will apply the sigmoid function on the output. If the model
|
| 72 |
+
has several labels, will apply the softmax function on the output.
|
| 73 |
+
- `"sigmoid"`: Applies the sigmoid function on the output.
|
| 74 |
+
- `"softmax"`: Applies the softmax function on the output.
|
| 75 |
+
- `"none"`: Does not apply any function on the output.""",
|
| 76 |
+
)
|
| 77 |
+
class ImageClassificationPipeline(Pipeline):
|
| 78 |
+
"""
|
| 79 |
+
Image classification pipeline using any `AutoModelForImageClassification`. This pipeline predicts the class of an
|
| 80 |
+
image.
|
| 81 |
+
|
| 82 |
+
Example:
|
| 83 |
+
|
| 84 |
+
```python
|
| 85 |
+
>>> from transformers import pipeline
|
| 86 |
+
|
| 87 |
+
>>> classifier = pipeline(model="microsoft/beit-base-patch16-224-pt22k-ft22k")
|
| 88 |
+
>>> classifier("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
|
| 89 |
+
[{'score': 0.442, 'label': 'macaw'}, {'score': 0.088, 'label': 'popinjay'}, {'score': 0.075, 'label': 'parrot'}, {'score': 0.073, 'label': 'parodist, lampooner'}, {'score': 0.046, 'label': 'poll, poll_parrot'}]
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
| 93 |
+
|
| 94 |
+
This image classification pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
| 95 |
+
`"image-classification"`.
|
| 96 |
+
|
| 97 |
+
See the list of available models on
|
| 98 |
+
[huggingface.co/models](https://huggingface.co/models?filter=image-classification).
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
function_to_apply: ClassificationFunction = ClassificationFunction.NONE
|
| 102 |
+
|
| 103 |
+
def __init__(self, *args, **kwargs):
|
| 104 |
+
super().__init__(*args, **kwargs)
|
| 105 |
+
requires_backends(self, "vision")
|
| 106 |
+
self.check_model_type(
|
| 107 |
+
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
| 108 |
+
if self.framework == "tf"
|
| 109 |
+
else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def _sanitize_parameters(self, top_k=None, function_to_apply=None, timeout=None):
|
| 113 |
+
preprocess_params = {}
|
| 114 |
+
if timeout is not None:
|
| 115 |
+
preprocess_params["timeout"] = timeout
|
| 116 |
+
postprocess_params = {}
|
| 117 |
+
if top_k is not None:
|
| 118 |
+
postprocess_params["top_k"] = top_k
|
| 119 |
+
if isinstance(function_to_apply, str):
|
| 120 |
+
function_to_apply = ClassificationFunction(function_to_apply.lower())
|
| 121 |
+
if function_to_apply is not None:
|
| 122 |
+
postprocess_params["function_to_apply"] = function_to_apply
|
| 123 |
+
return preprocess_params, {}, postprocess_params
|
| 124 |
+
|
| 125 |
+
def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
|
| 126 |
+
"""
|
| 127 |
+
Assign labels to the image(s) passed as inputs.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
| 131 |
+
The pipeline handles three types of images:
|
| 132 |
+
|
| 133 |
+
- A string containing a http link pointing to an image
|
| 134 |
+
- A string containing a local path to an image
|
| 135 |
+
- An image loaded in PIL directly
|
| 136 |
+
|
| 137 |
+
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
|
| 138 |
+
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
|
| 139 |
+
images.
|
| 140 |
+
function_to_apply (`str`, *optional*, defaults to `"default"`):
|
| 141 |
+
The function to apply to the model outputs in order to retrieve the scores. Accepts four different
|
| 142 |
+
values:
|
| 143 |
+
|
| 144 |
+
If this argument is not specified, then it will apply the following functions according to the number
|
| 145 |
+
of labels:
|
| 146 |
+
|
| 147 |
+
- If the model has a single label, will apply the sigmoid function on the output.
|
| 148 |
+
- If the model has several labels, will apply the softmax function on the output.
|
| 149 |
+
|
| 150 |
+
Possible values are:
|
| 151 |
+
|
| 152 |
+
- `"sigmoid"`: Applies the sigmoid function on the output.
|
| 153 |
+
- `"softmax"`: Applies the softmax function on the output.
|
| 154 |
+
- `"none"`: Does not apply any function on the output.
|
| 155 |
+
top_k (`int`, *optional*, defaults to 5):
|
| 156 |
+
The number of top labels that will be returned by the pipeline. If the provided number is higher than
|
| 157 |
+
the number of labels available in the model configuration, it will default to the number of labels.
|
| 158 |
+
timeout (`float`, *optional*, defaults to None):
|
| 159 |
+
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
| 160 |
+
the call may block forever.
|
| 161 |
+
|
| 162 |
+
Return:
|
| 163 |
+
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
|
| 164 |
+
dictionary, if the input is a list of several images, will return a list of dictionaries corresponding to
|
| 165 |
+
the images.
|
| 166 |
+
|
| 167 |
+
The dictionaries contain the following keys:
|
| 168 |
+
|
| 169 |
+
- **label** (`str`) -- The label identified by the model.
|
| 170 |
+
- **score** (`int`) -- The score attributed by the model for that label.
|
| 171 |
+
"""
|
| 172 |
+
# After deprecation of this is completed, remove the default `None` value for `images`
|
| 173 |
+
if "images" in kwargs:
|
| 174 |
+
inputs = kwargs.pop("images")
|
| 175 |
+
if inputs is None:
|
| 176 |
+
raise ValueError("Cannot call the image-classification pipeline without an inputs argument!")
|
| 177 |
+
return super().__call__(inputs, **kwargs)
|
| 178 |
+
|
| 179 |
+
def preprocess(self, image, timeout=None):
|
| 180 |
+
image = load_image(image, timeout=timeout)
|
| 181 |
+
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
|
| 182 |
+
if self.framework == "pt":
|
| 183 |
+
model_inputs = model_inputs.to(self.torch_dtype)
|
| 184 |
+
return model_inputs
|
| 185 |
+
|
| 186 |
+
def _forward(self, model_inputs):
|
| 187 |
+
model_outputs = self.model(**model_inputs)
|
| 188 |
+
return model_outputs
|
| 189 |
+
|
| 190 |
+
def postprocess(self, model_outputs, function_to_apply=None, top_k=5):
|
| 191 |
+
if function_to_apply is None:
|
| 192 |
+
if self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels == 1:
|
| 193 |
+
function_to_apply = ClassificationFunction.SIGMOID
|
| 194 |
+
elif self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels > 1:
|
| 195 |
+
function_to_apply = ClassificationFunction.SOFTMAX
|
| 196 |
+
elif hasattr(self.model.config, "function_to_apply") and function_to_apply is None:
|
| 197 |
+
function_to_apply = self.model.config.function_to_apply
|
| 198 |
+
else:
|
| 199 |
+
function_to_apply = ClassificationFunction.NONE
|
| 200 |
+
|
| 201 |
+
if top_k > self.model.config.num_labels:
|
| 202 |
+
top_k = self.model.config.num_labels
|
| 203 |
+
|
| 204 |
+
outputs = model_outputs["logits"][0]
|
| 205 |
+
if self.framework == "pt" and outputs.dtype in (torch.bfloat16, torch.float16):
|
| 206 |
+
outputs = outputs.to(torch.float32).numpy()
|
| 207 |
+
else:
|
| 208 |
+
outputs = outputs.numpy()
|
| 209 |
+
|
| 210 |
+
if function_to_apply == ClassificationFunction.SIGMOID:
|
| 211 |
+
scores = sigmoid(outputs)
|
| 212 |
+
elif function_to_apply == ClassificationFunction.SOFTMAX:
|
| 213 |
+
scores = softmax(outputs)
|
| 214 |
+
elif function_to_apply == ClassificationFunction.NONE:
|
| 215 |
+
scores = outputs
|
| 216 |
+
else:
|
| 217 |
+
raise ValueError(f"Unrecognized `function_to_apply` argument: {function_to_apply}")
|
| 218 |
+
|
| 219 |
+
dict_scores = [
|
| 220 |
+
{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores)
|
| 221 |
+
]
|
| 222 |
+
dict_scores.sort(key=lambda x: x["score"], reverse=True)
|
| 223 |
+
if top_k is not None:
|
| 224 |
+
dict_scores = dict_scores[:top_k]
|
| 225 |
+
|
| 226 |
+
return dict_scores
|
.venv/lib/python3.11/site-packages/transformers/pipelines/image_feature_extraction.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
from ..utils import add_end_docstrings, is_vision_available
|
| 4 |
+
from .base import GenericTensor, Pipeline, build_pipeline_init_args
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
if is_vision_available():
|
| 8 |
+
from ..image_utils import load_image
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@add_end_docstrings(
|
| 12 |
+
build_pipeline_init_args(has_image_processor=True),
|
| 13 |
+
"""
|
| 14 |
+
image_processor_kwargs (`dict`, *optional*):
|
| 15 |
+
Additional dictionary of keyword arguments passed along to the image processor e.g.
|
| 16 |
+
{"size": {"height": 100, "width": 100}}
|
| 17 |
+
pool (`bool`, *optional*, defaults to `False`):
|
| 18 |
+
Whether or not to return the pooled output. If `False`, the model will return the raw hidden states.
|
| 19 |
+
""",
|
| 20 |
+
)
|
| 21 |
+
class ImageFeatureExtractionPipeline(Pipeline):
|
| 22 |
+
"""
|
| 23 |
+
Image feature extraction pipeline uses no model head. This pipeline extracts the hidden states from the base
|
| 24 |
+
transformer, which can be used as features in downstream tasks.
|
| 25 |
+
|
| 26 |
+
Example:
|
| 27 |
+
|
| 28 |
+
```python
|
| 29 |
+
>>> from transformers import pipeline
|
| 30 |
+
|
| 31 |
+
>>> extractor = pipeline(model="google/vit-base-patch16-224", task="image-feature-extraction")
|
| 32 |
+
>>> result = extractor("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", return_tensors=True)
|
| 33 |
+
>>> result.shape # This is a tensor of shape [1, sequence_lenth, hidden_dimension] representing the input image.
|
| 34 |
+
torch.Size([1, 197, 768])
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
| 38 |
+
|
| 39 |
+
This image feature extraction pipeline can currently be loaded from [`pipeline`] using the task identifier:
|
| 40 |
+
`"image-feature-extraction"`.
|
| 41 |
+
|
| 42 |
+
All vision models may be used for this pipeline. See a list of all models, including community-contributed models on
|
| 43 |
+
[huggingface.co/models](https://huggingface.co/models).
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def _sanitize_parameters(self, image_processor_kwargs=None, return_tensors=None, pool=None, **kwargs):
|
| 47 |
+
preprocess_params = {} if image_processor_kwargs is None else image_processor_kwargs
|
| 48 |
+
|
| 49 |
+
postprocess_params = {}
|
| 50 |
+
if pool is not None:
|
| 51 |
+
postprocess_params["pool"] = pool
|
| 52 |
+
if return_tensors is not None:
|
| 53 |
+
postprocess_params["return_tensors"] = return_tensors
|
| 54 |
+
|
| 55 |
+
if "timeout" in kwargs:
|
| 56 |
+
preprocess_params["timeout"] = kwargs["timeout"]
|
| 57 |
+
|
| 58 |
+
return preprocess_params, {}, postprocess_params
|
| 59 |
+
|
| 60 |
+
def preprocess(self, image, timeout=None, **image_processor_kwargs) -> Dict[str, GenericTensor]:
|
| 61 |
+
image = load_image(image, timeout=timeout)
|
| 62 |
+
model_inputs = self.image_processor(image, return_tensors=self.framework, **image_processor_kwargs)
|
| 63 |
+
if self.framework == "pt":
|
| 64 |
+
model_inputs = model_inputs.to(self.torch_dtype)
|
| 65 |
+
return model_inputs
|
| 66 |
+
|
| 67 |
+
def _forward(self, model_inputs):
|
| 68 |
+
model_outputs = self.model(**model_inputs)
|
| 69 |
+
return model_outputs
|
| 70 |
+
|
| 71 |
+
def postprocess(self, model_outputs, pool=None, return_tensors=False):
|
| 72 |
+
pool = pool if pool is not None else False
|
| 73 |
+
|
| 74 |
+
if pool:
|
| 75 |
+
if "pooler_output" not in model_outputs:
|
| 76 |
+
raise ValueError(
|
| 77 |
+
"No pooled output was returned. Make sure the model has a `pooler` layer when using the `pool` option."
|
| 78 |
+
)
|
| 79 |
+
outputs = model_outputs["pooler_output"]
|
| 80 |
+
else:
|
| 81 |
+
# [0] is the first available tensor, logits or last_hidden_state.
|
| 82 |
+
outputs = model_outputs[0]
|
| 83 |
+
|
| 84 |
+
if return_tensors:
|
| 85 |
+
return outputs
|
| 86 |
+
if self.framework == "pt":
|
| 87 |
+
return outputs.tolist()
|
| 88 |
+
elif self.framework == "tf":
|
| 89 |
+
return outputs.numpy().tolist()
|
| 90 |
+
|
| 91 |
+
def __call__(self, *args, **kwargs):
|
| 92 |
+
"""
|
| 93 |
+
Extract the features of the input(s).
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
| 97 |
+
The pipeline handles three types of images:
|
| 98 |
+
|
| 99 |
+
- A string containing a http link pointing to an image
|
| 100 |
+
- A string containing a local path to an image
|
| 101 |
+
- An image loaded in PIL directly
|
| 102 |
+
|
| 103 |
+
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
|
| 104 |
+
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
|
| 105 |
+
images.
|
| 106 |
+
timeout (`float`, *optional*, defaults to None):
|
| 107 |
+
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is used and
|
| 108 |
+
the call may block forever.
|
| 109 |
+
Return:
|
| 110 |
+
A nested list of `float`: The features computed by the model.
|
| 111 |
+
"""
|
| 112 |
+
return super().__call__(*args, **kwargs)
|
.venv/lib/python3.11/site-packages/transformers/pipelines/image_segmentation.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
|
| 6 |
+
from .base import Pipeline, build_pipeline_init_args
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
if is_vision_available():
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from ..image_utils import load_image
|
| 13 |
+
|
| 14 |
+
if is_torch_available():
|
| 15 |
+
from ..models.auto.modeling_auto import (
|
| 16 |
+
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
|
| 17 |
+
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES,
|
| 18 |
+
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
|
| 19 |
+
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
Prediction = Dict[str, Any]
|
| 27 |
+
Predictions = List[Prediction]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
|
| 31 |
+
class ImageSegmentationPipeline(Pipeline):
|
| 32 |
+
"""
|
| 33 |
+
Image segmentation pipeline using any `AutoModelForXXXSegmentation`. This pipeline predicts masks of objects and
|
| 34 |
+
their classes.
|
| 35 |
+
|
| 36 |
+
Example:
|
| 37 |
+
|
| 38 |
+
```python
|
| 39 |
+
>>> from transformers import pipeline
|
| 40 |
+
|
| 41 |
+
>>> segmenter = pipeline(model="facebook/detr-resnet-50-panoptic")
|
| 42 |
+
>>> segments = segmenter("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
|
| 43 |
+
>>> len(segments)
|
| 44 |
+
2
|
| 45 |
+
|
| 46 |
+
>>> segments[0]["label"]
|
| 47 |
+
'bird'
|
| 48 |
+
|
| 49 |
+
>>> segments[1]["label"]
|
| 50 |
+
'bird'
|
| 51 |
+
|
| 52 |
+
>>> type(segments[0]["mask"]) # This is a black and white mask showing where is the bird on the original image.
|
| 53 |
+
<class 'PIL.Image.Image'>
|
| 54 |
+
|
| 55 |
+
>>> segments[0]["mask"].size
|
| 56 |
+
(768, 512)
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
This image segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
| 61 |
+
`"image-segmentation"`.
|
| 62 |
+
|
| 63 |
+
See the list of available models on
|
| 64 |
+
[huggingface.co/models](https://huggingface.co/models?filter=image-segmentation).
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, *args, **kwargs):
|
| 68 |
+
super().__init__(*args, **kwargs)
|
| 69 |
+
|
| 70 |
+
if self.framework == "tf":
|
| 71 |
+
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
| 72 |
+
|
| 73 |
+
requires_backends(self, "vision")
|
| 74 |
+
mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES.copy()
|
| 75 |
+
mapping.update(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES)
|
| 76 |
+
mapping.update(MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES)
|
| 77 |
+
mapping.update(MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES)
|
| 78 |
+
self.check_model_type(mapping)
|
| 79 |
+
|
| 80 |
+
def _sanitize_parameters(self, **kwargs):
|
| 81 |
+
preprocess_kwargs = {}
|
| 82 |
+
postprocess_kwargs = {}
|
| 83 |
+
if "subtask" in kwargs:
|
| 84 |
+
postprocess_kwargs["subtask"] = kwargs["subtask"]
|
| 85 |
+
preprocess_kwargs["subtask"] = kwargs["subtask"]
|
| 86 |
+
if "threshold" in kwargs:
|
| 87 |
+
postprocess_kwargs["threshold"] = kwargs["threshold"]
|
| 88 |
+
if "mask_threshold" in kwargs:
|
| 89 |
+
postprocess_kwargs["mask_threshold"] = kwargs["mask_threshold"]
|
| 90 |
+
if "overlap_mask_area_threshold" in kwargs:
|
| 91 |
+
postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
|
| 92 |
+
if "timeout" in kwargs:
|
| 93 |
+
preprocess_kwargs["timeout"] = kwargs["timeout"]
|
| 94 |
+
|
| 95 |
+
return preprocess_kwargs, {}, postprocess_kwargs
|
| 96 |
+
|
| 97 |
+
def __call__(self, inputs=None, **kwargs) -> Union[Predictions, List[Prediction]]:
|
| 98 |
+
"""
|
| 99 |
+
Perform segmentation (detect masks & classes) in the image(s) passed as inputs.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
| 103 |
+
The pipeline handles three types of images:
|
| 104 |
+
|
| 105 |
+
- A string containing an HTTP(S) link pointing to an image
|
| 106 |
+
- A string containing a local path to an image
|
| 107 |
+
- An image loaded in PIL directly
|
| 108 |
+
|
| 109 |
+
The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
|
| 110 |
+
same format: all as HTTP(S) links, all as local paths, or all as PIL images.
|
| 111 |
+
subtask (`str`, *optional*):
|
| 112 |
+
Segmentation task to be performed, choose [`semantic`, `instance` and `panoptic`] depending on model
|
| 113 |
+
capabilities. If not set, the pipeline will attempt tp resolve in the following order:
|
| 114 |
+
`panoptic`, `instance`, `semantic`.
|
| 115 |
+
threshold (`float`, *optional*, defaults to 0.9):
|
| 116 |
+
Probability threshold to filter out predicted masks.
|
| 117 |
+
mask_threshold (`float`, *optional*, defaults to 0.5):
|
| 118 |
+
Threshold to use when turning the predicted masks into binary values.
|
| 119 |
+
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.5):
|
| 120 |
+
Mask overlap threshold to eliminate small, disconnected segments.
|
| 121 |
+
timeout (`float`, *optional*, defaults to None):
|
| 122 |
+
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
| 123 |
+
the call may block forever.
|
| 124 |
+
|
| 125 |
+
Return:
|
| 126 |
+
A dictionary or a list of dictionaries containing the result. If the input is a single image, will return a
|
| 127 |
+
list of dictionaries, if the input is a list of several images, will return a list of list of dictionaries
|
| 128 |
+
corresponding to each image.
|
| 129 |
+
|
| 130 |
+
The dictionaries contain the mask, label and score (where applicable) of each detected object and contains
|
| 131 |
+
the following keys:
|
| 132 |
+
|
| 133 |
+
- **label** (`str`) -- The class label identified by the model.
|
| 134 |
+
- **mask** (`PIL.Image`) -- A binary mask of the detected object as a Pil Image of shape (width, height) of
|
| 135 |
+
the original image. Returns a mask filled with zeros if no object is found.
|
| 136 |
+
- **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of the
|
| 137 |
+
"object" described by the label and the mask.
|
| 138 |
+
"""
|
| 139 |
+
# After deprecation of this is completed, remove the default `None` value for `images`
|
| 140 |
+
if "images" in kwargs:
|
| 141 |
+
inputs = kwargs.pop("images")
|
| 142 |
+
if inputs is None:
|
| 143 |
+
raise ValueError("Cannot call the image-classification pipeline without an inputs argument!")
|
| 144 |
+
return super().__call__(inputs, **kwargs)
|
| 145 |
+
|
| 146 |
+
def preprocess(self, image, subtask=None, timeout=None):
|
| 147 |
+
image = load_image(image, timeout=timeout)
|
| 148 |
+
target_size = [(image.height, image.width)]
|
| 149 |
+
if self.model.config.__class__.__name__ == "OneFormerConfig":
|
| 150 |
+
if subtask is None:
|
| 151 |
+
kwargs = {}
|
| 152 |
+
else:
|
| 153 |
+
kwargs = {"task_inputs": [subtask]}
|
| 154 |
+
inputs = self.image_processor(images=[image], return_tensors="pt", **kwargs)
|
| 155 |
+
if self.framework == "pt":
|
| 156 |
+
inputs = inputs.to(self.torch_dtype)
|
| 157 |
+
inputs["task_inputs"] = self.tokenizer(
|
| 158 |
+
inputs["task_inputs"],
|
| 159 |
+
padding="max_length",
|
| 160 |
+
max_length=self.model.config.task_seq_len,
|
| 161 |
+
return_tensors=self.framework,
|
| 162 |
+
)["input_ids"]
|
| 163 |
+
else:
|
| 164 |
+
inputs = self.image_processor(images=[image], return_tensors="pt")
|
| 165 |
+
if self.framework == "pt":
|
| 166 |
+
inputs = inputs.to(self.torch_dtype)
|
| 167 |
+
inputs["target_size"] = target_size
|
| 168 |
+
return inputs
|
| 169 |
+
|
| 170 |
+
def _forward(self, model_inputs):
|
| 171 |
+
target_size = model_inputs.pop("target_size")
|
| 172 |
+
model_outputs = self.model(**model_inputs)
|
| 173 |
+
model_outputs["target_size"] = target_size
|
| 174 |
+
return model_outputs
|
| 175 |
+
|
| 176 |
+
def postprocess(
|
| 177 |
+
self, model_outputs, subtask=None, threshold=0.9, mask_threshold=0.5, overlap_mask_area_threshold=0.5
|
| 178 |
+
):
|
| 179 |
+
fn = None
|
| 180 |
+
if subtask in {"panoptic", None} and hasattr(self.image_processor, "post_process_panoptic_segmentation"):
|
| 181 |
+
fn = self.image_processor.post_process_panoptic_segmentation
|
| 182 |
+
elif subtask in {"instance", None} and hasattr(self.image_processor, "post_process_instance_segmentation"):
|
| 183 |
+
fn = self.image_processor.post_process_instance_segmentation
|
| 184 |
+
|
| 185 |
+
if fn is not None:
|
| 186 |
+
outputs = fn(
|
| 187 |
+
model_outputs,
|
| 188 |
+
threshold=threshold,
|
| 189 |
+
mask_threshold=mask_threshold,
|
| 190 |
+
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
| 191 |
+
target_sizes=model_outputs["target_size"],
|
| 192 |
+
)[0]
|
| 193 |
+
|
| 194 |
+
annotation = []
|
| 195 |
+
segmentation = outputs["segmentation"]
|
| 196 |
+
|
| 197 |
+
for segment in outputs["segments_info"]:
|
| 198 |
+
mask = (segmentation == segment["id"]) * 255
|
| 199 |
+
mask = Image.fromarray(mask.numpy().astype(np.uint8), mode="L")
|
| 200 |
+
label = self.model.config.id2label[segment["label_id"]]
|
| 201 |
+
score = segment["score"]
|
| 202 |
+
annotation.append({"score": score, "label": label, "mask": mask})
|
| 203 |
+
|
| 204 |
+
elif subtask in {"semantic", None} and hasattr(self.image_processor, "post_process_semantic_segmentation"):
|
| 205 |
+
outputs = self.image_processor.post_process_semantic_segmentation(
|
| 206 |
+
model_outputs, target_sizes=model_outputs["target_size"]
|
| 207 |
+
)[0]
|
| 208 |
+
|
| 209 |
+
annotation = []
|
| 210 |
+
segmentation = outputs.numpy()
|
| 211 |
+
labels = np.unique(segmentation)
|
| 212 |
+
|
| 213 |
+
for label in labels:
|
| 214 |
+
mask = (segmentation == label) * 255
|
| 215 |
+
mask = Image.fromarray(mask.astype(np.uint8), mode="L")
|
| 216 |
+
label = self.model.config.id2label[label]
|
| 217 |
+
annotation.append({"score": None, "label": label, "mask": mask})
|
| 218 |
+
else:
|
| 219 |
+
raise ValueError(f"Subtask {subtask} is not supported for model {type(self.model)}")
|
| 220 |
+
return annotation
|