Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/Lib/site-packages/torio/lib/libtorio_ffmpeg5.pyd +3 -0
- .venv/Lib/site-packages/transformers/__init__.py +0 -0
- .venv/Lib/site-packages/transformers/agents/__init__.py +69 -0
- .venv/Lib/site-packages/transformers/agents/agent_types.py +260 -0
- .venv/Lib/site-packages/transformers/agents/agents.py +1278 -0
- .venv/Lib/site-packages/transformers/agents/default_tools.py +187 -0
- .venv/Lib/site-packages/transformers/agents/document_question_answering.py +89 -0
- .venv/Lib/site-packages/transformers/agents/evaluate_agent.py +414 -0
- .venv/Lib/site-packages/transformers/agents/image_question_answering.py +58 -0
- .venv/Lib/site-packages/transformers/agents/llm_engine.py +238 -0
- .venv/Lib/site-packages/transformers/agents/monitoring.py +117 -0
- .venv/Lib/site-packages/transformers/agents/prompts.py +789 -0
- .venv/Lib/site-packages/transformers/agents/python_interpreter.py +908 -0
- .venv/Lib/site-packages/transformers/agents/search.py +77 -0
- .venv/Lib/site-packages/transformers/agents/speech_to_text.py +39 -0
- .venv/Lib/site-packages/transformers/agents/text_to_speech.py +67 -0
- .venv/Lib/site-packages/transformers/agents/tools.py +1003 -0
- .venv/Lib/site-packages/transformers/agents/translation.py +279 -0
- .venv/Lib/site-packages/transformers/benchmark/benchmark.py +270 -0
- .venv/Lib/site-packages/transformers/benchmark/benchmark_args.py +124 -0
- .venv/Lib/site-packages/transformers/benchmark/benchmark_args_tf.py +136 -0
- .venv/Lib/site-packages/transformers/commands/__init__.py +27 -0
- .venv/Lib/site-packages/transformers/commands/run.py +110 -0
- .venv/Lib/site-packages/transformers/commands/serving.py +228 -0
- .venv/Lib/site-packages/transformers/commands/train.py +158 -0
- .venv/Lib/site-packages/transformers/commands/transformers_cli.py +57 -0
- .venv/Lib/site-packages/transformers/commands/user.py +197 -0
- .venv/Lib/site-packages/transformers/data/__init__.py +45 -0
- .venv/Lib/site-packages/transformers/data/data_collator.py +1653 -0
- .venv/Lib/site-packages/transformers/data/datasets/__init__.py +23 -0
- .venv/Lib/site-packages/transformers/data/datasets/glue.py +161 -0
- .venv/Lib/site-packages/transformers/data/datasets/language_modeling.py +530 -0
- .venv/Lib/site-packages/transformers/data/datasets/squad.py +229 -0
- .venv/Lib/site-packages/transformers/data/metrics/__init__.py +98 -0
- .venv/Lib/site-packages/transformers/data/metrics/squad_metrics.py +779 -0
- .venv/Lib/site-packages/transformers/data/processors/__init__.py +18 -0
- .venv/Lib/site-packages/transformers/data/processors/glue.py +643 -0
- .venv/Lib/site-packages/transformers/data/processors/squad.py +845 -0
- .venv/Lib/site-packages/transformers/data/processors/utils.py +349 -0
- .venv/Lib/site-packages/transformers/data/processors/xnli.py +96 -0
- .venv/Lib/site-packages/transformers/generation/__init__.py +352 -0
- .venv/Lib/site-packages/transformers/generation/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/generation/__pycache__/beam_constraints.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/generation/__pycache__/beam_search.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/generation/__pycache__/candidate_generator.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/generation/__pycache__/configuration_utils.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/generation/__pycache__/logits_process.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/generation/__pycache__/stopping_criteria.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/generation/__pycache__/utils.cpython-39.pyc +0 -0
.gitattributes
CHANGED
|
@@ -83,3 +83,4 @@ reference_sample_wavs/syuukovoice_200918_3_01.wav filter=lfs diff=lfs merge=lfs
|
|
| 83 |
.venv/Lib/site-packages/torchaudio/lib/libtorchaudio.pyd filter=lfs diff=lfs merge=lfs -text
|
| 84 |
.venv/Lib/site-packages/torchvision/nvjpeg64_12.dll filter=lfs diff=lfs merge=lfs -text
|
| 85 |
.venv/Lib/site-packages/torchvision/_C.pyd filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 83 |
.venv/Lib/site-packages/torchaudio/lib/libtorchaudio.pyd filter=lfs diff=lfs merge=lfs -text
|
| 84 |
.venv/Lib/site-packages/torchvision/nvjpeg64_12.dll filter=lfs diff=lfs merge=lfs -text
|
| 85 |
.venv/Lib/site-packages/torchvision/_C.pyd filter=lfs diff=lfs merge=lfs -text
|
| 86 |
+
.venv/Lib/site-packages/torio/lib/libtorio_ffmpeg5.pyd filter=lfs diff=lfs merge=lfs -text
|
.venv/Lib/site-packages/torio/lib/libtorio_ffmpeg5.pyd
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7abc7280260cda768b24d26ab52f7f1409d073b921bb57b52ffde627d2200bb5
|
| 3 |
+
size 1094656
|
.venv/Lib/site-packages/transformers/__init__.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/Lib/site-packages/transformers/agents/__init__.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
|
| 4 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
from typing import TYPE_CHECKING
|
| 18 |
+
|
| 19 |
+
from ..utils import (
|
| 20 |
+
OptionalDependencyNotAvailable,
|
| 21 |
+
_LazyModule,
|
| 22 |
+
is_torch_available,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
_import_structure = {
|
| 27 |
+
"agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
|
| 28 |
+
"llm_engine": ["HfApiEngine", "TransformersEngine"],
|
| 29 |
+
"monitoring": ["stream_to_gradio"],
|
| 30 |
+
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool", "tool"],
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
if not is_torch_available():
|
| 35 |
+
raise OptionalDependencyNotAvailable()
|
| 36 |
+
except OptionalDependencyNotAvailable:
|
| 37 |
+
pass
|
| 38 |
+
else:
|
| 39 |
+
_import_structure["default_tools"] = ["FinalAnswerTool", "PythonInterpreterTool"]
|
| 40 |
+
_import_structure["document_question_answering"] = ["DocumentQuestionAnsweringTool"]
|
| 41 |
+
_import_structure["image_question_answering"] = ["ImageQuestionAnsweringTool"]
|
| 42 |
+
_import_structure["search"] = ["DuckDuckGoSearchTool", "VisitWebpageTool"]
|
| 43 |
+
_import_structure["speech_to_text"] = ["SpeechToTextTool"]
|
| 44 |
+
_import_structure["text_to_speech"] = ["TextToSpeechTool"]
|
| 45 |
+
_import_structure["translation"] = ["TranslationTool"]
|
| 46 |
+
|
| 47 |
+
if TYPE_CHECKING:
|
| 48 |
+
from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
|
| 49 |
+
from .llm_engine import HfApiEngine, TransformersEngine
|
| 50 |
+
from .monitoring import stream_to_gradio
|
| 51 |
+
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool, tool
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
if not is_torch_available():
|
| 55 |
+
raise OptionalDependencyNotAvailable()
|
| 56 |
+
except OptionalDependencyNotAvailable:
|
| 57 |
+
pass
|
| 58 |
+
else:
|
| 59 |
+
from .default_tools import FinalAnswerTool, PythonInterpreterTool
|
| 60 |
+
from .document_question_answering import DocumentQuestionAnsweringTool
|
| 61 |
+
from .image_question_answering import ImageQuestionAnsweringTool
|
| 62 |
+
from .search import DuckDuckGoSearchTool, VisitWebpageTool
|
| 63 |
+
from .speech_to_text import SpeechToTextTool
|
| 64 |
+
from .text_to_speech import TextToSpeechTool
|
| 65 |
+
from .translation import TranslationTool
|
| 66 |
+
else:
|
| 67 |
+
import sys
|
| 68 |
+
|
| 69 |
+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
.venv/Lib/site-packages/transformers/agents/agent_types.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 HuggingFace Inc.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import os
|
| 16 |
+
import pathlib
|
| 17 |
+
import tempfile
|
| 18 |
+
import uuid
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
from ..utils import is_soundfile_availble, is_torch_available, is_vision_available, logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
if is_vision_available():
|
| 28 |
+
from PIL import Image
|
| 29 |
+
from PIL.Image import Image as ImageType
|
| 30 |
+
else:
|
| 31 |
+
ImageType = object
|
| 32 |
+
|
| 33 |
+
if is_torch_available():
|
| 34 |
+
import torch
|
| 35 |
+
from torch import Tensor
|
| 36 |
+
else:
|
| 37 |
+
Tensor = object
|
| 38 |
+
|
| 39 |
+
if is_soundfile_availble():
|
| 40 |
+
import soundfile as sf
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class AgentType:
|
| 44 |
+
"""
|
| 45 |
+
Abstract class to be reimplemented to define types that can be returned by agents.
|
| 46 |
+
|
| 47 |
+
These objects serve three purposes:
|
| 48 |
+
|
| 49 |
+
- They behave as they were the type they're meant to be, e.g., a string for text, a PIL.Image for images
|
| 50 |
+
- They can be stringified: str(object) in order to return a string defining the object
|
| 51 |
+
- They should be displayed correctly in ipython notebooks/colab/jupyter
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, value):
|
| 55 |
+
self._value = value
|
| 56 |
+
|
| 57 |
+
def __str__(self):
|
| 58 |
+
return self.to_string()
|
| 59 |
+
|
| 60 |
+
def to_raw(self):
|
| 61 |
+
logger.error(
|
| 62 |
+
"This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
|
| 63 |
+
)
|
| 64 |
+
return self._value
|
| 65 |
+
|
| 66 |
+
def to_string(self) -> str:
|
| 67 |
+
logger.error(
|
| 68 |
+
"This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
|
| 69 |
+
)
|
| 70 |
+
return str(self._value)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class AgentText(AgentType, str):
|
| 74 |
+
"""
|
| 75 |
+
Text type returned by the agent. Behaves as a string.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def to_raw(self):
|
| 79 |
+
return self._value
|
| 80 |
+
|
| 81 |
+
def to_string(self):
|
| 82 |
+
return str(self._value)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class AgentImage(AgentType, ImageType):
|
| 86 |
+
"""
|
| 87 |
+
Image type returned by the agent. Behaves as a PIL.Image.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, value):
|
| 91 |
+
AgentType.__init__(self, value)
|
| 92 |
+
ImageType.__init__(self)
|
| 93 |
+
|
| 94 |
+
if not is_vision_available():
|
| 95 |
+
raise ImportError("PIL must be installed in order to handle images.")
|
| 96 |
+
|
| 97 |
+
self._path = None
|
| 98 |
+
self._raw = None
|
| 99 |
+
self._tensor = None
|
| 100 |
+
|
| 101 |
+
if isinstance(value, ImageType):
|
| 102 |
+
self._raw = value
|
| 103 |
+
elif isinstance(value, (str, pathlib.Path)):
|
| 104 |
+
self._path = value
|
| 105 |
+
elif isinstance(value, torch.Tensor):
|
| 106 |
+
self._tensor = value
|
| 107 |
+
elif isinstance(value, np.ndarray):
|
| 108 |
+
self._tensor = torch.from_numpy(value)
|
| 109 |
+
else:
|
| 110 |
+
raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
|
| 111 |
+
|
| 112 |
+
def _ipython_display_(self, include=None, exclude=None):
|
| 113 |
+
"""
|
| 114 |
+
Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
|
| 115 |
+
"""
|
| 116 |
+
from IPython.display import Image, display
|
| 117 |
+
|
| 118 |
+
display(Image(self.to_string()))
|
| 119 |
+
|
| 120 |
+
def to_raw(self):
|
| 121 |
+
"""
|
| 122 |
+
Returns the "raw" version of that object. In the case of an AgentImage, it is a PIL.Image.
|
| 123 |
+
"""
|
| 124 |
+
if self._raw is not None:
|
| 125 |
+
return self._raw
|
| 126 |
+
|
| 127 |
+
if self._path is not None:
|
| 128 |
+
self._raw = Image.open(self._path)
|
| 129 |
+
return self._raw
|
| 130 |
+
|
| 131 |
+
if self._tensor is not None:
|
| 132 |
+
array = self._tensor.cpu().detach().numpy()
|
| 133 |
+
return Image.fromarray((255 - array * 255).astype(np.uint8))
|
| 134 |
+
|
| 135 |
+
def to_string(self):
|
| 136 |
+
"""
|
| 137 |
+
Returns the stringified version of that object. In the case of an AgentImage, it is a path to the serialized
|
| 138 |
+
version of the image.
|
| 139 |
+
"""
|
| 140 |
+
if self._path is not None:
|
| 141 |
+
return self._path
|
| 142 |
+
|
| 143 |
+
if self._raw is not None:
|
| 144 |
+
directory = tempfile.mkdtemp()
|
| 145 |
+
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
|
| 146 |
+
self._raw.save(self._path)
|
| 147 |
+
return self._path
|
| 148 |
+
|
| 149 |
+
if self._tensor is not None:
|
| 150 |
+
array = self._tensor.cpu().detach().numpy()
|
| 151 |
+
|
| 152 |
+
# There is likely simpler than load into image into save
|
| 153 |
+
img = Image.fromarray((255 - array * 255).astype(np.uint8))
|
| 154 |
+
|
| 155 |
+
directory = tempfile.mkdtemp()
|
| 156 |
+
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
|
| 157 |
+
|
| 158 |
+
img.save(self._path)
|
| 159 |
+
|
| 160 |
+
return self._path
|
| 161 |
+
|
| 162 |
+
def save(self, output_bytes, format, **params):
|
| 163 |
+
"""
|
| 164 |
+
Saves the image to a file.
|
| 165 |
+
Args:
|
| 166 |
+
output_bytes (bytes): The output bytes to save the image to.
|
| 167 |
+
format (str): The format to use for the output image. The format is the same as in PIL.Image.save.
|
| 168 |
+
**params: Additional parameters to pass to PIL.Image.save.
|
| 169 |
+
"""
|
| 170 |
+
img = self.to_raw()
|
| 171 |
+
img.save(output_bytes, format, **params)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class AgentAudio(AgentType, str):
|
| 175 |
+
"""
|
| 176 |
+
Audio type returned by the agent.
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def __init__(self, value, samplerate=16_000):
|
| 180 |
+
super().__init__(value)
|
| 181 |
+
|
| 182 |
+
if not is_soundfile_availble():
|
| 183 |
+
raise ImportError("soundfile must be installed in order to handle audio.")
|
| 184 |
+
|
| 185 |
+
self._path = None
|
| 186 |
+
self._tensor = None
|
| 187 |
+
|
| 188 |
+
self.samplerate = samplerate
|
| 189 |
+
if isinstance(value, (str, pathlib.Path)):
|
| 190 |
+
self._path = value
|
| 191 |
+
elif is_torch_available() and isinstance(value, torch.Tensor):
|
| 192 |
+
self._tensor = value
|
| 193 |
+
elif isinstance(value, tuple):
|
| 194 |
+
self.samplerate = value[0]
|
| 195 |
+
if isinstance(value[1], np.ndarray):
|
| 196 |
+
self._tensor = torch.from_numpy(value[1])
|
| 197 |
+
else:
|
| 198 |
+
self._tensor = torch.tensor(value[1])
|
| 199 |
+
else:
|
| 200 |
+
raise ValueError(f"Unsupported audio type: {type(value)}")
|
| 201 |
+
|
| 202 |
+
def _ipython_display_(self, include=None, exclude=None):
|
| 203 |
+
"""
|
| 204 |
+
Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
|
| 205 |
+
"""
|
| 206 |
+
from IPython.display import Audio, display
|
| 207 |
+
|
| 208 |
+
display(Audio(self.to_string(), rate=self.samplerate))
|
| 209 |
+
|
| 210 |
+
def to_raw(self):
|
| 211 |
+
"""
|
| 212 |
+
Returns the "raw" version of that object. It is a `torch.Tensor` object.
|
| 213 |
+
"""
|
| 214 |
+
if self._tensor is not None:
|
| 215 |
+
return self._tensor
|
| 216 |
+
|
| 217 |
+
if self._path is not None:
|
| 218 |
+
tensor, self.samplerate = sf.read(self._path)
|
| 219 |
+
self._tensor = torch.tensor(tensor)
|
| 220 |
+
return self._tensor
|
| 221 |
+
|
| 222 |
+
def to_string(self):
|
| 223 |
+
"""
|
| 224 |
+
Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
|
| 225 |
+
version of the audio.
|
| 226 |
+
"""
|
| 227 |
+
if self._path is not None:
|
| 228 |
+
return self._path
|
| 229 |
+
|
| 230 |
+
if self._tensor is not None:
|
| 231 |
+
directory = tempfile.mkdtemp()
|
| 232 |
+
self._path = os.path.join(directory, str(uuid.uuid4()) + ".wav")
|
| 233 |
+
sf.write(self._path, self._tensor, samplerate=self.samplerate)
|
| 234 |
+
return self._path
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
|
| 238 |
+
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage}
|
| 239 |
+
|
| 240 |
+
if is_torch_available():
|
| 241 |
+
INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def handle_agent_inputs(*args, **kwargs):
|
| 245 |
+
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
|
| 246 |
+
kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()}
|
| 247 |
+
return args, kwargs
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def handle_agent_outputs(output, output_type=None):
|
| 251 |
+
if output_type in AGENT_TYPE_MAPPING:
|
| 252 |
+
# If the class has defined outputs, we can map directly according to the class definition
|
| 253 |
+
decoded_outputs = AGENT_TYPE_MAPPING[output_type](output)
|
| 254 |
+
return decoded_outputs
|
| 255 |
+
else:
|
| 256 |
+
# If the class does not have defined output, then we map according to the type
|
| 257 |
+
for _k, _v in INSTANCE_TYPE_MAPPING.items():
|
| 258 |
+
if isinstance(output, _k):
|
| 259 |
+
return _v(output)
|
| 260 |
+
return output
|
.venv/Lib/site-packages/transformers/agents/agents.py
ADDED
|
@@ -0,0 +1,1278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
|
| 4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
import json
|
| 18 |
+
import logging
|
| 19 |
+
import re
|
| 20 |
+
import time
|
| 21 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
from .. import is_torch_available
|
| 24 |
+
from ..utils import logging as transformers_logging
|
| 25 |
+
from ..utils.import_utils import is_pygments_available
|
| 26 |
+
from .agent_types import AgentAudio, AgentImage
|
| 27 |
+
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
|
| 28 |
+
from .llm_engine import HfApiEngine, MessageRole
|
| 29 |
+
from .monitoring import Monitor
|
| 30 |
+
from .prompts import (
|
| 31 |
+
DEFAULT_CODE_SYSTEM_PROMPT,
|
| 32 |
+
DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
| 33 |
+
DEFAULT_REACT_JSON_SYSTEM_PROMPT,
|
| 34 |
+
PLAN_UPDATE_FINAL_PLAN_REDACTION,
|
| 35 |
+
PROMPTS_FOR_INITIAL_PLAN,
|
| 36 |
+
PROMPTS_FOR_PLAN_UPDATE,
|
| 37 |
+
SUPPORTED_PLAN_TYPES,
|
| 38 |
+
SYSTEM_PROMPT_FACTS,
|
| 39 |
+
SYSTEM_PROMPT_FACTS_UPDATE,
|
| 40 |
+
USER_PROMPT_FACTS_UPDATE,
|
| 41 |
+
)
|
| 42 |
+
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
| 43 |
+
from .tools import (
|
| 44 |
+
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
| 45 |
+
Tool,
|
| 46 |
+
get_tool_description_with_args,
|
| 47 |
+
load_tool,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if is_pygments_available():
|
| 52 |
+
from pygments import highlight
|
| 53 |
+
from pygments.formatters import Terminal256Formatter
|
| 54 |
+
from pygments.lexers import PythonLexer
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class CustomFormatter(logging.Formatter):
|
| 58 |
+
grey = "\x1b[38;20m"
|
| 59 |
+
bold_yellow = "\x1b[33;1m"
|
| 60 |
+
red = "\x1b[31;20m"
|
| 61 |
+
green = "\x1b[32;20m"
|
| 62 |
+
bold_green = "\x1b[32;20;1m"
|
| 63 |
+
bold_red = "\x1b[31;1m"
|
| 64 |
+
bold_white = "\x1b[37;1m"
|
| 65 |
+
orange = "\x1b[38;5;214m"
|
| 66 |
+
bold_orange = "\x1b[38;5;214;1m"
|
| 67 |
+
reset = "\x1b[0m"
|
| 68 |
+
format = "%(message)s"
|
| 69 |
+
|
| 70 |
+
FORMATS = {
|
| 71 |
+
logging.DEBUG: grey + format + reset,
|
| 72 |
+
logging.INFO: format,
|
| 73 |
+
logging.WARNING: bold_yellow + format + reset,
|
| 74 |
+
logging.ERROR: red + format + reset,
|
| 75 |
+
logging.CRITICAL: bold_red + format + reset,
|
| 76 |
+
31: reset + format + reset,
|
| 77 |
+
32: green + format + reset,
|
| 78 |
+
33: bold_green + format + reset,
|
| 79 |
+
34: bold_white + format + reset,
|
| 80 |
+
35: orange + format + reset,
|
| 81 |
+
36: bold_orange + format + reset,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
def format(self, record):
|
| 85 |
+
log_fmt = self.FORMATS.get(record.levelno)
|
| 86 |
+
formatter = logging.Formatter(log_fmt)
|
| 87 |
+
return formatter.format(record)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
logger = transformers_logging.get_logger(__name__)
|
| 91 |
+
logger.propagate = False
|
| 92 |
+
ch = logging.StreamHandler()
|
| 93 |
+
ch.setFormatter(CustomFormatter())
|
| 94 |
+
logger.addHandler(ch)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
| 98 |
+
try:
|
| 99 |
+
first_accolade_index = json_blob.find("{")
|
| 100 |
+
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
|
| 101 |
+
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'")
|
| 102 |
+
json_data = json.loads(json_blob, strict=False)
|
| 103 |
+
return json_data
|
| 104 |
+
except json.JSONDecodeError as e:
|
| 105 |
+
place = e.pos
|
| 106 |
+
if json_blob[place - 1 : place + 2] == "},\n":
|
| 107 |
+
raise ValueError(
|
| 108 |
+
"JSON is invalid: you probably tried to provide multiple tool calls in one action. PROVIDE ONLY ONE TOOL CALL."
|
| 109 |
+
)
|
| 110 |
+
raise ValueError(
|
| 111 |
+
f"The JSON blob you used is invalid due to the following error: {e}.\n"
|
| 112 |
+
f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n"
|
| 113 |
+
f"'{json_blob[place-4:place+5]}'."
|
| 114 |
+
)
|
| 115 |
+
except Exception as e:
|
| 116 |
+
raise ValueError(f"Error in parsing the JSON blob: {e}")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def parse_code_blob(code_blob: str) -> str:
|
| 120 |
+
try:
|
| 121 |
+
pattern = r"```(?:py|python)?\n(.*?)\n```"
|
| 122 |
+
match = re.search(pattern, code_blob, re.DOTALL)
|
| 123 |
+
return match.group(1).strip()
|
| 124 |
+
except Exception as e:
|
| 125 |
+
raise ValueError(
|
| 126 |
+
f"""
|
| 127 |
+
The code blob you used is invalid: due to the following error: {e}
|
| 128 |
+
This means that the regex pattern {pattern} was not respected: make sure to include code with the correct pattern, for instance:
|
| 129 |
+
Thoughts: Your thoughts
|
| 130 |
+
Code:
|
| 131 |
+
```py
|
| 132 |
+
# Your python code here
|
| 133 |
+
```<end_action>"""
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
|
| 138 |
+
json_blob = json_blob.replace("```json", "").replace("```", "")
|
| 139 |
+
tool_call = parse_json_blob(json_blob)
|
| 140 |
+
if "action" in tool_call and "action_input" in tool_call:
|
| 141 |
+
return tool_call["action"], tool_call["action_input"]
|
| 142 |
+
elif "action" in tool_call:
|
| 143 |
+
return tool_call["action"], None
|
| 144 |
+
else:
|
| 145 |
+
raise ValueError(
|
| 146 |
+
f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def parse_text_tool_call(text: str) -> Tuple[str, Union[str, Dict[str, str]]]:
|
| 151 |
+
"""
|
| 152 |
+
Expects a text in the format: 'Action:', 'Action input:', 'Observation:'. 'Action input:' contains a json string with input arguments.
|
| 153 |
+
"""
|
| 154 |
+
try:
|
| 155 |
+
if "Observation:" in text:
|
| 156 |
+
text = text.split("Observation:")[0]
|
| 157 |
+
if "Action:" in text:
|
| 158 |
+
text = text.split("Action:")[1]
|
| 159 |
+
tool_name, tool_input = text.split("Action input:")
|
| 160 |
+
if "{" in tool_input:
|
| 161 |
+
tool_input = parse_json_blob(tool_input)
|
| 162 |
+
else:
|
| 163 |
+
tool_input = tool_input.strip().replace('"', "")
|
| 164 |
+
return tool_name.strip().replace('"', "").replace("\\", ""), tool_input
|
| 165 |
+
except Exception as e:
|
| 166 |
+
raise ValueError(
|
| 167 |
+
f"Error in parsing the text tool call: {e}. Be sure to provide the correct format. DO NOT repeat your previous incorrect tool call."
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def to_text(input: Union[List[Dict[str, str]], Dict[str, str], str]) -> str:
|
| 172 |
+
if isinstance(input, list):
|
| 173 |
+
return "\n".join([m["content"] for m in input])
|
| 174 |
+
elif isinstance(input, dict):
|
| 175 |
+
return input["content"]
|
| 176 |
+
else:
|
| 177 |
+
return input
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
HUGGINGFACE_DEFAULT_TOOLS = {}
|
| 181 |
+
_tools_are_initialized = False
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class Toolbox:
|
| 185 |
+
"""
|
| 186 |
+
The toolbox contains all tools that the agent can perform operations with, as well as a few methods to
|
| 187 |
+
manage them.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
tools (`List[Tool]`):
|
| 191 |
+
The list of tools to instantiate the toolbox with
|
| 192 |
+
add_base_tools (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 193 |
+
Whether to add the tools available within `transformers` to the toolbox.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
def __init__(self, tools: List[Tool], add_base_tools: bool = False):
|
| 197 |
+
self._tools = {tool.name: tool for tool in tools}
|
| 198 |
+
if add_base_tools:
|
| 199 |
+
self.add_base_tools()
|
| 200 |
+
self._load_tools_if_needed()
|
| 201 |
+
|
| 202 |
+
def add_base_tools(self, add_python_interpreter: bool = False):
|
| 203 |
+
global _tools_are_initialized
|
| 204 |
+
global HUGGINGFACE_DEFAULT_TOOLS
|
| 205 |
+
if not _tools_are_initialized:
|
| 206 |
+
HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools(logger)
|
| 207 |
+
_tools_are_initialized = True
|
| 208 |
+
for tool in HUGGINGFACE_DEFAULT_TOOLS.values():
|
| 209 |
+
if tool.name != "python_interpreter" or add_python_interpreter:
|
| 210 |
+
self.add_tool(tool)
|
| 211 |
+
self._load_tools_if_needed()
|
| 212 |
+
|
| 213 |
+
@property
|
| 214 |
+
def tools(self) -> Dict[str, Tool]:
|
| 215 |
+
"""Get all tools currently in the toolbox"""
|
| 216 |
+
return self._tools
|
| 217 |
+
|
| 218 |
+
def show_tool_descriptions(self, tool_description_template: str = None) -> str:
|
| 219 |
+
"""
|
| 220 |
+
Returns the description of all tools in the toolbox
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
tool_description_template (`str`, *optional*):
|
| 224 |
+
The template to use to describe the tools. If not provided, the default template will be used.
|
| 225 |
+
"""
|
| 226 |
+
return "\n".join(
|
| 227 |
+
[get_tool_description_with_args(tool, tool_description_template) for tool in self._tools.values()]
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
def add_tool(self, tool: Tool):
|
| 231 |
+
"""
|
| 232 |
+
Adds a tool to the toolbox
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
tool (`Tool`):
|
| 236 |
+
The tool to add to the toolbox.
|
| 237 |
+
"""
|
| 238 |
+
if tool.name in self._tools:
|
| 239 |
+
raise KeyError(f"Error: tool '{tool.name}' already exists in the toolbox.")
|
| 240 |
+
self._tools[tool.name] = tool
|
| 241 |
+
|
| 242 |
+
def remove_tool(self, tool_name: str):
|
| 243 |
+
"""
|
| 244 |
+
Removes a tool from the toolbox
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
tool_name (`str`):
|
| 248 |
+
The tool to remove from the toolbox.
|
| 249 |
+
"""
|
| 250 |
+
if tool_name not in self._tools:
|
| 251 |
+
raise KeyError(
|
| 252 |
+
f"Error: tool {tool_name} not found in toolbox for removal, should be instead one of {list(self._tools.keys())}."
|
| 253 |
+
)
|
| 254 |
+
del self._tools[tool_name]
|
| 255 |
+
|
| 256 |
+
def update_tool(self, tool: Tool):
|
| 257 |
+
"""
|
| 258 |
+
Updates a tool in the toolbox according to its name.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
tool (`Tool`):
|
| 262 |
+
The tool to update to the toolbox.
|
| 263 |
+
"""
|
| 264 |
+
if tool.name not in self._tools:
|
| 265 |
+
raise KeyError(
|
| 266 |
+
f"Error: tool {tool.name} not found in toolbox for update, should be instead one of {list(self._tools.keys())}."
|
| 267 |
+
)
|
| 268 |
+
self._tools[tool.name] = tool
|
| 269 |
+
|
| 270 |
+
def clear_toolbox(self):
|
| 271 |
+
"""Clears the toolbox"""
|
| 272 |
+
self._tools = {}
|
| 273 |
+
|
| 274 |
+
def _load_tools_if_needed(self):
|
| 275 |
+
for name, tool in self._tools.items():
|
| 276 |
+
if not isinstance(tool, Tool):
|
| 277 |
+
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
|
| 278 |
+
self._tools[name] = load_tool(task_or_repo_id)
|
| 279 |
+
|
| 280 |
+
def __repr__(self):
|
| 281 |
+
toolbox_description = "Toolbox contents:\n"
|
| 282 |
+
for tool in self._tools.values():
|
| 283 |
+
toolbox_description += f"\t{tool.name}: {tool.description}\n"
|
| 284 |
+
return toolbox_description
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class AgentError(Exception):
|
| 288 |
+
"""Base class for other agent-related exceptions"""
|
| 289 |
+
|
| 290 |
+
def __init__(self, message):
|
| 291 |
+
super().__init__(message)
|
| 292 |
+
self.message = message
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class AgentParsingError(AgentError):
|
| 296 |
+
"""Exception raised for errors in parsing in the agent"""
|
| 297 |
+
|
| 298 |
+
pass
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class AgentExecutionError(AgentError):
|
| 302 |
+
"""Exception raised for errors in execution in the agent"""
|
| 303 |
+
|
| 304 |
+
pass
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class AgentMaxIterationsError(AgentError):
|
| 308 |
+
"""Exception raised for errors in execution in the agent"""
|
| 309 |
+
|
| 310 |
+
pass
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class AgentGenerationError(AgentError):
|
| 314 |
+
"""Exception raised for errors in generation in the agent"""
|
| 315 |
+
|
| 316 |
+
pass
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
|
| 320 |
+
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
|
| 321 |
+
prompt = prompt_template.replace("<<tool_descriptions>>", tool_descriptions)
|
| 322 |
+
|
| 323 |
+
if "<<tool_names>>" in prompt:
|
| 324 |
+
tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()]
|
| 325 |
+
prompt = prompt.replace("<<tool_names>>", ", ".join(tool_names))
|
| 326 |
+
|
| 327 |
+
return prompt
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def show_agents_descriptions(managed_agents: list):
|
| 331 |
+
managed_agents_descriptions = """
|
| 332 |
+
You can also give requests to team members.
|
| 333 |
+
Calling a team member works the same as for calling a tool: simply, the only argument you can give in the call is 'request', a long string explaning your request.
|
| 334 |
+
Given that this team member is a real human, you should be very verbose in your request.
|
| 335 |
+
Here is a list of the team members that you can call:"""
|
| 336 |
+
for agent in managed_agents.values():
|
| 337 |
+
managed_agents_descriptions += f"\n- {agent.name}: {agent.description}"
|
| 338 |
+
return managed_agents_descriptions
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def format_prompt_with_managed_agents_descriptions(prompt_template, managed_agents=None) -> str:
|
| 342 |
+
if managed_agents is not None:
|
| 343 |
+
return prompt_template.replace("<<managed_agents_descriptions>>", show_agents_descriptions(managed_agents))
|
| 344 |
+
else:
|
| 345 |
+
return prompt_template.replace("<<managed_agents_descriptions>>", "")
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str:
|
| 349 |
+
if "<<authorized_imports>>" not in prompt_template:
|
| 350 |
+
raise AgentError("Tag '<<authorized_imports>>' should be provided in the prompt.")
|
| 351 |
+
return prompt_template.replace("<<authorized_imports>>", str(authorized_imports))
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class Agent:
|
| 355 |
+
def __init__(
|
| 356 |
+
self,
|
| 357 |
+
tools: Union[List[Tool], Toolbox],
|
| 358 |
+
llm_engine: Callable = None,
|
| 359 |
+
system_prompt: Optional[str] = None,
|
| 360 |
+
tool_description_template: Optional[str] = None,
|
| 361 |
+
additional_args: Dict = {},
|
| 362 |
+
max_iterations: int = 6,
|
| 363 |
+
tool_parser: Optional[Callable] = None,
|
| 364 |
+
add_base_tools: bool = False,
|
| 365 |
+
verbose: int = 0,
|
| 366 |
+
grammar: Optional[Dict[str, str]] = None,
|
| 367 |
+
managed_agents: Optional[List] = None,
|
| 368 |
+
step_callbacks: Optional[List[Callable]] = None,
|
| 369 |
+
monitor_metrics: bool = True,
|
| 370 |
+
):
|
| 371 |
+
if system_prompt is None:
|
| 372 |
+
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
|
| 373 |
+
if tool_parser is None:
|
| 374 |
+
tool_parser = parse_json_tool_call
|
| 375 |
+
self.agent_name = self.__class__.__name__
|
| 376 |
+
self.llm_engine = llm_engine
|
| 377 |
+
self.system_prompt_template = system_prompt
|
| 378 |
+
self.tool_description_template = (
|
| 379 |
+
tool_description_template if tool_description_template else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
| 380 |
+
)
|
| 381 |
+
self.additional_args = additional_args
|
| 382 |
+
self.max_iterations = max_iterations
|
| 383 |
+
self.logger = logger
|
| 384 |
+
self.tool_parser = tool_parser
|
| 385 |
+
self.grammar = grammar
|
| 386 |
+
|
| 387 |
+
self.managed_agents = None
|
| 388 |
+
if managed_agents is not None:
|
| 389 |
+
self.managed_agents = {agent.name: agent for agent in managed_agents}
|
| 390 |
+
|
| 391 |
+
if isinstance(tools, Toolbox):
|
| 392 |
+
self._toolbox = tools
|
| 393 |
+
if add_base_tools:
|
| 394 |
+
if not is_torch_available():
|
| 395 |
+
raise ImportError("Using the base tools requires torch to be installed.")
|
| 396 |
+
|
| 397 |
+
self._toolbox.add_base_tools(add_python_interpreter=(self.__class__ == ReactJsonAgent))
|
| 398 |
+
else:
|
| 399 |
+
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
|
| 400 |
+
self._toolbox.add_tool(FinalAnswerTool())
|
| 401 |
+
|
| 402 |
+
self.system_prompt = format_prompt_with_tools(
|
| 403 |
+
self._toolbox, self.system_prompt_template, self.tool_description_template
|
| 404 |
+
)
|
| 405 |
+
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
|
| 406 |
+
self.prompt = None
|
| 407 |
+
self.logs = []
|
| 408 |
+
self.task = None
|
| 409 |
+
|
| 410 |
+
if verbose == 0:
|
| 411 |
+
logger.setLevel(logging.WARNING)
|
| 412 |
+
elif verbose == 1:
|
| 413 |
+
logger.setLevel(logging.INFO)
|
| 414 |
+
elif verbose == 2:
|
| 415 |
+
logger.setLevel(logging.DEBUG)
|
| 416 |
+
|
| 417 |
+
# Initialize step callbacks
|
| 418 |
+
self.step_callbacks = step_callbacks if step_callbacks is not None else []
|
| 419 |
+
|
| 420 |
+
# Initialize Monitor if monitor_metrics is True
|
| 421 |
+
self.monitor = None
|
| 422 |
+
if monitor_metrics:
|
| 423 |
+
self.monitor = Monitor(self.llm_engine)
|
| 424 |
+
self.step_callbacks.append(self.monitor.update_metrics)
|
| 425 |
+
|
| 426 |
+
@property
|
| 427 |
+
def toolbox(self) -> Toolbox:
|
| 428 |
+
"""Get the toolbox currently available to the agent"""
|
| 429 |
+
return self._toolbox
|
| 430 |
+
|
| 431 |
+
def initialize_for_run(self):
|
| 432 |
+
self.token_count = 0
|
| 433 |
+
self.system_prompt = format_prompt_with_tools(
|
| 434 |
+
self._toolbox,
|
| 435 |
+
self.system_prompt_template,
|
| 436 |
+
self.tool_description_template,
|
| 437 |
+
)
|
| 438 |
+
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
|
| 439 |
+
if hasattr(self, "authorized_imports"):
|
| 440 |
+
self.system_prompt = format_prompt_with_imports(
|
| 441 |
+
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
|
| 442 |
+
)
|
| 443 |
+
self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
|
| 444 |
+
self.logger.log(33, "======== New task ========")
|
| 445 |
+
self.logger.log(34, self.task)
|
| 446 |
+
self.logger.debug("System prompt is as follows:")
|
| 447 |
+
self.logger.debug(self.system_prompt)
|
| 448 |
+
|
| 449 |
+
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
|
| 450 |
+
"""
|
| 451 |
+
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
|
| 452 |
+
that can be used as input to the LLM.
|
| 453 |
+
"""
|
| 454 |
+
prompt_message = {"role": MessageRole.SYSTEM, "content": self.logs[0]["system_prompt"]}
|
| 455 |
+
task_message = {
|
| 456 |
+
"role": MessageRole.USER,
|
| 457 |
+
"content": "Task: " + self.logs[0]["task"],
|
| 458 |
+
}
|
| 459 |
+
if summary_mode:
|
| 460 |
+
memory = [task_message]
|
| 461 |
+
else:
|
| 462 |
+
memory = [prompt_message, task_message]
|
| 463 |
+
for i, step_log in enumerate(self.logs[1:]):
|
| 464 |
+
if "llm_output" in step_log and not summary_mode:
|
| 465 |
+
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"].strip()}
|
| 466 |
+
memory.append(thought_message)
|
| 467 |
+
if "facts" in step_log:
|
| 468 |
+
thought_message = {
|
| 469 |
+
"role": MessageRole.ASSISTANT,
|
| 470 |
+
"content": "[FACTS LIST]:\n" + step_log["facts"].strip(),
|
| 471 |
+
}
|
| 472 |
+
memory.append(thought_message)
|
| 473 |
+
|
| 474 |
+
if "plan" in step_log and not summary_mode:
|
| 475 |
+
thought_message = {"role": MessageRole.ASSISTANT, "content": "[PLAN]:\n" + step_log["plan"].strip()}
|
| 476 |
+
memory.append(thought_message)
|
| 477 |
+
|
| 478 |
+
if "tool_call" in step_log and summary_mode:
|
| 479 |
+
tool_call_message = {
|
| 480 |
+
"role": MessageRole.ASSISTANT,
|
| 481 |
+
"content": f"[STEP {i} TOOL CALL]: " + str(step_log["tool_call"]).strip(),
|
| 482 |
+
}
|
| 483 |
+
memory.append(tool_call_message)
|
| 484 |
+
|
| 485 |
+
if "task" in step_log:
|
| 486 |
+
tool_call_message = {
|
| 487 |
+
"role": MessageRole.USER,
|
| 488 |
+
"content": "New task:\n" + step_log["task"],
|
| 489 |
+
}
|
| 490 |
+
memory.append(tool_call_message)
|
| 491 |
+
|
| 492 |
+
if "error" in step_log or "observation" in step_log:
|
| 493 |
+
if "error" in step_log:
|
| 494 |
+
message_content = (
|
| 495 |
+
f"[OUTPUT OF STEP {i}] -> Error:\n"
|
| 496 |
+
+ str(step_log["error"])
|
| 497 |
+
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
|
| 498 |
+
)
|
| 499 |
+
elif "observation" in step_log:
|
| 500 |
+
message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log['observation']}"
|
| 501 |
+
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
|
| 502 |
+
memory.append(tool_response_message)
|
| 503 |
+
|
| 504 |
+
return memory
|
| 505 |
+
|
| 506 |
+
def get_succinct_logs(self):
|
| 507 |
+
return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs]
|
| 508 |
+
|
| 509 |
+
def extract_action(self, llm_output: str, split_token: str) -> str:
|
| 510 |
+
"""
|
| 511 |
+
Parse action from the LLM output
|
| 512 |
+
|
| 513 |
+
Args:
|
| 514 |
+
llm_output (`str`): Output of the LLM
|
| 515 |
+
split_token (`str`): Separator for the action. Should match the example in the system prompt.
|
| 516 |
+
"""
|
| 517 |
+
try:
|
| 518 |
+
split = llm_output.split(split_token)
|
| 519 |
+
rationale, action = (
|
| 520 |
+
split[-2],
|
| 521 |
+
split[-1],
|
| 522 |
+
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
|
| 523 |
+
except Exception as e:
|
| 524 |
+
self.logger.error(e, exc_info=1)
|
| 525 |
+
raise AgentParsingError(
|
| 526 |
+
f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
|
| 527 |
+
)
|
| 528 |
+
return rationale.strip(), action.strip()
|
| 529 |
+
|
| 530 |
+
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
|
| 531 |
+
"""
|
| 532 |
+
Execute tool with the provided input and returns the result.
|
| 533 |
+
This method replaces arguments with the actual values from the state if they refer to state variables.
|
| 534 |
+
|
| 535 |
+
Args:
|
| 536 |
+
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
|
| 537 |
+
arguments (Dict[str, str]): Arguments passed to the Tool.
|
| 538 |
+
"""
|
| 539 |
+
available_tools = self.toolbox.tools
|
| 540 |
+
if self.managed_agents is not None:
|
| 541 |
+
available_tools = {**available_tools, **self.managed_agents}
|
| 542 |
+
if tool_name not in available_tools:
|
| 543 |
+
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
| 544 |
+
self.logger.error(error_msg, exc_info=1)
|
| 545 |
+
raise AgentExecutionError(error_msg)
|
| 546 |
+
|
| 547 |
+
try:
|
| 548 |
+
if isinstance(arguments, str):
|
| 549 |
+
observation = available_tools[tool_name](arguments)
|
| 550 |
+
elif isinstance(arguments, dict):
|
| 551 |
+
for key, value in arguments.items():
|
| 552 |
+
# if the value is the name of a state variable like "image.png", replace it with the actual value
|
| 553 |
+
if isinstance(value, str) and value in self.state:
|
| 554 |
+
arguments[key] = self.state[value]
|
| 555 |
+
observation = available_tools[tool_name](**arguments)
|
| 556 |
+
else:
|
| 557 |
+
raise AgentExecutionError(
|
| 558 |
+
f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
| 559 |
+
)
|
| 560 |
+
return observation
|
| 561 |
+
except Exception as e:
|
| 562 |
+
if tool_name in self.toolbox.tools:
|
| 563 |
+
raise AgentExecutionError(
|
| 564 |
+
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
| 565 |
+
f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(available_tools[tool_name])}"
|
| 566 |
+
)
|
| 567 |
+
elif tool_name in self.managed_agents:
|
| 568 |
+
raise AgentExecutionError(
|
| 569 |
+
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
|
| 570 |
+
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
def log_rationale_code_action(self, rationale: str, code_action: str) -> None:
|
| 574 |
+
self.logger.warning("=== Agent thoughts:")
|
| 575 |
+
self.logger.log(31, rationale)
|
| 576 |
+
self.logger.warning(">>> Agent is executing the code below:")
|
| 577 |
+
if is_pygments_available():
|
| 578 |
+
self.logger.log(
|
| 579 |
+
31, highlight(code_action, PythonLexer(ensurenl=False), Terminal256Formatter(style="nord"))
|
| 580 |
+
)
|
| 581 |
+
else:
|
| 582 |
+
self.logger.log(31, code_action)
|
| 583 |
+
self.logger.warning("====")
|
| 584 |
+
|
| 585 |
+
def run(self, **kwargs):
|
| 586 |
+
"""To be implemented in the child class"""
|
| 587 |
+
raise NotImplementedError
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
class CodeAgent(Agent):
|
| 591 |
+
"""
|
| 592 |
+
A class for an agent that solves the given task using a single block of code. It plans all its actions, then executes all in one shot.
|
| 593 |
+
"""
|
| 594 |
+
|
| 595 |
+
def __init__(
|
| 596 |
+
self,
|
| 597 |
+
tools: List[Tool],
|
| 598 |
+
llm_engine: Optional[Callable] = None,
|
| 599 |
+
system_prompt: Optional[str] = None,
|
| 600 |
+
tool_description_template: Optional[str] = None,
|
| 601 |
+
grammar: Optional[Dict[str, str]] = None,
|
| 602 |
+
additional_authorized_imports: Optional[List[str]] = None,
|
| 603 |
+
**kwargs,
|
| 604 |
+
):
|
| 605 |
+
if llm_engine is None:
|
| 606 |
+
llm_engine = HfApiEngine()
|
| 607 |
+
if system_prompt is None:
|
| 608 |
+
system_prompt = DEFAULT_CODE_SYSTEM_PROMPT
|
| 609 |
+
if tool_description_template is None:
|
| 610 |
+
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
| 611 |
+
super().__init__(
|
| 612 |
+
tools=tools,
|
| 613 |
+
llm_engine=llm_engine,
|
| 614 |
+
system_prompt=system_prompt,
|
| 615 |
+
tool_description_template=tool_description_template,
|
| 616 |
+
grammar=grammar,
|
| 617 |
+
**kwargs,
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
if not is_pygments_available():
|
| 621 |
+
transformers_logging.warning_once(
|
| 622 |
+
logger,
|
| 623 |
+
"pygments isn't installed. Installing pygments will enable color syntax highlighting in the "
|
| 624 |
+
"CodeAgent.",
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
self.python_evaluator = evaluate_python_code
|
| 628 |
+
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
| 629 |
+
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
|
| 630 |
+
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
|
| 631 |
+
|
| 632 |
+
def parse_code_blob(self, result: str) -> str:
|
| 633 |
+
"""
|
| 634 |
+
Override this method if you want to change the way the code is
|
| 635 |
+
cleaned in the `run` method.
|
| 636 |
+
"""
|
| 637 |
+
return parse_code_blob(result)
|
| 638 |
+
|
| 639 |
+
def run(self, task: str, return_generated_code: bool = False, **kwargs):
|
| 640 |
+
"""
|
| 641 |
+
Runs the agent for the given task.
|
| 642 |
+
|
| 643 |
+
Args:
|
| 644 |
+
task (`str`): The task to perform
|
| 645 |
+
return_generated_code (`bool`, *optional*, defaults to `False`): Whether to return the generated code instead of running it
|
| 646 |
+
kwargs (additional keyword arguments, *optional*):
|
| 647 |
+
Any keyword argument to send to the agent when evaluating the code.
|
| 648 |
+
|
| 649 |
+
Example:
|
| 650 |
+
|
| 651 |
+
```py
|
| 652 |
+
from transformers.agents import CodeAgent
|
| 653 |
+
|
| 654 |
+
agent = CodeAgent(tools=[])
|
| 655 |
+
agent.run("What is the result of 2 power 3.7384?")
|
| 656 |
+
```
|
| 657 |
+
"""
|
| 658 |
+
self.task = task
|
| 659 |
+
if len(kwargs) > 0:
|
| 660 |
+
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
| 661 |
+
self.state = kwargs.copy()
|
| 662 |
+
self.initialize_for_run()
|
| 663 |
+
|
| 664 |
+
# Run LLM
|
| 665 |
+
prompt_message = {"role": MessageRole.SYSTEM, "content": self.system_prompt}
|
| 666 |
+
task_message = {
|
| 667 |
+
"role": MessageRole.USER,
|
| 668 |
+
"content": "Task: " + self.task,
|
| 669 |
+
}
|
| 670 |
+
|
| 671 |
+
self.prompt = [prompt_message, task_message]
|
| 672 |
+
self.logger.info("====Executing with this prompt====")
|
| 673 |
+
self.logger.info(self.prompt)
|
| 674 |
+
|
| 675 |
+
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
| 676 |
+
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"], **additional_args)
|
| 677 |
+
|
| 678 |
+
if return_generated_code:
|
| 679 |
+
return llm_output
|
| 680 |
+
|
| 681 |
+
# Parse
|
| 682 |
+
try:
|
| 683 |
+
rationale, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
| 684 |
+
except Exception as e:
|
| 685 |
+
self.logger.debug(
|
| 686 |
+
f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}"
|
| 687 |
+
)
|
| 688 |
+
rationale, code_action = "", llm_output
|
| 689 |
+
|
| 690 |
+
try:
|
| 691 |
+
code_action = self.parse_code_blob(code_action)
|
| 692 |
+
except Exception as e:
|
| 693 |
+
error_msg = f"Error in code parsing: {e}. Be sure to provide correct code"
|
| 694 |
+
self.logger.error(error_msg, exc_info=1)
|
| 695 |
+
return error_msg
|
| 696 |
+
|
| 697 |
+
# Execute
|
| 698 |
+
self.log_rationale_code_action(rationale, code_action)
|
| 699 |
+
try:
|
| 700 |
+
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
|
| 701 |
+
output = self.python_evaluator(
|
| 702 |
+
code_action,
|
| 703 |
+
static_tools=available_tools,
|
| 704 |
+
custom_tools={},
|
| 705 |
+
state=self.state,
|
| 706 |
+
authorized_imports=self.authorized_imports,
|
| 707 |
+
)
|
| 708 |
+
self.logger.info(self.state["print_outputs"])
|
| 709 |
+
return output
|
| 710 |
+
except Exception as e:
|
| 711 |
+
error_msg = f"Error in execution: {e}. Be sure to provide correct code."
|
| 712 |
+
self.logger.error(error_msg, exc_info=1)
|
| 713 |
+
return error_msg
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
class ReactAgent(Agent):
|
| 717 |
+
"""
|
| 718 |
+
This agent that solves the given task step by step, using the ReAct framework:
|
| 719 |
+
While the objective is not reached, the agent will perform a cycle of thinking and acting.
|
| 720 |
+
The action will be parsed from the LLM output: it consists in calls to tools from the toolbox, with arguments chosen by the LLM engine.
|
| 721 |
+
"""
|
| 722 |
+
|
| 723 |
+
def __init__(
|
| 724 |
+
self,
|
| 725 |
+
tools: List[Tool],
|
| 726 |
+
llm_engine: Optional[Callable] = None,
|
| 727 |
+
system_prompt: Optional[str] = None,
|
| 728 |
+
tool_description_template: Optional[str] = None,
|
| 729 |
+
grammar: Optional[Dict[str, str]] = None,
|
| 730 |
+
plan_type: Optional[str] = None,
|
| 731 |
+
planning_interval: Optional[int] = None,
|
| 732 |
+
**kwargs,
|
| 733 |
+
):
|
| 734 |
+
if llm_engine is None:
|
| 735 |
+
llm_engine = HfApiEngine()
|
| 736 |
+
if system_prompt is None:
|
| 737 |
+
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
|
| 738 |
+
if tool_description_template is None:
|
| 739 |
+
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
| 740 |
+
if plan_type is None:
|
| 741 |
+
plan_type = SUPPORTED_PLAN_TYPES[0]
|
| 742 |
+
else:
|
| 743 |
+
assert plan_type in SUPPORTED_PLAN_TYPES, f"plan type {plan_type} is not supported"
|
| 744 |
+
super().__init__(
|
| 745 |
+
tools=tools,
|
| 746 |
+
llm_engine=llm_engine,
|
| 747 |
+
system_prompt=system_prompt,
|
| 748 |
+
tool_description_template=tool_description_template,
|
| 749 |
+
grammar=grammar,
|
| 750 |
+
**kwargs,
|
| 751 |
+
)
|
| 752 |
+
self.planning_interval = planning_interval
|
| 753 |
+
self.plan_type = plan_type
|
| 754 |
+
|
| 755 |
+
def provide_final_answer(self, task) -> str:
|
| 756 |
+
"""
|
| 757 |
+
This method provides a final answer to the task, based on the logs of the agent's interactions.
|
| 758 |
+
"""
|
| 759 |
+
self.prompt = [
|
| 760 |
+
{
|
| 761 |
+
"role": MessageRole.SYSTEM,
|
| 762 |
+
"content": "An agent tried to answer an user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
|
| 763 |
+
}
|
| 764 |
+
]
|
| 765 |
+
self.prompt += self.write_inner_memory_from_logs()[1:]
|
| 766 |
+
self.prompt += [
|
| 767 |
+
{
|
| 768 |
+
"role": MessageRole.USER,
|
| 769 |
+
"content": f"Based on the above, please provide an answer to the following user request:\n{task}",
|
| 770 |
+
}
|
| 771 |
+
]
|
| 772 |
+
try:
|
| 773 |
+
return self.llm_engine(self.prompt)
|
| 774 |
+
except Exception as e:
|
| 775 |
+
return f"Error in generating final llm output: {e}."
|
| 776 |
+
|
| 777 |
+
def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
|
| 778 |
+
"""
|
| 779 |
+
Runs the agent for the given task.
|
| 780 |
+
|
| 781 |
+
Args:
|
| 782 |
+
task (`str`): The task to perform
|
| 783 |
+
|
| 784 |
+
Example:
|
| 785 |
+
```py
|
| 786 |
+
from transformers.agents import ReactCodeAgent
|
| 787 |
+
agent = ReactCodeAgent(tools=[])
|
| 788 |
+
agent.run("What is the result of 2 power 3.7384?")
|
| 789 |
+
```
|
| 790 |
+
"""
|
| 791 |
+
self.task = task
|
| 792 |
+
if len(kwargs) > 0:
|
| 793 |
+
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
| 794 |
+
self.state = kwargs.copy()
|
| 795 |
+
if reset:
|
| 796 |
+
self.initialize_for_run()
|
| 797 |
+
else:
|
| 798 |
+
self.logs.append({"task": task})
|
| 799 |
+
if stream:
|
| 800 |
+
return self.stream_run(task)
|
| 801 |
+
else:
|
| 802 |
+
return self.direct_run(task)
|
| 803 |
+
|
| 804 |
+
def stream_run(self, task: str):
|
| 805 |
+
"""
|
| 806 |
+
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
|
| 807 |
+
"""
|
| 808 |
+
final_answer = None
|
| 809 |
+
iteration = 0
|
| 810 |
+
while final_answer is None and iteration < self.max_iterations:
|
| 811 |
+
step_start_time = time.time()
|
| 812 |
+
step_log_entry = {"iteration": iteration, "start_time": step_start_time}
|
| 813 |
+
try:
|
| 814 |
+
self.step(step_log_entry)
|
| 815 |
+
if "final_answer" in step_log_entry:
|
| 816 |
+
final_answer = step_log_entry["final_answer"]
|
| 817 |
+
except AgentError as e:
|
| 818 |
+
self.logger.error(e, exc_info=1)
|
| 819 |
+
step_log_entry["error"] = e
|
| 820 |
+
finally:
|
| 821 |
+
step_end_time = time.time()
|
| 822 |
+
step_log_entry["step_end_time"] = step_end_time
|
| 823 |
+
step_log_entry["step_duration"] = step_end_time - step_start_time
|
| 824 |
+
self.logs.append(step_log_entry)
|
| 825 |
+
for callback in self.step_callbacks:
|
| 826 |
+
callback(step_log_entry)
|
| 827 |
+
iteration += 1
|
| 828 |
+
yield step_log_entry
|
| 829 |
+
|
| 830 |
+
if final_answer is None and iteration == self.max_iterations:
|
| 831 |
+
error_message = "Reached max iterations."
|
| 832 |
+
final_step_log = {"error": AgentMaxIterationsError(error_message)}
|
| 833 |
+
self.logs.append(final_step_log)
|
| 834 |
+
self.logger.error(error_message, exc_info=1)
|
| 835 |
+
final_answer = self.provide_final_answer(task)
|
| 836 |
+
final_step_log["final_answer"] = final_answer
|
| 837 |
+
final_step_log["step_duration"] = 0
|
| 838 |
+
for callback in self.step_callbacks:
|
| 839 |
+
callback(final_step_log)
|
| 840 |
+
yield final_step_log
|
| 841 |
+
|
| 842 |
+
yield final_answer
|
| 843 |
+
|
| 844 |
+
def direct_run(self, task: str):
|
| 845 |
+
"""
|
| 846 |
+
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
|
| 847 |
+
"""
|
| 848 |
+
final_answer = None
|
| 849 |
+
iteration = 0
|
| 850 |
+
while final_answer is None and iteration < self.max_iterations:
|
| 851 |
+
step_start_time = time.time()
|
| 852 |
+
step_log_entry = {"iteration": iteration, "start_time": step_start_time}
|
| 853 |
+
try:
|
| 854 |
+
if self.planning_interval is not None and iteration % self.planning_interval == 0:
|
| 855 |
+
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
|
| 856 |
+
self.step(step_log_entry)
|
| 857 |
+
if "final_answer" in step_log_entry:
|
| 858 |
+
final_answer = step_log_entry["final_answer"]
|
| 859 |
+
except AgentError as e:
|
| 860 |
+
self.logger.error(e, exc_info=1)
|
| 861 |
+
step_log_entry["error"] = e
|
| 862 |
+
finally:
|
| 863 |
+
step_end_time = time.time()
|
| 864 |
+
step_log_entry["step_end_time"] = step_end_time
|
| 865 |
+
step_log_entry["step_duration"] = step_end_time - step_start_time
|
| 866 |
+
self.logs.append(step_log_entry)
|
| 867 |
+
for callback in self.step_callbacks:
|
| 868 |
+
callback(step_log_entry)
|
| 869 |
+
iteration += 1
|
| 870 |
+
|
| 871 |
+
if final_answer is None and iteration == self.max_iterations:
|
| 872 |
+
error_message = "Reached max iterations."
|
| 873 |
+
final_step_log = {"error": AgentMaxIterationsError(error_message)}
|
| 874 |
+
self.logs.append(final_step_log)
|
| 875 |
+
self.logger.error(error_message, exc_info=1)
|
| 876 |
+
final_answer = self.provide_final_answer(task)
|
| 877 |
+
final_step_log["final_answer"] = final_answer
|
| 878 |
+
final_step_log["step_duration"] = 0
|
| 879 |
+
for callback in self.step_callbacks:
|
| 880 |
+
callback(final_step_log)
|
| 881 |
+
|
| 882 |
+
return final_answer
|
| 883 |
+
|
| 884 |
+
def planning_step(self, task, is_first_step: bool = False, iteration: int = None):
|
| 885 |
+
"""
|
| 886 |
+
Used periodically by the agent to plan the next steps to reach the objective.
|
| 887 |
+
|
| 888 |
+
Args:
|
| 889 |
+
task (`str`): The task to perform
|
| 890 |
+
is_first_step (`bool`): If this step is not the first one, the plan should be an update over a previous plan.
|
| 891 |
+
iteration (`int`): The number of the current step, used as an indication for the LLM.
|
| 892 |
+
"""
|
| 893 |
+
if is_first_step:
|
| 894 |
+
message_prompt_facts = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_FACTS}
|
| 895 |
+
message_prompt_task = {
|
| 896 |
+
"role": MessageRole.USER,
|
| 897 |
+
"content": f"""Here is the task:
|
| 898 |
+
```
|
| 899 |
+
{task}
|
| 900 |
+
```
|
| 901 |
+
Now begin!""",
|
| 902 |
+
}
|
| 903 |
+
|
| 904 |
+
answer_facts = self.llm_engine([message_prompt_facts, message_prompt_task])
|
| 905 |
+
|
| 906 |
+
message_system_prompt_plan = {
|
| 907 |
+
"role": MessageRole.SYSTEM,
|
| 908 |
+
"content": PROMPTS_FOR_INITIAL_PLAN[self.plan_type]["system"],
|
| 909 |
+
}
|
| 910 |
+
message_user_prompt_plan = {
|
| 911 |
+
"role": MessageRole.USER,
|
| 912 |
+
"content": PROMPTS_FOR_INITIAL_PLAN[self.plan_type]["user"].format(
|
| 913 |
+
task=task,
|
| 914 |
+
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
|
| 915 |
+
managed_agents_descriptions=(
|
| 916 |
+
show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else ""
|
| 917 |
+
),
|
| 918 |
+
answer_facts=answer_facts,
|
| 919 |
+
),
|
| 920 |
+
}
|
| 921 |
+
answer_plan = self.llm_engine(
|
| 922 |
+
[message_system_prompt_plan, message_user_prompt_plan], stop_sequences=["<end_plan>"]
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
|
| 926 |
+
```
|
| 927 |
+
{answer_plan}
|
| 928 |
+
```"""
|
| 929 |
+
final_facts_redaction = f"""Here are the facts that I know so far:
|
| 930 |
+
```
|
| 931 |
+
{answer_facts}
|
| 932 |
+
```""".strip()
|
| 933 |
+
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
| 934 |
+
self.logger.log(36, "===== Initial plan =====")
|
| 935 |
+
self.logger.log(35, final_plan_redaction)
|
| 936 |
+
else: # update plan
|
| 937 |
+
agent_memory = self.write_inner_memory_from_logs(
|
| 938 |
+
summary_mode=False
|
| 939 |
+
) # This will not log the plan but will log facts
|
| 940 |
+
|
| 941 |
+
# Redact updated facts
|
| 942 |
+
facts_update_system_prompt = {
|
| 943 |
+
"role": MessageRole.SYSTEM,
|
| 944 |
+
"content": SYSTEM_PROMPT_FACTS_UPDATE,
|
| 945 |
+
}
|
| 946 |
+
facts_update_message = {
|
| 947 |
+
"role": MessageRole.USER,
|
| 948 |
+
"content": USER_PROMPT_FACTS_UPDATE,
|
| 949 |
+
}
|
| 950 |
+
facts_update = self.llm_engine([facts_update_system_prompt] + agent_memory + [facts_update_message])
|
| 951 |
+
|
| 952 |
+
# Redact updated plan
|
| 953 |
+
plan_update_message = {
|
| 954 |
+
"role": MessageRole.SYSTEM,
|
| 955 |
+
"content": PROMPTS_FOR_PLAN_UPDATE[self.plan_type]["system"].format(task=task),
|
| 956 |
+
}
|
| 957 |
+
plan_update_message_user = {
|
| 958 |
+
"role": MessageRole.USER,
|
| 959 |
+
"content": PROMPTS_FOR_PLAN_UPDATE[self.plan_type]["user"].format(
|
| 960 |
+
task=task,
|
| 961 |
+
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
|
| 962 |
+
managed_agents_descriptions=(
|
| 963 |
+
show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else ""
|
| 964 |
+
),
|
| 965 |
+
facts_update=facts_update,
|
| 966 |
+
remaining_steps=(self.max_iterations - iteration),
|
| 967 |
+
),
|
| 968 |
+
}
|
| 969 |
+
plan_update = self.llm_engine(
|
| 970 |
+
[plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=["<end_plan>"]
|
| 971 |
+
)
|
| 972 |
+
|
| 973 |
+
# Log final facts and plan
|
| 974 |
+
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
|
| 975 |
+
final_facts_redaction = f"""Here is the updated list of the facts that I know:
|
| 976 |
+
```
|
| 977 |
+
{facts_update}
|
| 978 |
+
```"""
|
| 979 |
+
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
| 980 |
+
self.logger.log(36, "===== Updated plan =====")
|
| 981 |
+
self.logger.log(35, final_plan_redaction)
|
| 982 |
+
|
| 983 |
+
|
| 984 |
+
class ReactJsonAgent(ReactAgent):
|
| 985 |
+
"""
|
| 986 |
+
This agent that solves the given task step by step, using the ReAct framework:
|
| 987 |
+
While the objective is not reached, the agent will perform a cycle of thinking and acting.
|
| 988 |
+
The tool calls will be formulated by the LLM in JSON format, then parsed and executed.
|
| 989 |
+
"""
|
| 990 |
+
|
| 991 |
+
def __init__(
|
| 992 |
+
self,
|
| 993 |
+
tools: List[Tool],
|
| 994 |
+
llm_engine: Optional[Callable] = None,
|
| 995 |
+
system_prompt: Optional[str] = None,
|
| 996 |
+
tool_description_template: Optional[str] = None,
|
| 997 |
+
grammar: Optional[Dict[str, str]] = None,
|
| 998 |
+
planning_interval: Optional[int] = None,
|
| 999 |
+
**kwargs,
|
| 1000 |
+
):
|
| 1001 |
+
if llm_engine is None:
|
| 1002 |
+
llm_engine = HfApiEngine()
|
| 1003 |
+
if system_prompt is None:
|
| 1004 |
+
system_prompt = DEFAULT_REACT_JSON_SYSTEM_PROMPT
|
| 1005 |
+
if tool_description_template is None:
|
| 1006 |
+
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
| 1007 |
+
super().__init__(
|
| 1008 |
+
tools=tools,
|
| 1009 |
+
llm_engine=llm_engine,
|
| 1010 |
+
system_prompt=system_prompt,
|
| 1011 |
+
tool_description_template=tool_description_template,
|
| 1012 |
+
grammar=grammar,
|
| 1013 |
+
planning_interval=planning_interval,
|
| 1014 |
+
**kwargs,
|
| 1015 |
+
)
|
| 1016 |
+
|
| 1017 |
+
def step(self, log_entry: Dict[str, Any]):
|
| 1018 |
+
"""
|
| 1019 |
+
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
| 1020 |
+
The errors are raised here, they are caught and logged in the run() method.
|
| 1021 |
+
"""
|
| 1022 |
+
agent_memory = self.write_inner_memory_from_logs()
|
| 1023 |
+
|
| 1024 |
+
self.prompt = agent_memory
|
| 1025 |
+
self.logger.debug("===== New step =====")
|
| 1026 |
+
|
| 1027 |
+
# Add new step in logs
|
| 1028 |
+
log_entry["agent_memory"] = agent_memory.copy()
|
| 1029 |
+
|
| 1030 |
+
self.logger.info("===== Calling LLM with this last message: =====")
|
| 1031 |
+
self.logger.info(self.prompt[-1])
|
| 1032 |
+
|
| 1033 |
+
try:
|
| 1034 |
+
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
| 1035 |
+
llm_output = self.llm_engine(
|
| 1036 |
+
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
| 1037 |
+
)
|
| 1038 |
+
except Exception as e:
|
| 1039 |
+
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
| 1040 |
+
self.logger.debug("===== Output message of the LLM: =====")
|
| 1041 |
+
self.logger.debug(llm_output)
|
| 1042 |
+
log_entry["llm_output"] = llm_output
|
| 1043 |
+
|
| 1044 |
+
# Parse
|
| 1045 |
+
self.logger.debug("===== Extracting action =====")
|
| 1046 |
+
rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:")
|
| 1047 |
+
|
| 1048 |
+
try:
|
| 1049 |
+
tool_name, arguments = self.tool_parser(action)
|
| 1050 |
+
except Exception as e:
|
| 1051 |
+
raise AgentParsingError(f"Could not parse the given action: {e}.")
|
| 1052 |
+
|
| 1053 |
+
log_entry["rationale"] = rationale
|
| 1054 |
+
log_entry["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
|
| 1055 |
+
|
| 1056 |
+
# Execute
|
| 1057 |
+
self.logger.warning("=== Agent thoughts:")
|
| 1058 |
+
self.logger.log(31, rationale)
|
| 1059 |
+
self.logger.warning(f">>> Calling tool: '{tool_name}' with arguments: {arguments}")
|
| 1060 |
+
if tool_name == "final_answer":
|
| 1061 |
+
if isinstance(arguments, dict):
|
| 1062 |
+
if "answer" in arguments:
|
| 1063 |
+
answer = arguments["answer"]
|
| 1064 |
+
if (
|
| 1065 |
+
isinstance(answer, str) and answer in self.state.keys()
|
| 1066 |
+
): # if the answer is a state variable, return the value
|
| 1067 |
+
answer = self.state[answer]
|
| 1068 |
+
else:
|
| 1069 |
+
answer = arguments
|
| 1070 |
+
else:
|
| 1071 |
+
answer = arguments
|
| 1072 |
+
log_entry["final_answer"] = answer
|
| 1073 |
+
return answer
|
| 1074 |
+
else:
|
| 1075 |
+
if arguments is None:
|
| 1076 |
+
arguments = {}
|
| 1077 |
+
observation = self.execute_tool_call(tool_name, arguments)
|
| 1078 |
+
observation_type = type(observation)
|
| 1079 |
+
if observation_type in [AgentImage, AgentAudio]:
|
| 1080 |
+
if observation_type == AgentImage:
|
| 1081 |
+
observation_name = "image.png"
|
| 1082 |
+
elif observation_type == AgentAudio:
|
| 1083 |
+
observation_name = "audio.mp3"
|
| 1084 |
+
# TODO: observation naming could allow for different names of same type
|
| 1085 |
+
|
| 1086 |
+
self.state[observation_name] = observation
|
| 1087 |
+
updated_information = f"Stored '{observation_name}' in memory."
|
| 1088 |
+
else:
|
| 1089 |
+
updated_information = str(observation).strip()
|
| 1090 |
+
self.logger.info(updated_information)
|
| 1091 |
+
log_entry["observation"] = updated_information
|
| 1092 |
+
return log_entry
|
| 1093 |
+
|
| 1094 |
+
|
| 1095 |
+
class ReactCodeAgent(ReactAgent):
|
| 1096 |
+
"""
|
| 1097 |
+
This agent that solves the given task step by step, using the ReAct framework:
|
| 1098 |
+
While the objective is not reached, the agent will perform a cycle of thinking and acting.
|
| 1099 |
+
The tool calls will be formulated by the LLM in code format, then parsed and executed.
|
| 1100 |
+
"""
|
| 1101 |
+
|
| 1102 |
+
def __init__(
|
| 1103 |
+
self,
|
| 1104 |
+
tools: List[Tool],
|
| 1105 |
+
llm_engine: Optional[Callable] = None,
|
| 1106 |
+
system_prompt: Optional[str] = None,
|
| 1107 |
+
tool_description_template: Optional[str] = None,
|
| 1108 |
+
grammar: Optional[Dict[str, str]] = None,
|
| 1109 |
+
additional_authorized_imports: Optional[List[str]] = None,
|
| 1110 |
+
planning_interval: Optional[int] = None,
|
| 1111 |
+
**kwargs,
|
| 1112 |
+
):
|
| 1113 |
+
if llm_engine is None:
|
| 1114 |
+
llm_engine = HfApiEngine()
|
| 1115 |
+
if system_prompt is None:
|
| 1116 |
+
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
|
| 1117 |
+
if tool_description_template is None:
|
| 1118 |
+
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
| 1119 |
+
super().__init__(
|
| 1120 |
+
tools=tools,
|
| 1121 |
+
llm_engine=llm_engine,
|
| 1122 |
+
system_prompt=system_prompt,
|
| 1123 |
+
tool_description_template=tool_description_template,
|
| 1124 |
+
grammar=grammar,
|
| 1125 |
+
planning_interval=planning_interval,
|
| 1126 |
+
**kwargs,
|
| 1127 |
+
)
|
| 1128 |
+
|
| 1129 |
+
if not is_pygments_available():
|
| 1130 |
+
transformers_logging.warning_once(
|
| 1131 |
+
logger,
|
| 1132 |
+
"pygments isn't installed. Installing pygments will enable color syntax highlighting in the "
|
| 1133 |
+
"ReactCodeAgent.",
|
| 1134 |
+
)
|
| 1135 |
+
|
| 1136 |
+
self.python_evaluator = evaluate_python_code
|
| 1137 |
+
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
| 1138 |
+
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
|
| 1139 |
+
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
|
| 1140 |
+
self.custom_tools = {}
|
| 1141 |
+
|
| 1142 |
+
def step(self, log_entry: Dict[str, Any]):
|
| 1143 |
+
"""
|
| 1144 |
+
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
| 1145 |
+
The errors are raised here, they are caught and logged in the run() method.
|
| 1146 |
+
"""
|
| 1147 |
+
agent_memory = self.write_inner_memory_from_logs()
|
| 1148 |
+
|
| 1149 |
+
self.prompt = agent_memory.copy()
|
| 1150 |
+
self.logger.debug("===== New step =====")
|
| 1151 |
+
|
| 1152 |
+
# Add new step in logs
|
| 1153 |
+
log_entry["agent_memory"] = agent_memory.copy()
|
| 1154 |
+
|
| 1155 |
+
self.logger.info("===== Calling LLM with these last messages: =====")
|
| 1156 |
+
self.logger.info(self.prompt[-2:])
|
| 1157 |
+
|
| 1158 |
+
try:
|
| 1159 |
+
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
| 1160 |
+
llm_output = self.llm_engine(
|
| 1161 |
+
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
|
| 1162 |
+
)
|
| 1163 |
+
except Exception as e:
|
| 1164 |
+
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
| 1165 |
+
|
| 1166 |
+
self.logger.debug("=== Output message of the LLM:")
|
| 1167 |
+
self.logger.debug(llm_output)
|
| 1168 |
+
log_entry["llm_output"] = llm_output
|
| 1169 |
+
|
| 1170 |
+
# Parse
|
| 1171 |
+
self.logger.debug("=== Extracting action ===")
|
| 1172 |
+
try:
|
| 1173 |
+
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
| 1174 |
+
except Exception as e:
|
| 1175 |
+
self.logger.debug(f"Error in extracting action, trying to parse the whole output. Error trace: {e}")
|
| 1176 |
+
rationale, raw_code_action = llm_output, llm_output
|
| 1177 |
+
|
| 1178 |
+
try:
|
| 1179 |
+
code_action = parse_code_blob(raw_code_action)
|
| 1180 |
+
except Exception as e:
|
| 1181 |
+
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
|
| 1182 |
+
raise AgentParsingError(error_msg)
|
| 1183 |
+
|
| 1184 |
+
log_entry["rationale"] = rationale
|
| 1185 |
+
log_entry["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action}
|
| 1186 |
+
|
| 1187 |
+
# Execute
|
| 1188 |
+
self.log_rationale_code_action(rationale, code_action)
|
| 1189 |
+
try:
|
| 1190 |
+
static_tools = {
|
| 1191 |
+
**BASE_PYTHON_TOOLS.copy(),
|
| 1192 |
+
**self.toolbox.tools,
|
| 1193 |
+
}
|
| 1194 |
+
if self.managed_agents is not None:
|
| 1195 |
+
static_tools = {**static_tools, **self.managed_agents}
|
| 1196 |
+
result = self.python_evaluator(
|
| 1197 |
+
code_action,
|
| 1198 |
+
static_tools=static_tools,
|
| 1199 |
+
custom_tools=self.custom_tools,
|
| 1200 |
+
state=self.state,
|
| 1201 |
+
authorized_imports=self.authorized_imports,
|
| 1202 |
+
)
|
| 1203 |
+
self.logger.warning("Print outputs:")
|
| 1204 |
+
self.logger.log(32, self.state["print_outputs"])
|
| 1205 |
+
observation = "Print outputs:\n" + self.state["print_outputs"]
|
| 1206 |
+
if result is not None:
|
| 1207 |
+
self.logger.warning("Last output from code snippet:")
|
| 1208 |
+
self.logger.log(32, str(result))
|
| 1209 |
+
observation += "Last output from code snippet:\n" + str(result)[:100000]
|
| 1210 |
+
log_entry["observation"] = observation
|
| 1211 |
+
except Exception as e:
|
| 1212 |
+
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
|
| 1213 |
+
if "'dict' object has no attribute 'read'" in str(e):
|
| 1214 |
+
error_msg += "\nYou get this error because you passed a dict as input for one of the arguments instead of a string."
|
| 1215 |
+
raise AgentExecutionError(error_msg)
|
| 1216 |
+
for line in code_action.split("\n"):
|
| 1217 |
+
if line[: len("final_answer")] == "final_answer":
|
| 1218 |
+
self.logger.log(33, "Final answer:")
|
| 1219 |
+
self.logger.log(32, result)
|
| 1220 |
+
log_entry["final_answer"] = result
|
| 1221 |
+
return result
|
| 1222 |
+
|
| 1223 |
+
|
| 1224 |
+
LENGTH_TRUNCATE_REPORTS = 1000
|
| 1225 |
+
|
| 1226 |
+
|
| 1227 |
+
class ManagedAgent:
|
| 1228 |
+
def __init__(self, agent, name, description, additional_prompting=None, provide_run_summary=False):
|
| 1229 |
+
self.agent = agent
|
| 1230 |
+
self.name = name
|
| 1231 |
+
self.description = description
|
| 1232 |
+
self.additional_prompting = additional_prompting
|
| 1233 |
+
self.provide_run_summary = provide_run_summary
|
| 1234 |
+
|
| 1235 |
+
def write_full_task(self, task):
|
| 1236 |
+
full_task = f"""You're a helpful agent named '{self.name}'.
|
| 1237 |
+
You have been submitted this task by your manager.
|
| 1238 |
+
---
|
| 1239 |
+
Task:
|
| 1240 |
+
{task}
|
| 1241 |
+
---
|
| 1242 |
+
You're helping your manager solve a wider task: so make sure to not provide a one-line answer, but give as much information as possible so that they have a clear understanding of the answer.
|
| 1243 |
+
|
| 1244 |
+
Your final_answer WILL HAVE to contain these parts:
|
| 1245 |
+
### 1. Task outcome (short version):
|
| 1246 |
+
### 2. Task outcome (extremely detailed version):
|
| 1247 |
+
### 3. Additional context (if relevant):
|
| 1248 |
+
|
| 1249 |
+
Put all these in your final_answer tool, everything that you do not pass as an argument to final_answer will be lost.
|
| 1250 |
+
And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback.
|
| 1251 |
+
<<additional_prompting>>"""
|
| 1252 |
+
if self.additional_prompting:
|
| 1253 |
+
full_task = full_task.replace("\n<<additional_prompting>>", self.additional_prompting).strip()
|
| 1254 |
+
else:
|
| 1255 |
+
full_task = full_task.replace("\n<<additional_prompting>>", "").strip()
|
| 1256 |
+
return full_task
|
| 1257 |
+
|
| 1258 |
+
def __call__(self, request, **kwargs):
|
| 1259 |
+
full_task = self.write_full_task(request)
|
| 1260 |
+
output = self.agent.run(full_task, **kwargs)
|
| 1261 |
+
if self.provide_run_summary:
|
| 1262 |
+
answer = f"Here is the final answer from your managed agent '{self.name}':\n"
|
| 1263 |
+
answer += str(output)
|
| 1264 |
+
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
|
| 1265 |
+
for message in self.agent.write_inner_memory_from_logs(summary_mode=True):
|
| 1266 |
+
content = message["content"]
|
| 1267 |
+
if len(str(content)) < LENGTH_TRUNCATE_REPORTS or "[FACTS LIST]" in str(content):
|
| 1268 |
+
answer += "\n" + str(content) + "\n---"
|
| 1269 |
+
else:
|
| 1270 |
+
answer += (
|
| 1271 |
+
"\n"
|
| 1272 |
+
+ str(content)[:LENGTH_TRUNCATE_REPORTS]
|
| 1273 |
+
+ "\n(...Step was truncated because too long)...\n---"
|
| 1274 |
+
)
|
| 1275 |
+
answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'."
|
| 1276 |
+
return answer
|
| 1277 |
+
else:
|
| 1278 |
+
return output
|
.venv/Lib/site-packages/transformers/agents/default_tools.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
|
| 4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
import importlib.util
|
| 18 |
+
import json
|
| 19 |
+
import math
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from math import sqrt
|
| 22 |
+
from typing import Dict
|
| 23 |
+
|
| 24 |
+
from huggingface_hub import hf_hub_download, list_spaces
|
| 25 |
+
|
| 26 |
+
from ..utils import is_offline_mode
|
| 27 |
+
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
| 28 |
+
from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def custom_print(*args):
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
BASE_PYTHON_TOOLS = {
|
| 36 |
+
"print": custom_print,
|
| 37 |
+
"isinstance": isinstance,
|
| 38 |
+
"range": range,
|
| 39 |
+
"float": float,
|
| 40 |
+
"int": int,
|
| 41 |
+
"bool": bool,
|
| 42 |
+
"str": str,
|
| 43 |
+
"set": set,
|
| 44 |
+
"list": list,
|
| 45 |
+
"dict": dict,
|
| 46 |
+
"tuple": tuple,
|
| 47 |
+
"round": round,
|
| 48 |
+
"ceil": math.ceil,
|
| 49 |
+
"floor": math.floor,
|
| 50 |
+
"log": math.log,
|
| 51 |
+
"exp": math.exp,
|
| 52 |
+
"sin": math.sin,
|
| 53 |
+
"cos": math.cos,
|
| 54 |
+
"tan": math.tan,
|
| 55 |
+
"asin": math.asin,
|
| 56 |
+
"acos": math.acos,
|
| 57 |
+
"atan": math.atan,
|
| 58 |
+
"atan2": math.atan2,
|
| 59 |
+
"degrees": math.degrees,
|
| 60 |
+
"radians": math.radians,
|
| 61 |
+
"pow": math.pow,
|
| 62 |
+
"sqrt": sqrt,
|
| 63 |
+
"len": len,
|
| 64 |
+
"sum": sum,
|
| 65 |
+
"max": max,
|
| 66 |
+
"min": min,
|
| 67 |
+
"abs": abs,
|
| 68 |
+
"enumerate": enumerate,
|
| 69 |
+
"zip": zip,
|
| 70 |
+
"reversed": reversed,
|
| 71 |
+
"sorted": sorted,
|
| 72 |
+
"all": all,
|
| 73 |
+
"any": any,
|
| 74 |
+
"map": map,
|
| 75 |
+
"filter": filter,
|
| 76 |
+
"ord": ord,
|
| 77 |
+
"chr": chr,
|
| 78 |
+
"next": next,
|
| 79 |
+
"iter": iter,
|
| 80 |
+
"divmod": divmod,
|
| 81 |
+
"callable": callable,
|
| 82 |
+
"getattr": getattr,
|
| 83 |
+
"hasattr": hasattr,
|
| 84 |
+
"setattr": setattr,
|
| 85 |
+
"issubclass": issubclass,
|
| 86 |
+
"type": type,
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@dataclass
|
| 91 |
+
class PreTool:
|
| 92 |
+
name: str
|
| 93 |
+
inputs: Dict[str, str]
|
| 94 |
+
output_type: type
|
| 95 |
+
task: str
|
| 96 |
+
description: str
|
| 97 |
+
repo_id: str
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
|
| 101 |
+
"image-transformation",
|
| 102 |
+
"text-to-image",
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_remote_tools(logger, organization="huggingface-tools"):
|
| 107 |
+
if is_offline_mode():
|
| 108 |
+
logger.info("You are in offline mode, so remote tools are not available.")
|
| 109 |
+
return {}
|
| 110 |
+
|
| 111 |
+
spaces = list_spaces(author=organization)
|
| 112 |
+
tools = {}
|
| 113 |
+
for space_info in spaces:
|
| 114 |
+
repo_id = space_info.id
|
| 115 |
+
resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
|
| 116 |
+
with open(resolved_config_file, encoding="utf-8") as reader:
|
| 117 |
+
config = json.load(reader)
|
| 118 |
+
task = repo_id.split("/")[-1]
|
| 119 |
+
tools[config["name"]] = PreTool(
|
| 120 |
+
task=task,
|
| 121 |
+
description=config["description"],
|
| 122 |
+
repo_id=repo_id,
|
| 123 |
+
name=task,
|
| 124 |
+
inputs=config["inputs"],
|
| 125 |
+
output_type=config["output_type"],
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
return tools
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def setup_default_tools(logger):
|
| 132 |
+
default_tools = {}
|
| 133 |
+
main_module = importlib.import_module("transformers")
|
| 134 |
+
tools_module = main_module.agents
|
| 135 |
+
|
| 136 |
+
for task_name, tool_class_name in TOOL_MAPPING.items():
|
| 137 |
+
tool_class = getattr(tools_module, tool_class_name)
|
| 138 |
+
tool_instance = tool_class()
|
| 139 |
+
default_tools[tool_class.name] = PreTool(
|
| 140 |
+
name=tool_instance.name,
|
| 141 |
+
inputs=tool_instance.inputs,
|
| 142 |
+
output_type=tool_instance.output_type,
|
| 143 |
+
task=task_name,
|
| 144 |
+
description=tool_instance.description,
|
| 145 |
+
repo_id=None,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
return default_tools
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class PythonInterpreterTool(Tool):
|
| 152 |
+
name = "python_interpreter"
|
| 153 |
+
description = "This is a tool that evaluates python code. It can be used to perform calculations."
|
| 154 |
+
|
| 155 |
+
output_type = "string"
|
| 156 |
+
|
| 157 |
+
def __init__(self, *args, authorized_imports=None, **kwargs):
|
| 158 |
+
if authorized_imports is None:
|
| 159 |
+
self.authorized_imports = list(set(LIST_SAFE_MODULES))
|
| 160 |
+
else:
|
| 161 |
+
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
|
| 162 |
+
self.inputs = {
|
| 163 |
+
"code": {
|
| 164 |
+
"type": "string",
|
| 165 |
+
"description": (
|
| 166 |
+
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
|
| 167 |
+
f"else you will get an error. This code can only import the following python libraries: {authorized_imports}."
|
| 168 |
+
),
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
super().__init__(*args, **kwargs)
|
| 172 |
+
|
| 173 |
+
def forward(self, code):
|
| 174 |
+
output = str(
|
| 175 |
+
evaluate_python_code(code, static_tools=BASE_PYTHON_TOOLS, authorized_imports=self.authorized_imports)
|
| 176 |
+
)
|
| 177 |
+
return output
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class FinalAnswerTool(Tool):
|
| 181 |
+
name = "final_answer"
|
| 182 |
+
description = "Provides a final answer to the given problem."
|
| 183 |
+
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
|
| 184 |
+
output_type = "any"
|
| 185 |
+
|
| 186 |
+
def forward(self, answer):
|
| 187 |
+
return answer
|
.venv/Lib/site-packages/transformers/agents/document_question_answering.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
|
| 4 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
import re
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from ..models.auto import AutoProcessor
|
| 23 |
+
from ..models.vision_encoder_decoder import VisionEncoderDecoderModel
|
| 24 |
+
from ..utils import is_vision_available
|
| 25 |
+
from .tools import PipelineTool
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if is_vision_available():
|
| 29 |
+
from PIL import Image
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class DocumentQuestionAnsweringTool(PipelineTool):
|
| 33 |
+
default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
|
| 34 |
+
description = "This is a tool that answers a question about an document (pdf). It returns a string that contains the answer to the question."
|
| 35 |
+
name = "document_qa"
|
| 36 |
+
pre_processor_class = AutoProcessor
|
| 37 |
+
model_class = VisionEncoderDecoderModel
|
| 38 |
+
|
| 39 |
+
inputs = {
|
| 40 |
+
"document": {
|
| 41 |
+
"type": "image",
|
| 42 |
+
"description": "The image containing the information. Can be a PIL Image or a string path to the image.",
|
| 43 |
+
},
|
| 44 |
+
"question": {"type": "string", "description": "The question in English"},
|
| 45 |
+
}
|
| 46 |
+
output_type = "string"
|
| 47 |
+
|
| 48 |
+
def __init__(self, *args, **kwargs):
|
| 49 |
+
if not is_vision_available():
|
| 50 |
+
raise ValueError("Pillow must be installed to use the DocumentQuestionAnsweringTool.")
|
| 51 |
+
|
| 52 |
+
super().__init__(*args, **kwargs)
|
| 53 |
+
|
| 54 |
+
def encode(self, document: "Image", question: str):
|
| 55 |
+
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
|
| 56 |
+
prompt = task_prompt.replace("{user_input}", question)
|
| 57 |
+
decoder_input_ids = self.pre_processor.tokenizer(
|
| 58 |
+
prompt, add_special_tokens=False, return_tensors="pt"
|
| 59 |
+
).input_ids
|
| 60 |
+
if isinstance(document, str):
|
| 61 |
+
img = Image.open(document).convert("RGB")
|
| 62 |
+
img_array = np.array(img).transpose(2, 0, 1)
|
| 63 |
+
document = torch.from_numpy(img_array)
|
| 64 |
+
pixel_values = self.pre_processor(document, return_tensors="pt").pixel_values
|
| 65 |
+
|
| 66 |
+
return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values}
|
| 67 |
+
|
| 68 |
+
def forward(self, inputs):
|
| 69 |
+
return self.model.generate(
|
| 70 |
+
inputs["pixel_values"].to(self.device),
|
| 71 |
+
decoder_input_ids=inputs["decoder_input_ids"].to(self.device),
|
| 72 |
+
max_length=self.model.decoder.config.max_position_embeddings,
|
| 73 |
+
early_stopping=True,
|
| 74 |
+
pad_token_id=self.pre_processor.tokenizer.pad_token_id,
|
| 75 |
+
eos_token_id=self.pre_processor.tokenizer.eos_token_id,
|
| 76 |
+
use_cache=True,
|
| 77 |
+
num_beams=1,
|
| 78 |
+
bad_words_ids=[[self.pre_processor.tokenizer.unk_token_id]],
|
| 79 |
+
return_dict_in_generate=True,
|
| 80 |
+
).sequences
|
| 81 |
+
|
| 82 |
+
def decode(self, outputs):
|
| 83 |
+
sequence = self.pre_processor.batch_decode(outputs)[0]
|
| 84 |
+
sequence = sequence.replace(self.pre_processor.tokenizer.eos_token, "")
|
| 85 |
+
sequence = sequence.replace(self.pre_processor.tokenizer.pad_token, "")
|
| 86 |
+
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
| 87 |
+
sequence = self.pre_processor.token2json(sequence)
|
| 88 |
+
|
| 89 |
+
return sequence["answer"]
|
.venv/Lib/site-packages/transformers/agents/evaluate_agent.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
|
| 4 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
from .agents import BASE_PYTHON_TOOLS
|
| 18 |
+
from .python_interpreter import InterpreterError, evaluate
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
### Fake tools for test
|
| 22 |
+
def classifier(text, labels):
|
| 23 |
+
return f"This is the classification of {text} along {labels}."
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def translator(text, src_lang, tgt_lang):
|
| 27 |
+
return f"This is the translation of {text} from {src_lang} to {tgt_lang}."
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def speaker(text):
|
| 31 |
+
return f"This is actually a sound reading {text}."
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def transcriber(audio):
|
| 35 |
+
if "sound" not in audio:
|
| 36 |
+
raise ValueError(f"`audio` ({audio}) is not a sound.")
|
| 37 |
+
return f"This is the transcribed text from {audio}."
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def image_generator(prompt):
|
| 41 |
+
return f"This is actually an image representing {prompt}."
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def image_captioner(image):
|
| 45 |
+
if "image" not in image:
|
| 46 |
+
raise ValueError(f"`image` ({image}) is not an image.")
|
| 47 |
+
return f"This is a description of {image}."
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def image_transformer(image, prompt):
|
| 51 |
+
if "image" not in image:
|
| 52 |
+
raise ValueError(f"`image` ({image}) is not an image.")
|
| 53 |
+
return f"This is a transformation of {image} according to {prompt}."
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def question_answerer(text, question):
|
| 57 |
+
return f"This is the answer to {question} from {text}."
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def image_qa(image, question):
|
| 61 |
+
if "image" not in image:
|
| 62 |
+
raise ValueError(f"`image` ({image}) is not an image.")
|
| 63 |
+
return f"This is the answer to {question} from {image}."
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def text_downloader(url):
|
| 67 |
+
return f"This is the content of {url}."
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def summarizer(text):
|
| 71 |
+
return f"This is a summary of {text}."
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def video_generator(prompt, seconds=2):
|
| 75 |
+
return f"A video of {prompt}"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def document_qa(image, question):
|
| 79 |
+
return f"This is the answer to {question} from the document {image}."
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def image_segmenter(image, prompt):
|
| 83 |
+
return f"This is the mask of {prompt} in {image}"
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
TEST_TOOLS = {
|
| 87 |
+
"text_classifier": classifier,
|
| 88 |
+
"translator": translator,
|
| 89 |
+
"text_reader": speaker,
|
| 90 |
+
"summarizer": summarizer,
|
| 91 |
+
"transcriber": transcriber,
|
| 92 |
+
"image_generator": image_generator,
|
| 93 |
+
"image_captioner": image_captioner,
|
| 94 |
+
"image_transformer": image_transformer,
|
| 95 |
+
"text_qa": question_answerer,
|
| 96 |
+
"text_downloader": text_downloader,
|
| 97 |
+
"image_qa": image_qa,
|
| 98 |
+
"video_generator": video_generator,
|
| 99 |
+
"document_qa": document_qa,
|
| 100 |
+
"image_segmenter": image_segmenter,
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class Problem:
|
| 105 |
+
"""
|
| 106 |
+
A class regrouping all the information to solve a problem on which we will evaluate agents.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
task (`str` ou `list[str]`):
|
| 110 |
+
One or several descriptions of the task to perform. If a list, it should contain variations on the
|
| 111 |
+
phrasing, but for the same task.
|
| 112 |
+
inputs (`list[str]` or `dict[str, str]`):
|
| 113 |
+
The inputs that will be fed to the tools. For this testing environment, only strings are accepted as
|
| 114 |
+
values. Pass along a dictionary when you want to specify the values of each inputs, or just the list of
|
| 115 |
+
inputs expected (the value used will be `<<input_name>>` in this case).
|
| 116 |
+
answer (`str` or `list[str]`):
|
| 117 |
+
The theoretical answer (or list of possible valid answers) to the problem, as code.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(self, task, inputs, answer):
|
| 121 |
+
self.task = task
|
| 122 |
+
self.inputs = inputs
|
| 123 |
+
self.answer = answer
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
### The list of problems the agent will be evaluated on.
|
| 127 |
+
EVALUATION_TASKS = [
|
| 128 |
+
Problem(
|
| 129 |
+
task=[
|
| 130 |
+
"Is the following `text` (in Spanish) positive or negative?",
|
| 131 |
+
"Is the text in the variable `text` (in Spanish) positive or negative?",
|
| 132 |
+
"Translate the following `text` from Spanish to English then tell me if its positive or negative.",
|
| 133 |
+
],
|
| 134 |
+
inputs=["text"],
|
| 135 |
+
answer="""text_classifier(translator(text, src_lang="Spanish", tgt_lang="English"), labels=["positive", "negative"])""",
|
| 136 |
+
),
|
| 137 |
+
Problem(
|
| 138 |
+
task=[
|
| 139 |
+
"Tell me out loud what the `image` contains.",
|
| 140 |
+
"Describe the following `image` out loud.",
|
| 141 |
+
"Find what is in the picture stored in `image` then read it out loud.",
|
| 142 |
+
],
|
| 143 |
+
inputs=["image"],
|
| 144 |
+
answer=[
|
| 145 |
+
"text_reader(image_captioner(image))",
|
| 146 |
+
"text_reader(image_qa(image, question='What is in the image?'))",
|
| 147 |
+
],
|
| 148 |
+
),
|
| 149 |
+
Problem(
|
| 150 |
+
task=[
|
| 151 |
+
"Generate an image from the text given in `text_input`. Then transform it according to the text in `prompt`.",
|
| 152 |
+
"Use the following `text_input` to generate an image, then transform it by using the text in `prompt`.",
|
| 153 |
+
],
|
| 154 |
+
inputs=["text_input", "prompt"],
|
| 155 |
+
answer="image_transformer(image_generator(text_input), prompt)",
|
| 156 |
+
),
|
| 157 |
+
Problem(
|
| 158 |
+
task=[
|
| 159 |
+
"Download the content of `url`, summarize it then generate an image from its content.",
|
| 160 |
+
"Use a summary of the web page at `url` to generate an image.",
|
| 161 |
+
"Summarize the content of the web page at `url`, and use the result to generate an image.",
|
| 162 |
+
],
|
| 163 |
+
inputs=["url"],
|
| 164 |
+
answer="image_generator(summarizer(text_downloader(url)))",
|
| 165 |
+
),
|
| 166 |
+
Problem(
|
| 167 |
+
task=[
|
| 168 |
+
"Transform the following `image` using the prompt in `text`. The prompt is in Spanish.",
|
| 169 |
+
"Use the text prompt in `text` (in Spanish) to transform the following `image`.",
|
| 170 |
+
"Translate the `text` from Spanish to English then use it to transform the picture in `image`.",
|
| 171 |
+
],
|
| 172 |
+
inputs=["text", "image"],
|
| 173 |
+
answer="image_transformer(image, translator(text, src_lang='Spanish', tgt_lang='English'))",
|
| 174 |
+
),
|
| 175 |
+
Problem(
|
| 176 |
+
task=[
|
| 177 |
+
"Download the content of `url`, summarize it then read it out loud to me.",
|
| 178 |
+
"Read me a summary of the web page at `url`.",
|
| 179 |
+
],
|
| 180 |
+
inputs=["url"],
|
| 181 |
+
answer="text_reader(summarizer(text_downloader(url)))",
|
| 182 |
+
),
|
| 183 |
+
Problem(
|
| 184 |
+
task=[
|
| 185 |
+
"Generate an image from the text given in `text_input`.",
|
| 186 |
+
],
|
| 187 |
+
inputs=["text_input"],
|
| 188 |
+
answer="image_generator(text_input)",
|
| 189 |
+
),
|
| 190 |
+
Problem(
|
| 191 |
+
task=[
|
| 192 |
+
"Replace the beaver in the `image` by the `prompt`.",
|
| 193 |
+
"Transform the `image` so that it contains the `prompt`.",
|
| 194 |
+
"Use `prompt` to transform this `image`.",
|
| 195 |
+
],
|
| 196 |
+
inputs=["image", "prompt"],
|
| 197 |
+
answer="image_transformer(image, prompt)",
|
| 198 |
+
),
|
| 199 |
+
Problem(
|
| 200 |
+
task=[
|
| 201 |
+
"Provide me the summary of the `text`, then read it to me before transcribing it and translating it in French.",
|
| 202 |
+
"Summarize `text`, read it out loud then transcribe the audio and translate it in French.",
|
| 203 |
+
"Read me a summary of the `text` out loud. Transcribe this and translate it in French.",
|
| 204 |
+
],
|
| 205 |
+
inputs=["text"],
|
| 206 |
+
answer="translator(transcriber(text_reader(summarizer(text))), src_lang='English', tgt_lang='French')",
|
| 207 |
+
),
|
| 208 |
+
Problem(
|
| 209 |
+
task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
|
| 210 |
+
inputs={"prompt": "A lobster swimming"},
|
| 211 |
+
answer="video_generator('A lobster swimming')",
|
| 212 |
+
),
|
| 213 |
+
Problem(
|
| 214 |
+
task=[
|
| 215 |
+
"Download the following file `url`, summarize it in a few words and generate a video from it."
|
| 216 |
+
"Fetch the file at this `url`, summarize it, and create an animation out of it."
|
| 217 |
+
],
|
| 218 |
+
inputs=["url"],
|
| 219 |
+
answer="video_generator(summarizer(text_downloader(url)))",
|
| 220 |
+
),
|
| 221 |
+
]
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_theoretical_tools(agent_answer, theoretical_answer, code_answer):
|
| 225 |
+
if not isinstance(theoretical_answer, list):
|
| 226 |
+
return {name for name in TEST_TOOLS if name in code_answer}
|
| 227 |
+
|
| 228 |
+
if isinstance(agent_answer, dict):
|
| 229 |
+
for one_answer, one_code in zip(theoretical_answer, code_answer):
|
| 230 |
+
if one_answer in agent_answer.values():
|
| 231 |
+
return {name for name in TEST_TOOLS if name in one_code}
|
| 232 |
+
|
| 233 |
+
for one_answer, one_code in zip(theoretical_answer, code_answer):
|
| 234 |
+
if agent_answer == one_answer:
|
| 235 |
+
return {name for name in TEST_TOOLS if name in one_code}
|
| 236 |
+
|
| 237 |
+
return {name for name in TEST_TOOLS if name in code_answer[0]}
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def evaluate_code(code, inputs=None, state=None, verbose=False, return_interpretor_error=False):
|
| 241 |
+
tools = BASE_PYTHON_TOOLS.copy()
|
| 242 |
+
for name, tool in TEST_TOOLS.items():
|
| 243 |
+
if name not in code:
|
| 244 |
+
continue
|
| 245 |
+
tools[name] = tool
|
| 246 |
+
|
| 247 |
+
if isinstance(inputs, dict):
|
| 248 |
+
inputs = inputs.copy()
|
| 249 |
+
elif inputs is not None:
|
| 250 |
+
inputs = {inp: f"<<{inp}>>" for inp in inputs}
|
| 251 |
+
|
| 252 |
+
if state is not None:
|
| 253 |
+
state.update(inputs)
|
| 254 |
+
else:
|
| 255 |
+
state = inputs
|
| 256 |
+
|
| 257 |
+
try:
|
| 258 |
+
return evaluate(code, tools, state)
|
| 259 |
+
except InterpreterError as e:
|
| 260 |
+
return str(e)
|
| 261 |
+
except Exception as e:
|
| 262 |
+
if verbose:
|
| 263 |
+
print(e)
|
| 264 |
+
return None
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def score_code(agent_answer, theoretical_answer, verbose: bool = False):
|
| 268 |
+
if verbose:
|
| 269 |
+
print(agent_answer, theoretical_answer)
|
| 270 |
+
theoretical_answer = theoretical_answer if isinstance(theoretical_answer, list) else [theoretical_answer]
|
| 271 |
+
|
| 272 |
+
if agent_answer in theoretical_answer:
|
| 273 |
+
if verbose:
|
| 274 |
+
print("Perfect!")
|
| 275 |
+
return 1
|
| 276 |
+
elif isinstance(agent_answer, dict) and any(v in theoretical_answer for v in agent_answer.values()):
|
| 277 |
+
if verbose:
|
| 278 |
+
print("Almsot perfect, result in state!")
|
| 279 |
+
return 0.75
|
| 280 |
+
else:
|
| 281 |
+
if verbose:
|
| 282 |
+
print("Result is not the right one but code executed.")
|
| 283 |
+
return 0.3
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def evaluate_one_result(code, agent_answer, theoretical_answer, answer, verbose=False):
|
| 287 |
+
tools_in_code = {name for name in TEST_TOOLS if f"`{name}`" in code}
|
| 288 |
+
theoretical_tools = get_theoretical_tools(agent_answer, theoretical_answer, answer)
|
| 289 |
+
if tools_in_code == theoretical_tools:
|
| 290 |
+
tool_selection_score = 1.0
|
| 291 |
+
tool_selection_errors = None
|
| 292 |
+
else:
|
| 293 |
+
missing_tools = len(theoretical_tools - tools_in_code)
|
| 294 |
+
unexpected_tools = len(tools_in_code - theoretical_tools)
|
| 295 |
+
tool_selection_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
|
| 296 |
+
|
| 297 |
+
tool_selection_errors = {
|
| 298 |
+
"selected_tools": tools_in_code,
|
| 299 |
+
"theoretical_tools": theoretical_tools,
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
tools_in_code = {name for name in TEST_TOOLS if name in code}
|
| 303 |
+
if tools_in_code == theoretical_tools:
|
| 304 |
+
tool_used_score = 1.0
|
| 305 |
+
tool_used_errors = None
|
| 306 |
+
else:
|
| 307 |
+
missing_tools = len(theoretical_tools - tools_in_code)
|
| 308 |
+
unexpected_tools = len(tools_in_code - theoretical_tools)
|
| 309 |
+
tool_used_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
|
| 310 |
+
|
| 311 |
+
tool_used_errors = {
|
| 312 |
+
"selected_tools": tools_in_code,
|
| 313 |
+
"theoretical_tools": theoretical_tools,
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
score = score_code(agent_answer, theoretical_answer, verbose=verbose)
|
| 317 |
+
if score < 1.0:
|
| 318 |
+
code_errors = {
|
| 319 |
+
"code_produced": code,
|
| 320 |
+
"evaluation": agent_answer,
|
| 321 |
+
"theoretical_answer": theoretical_answer,
|
| 322 |
+
}
|
| 323 |
+
else:
|
| 324 |
+
code_errors = None
|
| 325 |
+
|
| 326 |
+
return (tool_selection_score, tool_used_score, score), (tool_selection_errors, tool_used_errors, code_errors)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def evaluate_agent(agent, batch_size=8, verbose=False, return_errors=False):
|
| 330 |
+
"""
|
| 331 |
+
Evaluates a new agent on all `EVALUATION_TASKS`.
|
| 332 |
+
|
| 333 |
+
Example:
|
| 334 |
+
|
| 335 |
+
```py
|
| 336 |
+
agent = NewOpenAiAgent(model="text-davinci-003", api_key=your_api_key)
|
| 337 |
+
bads = new_evaluate_agent(agent)
|
| 338 |
+
for bad in bads:
|
| 339 |
+
print(bad)
|
| 340 |
+
```
|
| 341 |
+
"""
|
| 342 |
+
# Sanity check
|
| 343 |
+
agent_tools = set(agent.toolbox.keys())
|
| 344 |
+
if agent_tools != set(TEST_TOOLS):
|
| 345 |
+
missing_tools = set(TEST_TOOLS) - agent_tools
|
| 346 |
+
unexpected_tools = set(agent_tools) - TEST_TOOLS
|
| 347 |
+
raise ValueError(
|
| 348 |
+
f"Fix the test tools in the evaluate_agent module. Tools mising: {missing_tools}. Extra tools: {unexpected_tools}."
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
eval_tasks = []
|
| 352 |
+
eval_idx = []
|
| 353 |
+
for idx, pb in enumerate(EVALUATION_TASKS):
|
| 354 |
+
if isinstance(pb.task, list):
|
| 355 |
+
eval_tasks.extend(pb.task)
|
| 356 |
+
eval_idx.extend([idx] * len(pb.task))
|
| 357 |
+
else:
|
| 358 |
+
eval_tasks.append(pb.task)
|
| 359 |
+
eval_idx.append(idx)
|
| 360 |
+
|
| 361 |
+
tool_selection_score = 0
|
| 362 |
+
tool_used_score = 0
|
| 363 |
+
code_score = 0
|
| 364 |
+
|
| 365 |
+
if return_errors:
|
| 366 |
+
tool_selection_errors = {}
|
| 367 |
+
tool_used_errors = {}
|
| 368 |
+
code_errors = {}
|
| 369 |
+
|
| 370 |
+
for start_idx in range(0, len(eval_tasks), batch_size):
|
| 371 |
+
end_idx = min(start_idx + batch_size, len(eval_tasks))
|
| 372 |
+
batch_tasks = eval_tasks[start_idx:end_idx]
|
| 373 |
+
|
| 374 |
+
results = [agent.run(task, return_generated_code=True) for task in batch_tasks]
|
| 375 |
+
|
| 376 |
+
for idx, result in enumerate(results):
|
| 377 |
+
problem = EVALUATION_TASKS[eval_idx[start_idx + idx]]
|
| 378 |
+
if verbose:
|
| 379 |
+
print(f"====Task {start_idx + idx}====\n{batch_tasks[idx]}\n")
|
| 380 |
+
code = agent.extract_action(result, split_token="Answer:")
|
| 381 |
+
|
| 382 |
+
# Evaluate agent answer and code answer
|
| 383 |
+
agent_answer = evaluate_code(code, problem.inputs, verbose=verbose)
|
| 384 |
+
if isinstance(problem.answer, list):
|
| 385 |
+
theoretical_answer = [evaluate_code(answer, problem.inputs) for answer in problem.answer]
|
| 386 |
+
else:
|
| 387 |
+
theoretical_answer = evaluate_code(problem.answer, problem.inputs)
|
| 388 |
+
|
| 389 |
+
scores, errors = evaluate_one_result(
|
| 390 |
+
code, agent_answer, theoretical_answer, problem.answer, verbose=verbose
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
tool_selection_score += scores[0]
|
| 394 |
+
tool_used_score += scores[1]
|
| 395 |
+
code_score += scores[2]
|
| 396 |
+
|
| 397 |
+
if return_errors:
|
| 398 |
+
if errors[0] is not None:
|
| 399 |
+
tool_selection_errors[batch_tasks[idx]] = errors[0]
|
| 400 |
+
if errors[1] is not None:
|
| 401 |
+
tool_used_errors[batch_tasks[idx]] = errors[1]
|
| 402 |
+
if errors[2] is not None:
|
| 403 |
+
code_errors[batch_tasks[idx]] = errors[2]
|
| 404 |
+
|
| 405 |
+
scores = {
|
| 406 |
+
"tool selection score": 100 * (tool_selection_score / len(eval_tasks)),
|
| 407 |
+
"tool used score": 100 * (tool_used_score / len(eval_tasks)),
|
| 408 |
+
"code score": 100 * (code_score / len(eval_tasks)),
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
if return_errors:
|
| 412 |
+
return scores, tool_selection_errors, tool_used_errors, code_errors
|
| 413 |
+
else:
|
| 414 |
+
return scores
|
.venv/Lib/site-packages/transformers/agents/image_question_answering.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
|
| 4 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from PIL import Image
|
| 20 |
+
|
| 21 |
+
from ..models.auto import AutoModelForVisualQuestionAnswering, AutoProcessor
|
| 22 |
+
from ..utils import requires_backends
|
| 23 |
+
from .tools import PipelineTool
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ImageQuestionAnsweringTool(PipelineTool):
|
| 27 |
+
default_checkpoint = "dandelin/vilt-b32-finetuned-vqa"
|
| 28 |
+
description = (
|
| 29 |
+
"This is a tool that answers a question about an image. It "
|
| 30 |
+
"returns a text that is the answer to the question."
|
| 31 |
+
)
|
| 32 |
+
name = "image_qa"
|
| 33 |
+
pre_processor_class = AutoProcessor
|
| 34 |
+
model_class = AutoModelForVisualQuestionAnswering
|
| 35 |
+
|
| 36 |
+
inputs = {
|
| 37 |
+
"image": {
|
| 38 |
+
"type": "image",
|
| 39 |
+
"description": "The image containing the information. Can be a PIL Image or a string path to the image.",
|
| 40 |
+
},
|
| 41 |
+
"question": {"type": "string", "description": "The question in English"},
|
| 42 |
+
}
|
| 43 |
+
output_type = "string"
|
| 44 |
+
|
| 45 |
+
def __init__(self, *args, **kwargs):
|
| 46 |
+
requires_backends(self, ["vision"])
|
| 47 |
+
super().__init__(*args, **kwargs)
|
| 48 |
+
|
| 49 |
+
def encode(self, image: "Image", question: str):
|
| 50 |
+
return self.pre_processor(image, question, return_tensors="pt")
|
| 51 |
+
|
| 52 |
+
def forward(self, inputs):
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
return self.model(**inputs).logits
|
| 55 |
+
|
| 56 |
+
def decode(self, outputs):
|
| 57 |
+
idx = outputs.argmax(-1).item()
|
| 58 |
+
return self.model.config.id2label[idx]
|
.venv/Lib/site-packages/transformers/agents/llm_engine.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
|
| 4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
from copy import deepcopy
|
| 18 |
+
from enum import Enum
|
| 19 |
+
from typing import Dict, List, Optional
|
| 20 |
+
|
| 21 |
+
from huggingface_hub import InferenceClient
|
| 22 |
+
|
| 23 |
+
from .. import AutoTokenizer
|
| 24 |
+
from ..pipelines.base import Pipeline
|
| 25 |
+
from ..utils import logging
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class MessageRole(str, Enum):
|
| 32 |
+
USER = "user"
|
| 33 |
+
ASSISTANT = "assistant"
|
| 34 |
+
SYSTEM = "system"
|
| 35 |
+
TOOL_CALL = "tool-call"
|
| 36 |
+
TOOL_RESPONSE = "tool-response"
|
| 37 |
+
|
| 38 |
+
@classmethod
|
| 39 |
+
def roles(cls):
|
| 40 |
+
return [r.value for r in cls]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}):
|
| 44 |
+
"""
|
| 45 |
+
Subsequent messages with the same role will be concatenated to a single message.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
message_list (`List[Dict[str, str]]`): List of chat messages.
|
| 49 |
+
"""
|
| 50 |
+
final_message_list = []
|
| 51 |
+
message_list = deepcopy(message_list) # Avoid modifying the original list
|
| 52 |
+
for message in message_list:
|
| 53 |
+
if not set(message.keys()) == {"role", "content"}:
|
| 54 |
+
raise ValueError("Message should contain only 'role' and 'content' keys!")
|
| 55 |
+
|
| 56 |
+
role = message["role"]
|
| 57 |
+
if role not in MessageRole.roles():
|
| 58 |
+
raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
|
| 59 |
+
|
| 60 |
+
if role in role_conversions:
|
| 61 |
+
message["role"] = role_conversions[role]
|
| 62 |
+
|
| 63 |
+
if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
|
| 64 |
+
final_message_list[-1]["content"] += "\n=======\n" + message["content"]
|
| 65 |
+
else:
|
| 66 |
+
final_message_list.append(message)
|
| 67 |
+
return final_message_list
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
llama_role_conversions = {
|
| 71 |
+
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class HfEngine:
|
| 76 |
+
def __init__(self, model_id: Optional[str] = None):
|
| 77 |
+
self.last_input_token_count = None
|
| 78 |
+
self.last_output_token_count = None
|
| 79 |
+
if model_id is None:
|
| 80 |
+
model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
| 81 |
+
logger.warning(f"Using default model for token counting: '{model_id}'")
|
| 82 |
+
try:
|
| 83 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 84 |
+
except Exception as e:
|
| 85 |
+
logger.warning(f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead.")
|
| 86 |
+
self.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct")
|
| 87 |
+
|
| 88 |
+
def get_token_counts(self):
|
| 89 |
+
return {
|
| 90 |
+
"input_token_count": self.last_input_token_count,
|
| 91 |
+
"output_token_count": self.last_output_token_count,
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
def generate(
|
| 95 |
+
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
|
| 96 |
+
):
|
| 97 |
+
raise NotImplementedError
|
| 98 |
+
|
| 99 |
+
def __call__(
|
| 100 |
+
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
|
| 101 |
+
) -> str:
|
| 102 |
+
"""Process the input messages and return the model's response.
|
| 103 |
+
|
| 104 |
+
This method sends a list of messages to the Hugging Face Inference API, optionally with stop sequences and grammar customization.
|
| 105 |
+
|
| 106 |
+
Parameters:
|
| 107 |
+
messages (`List[Dict[str, str]]`):
|
| 108 |
+
A list of message dictionaries to be processed. Each dictionary should have the structure `{"role": "user/system", "content": "message content"}`.
|
| 109 |
+
stop_sequences (`List[str]`, *optional*):
|
| 110 |
+
A list of strings that will stop the generation if encountered in the model's output.
|
| 111 |
+
grammar (`str`, *optional*):
|
| 112 |
+
The grammar or formatting structure to use in the model's response.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
`str`: The text content of the model's response.
|
| 116 |
+
|
| 117 |
+
Example:
|
| 118 |
+
```python
|
| 119 |
+
>>> engine = HfApiEngine(
|
| 120 |
+
... model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
| 121 |
+
... token="your_hf_token_here",
|
| 122 |
+
... max_tokens=2000
|
| 123 |
+
... )
|
| 124 |
+
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
|
| 125 |
+
>>> response = engine(messages, stop_sequences=["END"])
|
| 126 |
+
>>> print(response)
|
| 127 |
+
"Quantum mechanics is the branch of physics that studies..."
|
| 128 |
+
```
|
| 129 |
+
"""
|
| 130 |
+
if not isinstance(messages, List):
|
| 131 |
+
raise ValueError("Messages should be a list of dictionaries with 'role' and 'content' keys.")
|
| 132 |
+
if stop_sequences is None:
|
| 133 |
+
stop_sequences = []
|
| 134 |
+
response = self.generate(messages, stop_sequences, grammar)
|
| 135 |
+
self.last_input_token_count = len(self.tokenizer.apply_chat_template(messages, tokenize=True))
|
| 136 |
+
self.last_output_token_count = len(self.tokenizer.encode(response))
|
| 137 |
+
|
| 138 |
+
# Remove stop sequences from LLM output
|
| 139 |
+
for stop_seq in stop_sequences:
|
| 140 |
+
if response[-len(stop_seq) :] == stop_seq:
|
| 141 |
+
response = response[: -len(stop_seq)]
|
| 142 |
+
return response
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class HfApiEngine(HfEngine):
|
| 146 |
+
"""A class to interact with Hugging Face's Inference API for language model interaction.
|
| 147 |
+
|
| 148 |
+
This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
|
| 149 |
+
|
| 150 |
+
Parameters:
|
| 151 |
+
model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3.1-8B-Instruct"`):
|
| 152 |
+
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
|
| 153 |
+
token (`str`, *optional*):
|
| 154 |
+
Token used by the Hugging Face API for authentication.
|
| 155 |
+
If not provided, the class will use the token stored in the Hugging Face CLI configuration.
|
| 156 |
+
max_tokens (`int`, *optional*, defaults to 1500):
|
| 157 |
+
The maximum number of tokens allowed in the output.
|
| 158 |
+
timeout (`int`, *optional*, defaults to 120):
|
| 159 |
+
Timeout for the API request, in seconds.
|
| 160 |
+
|
| 161 |
+
Raises:
|
| 162 |
+
ValueError:
|
| 163 |
+
If the model name is not provided.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
| 169 |
+
token: Optional[str] = None,
|
| 170 |
+
max_tokens: Optional[int] = 1500,
|
| 171 |
+
timeout: Optional[int] = 120,
|
| 172 |
+
):
|
| 173 |
+
super().__init__(model_id=model)
|
| 174 |
+
self.model = model
|
| 175 |
+
self.client = InferenceClient(self.model, token=token, timeout=timeout)
|
| 176 |
+
self.max_tokens = max_tokens
|
| 177 |
+
|
| 178 |
+
def generate(
|
| 179 |
+
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
|
| 180 |
+
) -> str:
|
| 181 |
+
# Get clean message list
|
| 182 |
+
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
|
| 183 |
+
|
| 184 |
+
# Send messages to the Hugging Face Inference API
|
| 185 |
+
if grammar is not None:
|
| 186 |
+
response = self.client.chat_completion(
|
| 187 |
+
messages, stop=stop_sequences, max_tokens=self.max_tokens, response_format=grammar
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=self.max_tokens)
|
| 191 |
+
|
| 192 |
+
response = response.choices[0].message.content
|
| 193 |
+
return response
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class TransformersEngine(HfEngine):
|
| 197 |
+
"""This engine uses a pre-initialized local text-generation pipeline."""
|
| 198 |
+
|
| 199 |
+
def __init__(self, pipeline: Pipeline, model_id: Optional[str] = None):
|
| 200 |
+
super().__init__(model_id)
|
| 201 |
+
self.pipeline = pipeline
|
| 202 |
+
|
| 203 |
+
def generate(
|
| 204 |
+
self,
|
| 205 |
+
messages: List[Dict[str, str]],
|
| 206 |
+
stop_sequences: Optional[List[str]] = None,
|
| 207 |
+
grammar: Optional[str] = None,
|
| 208 |
+
max_length: int = 1500,
|
| 209 |
+
) -> str:
|
| 210 |
+
# Get clean message list
|
| 211 |
+
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
|
| 212 |
+
|
| 213 |
+
# Get LLM output
|
| 214 |
+
if stop_sequences is not None and len(stop_sequences) > 0:
|
| 215 |
+
stop_strings = stop_sequences
|
| 216 |
+
else:
|
| 217 |
+
stop_strings = None
|
| 218 |
+
|
| 219 |
+
output = self.pipeline(
|
| 220 |
+
messages,
|
| 221 |
+
stop_strings=stop_strings,
|
| 222 |
+
max_length=max_length,
|
| 223 |
+
tokenizer=self.pipeline.tokenizer,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
response = output[0]["generated_text"][-1]["content"]
|
| 227 |
+
return response
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
|
| 231 |
+
"type": "regex",
|
| 232 |
+
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>',
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
|
| 236 |
+
"type": "regex",
|
| 237 |
+
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
|
| 238 |
+
}
|
.venv/Lib/site-packages/transformers/agents/monitoring.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
|
| 4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
from ..utils import logging
|
| 18 |
+
from .agent_types import AgentAudio, AgentImage, AgentText
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def pull_message(step_log: dict, test_mode: bool = True):
|
| 25 |
+
try:
|
| 26 |
+
from gradio import ChatMessage
|
| 27 |
+
except ImportError:
|
| 28 |
+
if test_mode:
|
| 29 |
+
|
| 30 |
+
class ChatMessage:
|
| 31 |
+
def __init__(self, role, content, metadata=None):
|
| 32 |
+
self.role = role
|
| 33 |
+
self.content = content
|
| 34 |
+
self.metadata = metadata
|
| 35 |
+
else:
|
| 36 |
+
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
| 37 |
+
|
| 38 |
+
if step_log.get("rationale"):
|
| 39 |
+
yield ChatMessage(role="assistant", content=step_log["rationale"])
|
| 40 |
+
if step_log.get("tool_call"):
|
| 41 |
+
used_code = step_log["tool_call"]["tool_name"] == "code interpreter"
|
| 42 |
+
content = step_log["tool_call"]["tool_arguments"]
|
| 43 |
+
if used_code:
|
| 44 |
+
content = f"```py\n{content}\n```"
|
| 45 |
+
yield ChatMessage(
|
| 46 |
+
role="assistant",
|
| 47 |
+
metadata={"title": f"🛠️ Used tool {step_log['tool_call']['tool_name']}"},
|
| 48 |
+
content=str(content),
|
| 49 |
+
)
|
| 50 |
+
if step_log.get("observation"):
|
| 51 |
+
yield ChatMessage(role="assistant", content=f"```\n{step_log['observation']}\n```")
|
| 52 |
+
if step_log.get("error"):
|
| 53 |
+
yield ChatMessage(
|
| 54 |
+
role="assistant",
|
| 55 |
+
content=str(step_log["error"]),
|
| 56 |
+
metadata={"title": "💥 Error"},
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def stream_to_gradio(agent, task: str, test_mode: bool = False, **kwargs):
|
| 61 |
+
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
from gradio import ChatMessage
|
| 65 |
+
except ImportError:
|
| 66 |
+
if test_mode:
|
| 67 |
+
|
| 68 |
+
class ChatMessage:
|
| 69 |
+
def __init__(self, role, content, metadata=None):
|
| 70 |
+
self.role = role
|
| 71 |
+
self.content = content
|
| 72 |
+
self.metadata = metadata
|
| 73 |
+
else:
|
| 74 |
+
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
| 75 |
+
|
| 76 |
+
for step_log in agent.run(task, stream=True, **kwargs):
|
| 77 |
+
if isinstance(step_log, dict):
|
| 78 |
+
for message in pull_message(step_log, test_mode=test_mode):
|
| 79 |
+
yield message
|
| 80 |
+
|
| 81 |
+
final_answer = step_log # Last log is the run's final_answer
|
| 82 |
+
|
| 83 |
+
if isinstance(final_answer, AgentText):
|
| 84 |
+
yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```")
|
| 85 |
+
elif isinstance(final_answer, AgentImage):
|
| 86 |
+
yield ChatMessage(
|
| 87 |
+
role="assistant",
|
| 88 |
+
content={"path": final_answer.to_string(), "mime_type": "image/png"},
|
| 89 |
+
)
|
| 90 |
+
elif isinstance(final_answer, AgentAudio):
|
| 91 |
+
yield ChatMessage(
|
| 92 |
+
role="assistant",
|
| 93 |
+
content={"path": final_answer.to_string(), "mime_type": "audio/wav"},
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
yield ChatMessage(role="assistant", content=str(final_answer))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class Monitor:
|
| 100 |
+
def __init__(self, tracked_llm_engine):
|
| 101 |
+
self.step_durations = []
|
| 102 |
+
self.tracked_llm_engine = tracked_llm_engine
|
| 103 |
+
if getattr(self.tracked_llm_engine, "last_input_token_count", "Not found") != "Not found":
|
| 104 |
+
self.total_input_token_count = 0
|
| 105 |
+
self.total_output_token_count = 0
|
| 106 |
+
|
| 107 |
+
def update_metrics(self, step_log):
|
| 108 |
+
step_duration = step_log["step_duration"]
|
| 109 |
+
self.step_durations.append(step_duration)
|
| 110 |
+
logger.info(f"Step {len(self.step_durations)}:")
|
| 111 |
+
logger.info(f"- Time taken: {step_duration:.2f} seconds (valid only if step succeeded)")
|
| 112 |
+
|
| 113 |
+
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
|
| 114 |
+
self.total_input_token_count += self.tracked_llm_engine.last_input_token_count
|
| 115 |
+
self.total_output_token_count += self.tracked_llm_engine.last_output_token_count
|
| 116 |
+
logger.info(f"- Input tokens: {self.total_input_token_count}")
|
| 117 |
+
logger.info(f"- Output tokens: {self.total_output_token_count}")
|
.venv/Lib/site-packages/transformers/agents/prompts.py
ADDED
|
@@ -0,0 +1,789 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
|
| 4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
import re
|
| 18 |
+
|
| 19 |
+
from ..utils import cached_file
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# docstyle-ignore
|
| 23 |
+
CHAT_MESSAGE_PROMPT = """
|
| 24 |
+
Human: <<task>>
|
| 25 |
+
|
| 26 |
+
Assistant: """
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
DEFAULT_PROMPTS_REPO = "huggingface-tools/default-prompts"
|
| 30 |
+
PROMPT_FILES = {"chat": "chat_prompt_template.txt", "run": "run_prompt_template.txt"}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
|
| 34 |
+
"""
|
| 35 |
+
Downloads and caches the prompt from a repo and returns it contents (if necessary).
|
| 36 |
+
"""
|
| 37 |
+
if prompt_or_repo_id is None:
|
| 38 |
+
prompt_or_repo_id = DEFAULT_PROMPTS_REPO
|
| 39 |
+
|
| 40 |
+
# prompt is considered a repo ID when it does not contain any kind of space
|
| 41 |
+
if re.search("\\s", prompt_or_repo_id) is not None:
|
| 42 |
+
return prompt_or_repo_id
|
| 43 |
+
|
| 44 |
+
prompt_file = cached_file(
|
| 45 |
+
prompt_or_repo_id, PROMPT_FILES[mode], repo_type="dataset", user_agent={"agent": agent_name}
|
| 46 |
+
)
|
| 47 |
+
with open(prompt_file, "r", encoding="utf-8") as f:
|
| 48 |
+
return f.read()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
DEFAULT_CODE_SYSTEM_PROMPT = """You will be given a task to solve, your job is to come up with a series of simple commands in Python that will perform the task.
|
| 52 |
+
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
|
| 53 |
+
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
|
| 54 |
+
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
|
| 55 |
+
In the end, use tool 'final_answer' to return your answer, its argument will be what gets returned.
|
| 56 |
+
You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
|
| 57 |
+
Be sure to provide a 'Code:' token, else the run will fail.
|
| 58 |
+
|
| 59 |
+
Tools:
|
| 60 |
+
<<tool_descriptions>>
|
| 61 |
+
|
| 62 |
+
Examples:
|
| 63 |
+
---
|
| 64 |
+
Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
|
| 65 |
+
|
| 66 |
+
Thought: I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
|
| 67 |
+
Code:
|
| 68 |
+
```py
|
| 69 |
+
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
| 70 |
+
print(f"The translated question is {translated_question}.")
|
| 71 |
+
answer = image_qa(image=image, question=translated_question)
|
| 72 |
+
final_answer(f"The answer is {answer}")
|
| 73 |
+
```<end_action>
|
| 74 |
+
|
| 75 |
+
---
|
| 76 |
+
Task: "Identify the oldest person in the `document` and create an image showcasing the result."
|
| 77 |
+
|
| 78 |
+
Thought: I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
| 79 |
+
Code:
|
| 80 |
+
```py
|
| 81 |
+
answer = document_qa(document, question="What is the oldest person?")
|
| 82 |
+
print(f"The answer is {answer}.")
|
| 83 |
+
image = image_generator(answer)
|
| 84 |
+
final_answer(image)
|
| 85 |
+
```<end_action>
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
Task: "Generate an image using the text given in the variable `caption`."
|
| 89 |
+
|
| 90 |
+
Thought: I will use the following tool: `image_generator` to generate an image.
|
| 91 |
+
Code:
|
| 92 |
+
```py
|
| 93 |
+
image = image_generator(prompt=caption)
|
| 94 |
+
final_answer(image)
|
| 95 |
+
```<end_action>
|
| 96 |
+
|
| 97 |
+
---
|
| 98 |
+
Task: "Summarize the text given in the variable `text` and read it out loud."
|
| 99 |
+
|
| 100 |
+
Thought: I will use the following tools: `summarizer` to create a summary of the input text, then `text_reader` to read it out loud.
|
| 101 |
+
Code:
|
| 102 |
+
```py
|
| 103 |
+
summarized_text = summarizer(text)
|
| 104 |
+
print(f"Summary: {summarized_text}")
|
| 105 |
+
audio_summary = text_reader(summarized_text)
|
| 106 |
+
final_answer(audio_summary)
|
| 107 |
+
```<end_action>
|
| 108 |
+
|
| 109 |
+
---
|
| 110 |
+
Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image."
|
| 111 |
+
|
| 112 |
+
Thought: I will use the following tools: `text_qa` to create the answer, then `image_generator` to generate an image according to the answer.
|
| 113 |
+
Code:
|
| 114 |
+
```py
|
| 115 |
+
answer = text_qa(text=text, question=question)
|
| 116 |
+
print(f"The answer is {answer}.")
|
| 117 |
+
image = image_generator(answer)
|
| 118 |
+
final_answer(image)
|
| 119 |
+
```<end_action>
|
| 120 |
+
|
| 121 |
+
---
|
| 122 |
+
Task: "Caption the following `image`."
|
| 123 |
+
|
| 124 |
+
Thought: I will use the following tool: `image_captioner` to generate a caption for the image.
|
| 125 |
+
Code:
|
| 126 |
+
```py
|
| 127 |
+
caption = image_captioner(image)
|
| 128 |
+
final_answer(caption)
|
| 129 |
+
```<end_action>
|
| 130 |
+
|
| 131 |
+
---
|
| 132 |
+
Above example were using tools that might not exist for you. You only have acces to those Tools:
|
| 133 |
+
<<tool_names>>
|
| 134 |
+
|
| 135 |
+
Remember to make sure that variables you use are all defined.
|
| 136 |
+
Be sure to provide a 'Code:\n```' sequence before the code and '```<end_action>' after, else you will get an error.
|
| 137 |
+
DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'.
|
| 138 |
+
|
| 139 |
+
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
DEFAULT_REACT_JSON_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using JSON tool calls. You will be given a task to solve as best you can.
|
| 144 |
+
To do so, you have been given access to the following tools: <<tool_names>>
|
| 145 |
+
The way you use the tools is by specifying a json blob, ending with '<end_action>'.
|
| 146 |
+
Specifically, this json should have an `action` key (name of the tool to use) and an `action_input` key (input to the tool).
|
| 147 |
+
|
| 148 |
+
The $ACTION_JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. It should be formatted in json. Do not try to escape special characters. Here is the template of a valid $ACTION_JSON_BLOB:
|
| 149 |
+
{
|
| 150 |
+
"action": $TOOL_NAME,
|
| 151 |
+
"action_input": $INPUT
|
| 152 |
+
}<end_action>
|
| 153 |
+
|
| 154 |
+
Make sure to have the $INPUT as a dictionary in the right format for the tool you are using, and do not put variable names as input if you can find the right values.
|
| 155 |
+
|
| 156 |
+
You should ALWAYS use the following format:
|
| 157 |
+
|
| 158 |
+
Thought: you should always think about one action to take. Then use the action as follows:
|
| 159 |
+
Action:
|
| 160 |
+
$ACTION_JSON_BLOB
|
| 161 |
+
Observation: the result of the action
|
| 162 |
+
... (this Thought/Action/Observation can repeat N times, you should take several steps when needed. The $ACTION_JSON_BLOB must only use a SINGLE action at a time.)
|
| 163 |
+
|
| 164 |
+
You can use the result of the previous action as input for the next action.
|
| 165 |
+
The observation will always be a string: it can represent a file, like "image_1.jpg".
|
| 166 |
+
Then you can use it as input for the next action. You can do it for instance as follows:
|
| 167 |
+
|
| 168 |
+
Observation: "image_1.jpg"
|
| 169 |
+
|
| 170 |
+
Thought: I need to transform the image that I received in the previous observation to make it green.
|
| 171 |
+
Action:
|
| 172 |
+
{
|
| 173 |
+
"action": "image_transformer",
|
| 174 |
+
"action_input": {"image": "image_1.jpg"}
|
| 175 |
+
}<end_action>
|
| 176 |
+
|
| 177 |
+
To provide the final answer to the task, use an action blob with "action": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this:
|
| 178 |
+
Action:
|
| 179 |
+
{
|
| 180 |
+
"action": "final_answer",
|
| 181 |
+
"action_input": {"answer": "insert your final answer here"}
|
| 182 |
+
}<end_action>
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
Here are a few examples using notional tools:
|
| 186 |
+
---
|
| 187 |
+
Task: "Generate an image of the oldest person in this document."
|
| 188 |
+
|
| 189 |
+
Thought: I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
| 190 |
+
Action:
|
| 191 |
+
{
|
| 192 |
+
"action": "document_qa",
|
| 193 |
+
"action_input": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"}
|
| 194 |
+
}<end_action>
|
| 195 |
+
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
Thought: I will now generate an image showcasing the oldest person.
|
| 199 |
+
Action:
|
| 200 |
+
{
|
| 201 |
+
"action": "image_generator",
|
| 202 |
+
"action_input": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}
|
| 203 |
+
}<end_action>
|
| 204 |
+
Observation: "image.png"
|
| 205 |
+
|
| 206 |
+
Thought: I will now return the generated image.
|
| 207 |
+
Action:
|
| 208 |
+
{
|
| 209 |
+
"action": "final_answer",
|
| 210 |
+
"action_input": "image.png"
|
| 211 |
+
}<end_action>
|
| 212 |
+
|
| 213 |
+
---
|
| 214 |
+
Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
| 215 |
+
|
| 216 |
+
Thought: I will use python code evaluator to compute the result of the operation and then return the final answer using the `final_answer` tool
|
| 217 |
+
Action:
|
| 218 |
+
{
|
| 219 |
+
"action": "python_interpreter",
|
| 220 |
+
"action_input": {"code": "5 + 3 + 1294.678"}
|
| 221 |
+
}<end_action>
|
| 222 |
+
Observation: 1302.678
|
| 223 |
+
|
| 224 |
+
Thought: Now that I know the result, I will now return it.
|
| 225 |
+
Action:
|
| 226 |
+
{
|
| 227 |
+
"action": "final_answer",
|
| 228 |
+
"action_input": "1302.678"
|
| 229 |
+
}<end_action>
|
| 230 |
+
|
| 231 |
+
---
|
| 232 |
+
Task: "Which city has the highest population , Guangzhou or Shanghai?"
|
| 233 |
+
|
| 234 |
+
Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
|
| 235 |
+
Action:
|
| 236 |
+
{
|
| 237 |
+
"action": "search",
|
| 238 |
+
"action_input": "Population Guangzhou"
|
| 239 |
+
}<end_action>
|
| 240 |
+
Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
Thought: Now let's get the population of Shanghai using the tool 'search'.
|
| 244 |
+
Action:
|
| 245 |
+
{
|
| 246 |
+
"action": "search",
|
| 247 |
+
"action_input": "Population Shanghai"
|
| 248 |
+
}
|
| 249 |
+
Observation: '26 million (2019)'
|
| 250 |
+
|
| 251 |
+
Thought: Now I know that Shanghai has a larger population. Let's return the result.
|
| 252 |
+
Action:
|
| 253 |
+
{
|
| 254 |
+
"action": "final_answer",
|
| 255 |
+
"action_input": "Shanghai"
|
| 256 |
+
}<end_action>
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
Above example were using notional tools that might not exist for you. You only have acces to those tools:
|
| 260 |
+
<<tool_descriptions>>
|
| 261 |
+
|
| 262 |
+
Here are the rules you should always follow to solve your task:
|
| 263 |
+
1. ALWAYS provide a 'Thought:' sequence, and an 'Action:' sequence that ends with <end_action>, else you will fail.
|
| 264 |
+
2. Always use the right arguments for the tools. Never use variable names in the 'action_input' field, use the value instead.
|
| 265 |
+
3. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself.
|
| 266 |
+
4. Never re-do a tool call that you previously did with the exact same parameters.
|
| 267 |
+
|
| 268 |
+
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
|
| 269 |
+
"""
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can.
|
| 273 |
+
To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code.
|
| 274 |
+
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
|
| 275 |
+
|
| 276 |
+
At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use.
|
| 277 |
+
Then in the 'Code:' sequence, you should write the code in simple Python. The code sequence must end with '<end_action>' sequence.
|
| 278 |
+
During each intermediate step, you can use 'print()' to save whatever important information you will then need.
|
| 279 |
+
These print outputs will then appear in the 'Observation:' field, which will be available as input for the next step.
|
| 280 |
+
In the end you have to return a final answer using the `final_answer` tool.
|
| 281 |
+
|
| 282 |
+
Here are a few examples using notional tools:
|
| 283 |
+
---
|
| 284 |
+
Task: "Generate an image of the oldest person in this document."
|
| 285 |
+
|
| 286 |
+
Thought: I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
| 287 |
+
Code:
|
| 288 |
+
```py
|
| 289 |
+
answer = document_qa(document=document, question="Who is the oldest person mentioned?")
|
| 290 |
+
print(answer)
|
| 291 |
+
```<end_action>
|
| 292 |
+
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
|
| 293 |
+
|
| 294 |
+
Thought: I will now generate an image showcasing the oldest person.
|
| 295 |
+
Code:
|
| 296 |
+
```py
|
| 297 |
+
image = image_generator("A portrait of John Doe, a 55-year-old man living in Canada.")
|
| 298 |
+
final_answer(image)
|
| 299 |
+
```<end_action>
|
| 300 |
+
|
| 301 |
+
---
|
| 302 |
+
Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
| 303 |
+
|
| 304 |
+
Thought: I will use python code to compute the result of the operation and then return the final answer using the `final_answer` tool
|
| 305 |
+
Code:
|
| 306 |
+
```py
|
| 307 |
+
result = 5 + 3 + 1294.678
|
| 308 |
+
final_answer(result)
|
| 309 |
+
```<end_action>
|
| 310 |
+
|
| 311 |
+
---
|
| 312 |
+
Task: "Which city has the highest population: Guangzhou or Shanghai?"
|
| 313 |
+
|
| 314 |
+
Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
|
| 315 |
+
Code:
|
| 316 |
+
```py
|
| 317 |
+
population_guangzhou = search("Guangzhou population")
|
| 318 |
+
print("Population Guangzhou:", population_guangzhou)
|
| 319 |
+
population_shanghai = search("Shanghai population")
|
| 320 |
+
print("Population Shanghai:", population_shanghai)
|
| 321 |
+
```<end_action>
|
| 322 |
+
Observation:
|
| 323 |
+
Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
|
| 324 |
+
Population Shanghai: '26 million (2019)'
|
| 325 |
+
|
| 326 |
+
Thought: Now I know that Shanghai has the highest population.
|
| 327 |
+
Code:
|
| 328 |
+
```py
|
| 329 |
+
final_answer("Shanghai")
|
| 330 |
+
```<end_action>
|
| 331 |
+
|
| 332 |
+
---
|
| 333 |
+
Task: "What is the current age of the pope, raised to the power 0.36?"
|
| 334 |
+
|
| 335 |
+
Thought: I will use the tool `wiki` to get the age of the pope, then raise it to the power 0.36.
|
| 336 |
+
Code:
|
| 337 |
+
```py
|
| 338 |
+
pope_age = wiki(query="current pope age")
|
| 339 |
+
print("Pope age:", pope_age)
|
| 340 |
+
```<end_action>
|
| 341 |
+
Observation:
|
| 342 |
+
Pope age: "The pope Francis is currently 85 years old."
|
| 343 |
+
|
| 344 |
+
Thought: I know that the pope is 85 years old. Let's compute the result using python code.
|
| 345 |
+
Code:
|
| 346 |
+
```py
|
| 347 |
+
pope_current_age = 85 ** 0.36
|
| 348 |
+
final_answer(pope_current_age)
|
| 349 |
+
```<end_action>
|
| 350 |
+
|
| 351 |
+
Above example were using notional tools that might not exist for you. On top of performing computations in the Python code snippets that you create, you have acces to those tools (and no other tool):
|
| 352 |
+
|
| 353 |
+
<<tool_descriptions>>
|
| 354 |
+
|
| 355 |
+
<<managed_agents_descriptions>>
|
| 356 |
+
|
| 357 |
+
Here are the rules you should always follow to solve your task:
|
| 358 |
+
1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```<end_action>' sequence, else you will fail.
|
| 359 |
+
2. Use only variables that you have defined!
|
| 360 |
+
3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = wiki({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = wiki(query="What is the place where James Bond lives?")'.
|
| 361 |
+
4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
|
| 362 |
+
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
|
| 363 |
+
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
|
| 364 |
+
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
|
| 365 |
+
8. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
|
| 366 |
+
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
|
| 367 |
+
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
|
| 368 |
+
|
| 369 |
+
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
|
| 370 |
+
"""
|
| 371 |
+
|
| 372 |
+
SYSTEM_PROMPT_FACTS = """Below I will present you a task.
|
| 373 |
+
|
| 374 |
+
You will now build a comprehensive preparatory survey of which facts we have at our disposal and which ones we still need.
|
| 375 |
+
To do so, you will have to read the task and identify things that must be discovered in order to successfully complete it.
|
| 376 |
+
Don't make any assumptions. For each item, provide a thorough reasoning. Here is how you will structure this survey:
|
| 377 |
+
|
| 378 |
+
---
|
| 379 |
+
### 1. Facts given in the task
|
| 380 |
+
List here the specific facts given in the task that could help you (there might be nothing here).
|
| 381 |
+
|
| 382 |
+
### 2. Facts to look up
|
| 383 |
+
List here any facts that we may need to look up.
|
| 384 |
+
Also list where to find each of these, for instance a website, a file... - maybe the task contains some sources that you should re-use here.
|
| 385 |
+
|
| 386 |
+
### 3. Facts to derive
|
| 387 |
+
List here anything that we want to derive from the above by logical reasoning, for instance computation or simulation.
|
| 388 |
+
|
| 389 |
+
Keep in mind that "facts" will typically be specific names, dates, values, etc. Your answer should use the below headings:
|
| 390 |
+
### 1. Facts given in the task
|
| 391 |
+
### 2. Facts to look up
|
| 392 |
+
### 3. Facts to derive
|
| 393 |
+
Do not add anything else."""
|
| 394 |
+
|
| 395 |
+
SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
|
| 396 |
+
|
| 397 |
+
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
|
| 398 |
+
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
|
| 399 |
+
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
|
| 400 |
+
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there."""
|
| 401 |
+
|
| 402 |
+
USER_PROMPT_PLAN = """
|
| 403 |
+
Here is your task:
|
| 404 |
+
|
| 405 |
+
Task:
|
| 406 |
+
```
|
| 407 |
+
{task}
|
| 408 |
+
```
|
| 409 |
+
|
| 410 |
+
Your plan can leverage any of these tools:
|
| 411 |
+
{tool_descriptions}
|
| 412 |
+
|
| 413 |
+
{managed_agents_descriptions}
|
| 414 |
+
|
| 415 |
+
List of facts that you know:
|
| 416 |
+
```
|
| 417 |
+
{answer_facts}
|
| 418 |
+
```
|
| 419 |
+
|
| 420 |
+
Now begin! Write your plan below."""
|
| 421 |
+
|
| 422 |
+
SYSTEM_PROMPT_FACTS_UPDATE = """
|
| 423 |
+
You are a world expert at gathering known and unknown facts based on a conversation.
|
| 424 |
+
Below you will find a task, and ahistory of attempts made to solve the task. You will have to produce a list of these:
|
| 425 |
+
### 1. Facts given in the task
|
| 426 |
+
### 2. Facts that we have learned
|
| 427 |
+
### 3. Facts still to look up
|
| 428 |
+
### 4. Facts still to derive
|
| 429 |
+
Find the task and history below."""
|
| 430 |
+
|
| 431 |
+
USER_PROMPT_FACTS_UPDATE = """Earlier we've built a list of facts.
|
| 432 |
+
But since in your previous steps you may have learned useful new facts or invalidated some false ones.
|
| 433 |
+
Please update your list of facts based on the previous history, and provide these headings:
|
| 434 |
+
### 1. Facts given in the task
|
| 435 |
+
### 2. Facts that we have learned
|
| 436 |
+
### 3. Facts still to look up
|
| 437 |
+
### 4. Facts still to derive
|
| 438 |
+
|
| 439 |
+
Now write your new list of facts below."""
|
| 440 |
+
|
| 441 |
+
SYSTEM_PROMPT_PLAN_UPDATE = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
|
| 442 |
+
|
| 443 |
+
You have been given a task:
|
| 444 |
+
```
|
| 445 |
+
{task}
|
| 446 |
+
```
|
| 447 |
+
|
| 448 |
+
Find below the record of what has been tried so far to solve it. Then you will be asked to make an updated plan to solve the task.
|
| 449 |
+
If the previous tries so far have met some success, you can make an updated plan based on these actions.
|
| 450 |
+
If you are stalled, you can make a completely new plan starting from scratch.
|
| 451 |
+
"""
|
| 452 |
+
|
| 453 |
+
USER_PROMPT_PLAN_UPDATE = """You're still working towards solving this task:
|
| 454 |
+
```
|
| 455 |
+
{task}
|
| 456 |
+
```
|
| 457 |
+
|
| 458 |
+
You have access to these tools and only these:
|
| 459 |
+
{tool_descriptions}
|
| 460 |
+
|
| 461 |
+
{managed_agents_descriptions}
|
| 462 |
+
|
| 463 |
+
Here is the up to date list of facts that you know:
|
| 464 |
+
```
|
| 465 |
+
{facts_update}
|
| 466 |
+
```
|
| 467 |
+
|
| 468 |
+
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
|
| 469 |
+
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
|
| 470 |
+
Beware that you have {remaining_steps} steps remaining.
|
| 471 |
+
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
|
| 472 |
+
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there.
|
| 473 |
+
|
| 474 |
+
Now write your new plan below."""
|
| 475 |
+
|
| 476 |
+
SYSTEM_PROMPT_PLAN_STRUCTURED = """Output a step-by-step plan to solve the task using the given tools.
|
| 477 |
+
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer. Each step should be structured as follows:
|
| 478 |
+
Step #n: {
|
| 479 |
+
"description": <description of what the step does and its output>
|
| 480 |
+
"tool": <tool to use>,
|
| 481 |
+
"params": {
|
| 482 |
+
<parameters to pass to the tool as a valid dict>
|
| 483 |
+
}
|
| 484 |
+
"output_var": <output variable name>
|
| 485 |
+
}
|
| 486 |
+
Each step must be necessary to reach the final answer. Steps should reuse outputs produced by earlier steps. The last step must be the final answer.
|
| 487 |
+
|
| 488 |
+
Below are some examples:
|
| 489 |
+
|
| 490 |
+
Example 1:
|
| 491 |
+
------
|
| 492 |
+
Inputs:
|
| 493 |
+
---
|
| 494 |
+
Task:
|
| 495 |
+
How many encoder blocks were in the first attention-only ML architecture published?
|
| 496 |
+
|
| 497 |
+
[FACTS LIST]:
|
| 498 |
+
### 1. Facts given in the task
|
| 499 |
+
- The paper first introduced an attention-only ML architecture.
|
| 500 |
+
- The specific information required is the page number where the number of encoder blocks is stated.
|
| 501 |
+
- No local files are provided for access.
|
| 502 |
+
|
| 503 |
+
### 2. Facts to look up
|
| 504 |
+
- The title and authors of the paper that first introduced an attention-only ML architecture.
|
| 505 |
+
- Source: Online search (e.g., Google Scholar, arXiv, or other academic databases)
|
| 506 |
+
- The full text of the identified paper.
|
| 507 |
+
- Source: Online academic repositories (e.g., arXiv, journal websites)
|
| 508 |
+
- The specific page number in the paper where the number of encoder blocks is mentioned.
|
| 509 |
+
- Source: The content of the identified paper
|
| 510 |
+
|
| 511 |
+
### 3. Facts to derive
|
| 512 |
+
- By identifying the correct paper and locating the specific page, we will derive the page number where the number of encoder blocks is stated.
|
| 513 |
+
- Logical steps: Identify the correct paper, access its content, search for the term "encoder blocks," and note the page number where this information is found.
|
| 514 |
+
```
|
| 515 |
+
|
| 516 |
+
[STEP 1 TOOL CALL]: {'tool_name': 'code interpreter', 'tool_arguments': '# Step 1: Identify the title and authors of the paper that first introduced an attention-only ML architecture.\nanswer = ask_search_agent(query="Can you find the title and authors of the paper that first introduced an attention-only machine learning architecture? Please provide the full citation.")\nprint(answer)'}
|
| 517 |
+
[OUTPUT OF STEP 1] Observation: **Title**: Attention Is All You Need
|
| 518 |
+
**Authors**: Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin
|
| 519 |
+
[STEP 2 TOOL CALL]: {'tool_name': 'code interpreter', 'tool_arguments': '# Step 1: Find the full text of the identified paper on arXiv\\npaper_url = "https://arxiv.org/pdf/1706.03762.pdf"\\nprint(paper_url)'}
|
| 520 |
+
[OUTPUT OF STEP 2] Observation: https://arxiv.org/pdf/1706.03762.pdf
|
| 521 |
+
---
|
| 522 |
+
|
| 523 |
+
Output plan:
|
| 524 |
+
---
|
| 525 |
+
Step #1: {
|
| 526 |
+
"description": "Open the PDF of the paper from the provided URL and search within the text of the paper for the mention of "encoder blocks"",
|
| 527 |
+
"tool": "inspect_file_as_text",
|
| 528 |
+
"params": {
|
| 529 |
+
"file_path": "https://arxiv.org/pdf/1706.03762.pdf",
|
| 530 |
+
"question": "On which page is the number of encoder blocks mentioned?"
|
| 531 |
+
},
|
| 532 |
+
"output_var": "page_number"
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
Step #2: {
|
| 536 |
+
"description": "Provide the final answer",
|
| 537 |
+
"tool": "final_answer",
|
| 538 |
+
"params": {
|
| 539 |
+
"answer": "{page_number}"
|
| 540 |
+
},
|
| 541 |
+
"output_var": ""
|
| 542 |
+
}
|
| 543 |
+
------
|
| 544 |
+
|
| 545 |
+
Example 2:
|
| 546 |
+
------
|
| 547 |
+
Inputs:
|
| 548 |
+
---
|
| 549 |
+
Task:
|
| 550 |
+
How many golf balls fits into a Boeing-747?
|
| 551 |
+
|
| 552 |
+
[FACTS LIST]:
|
| 553 |
+
### 1. Facts given in the task
|
| 554 |
+
- The task requires calculating the number of golf balls that fir into a Boeing-747
|
| 555 |
+
### 2. Facts to look up
|
| 556 |
+
- The volume of a golf ball
|
| 557 |
+
- The volume of a Boeing-747
|
| 558 |
+
### 3. Facts to derive
|
| 559 |
+
- Once the volumes are known the final answer can be calculated
|
| 560 |
+
---
|
| 561 |
+
Output plan:
|
| 562 |
+
---
|
| 563 |
+
Step #1: {
|
| 564 |
+
"description": "Find the volume of a Boeing-747",
|
| 565 |
+
"tool": "web_search",
|
| 566 |
+
"params": {
|
| 567 |
+
"query": "What is the internal volume of a Boeing-747 in cubic meters?"
|
| 568 |
+
},
|
| 569 |
+
"output_var": "boeing_volume"
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
Step #2: {
|
| 573 |
+
"description": "Find the volume of a standard golf ball",
|
| 574 |
+
"tool": "ask_search_agent",
|
| 575 |
+
"params": {
|
| 576 |
+
"query": "What is the volume of a standard golf ball in cubic centimeters?"
|
| 577 |
+
},
|
| 578 |
+
"output_var": "golf_ball_volume"
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
Step #3: {
|
| 582 |
+
"description": "Convert the volume of a golf ball from cubic centimeters to cubic meters. Calculate the number of golf balls that fit into the Boeing-747 by dividing the internal volume of the Boeing-747 by the volume of a golf ball.",
|
| 583 |
+
"tool": "python_code",
|
| 584 |
+
"params": {
|
| 585 |
+
"code": "golf_ball_volume_m3 = golf_ball_volume / 1e6\nnumber_of_golf_balls = boeing_volume / golf_ball_volume_m3"
|
| 586 |
+
},
|
| 587 |
+
"output_var": "number_of_golf_balls"
|
| 588 |
+
}
|
| 589 |
+
|
| 590 |
+
Step #4: {
|
| 591 |
+
"description": "Provide the final answer",
|
| 592 |
+
"tool": "final_answer",
|
| 593 |
+
"params": {
|
| 594 |
+
"answer": "{number_of_golf_balls}"
|
| 595 |
+
},
|
| 596 |
+
"output_var": ""
|
| 597 |
+
}
|
| 598 |
+
------
|
| 599 |
+
Above example were using tools that might not exist for you.
|
| 600 |
+
Your goal is to create a plan to solve the task."""
|
| 601 |
+
|
| 602 |
+
USER_PROMPT_PLAN_STRUCTURED = """
|
| 603 |
+
Here are your inputs:
|
| 604 |
+
|
| 605 |
+
Task:
|
| 606 |
+
```
|
| 607 |
+
{task}
|
| 608 |
+
```
|
| 609 |
+
|
| 610 |
+
Your plan can leverage any of these tools:
|
| 611 |
+
{tool_descriptions}
|
| 612 |
+
These tools are Python functions which you can call with code. You also have access to a Python interpreter so you can run Python code.
|
| 613 |
+
|
| 614 |
+
List of facts that you know:
|
| 615 |
+
```
|
| 616 |
+
{answer_facts}
|
| 617 |
+
```
|
| 618 |
+
|
| 619 |
+
Now for the given task, create a plan taking into account the list of facts.
|
| 620 |
+
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there. Output the plan only and nothing else."""
|
| 621 |
+
|
| 622 |
+
SYSTEM_PROMPT_PLAN_UPDATE_STRUCTURED = """Output a step-by-step plan to solve the task using the given tools.
|
| 623 |
+
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer. Each step should be structured as follows:
|
| 624 |
+
Step #n: {{
|
| 625 |
+
"description": <description of what the step does and its output>
|
| 626 |
+
"tool": <tool to use>,
|
| 627 |
+
"params": {{
|
| 628 |
+
<parameters to pass to the tool as a valid dict>
|
| 629 |
+
}}
|
| 630 |
+
"output_var": <output variable name>
|
| 631 |
+
}}
|
| 632 |
+
Each step must be necessary to reach the final answer. Steps should reuse outputs produced by earlier steps. The last step must be the final answer.
|
| 633 |
+
|
| 634 |
+
Below are some examples:
|
| 635 |
+
|
| 636 |
+
Example 1:
|
| 637 |
+
------
|
| 638 |
+
Inputs:
|
| 639 |
+
---
|
| 640 |
+
Task:
|
| 641 |
+
How many encoder blocks were in the first attention-only ML architecture published?
|
| 642 |
+
|
| 643 |
+
[FACTS LIST]:
|
| 644 |
+
### 1. Facts given in the task
|
| 645 |
+
- The paper first introduced an attention-only ML architecture.
|
| 646 |
+
- The specific information required is the page number where the number of encoder blocks is stated.
|
| 647 |
+
- No local files are provided for access.
|
| 648 |
+
|
| 649 |
+
### 2. Facts to look up
|
| 650 |
+
- The title and authors of the paper that first introduced an attention-only ML architecture.
|
| 651 |
+
- Source: Online search (e.g., Google Scholar, arXiv, or other academic databases)
|
| 652 |
+
- The full text of the identified paper.
|
| 653 |
+
- Source: Online academic repositories (e.g., arXiv, journal websites)
|
| 654 |
+
- The specific page number in the paper where the number of encoder blocks is mentioned.
|
| 655 |
+
- Source: The content of the identified paper
|
| 656 |
+
|
| 657 |
+
### 3. Facts to derive
|
| 658 |
+
- By identifying the correct paper and locating the specific page, we will derive the page number where the number of encoder blocks is stated.
|
| 659 |
+
- Logical steps: Identify the correct paper, access its content, search for the term "encoder blocks," and note the page number where this information is found.
|
| 660 |
+
```
|
| 661 |
+
|
| 662 |
+
[STEP 1 TOOL CALL]: {{'tool_name': 'code interpreter', 'tool_arguments': '# Step 1: Identify the title and authors of the paper that first introduced an attention-only ML architecture.\nanswer = ask_search_agent(query="Can you find the title and authors of the paper that first introduced an attention-only machine learning architecture? Please provide the full citation.")\nprint(answer)'}}
|
| 663 |
+
[OUTPUT OF STEP 1] Observation: **Title**: Attention Is All You Need
|
| 664 |
+
**Authors**: Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin
|
| 665 |
+
[STEP 2 TOOL CALL]: {{'tool_name': 'code interpreter', 'tool_arguments': '# Step 1: Find the full text of the identified paper on arXiv\\npaper_url = "https://arxiv.org/pdf/1706.03762.pdf"\\nprint(paper_url)'}}
|
| 666 |
+
[OUTPUT OF STEP 2] Observation: https://arxiv.org/pdf/1706.03762.pdf
|
| 667 |
+
---
|
| 668 |
+
|
| 669 |
+
Output plan:
|
| 670 |
+
---
|
| 671 |
+
Step #1: {{
|
| 672 |
+
"description": "Open the PDF of the paper from the provided URL and search within the text of the paper for the mention of "encoder blocks"",
|
| 673 |
+
"tool": "inspect_file_as_text",
|
| 674 |
+
"params": {{
|
| 675 |
+
"file_path": "https://arxiv.org/pdf/1706.03762.pdf",
|
| 676 |
+
"question": "On which page is the number of encoder blocks mentioned?"
|
| 677 |
+
}},
|
| 678 |
+
"output_var": "page_number"
|
| 679 |
+
}}
|
| 680 |
+
|
| 681 |
+
Step #2: {{
|
| 682 |
+
"description": "Provide the final answer",
|
| 683 |
+
"tool": "final_answer",
|
| 684 |
+
"params": {{
|
| 685 |
+
"answer": "{{page_number}}"
|
| 686 |
+
}},
|
| 687 |
+
"output_var": ""
|
| 688 |
+
}}
|
| 689 |
+
------
|
| 690 |
+
|
| 691 |
+
Example 2:
|
| 692 |
+
------
|
| 693 |
+
Inputs:
|
| 694 |
+
---
|
| 695 |
+
Task:
|
| 696 |
+
How many golf balls fits into a Boeing-747?
|
| 697 |
+
|
| 698 |
+
[FACTS LIST]:
|
| 699 |
+
### 1. Facts given in the task
|
| 700 |
+
- The task requires calculating the number of golf balls that fir into a Boeing-747
|
| 701 |
+
### 2. Facts to look up
|
| 702 |
+
- The volume of a golf ball
|
| 703 |
+
- The volume of a Boeing-747
|
| 704 |
+
### 3. Facts to derive
|
| 705 |
+
- Once the volumes are known the final answer can be calculated
|
| 706 |
+
---
|
| 707 |
+
Output plan:
|
| 708 |
+
---
|
| 709 |
+
Step #1: {{
|
| 710 |
+
"description": "Find the volume of a Boeing-747",
|
| 711 |
+
"tool": "web_search",
|
| 712 |
+
"params": {{
|
| 713 |
+
"query": "What is the internal volume of a Boeing-747 in cubic meters?"
|
| 714 |
+
}},
|
| 715 |
+
"output_var": "boeing_volume"
|
| 716 |
+
}}
|
| 717 |
+
|
| 718 |
+
Step #2: {{
|
| 719 |
+
"description": "Find the volume of a standard golf ball",
|
| 720 |
+
"tool": "ask_search_agent",
|
| 721 |
+
"params": {{
|
| 722 |
+
"query": "What is the volume of a standard golf ball in cubic centimeters?"
|
| 723 |
+
}},
|
| 724 |
+
"output_var": "golf_ball_volume"
|
| 725 |
+
}}
|
| 726 |
+
|
| 727 |
+
Step #3: {{
|
| 728 |
+
"description": "Convert the volume of a golf ball from cubic centimeters to cubic meters. Calculate the number of golf balls that fit into the Boeing-747 by dividing the internal volume of the Boeing-747 by the volume of a golf ball.",
|
| 729 |
+
"tool": "python_code",
|
| 730 |
+
"params": {{
|
| 731 |
+
"code": "golf_ball_volume_m3 = golf_ball_volume / 1e6\nnumber_of_golf_balls = boeing_volume / golf_ball_volume_m3"
|
| 732 |
+
}},
|
| 733 |
+
"output_var": "number_of_golf_balls"
|
| 734 |
+
}}
|
| 735 |
+
|
| 736 |
+
Step #4: {{
|
| 737 |
+
"description": "Provide the final answer",
|
| 738 |
+
"tool": "final_answer",
|
| 739 |
+
"params": {{
|
| 740 |
+
"answer": "{{number_of_golf_balls}}"
|
| 741 |
+
}},
|
| 742 |
+
"output_var": ""
|
| 743 |
+
}}
|
| 744 |
+
------
|
| 745 |
+
Above example were using tools that might not exist for you.
|
| 746 |
+
Find below the record of what has been tried so far to solve it. Your goal is to create an updated plan to solve the task."""
|
| 747 |
+
|
| 748 |
+
USER_PROMPT_PLAN_UPDATE_STRUCTURED = """
|
| 749 |
+
Here are your inputs:
|
| 750 |
+
|
| 751 |
+
Task:
|
| 752 |
+
```
|
| 753 |
+
{task}
|
| 754 |
+
```
|
| 755 |
+
|
| 756 |
+
Your plan can leverage any of these tools:
|
| 757 |
+
{tool_descriptions}
|
| 758 |
+
These tools are Python functions which you can call with code. You also have access to a Python interpreter so you can run Python code.
|
| 759 |
+
|
| 760 |
+
List of facts that you know:
|
| 761 |
+
```
|
| 762 |
+
{facts_update}
|
| 763 |
+
```
|
| 764 |
+
|
| 765 |
+
Now for the given task, create a plan taking into account the above inputs and list of facts.
|
| 766 |
+
Beware that you have {remaining_steps} steps remaining.
|
| 767 |
+
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there. Output the plan only and nothing else."""
|
| 768 |
+
|
| 769 |
+
PLAN_UPDATE_FINAL_PLAN_REDACTION = """I still need to solve the task I was given:
|
| 770 |
+
```
|
| 771 |
+
{task}
|
| 772 |
+
```
|
| 773 |
+
|
| 774 |
+
Here is my new/updated plan of action to solve the task:
|
| 775 |
+
```
|
| 776 |
+
{plan_update}
|
| 777 |
+
```"""
|
| 778 |
+
|
| 779 |
+
SUPPORTED_PLAN_TYPES = ["default", "structured"]
|
| 780 |
+
|
| 781 |
+
PROMPTS_FOR_INITIAL_PLAN = {
|
| 782 |
+
"default": {"system": SYSTEM_PROMPT_PLAN, "user": USER_PROMPT_PLAN},
|
| 783 |
+
"structured": {"system": SYSTEM_PROMPT_PLAN_STRUCTURED, "user": USER_PROMPT_PLAN_STRUCTURED},
|
| 784 |
+
}
|
| 785 |
+
|
| 786 |
+
PROMPTS_FOR_PLAN_UPDATE = {
|
| 787 |
+
"default": {"system": SYSTEM_PROMPT_PLAN_UPDATE, "user": USER_PROMPT_PLAN_UPDATE},
|
| 788 |
+
"structured": {"system": SYSTEM_PROMPT_PLAN_UPDATE_STRUCTURED, "user": USER_PROMPT_PLAN_UPDATE_STRUCTURED},
|
| 789 |
+
}
|
.venv/Lib/site-packages/transformers/agents/python_interpreter.py
ADDED
|
@@ -0,0 +1,908 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
|
| 4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
import ast
|
| 18 |
+
import builtins
|
| 19 |
+
import difflib
|
| 20 |
+
from collections.abc import Mapping
|
| 21 |
+
from importlib import import_module
|
| 22 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
from ..utils import is_pandas_available
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if is_pandas_available():
|
| 30 |
+
import pandas as pd
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class InterpreterError(ValueError):
|
| 34 |
+
"""
|
| 35 |
+
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
|
| 36 |
+
operations.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
ERRORS = {
|
| 43 |
+
name: getattr(builtins, name)
|
| 44 |
+
for name in dir(builtins)
|
| 45 |
+
if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
LIST_SAFE_MODULES = [
|
| 50 |
+
"random",
|
| 51 |
+
"collections",
|
| 52 |
+
"math",
|
| 53 |
+
"time",
|
| 54 |
+
"queue",
|
| 55 |
+
"itertools",
|
| 56 |
+
"re",
|
| 57 |
+
"stat",
|
| 58 |
+
"statistics",
|
| 59 |
+
"unicodedata",
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
|
| 63 |
+
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class BreakException(Exception):
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class ContinueException(Exception):
|
| 71 |
+
pass
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class ReturnException(Exception):
|
| 75 |
+
def __init__(self, value):
|
| 76 |
+
self.value = value
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_iterable(obj):
|
| 80 |
+
if isinstance(obj, list):
|
| 81 |
+
return obj
|
| 82 |
+
elif hasattr(obj, "__iter__"):
|
| 83 |
+
return list(obj)
|
| 84 |
+
else:
|
| 85 |
+
raise InterpreterError("Object is not iterable")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def evaluate_unaryop(expression, state, static_tools, custom_tools):
|
| 89 |
+
operand = evaluate_ast(expression.operand, state, static_tools, custom_tools)
|
| 90 |
+
if isinstance(expression.op, ast.USub):
|
| 91 |
+
return -operand
|
| 92 |
+
elif isinstance(expression.op, ast.UAdd):
|
| 93 |
+
return operand
|
| 94 |
+
elif isinstance(expression.op, ast.Not):
|
| 95 |
+
return not operand
|
| 96 |
+
elif isinstance(expression.op, ast.Invert):
|
| 97 |
+
return ~operand
|
| 98 |
+
else:
|
| 99 |
+
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
|
| 103 |
+
args = [arg.arg for arg in lambda_expression.args.args]
|
| 104 |
+
|
| 105 |
+
def lambda_func(*values):
|
| 106 |
+
new_state = state.copy()
|
| 107 |
+
for arg, value in zip(args, values):
|
| 108 |
+
new_state[arg] = value
|
| 109 |
+
return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools)
|
| 110 |
+
|
| 111 |
+
return lambda_func
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def evaluate_while(while_loop, state, static_tools, custom_tools):
|
| 115 |
+
max_iterations = 1000
|
| 116 |
+
iterations = 0
|
| 117 |
+
while evaluate_ast(while_loop.test, state, static_tools, custom_tools):
|
| 118 |
+
for node in while_loop.body:
|
| 119 |
+
try:
|
| 120 |
+
evaluate_ast(node, state, static_tools, custom_tools)
|
| 121 |
+
except BreakException:
|
| 122 |
+
return None
|
| 123 |
+
except ContinueException:
|
| 124 |
+
break
|
| 125 |
+
iterations += 1
|
| 126 |
+
if iterations > max_iterations:
|
| 127 |
+
raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def create_function(func_def, state, static_tools, custom_tools):
|
| 132 |
+
def new_func(*args, **kwargs):
|
| 133 |
+
func_state = state.copy()
|
| 134 |
+
arg_names = [arg.arg for arg in func_def.args.args]
|
| 135 |
+
default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults]
|
| 136 |
+
|
| 137 |
+
# Apply default values
|
| 138 |
+
defaults = dict(zip(arg_names[-len(default_values) :], default_values))
|
| 139 |
+
|
| 140 |
+
# Set positional arguments
|
| 141 |
+
for name, value in zip(arg_names, args):
|
| 142 |
+
func_state[name] = value
|
| 143 |
+
|
| 144 |
+
# # Set keyword arguments
|
| 145 |
+
for name, value in kwargs.items():
|
| 146 |
+
func_state[name] = value
|
| 147 |
+
|
| 148 |
+
# Handle variable arguments
|
| 149 |
+
if func_def.args.vararg:
|
| 150 |
+
vararg_name = func_def.args.vararg.arg
|
| 151 |
+
func_state[vararg_name] = args
|
| 152 |
+
|
| 153 |
+
if func_def.args.kwarg:
|
| 154 |
+
kwarg_name = func_def.args.kwarg.arg
|
| 155 |
+
func_state[kwarg_name] = kwargs
|
| 156 |
+
|
| 157 |
+
# Set default values for arguments that were not provided
|
| 158 |
+
for name, value in defaults.items():
|
| 159 |
+
if name not in func_state:
|
| 160 |
+
func_state[name] = value
|
| 161 |
+
|
| 162 |
+
# Update function state with self and __class__
|
| 163 |
+
if func_def.args.args and func_def.args.args[0].arg == "self":
|
| 164 |
+
if args:
|
| 165 |
+
func_state["self"] = args[0]
|
| 166 |
+
func_state["__class__"] = args[0].__class__
|
| 167 |
+
|
| 168 |
+
result = None
|
| 169 |
+
try:
|
| 170 |
+
for stmt in func_def.body:
|
| 171 |
+
result = evaluate_ast(stmt, func_state, static_tools, custom_tools)
|
| 172 |
+
except ReturnException as e:
|
| 173 |
+
result = e.value
|
| 174 |
+
return result
|
| 175 |
+
|
| 176 |
+
return new_func
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def create_class(class_name, class_bases, class_body):
|
| 180 |
+
class_dict = {}
|
| 181 |
+
for key, value in class_body.items():
|
| 182 |
+
class_dict[key] = value
|
| 183 |
+
return type(class_name, tuple(class_bases), class_dict)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def evaluate_function_def(func_def, state, static_tools, custom_tools):
|
| 187 |
+
custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools)
|
| 188 |
+
return custom_tools[func_def.name]
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def evaluate_class_def(class_def, state, static_tools, custom_tools):
|
| 192 |
+
class_name = class_def.name
|
| 193 |
+
bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases]
|
| 194 |
+
class_dict = {}
|
| 195 |
+
|
| 196 |
+
for stmt in class_def.body:
|
| 197 |
+
if isinstance(stmt, ast.FunctionDef):
|
| 198 |
+
class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools)
|
| 199 |
+
elif isinstance(stmt, ast.Assign):
|
| 200 |
+
for target in stmt.targets:
|
| 201 |
+
if isinstance(target, ast.Name):
|
| 202 |
+
class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
|
| 203 |
+
elif isinstance(target, ast.Attribute):
|
| 204 |
+
class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
|
| 205 |
+
else:
|
| 206 |
+
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
|
| 207 |
+
|
| 208 |
+
new_class = type(class_name, tuple(bases), class_dict)
|
| 209 |
+
state[class_name] = new_class
|
| 210 |
+
return new_class
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def evaluate_augassign(expression, state, static_tools, custom_tools):
|
| 214 |
+
# Helper function to get current value and set new value based on the target type
|
| 215 |
+
def get_current_value(target):
|
| 216 |
+
if isinstance(target, ast.Name):
|
| 217 |
+
return state.get(target.id, 0)
|
| 218 |
+
elif isinstance(target, ast.Subscript):
|
| 219 |
+
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
| 220 |
+
key = evaluate_ast(target.slice, state, static_tools, custom_tools)
|
| 221 |
+
return obj[key]
|
| 222 |
+
elif isinstance(target, ast.Attribute):
|
| 223 |
+
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
| 224 |
+
return getattr(obj, target.attr)
|
| 225 |
+
elif isinstance(target, ast.Tuple):
|
| 226 |
+
return tuple(get_current_value(elt) for elt in target.elts)
|
| 227 |
+
elif isinstance(target, ast.List):
|
| 228 |
+
return [get_current_value(elt) for elt in target.elts]
|
| 229 |
+
else:
|
| 230 |
+
raise InterpreterError("AugAssign not supported for {type(target)} targets.")
|
| 231 |
+
|
| 232 |
+
current_value = get_current_value(expression.target)
|
| 233 |
+
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
|
| 234 |
+
|
| 235 |
+
# Determine the operation and apply it
|
| 236 |
+
if isinstance(expression.op, ast.Add):
|
| 237 |
+
if isinstance(current_value, list):
|
| 238 |
+
if not isinstance(value_to_add, list):
|
| 239 |
+
raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
|
| 240 |
+
updated_value = current_value + value_to_add
|
| 241 |
+
else:
|
| 242 |
+
updated_value = current_value + value_to_add
|
| 243 |
+
elif isinstance(expression.op, ast.Sub):
|
| 244 |
+
updated_value = current_value - value_to_add
|
| 245 |
+
elif isinstance(expression.op, ast.Mult):
|
| 246 |
+
updated_value = current_value * value_to_add
|
| 247 |
+
elif isinstance(expression.op, ast.Div):
|
| 248 |
+
updated_value = current_value / value_to_add
|
| 249 |
+
elif isinstance(expression.op, ast.Mod):
|
| 250 |
+
updated_value = current_value % value_to_add
|
| 251 |
+
elif isinstance(expression.op, ast.Pow):
|
| 252 |
+
updated_value = current_value**value_to_add
|
| 253 |
+
elif isinstance(expression.op, ast.FloorDiv):
|
| 254 |
+
updated_value = current_value // value_to_add
|
| 255 |
+
elif isinstance(expression.op, ast.BitAnd):
|
| 256 |
+
updated_value = current_value & value_to_add
|
| 257 |
+
elif isinstance(expression.op, ast.BitOr):
|
| 258 |
+
updated_value = current_value | value_to_add
|
| 259 |
+
elif isinstance(expression.op, ast.BitXor):
|
| 260 |
+
updated_value = current_value ^ value_to_add
|
| 261 |
+
elif isinstance(expression.op, ast.LShift):
|
| 262 |
+
updated_value = current_value << value_to_add
|
| 263 |
+
elif isinstance(expression.op, ast.RShift):
|
| 264 |
+
updated_value = current_value >> value_to_add
|
| 265 |
+
else:
|
| 266 |
+
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
|
| 267 |
+
|
| 268 |
+
# Update the state
|
| 269 |
+
set_value(expression.target, updated_value, state, static_tools, custom_tools)
|
| 270 |
+
|
| 271 |
+
return updated_value
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def evaluate_boolop(node, state, static_tools, custom_tools):
|
| 275 |
+
if isinstance(node.op, ast.And):
|
| 276 |
+
for value in node.values:
|
| 277 |
+
if not evaluate_ast(value, state, static_tools, custom_tools):
|
| 278 |
+
return False
|
| 279 |
+
return True
|
| 280 |
+
elif isinstance(node.op, ast.Or):
|
| 281 |
+
for value in node.values:
|
| 282 |
+
if evaluate_ast(value, state, static_tools, custom_tools):
|
| 283 |
+
return True
|
| 284 |
+
return False
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def evaluate_binop(binop, state, static_tools, custom_tools):
|
| 288 |
+
# Recursively evaluate the left and right operands
|
| 289 |
+
left_val = evaluate_ast(binop.left, state, static_tools, custom_tools)
|
| 290 |
+
right_val = evaluate_ast(binop.right, state, static_tools, custom_tools)
|
| 291 |
+
|
| 292 |
+
# Determine the operation based on the type of the operator in the BinOp
|
| 293 |
+
if isinstance(binop.op, ast.Add):
|
| 294 |
+
return left_val + right_val
|
| 295 |
+
elif isinstance(binop.op, ast.Sub):
|
| 296 |
+
return left_val - right_val
|
| 297 |
+
elif isinstance(binop.op, ast.Mult):
|
| 298 |
+
return left_val * right_val
|
| 299 |
+
elif isinstance(binop.op, ast.Div):
|
| 300 |
+
return left_val / right_val
|
| 301 |
+
elif isinstance(binop.op, ast.Mod):
|
| 302 |
+
return left_val % right_val
|
| 303 |
+
elif isinstance(binop.op, ast.Pow):
|
| 304 |
+
return left_val**right_val
|
| 305 |
+
elif isinstance(binop.op, ast.FloorDiv):
|
| 306 |
+
return left_val // right_val
|
| 307 |
+
elif isinstance(binop.op, ast.BitAnd):
|
| 308 |
+
return left_val & right_val
|
| 309 |
+
elif isinstance(binop.op, ast.BitOr):
|
| 310 |
+
return left_val | right_val
|
| 311 |
+
elif isinstance(binop.op, ast.BitXor):
|
| 312 |
+
return left_val ^ right_val
|
| 313 |
+
elif isinstance(binop.op, ast.LShift):
|
| 314 |
+
return left_val << right_val
|
| 315 |
+
elif isinstance(binop.op, ast.RShift):
|
| 316 |
+
return left_val >> right_val
|
| 317 |
+
else:
|
| 318 |
+
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def evaluate_assign(assign, state, static_tools, custom_tools):
|
| 322 |
+
result = evaluate_ast(assign.value, state, static_tools, custom_tools)
|
| 323 |
+
if len(assign.targets) == 1:
|
| 324 |
+
target = assign.targets[0]
|
| 325 |
+
set_value(target, result, state, static_tools, custom_tools)
|
| 326 |
+
else:
|
| 327 |
+
if len(assign.targets) != len(result):
|
| 328 |
+
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
|
| 329 |
+
expanded_values = []
|
| 330 |
+
for tgt in assign.targets:
|
| 331 |
+
if isinstance(tgt, ast.Starred):
|
| 332 |
+
expanded_values.extend(result)
|
| 333 |
+
else:
|
| 334 |
+
expanded_values.append(result)
|
| 335 |
+
for tgt, val in zip(assign.targets, expanded_values):
|
| 336 |
+
set_value(tgt, val, state, static_tools, custom_tools)
|
| 337 |
+
return result
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def set_value(target, value, state, static_tools, custom_tools):
|
| 341 |
+
if isinstance(target, ast.Name):
|
| 342 |
+
if target.id in static_tools:
|
| 343 |
+
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
|
| 344 |
+
state[target.id] = value
|
| 345 |
+
elif isinstance(target, ast.Tuple):
|
| 346 |
+
if not isinstance(value, tuple):
|
| 347 |
+
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
|
| 348 |
+
value = tuple(value)
|
| 349 |
+
else:
|
| 350 |
+
raise InterpreterError("Cannot unpack non-tuple value")
|
| 351 |
+
if len(target.elts) != len(value):
|
| 352 |
+
raise InterpreterError("Cannot unpack tuple of wrong size")
|
| 353 |
+
for i, elem in enumerate(target.elts):
|
| 354 |
+
set_value(elem, value[i], state, static_tools, custom_tools)
|
| 355 |
+
elif isinstance(target, ast.Subscript):
|
| 356 |
+
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
| 357 |
+
key = evaluate_ast(target.slice, state, static_tools, custom_tools)
|
| 358 |
+
obj[key] = value
|
| 359 |
+
elif isinstance(target, ast.Attribute):
|
| 360 |
+
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
| 361 |
+
setattr(obj, target.attr, value)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def evaluate_call(call, state, static_tools, custom_tools):
|
| 365 |
+
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
|
| 366 |
+
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
| 367 |
+
if isinstance(call.func, ast.Attribute):
|
| 368 |
+
obj = evaluate_ast(call.func.value, state, static_tools, custom_tools)
|
| 369 |
+
func_name = call.func.attr
|
| 370 |
+
if not hasattr(obj, func_name):
|
| 371 |
+
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
|
| 372 |
+
func = getattr(obj, func_name)
|
| 373 |
+
|
| 374 |
+
elif isinstance(call.func, ast.Name):
|
| 375 |
+
func_name = call.func.id
|
| 376 |
+
if func_name in state:
|
| 377 |
+
func = state[func_name]
|
| 378 |
+
elif func_name in static_tools:
|
| 379 |
+
func = static_tools[func_name]
|
| 380 |
+
elif func_name in custom_tools:
|
| 381 |
+
func = custom_tools[func_name]
|
| 382 |
+
elif func_name in ERRORS:
|
| 383 |
+
func = ERRORS[func_name]
|
| 384 |
+
else:
|
| 385 |
+
raise InterpreterError(
|
| 386 |
+
f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
args = []
|
| 390 |
+
for arg in call.args:
|
| 391 |
+
if isinstance(arg, ast.Starred):
|
| 392 |
+
args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools))
|
| 393 |
+
else:
|
| 394 |
+
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
|
| 395 |
+
|
| 396 |
+
args = []
|
| 397 |
+
for arg in call.args:
|
| 398 |
+
if isinstance(arg, ast.Starred):
|
| 399 |
+
unpacked = evaluate_ast(arg.value, state, static_tools, custom_tools)
|
| 400 |
+
if not hasattr(unpacked, "__iter__") or isinstance(unpacked, (str, bytes)):
|
| 401 |
+
raise InterpreterError(f"Cannot unpack non-iterable value {unpacked}")
|
| 402 |
+
args.extend(unpacked)
|
| 403 |
+
else:
|
| 404 |
+
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
|
| 405 |
+
|
| 406 |
+
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords}
|
| 407 |
+
|
| 408 |
+
if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
|
| 409 |
+
# Instantiate the class using its constructor
|
| 410 |
+
obj = func.__new__(func) # Create a new instance of the class
|
| 411 |
+
if hasattr(obj, "__init__"): # Check if the class has an __init__ method
|
| 412 |
+
obj.__init__(*args, **kwargs) # Call the __init__ method correctly
|
| 413 |
+
return obj
|
| 414 |
+
else:
|
| 415 |
+
if func_name == "super":
|
| 416 |
+
if not args:
|
| 417 |
+
if "__class__" in state and "self" in state:
|
| 418 |
+
return super(state["__class__"], state["self"])
|
| 419 |
+
else:
|
| 420 |
+
raise InterpreterError("super() needs at least one argument")
|
| 421 |
+
cls = args[0]
|
| 422 |
+
if not isinstance(cls, type):
|
| 423 |
+
raise InterpreterError("super() argument 1 must be type")
|
| 424 |
+
if len(args) == 1:
|
| 425 |
+
return super(cls)
|
| 426 |
+
elif len(args) == 2:
|
| 427 |
+
instance = args[1]
|
| 428 |
+
return super(cls, instance)
|
| 429 |
+
else:
|
| 430 |
+
raise InterpreterError("super() takes at most 2 arguments")
|
| 431 |
+
else:
|
| 432 |
+
if func_name == "print":
|
| 433 |
+
output = " ".join(map(str, args))
|
| 434 |
+
global PRINT_OUTPUTS
|
| 435 |
+
PRINT_OUTPUTS += output + "\n"
|
| 436 |
+
# cap the number of lines
|
| 437 |
+
return None
|
| 438 |
+
else: # Assume it's a callable object
|
| 439 |
+
output = func(*args, **kwargs)
|
| 440 |
+
return output
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def evaluate_subscript(subscript, state, static_tools, custom_tools):
|
| 444 |
+
index = evaluate_ast(subscript.slice, state, static_tools, custom_tools)
|
| 445 |
+
value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
|
| 446 |
+
|
| 447 |
+
if isinstance(value, str) and isinstance(index, str):
|
| 448 |
+
raise InterpreterError("You're trying to subscript a string with a string index, which is impossible")
|
| 449 |
+
if isinstance(value, pd.core.indexing._LocIndexer):
|
| 450 |
+
parent_object = value.obj
|
| 451 |
+
return parent_object.loc[index]
|
| 452 |
+
if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)):
|
| 453 |
+
return value[index]
|
| 454 |
+
elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy):
|
| 455 |
+
return value[index]
|
| 456 |
+
elif isinstance(index, slice):
|
| 457 |
+
return value[index]
|
| 458 |
+
elif isinstance(value, (list, tuple)):
|
| 459 |
+
if not (-len(value) <= index < len(value)):
|
| 460 |
+
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
|
| 461 |
+
return value[int(index)]
|
| 462 |
+
elif isinstance(value, str):
|
| 463 |
+
if not (-len(value) <= index < len(value)):
|
| 464 |
+
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
|
| 465 |
+
return value[index]
|
| 466 |
+
elif index in value:
|
| 467 |
+
return value[index]
|
| 468 |
+
elif isinstance(index, str) and isinstance(value, Mapping):
|
| 469 |
+
close_matches = difflib.get_close_matches(index, list(value.keys()))
|
| 470 |
+
if len(close_matches) > 0:
|
| 471 |
+
return value[close_matches[0]]
|
| 472 |
+
raise InterpreterError(f"Could not index {value} with '{index}'.")
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def evaluate_name(name, state, static_tools, custom_tools):
|
| 476 |
+
if name.id in state:
|
| 477 |
+
return state[name.id]
|
| 478 |
+
elif name.id in static_tools:
|
| 479 |
+
return static_tools[name.id]
|
| 480 |
+
elif name.id in ERRORS:
|
| 481 |
+
return ERRORS[name.id]
|
| 482 |
+
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
| 483 |
+
if len(close_matches) > 0:
|
| 484 |
+
return state[close_matches[0]]
|
| 485 |
+
raise InterpreterError(f"The variable `{name.id}` is not defined.")
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def evaluate_condition(condition, state, static_tools, custom_tools):
|
| 489 |
+
left = evaluate_ast(condition.left, state, static_tools, custom_tools)
|
| 490 |
+
comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators]
|
| 491 |
+
ops = [type(op) for op in condition.ops]
|
| 492 |
+
|
| 493 |
+
result = True
|
| 494 |
+
current_left = left
|
| 495 |
+
|
| 496 |
+
for op, comparator in zip(ops, comparators):
|
| 497 |
+
if op == ast.Eq:
|
| 498 |
+
current_result = current_left == comparator
|
| 499 |
+
elif op == ast.NotEq:
|
| 500 |
+
current_result = current_left != comparator
|
| 501 |
+
elif op == ast.Lt:
|
| 502 |
+
current_result = current_left < comparator
|
| 503 |
+
elif op == ast.LtE:
|
| 504 |
+
current_result = current_left <= comparator
|
| 505 |
+
elif op == ast.Gt:
|
| 506 |
+
current_result = current_left > comparator
|
| 507 |
+
elif op == ast.GtE:
|
| 508 |
+
current_result = current_left >= comparator
|
| 509 |
+
elif op == ast.Is:
|
| 510 |
+
current_result = current_left is comparator
|
| 511 |
+
elif op == ast.IsNot:
|
| 512 |
+
current_result = current_left is not comparator
|
| 513 |
+
elif op == ast.In:
|
| 514 |
+
current_result = current_left in comparator
|
| 515 |
+
elif op == ast.NotIn:
|
| 516 |
+
current_result = current_left not in comparator
|
| 517 |
+
else:
|
| 518 |
+
raise InterpreterError(f"Operator not supported: {op}")
|
| 519 |
+
|
| 520 |
+
result = result & current_result
|
| 521 |
+
current_left = comparator
|
| 522 |
+
|
| 523 |
+
if isinstance(result, bool) and not result:
|
| 524 |
+
break
|
| 525 |
+
|
| 526 |
+
return result if isinstance(result, (bool, pd.Series)) else result.all()
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def evaluate_if(if_statement, state, static_tools, custom_tools):
|
| 530 |
+
result = None
|
| 531 |
+
test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools)
|
| 532 |
+
if test_result:
|
| 533 |
+
for line in if_statement.body:
|
| 534 |
+
line_result = evaluate_ast(line, state, static_tools, custom_tools)
|
| 535 |
+
if line_result is not None:
|
| 536 |
+
result = line_result
|
| 537 |
+
else:
|
| 538 |
+
for line in if_statement.orelse:
|
| 539 |
+
line_result = evaluate_ast(line, state, static_tools, custom_tools)
|
| 540 |
+
if line_result is not None:
|
| 541 |
+
result = line_result
|
| 542 |
+
return result
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def evaluate_for(for_loop, state, static_tools, custom_tools):
|
| 546 |
+
result = None
|
| 547 |
+
iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools)
|
| 548 |
+
for counter in iterator:
|
| 549 |
+
set_value(for_loop.target, counter, state, static_tools, custom_tools)
|
| 550 |
+
for node in for_loop.body:
|
| 551 |
+
try:
|
| 552 |
+
line_result = evaluate_ast(node, state, static_tools, custom_tools)
|
| 553 |
+
if line_result is not None:
|
| 554 |
+
result = line_result
|
| 555 |
+
except BreakException:
|
| 556 |
+
break
|
| 557 |
+
except ContinueException:
|
| 558 |
+
continue
|
| 559 |
+
else:
|
| 560 |
+
continue
|
| 561 |
+
break
|
| 562 |
+
return result
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
|
| 566 |
+
def inner_evaluate(generators, index, current_state):
|
| 567 |
+
if index >= len(generators):
|
| 568 |
+
return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)]
|
| 569 |
+
generator = generators[index]
|
| 570 |
+
iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools)
|
| 571 |
+
result = []
|
| 572 |
+
for value in iter_value:
|
| 573 |
+
new_state = current_state.copy()
|
| 574 |
+
if isinstance(generator.target, ast.Tuple):
|
| 575 |
+
for idx, elem in enumerate(generator.target.elts):
|
| 576 |
+
new_state[elem.id] = value[idx]
|
| 577 |
+
else:
|
| 578 |
+
new_state[generator.target.id] = value
|
| 579 |
+
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs):
|
| 580 |
+
result.extend(inner_evaluate(generators, index + 1, new_state))
|
| 581 |
+
return result
|
| 582 |
+
|
| 583 |
+
return inner_evaluate(listcomp.generators, 0, state)
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def evaluate_try(try_node, state, static_tools, custom_tools):
|
| 587 |
+
try:
|
| 588 |
+
for stmt in try_node.body:
|
| 589 |
+
evaluate_ast(stmt, state, static_tools, custom_tools)
|
| 590 |
+
except Exception as e:
|
| 591 |
+
matched = False
|
| 592 |
+
for handler in try_node.handlers:
|
| 593 |
+
if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)):
|
| 594 |
+
matched = True
|
| 595 |
+
if handler.name:
|
| 596 |
+
state[handler.name] = e
|
| 597 |
+
for stmt in handler.body:
|
| 598 |
+
evaluate_ast(stmt, state, static_tools, custom_tools)
|
| 599 |
+
break
|
| 600 |
+
if not matched:
|
| 601 |
+
raise e
|
| 602 |
+
else:
|
| 603 |
+
if try_node.orelse:
|
| 604 |
+
for stmt in try_node.orelse:
|
| 605 |
+
evaluate_ast(stmt, state, static_tools, custom_tools)
|
| 606 |
+
finally:
|
| 607 |
+
if try_node.finalbody:
|
| 608 |
+
for stmt in try_node.finalbody:
|
| 609 |
+
evaluate_ast(stmt, state, static_tools, custom_tools)
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
def evaluate_raise(raise_node, state, static_tools, custom_tools):
|
| 613 |
+
if raise_node.exc is not None:
|
| 614 |
+
exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools)
|
| 615 |
+
else:
|
| 616 |
+
exc = None
|
| 617 |
+
if raise_node.cause is not None:
|
| 618 |
+
cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools)
|
| 619 |
+
else:
|
| 620 |
+
cause = None
|
| 621 |
+
if exc is not None:
|
| 622 |
+
if cause is not None:
|
| 623 |
+
raise exc from cause
|
| 624 |
+
else:
|
| 625 |
+
raise exc
|
| 626 |
+
else:
|
| 627 |
+
raise InterpreterError("Re-raise is not supported without an active exception")
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
def evaluate_assert(assert_node, state, static_tools, custom_tools):
|
| 631 |
+
test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools)
|
| 632 |
+
if not test_result:
|
| 633 |
+
if assert_node.msg:
|
| 634 |
+
msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools)
|
| 635 |
+
raise AssertionError(msg)
|
| 636 |
+
else:
|
| 637 |
+
# Include the failing condition in the assertion message
|
| 638 |
+
test_code = ast.unparse(assert_node.test)
|
| 639 |
+
raise AssertionError(f"Assertion failed: {test_code}")
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def evaluate_with(with_node, state, static_tools, custom_tools):
|
| 643 |
+
contexts = []
|
| 644 |
+
for item in with_node.items:
|
| 645 |
+
context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools)
|
| 646 |
+
if item.optional_vars:
|
| 647 |
+
state[item.optional_vars.id] = context_expr.__enter__()
|
| 648 |
+
contexts.append(state[item.optional_vars.id])
|
| 649 |
+
else:
|
| 650 |
+
context_var = context_expr.__enter__()
|
| 651 |
+
contexts.append(context_var)
|
| 652 |
+
|
| 653 |
+
try:
|
| 654 |
+
for stmt in with_node.body:
|
| 655 |
+
evaluate_ast(stmt, state, static_tools, custom_tools)
|
| 656 |
+
except Exception as e:
|
| 657 |
+
for context in reversed(contexts):
|
| 658 |
+
context.__exit__(type(e), e, e.__traceback__)
|
| 659 |
+
raise
|
| 660 |
+
else:
|
| 661 |
+
for context in reversed(contexts):
|
| 662 |
+
context.__exit__(None, None, None)
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
def import_modules(expression, state, authorized_imports):
|
| 666 |
+
def check_module_authorized(module_name):
|
| 667 |
+
module_path = module_name.split(".")
|
| 668 |
+
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
|
| 669 |
+
return any(subpath in authorized_imports for subpath in module_subpaths)
|
| 670 |
+
|
| 671 |
+
if isinstance(expression, ast.Import):
|
| 672 |
+
for alias in expression.names:
|
| 673 |
+
if check_module_authorized(alias.name):
|
| 674 |
+
module = import_module(alias.name)
|
| 675 |
+
state[alias.asname or alias.name] = module
|
| 676 |
+
else:
|
| 677 |
+
raise InterpreterError(
|
| 678 |
+
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
|
| 679 |
+
)
|
| 680 |
+
return None
|
| 681 |
+
elif isinstance(expression, ast.ImportFrom):
|
| 682 |
+
if check_module_authorized(expression.module):
|
| 683 |
+
module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
|
| 684 |
+
for alias in expression.names:
|
| 685 |
+
state[alias.asname or alias.name] = getattr(module, alias.name)
|
| 686 |
+
else:
|
| 687 |
+
raise InterpreterError(f"Import from {expression.module} is not allowed.")
|
| 688 |
+
return None
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools):
|
| 692 |
+
result = {}
|
| 693 |
+
for gen in dictcomp.generators:
|
| 694 |
+
iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools)
|
| 695 |
+
for value in iter_value:
|
| 696 |
+
new_state = state.copy()
|
| 697 |
+
set_value(gen.target, value, new_state, static_tools, custom_tools)
|
| 698 |
+
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs):
|
| 699 |
+
key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools)
|
| 700 |
+
val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools)
|
| 701 |
+
result[key] = val
|
| 702 |
+
return result
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
def evaluate_ast(
|
| 706 |
+
expression: ast.AST,
|
| 707 |
+
state: Dict[str, Any],
|
| 708 |
+
static_tools: Dict[str, Callable],
|
| 709 |
+
custom_tools: Dict[str, Callable],
|
| 710 |
+
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
| 711 |
+
):
|
| 712 |
+
"""
|
| 713 |
+
Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
|
| 714 |
+
set of functions.
|
| 715 |
+
|
| 716 |
+
This function will recurse trough the nodes of the tree provided.
|
| 717 |
+
|
| 718 |
+
Args:
|
| 719 |
+
expression (`ast.AST`):
|
| 720 |
+
The code to evaluate, as an abstract syntax tree.
|
| 721 |
+
state (`Dict[str, Any]`):
|
| 722 |
+
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
|
| 723 |
+
encounters assignements.
|
| 724 |
+
static_tools (`Dict[str, Callable]`):
|
| 725 |
+
Functions that may be called during the evaluation. Trying to change one of these static_tools will raise an error.
|
| 726 |
+
custom_tools (`Dict[str, Callable]`):
|
| 727 |
+
Functions that may be called during the evaluation. These static_tools can be overwritten.
|
| 728 |
+
authorized_imports (`List[str]`):
|
| 729 |
+
The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
|
| 730 |
+
Add more at your own risk!
|
| 731 |
+
"""
|
| 732 |
+
global OPERATIONS_COUNT
|
| 733 |
+
if OPERATIONS_COUNT >= MAX_OPERATIONS:
|
| 734 |
+
raise InterpreterError(
|
| 735 |
+
f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations."
|
| 736 |
+
)
|
| 737 |
+
OPERATIONS_COUNT += 1
|
| 738 |
+
if isinstance(expression, ast.Assign):
|
| 739 |
+
# Assignement -> we evaluate the assignment which should update the state
|
| 740 |
+
# We return the variable assigned as it may be used to determine the final result.
|
| 741 |
+
return evaluate_assign(expression, state, static_tools, custom_tools)
|
| 742 |
+
elif isinstance(expression, ast.AugAssign):
|
| 743 |
+
return evaluate_augassign(expression, state, static_tools, custom_tools)
|
| 744 |
+
elif isinstance(expression, ast.Call):
|
| 745 |
+
# Function call -> we return the value of the function call
|
| 746 |
+
return evaluate_call(expression, state, static_tools, custom_tools)
|
| 747 |
+
elif isinstance(expression, ast.Constant):
|
| 748 |
+
# Constant -> just return the value
|
| 749 |
+
return expression.value
|
| 750 |
+
elif isinstance(expression, ast.Tuple):
|
| 751 |
+
return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts)
|
| 752 |
+
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
|
| 753 |
+
return evaluate_listcomp(expression, state, static_tools, custom_tools)
|
| 754 |
+
elif isinstance(expression, ast.UnaryOp):
|
| 755 |
+
return evaluate_unaryop(expression, state, static_tools, custom_tools)
|
| 756 |
+
elif isinstance(expression, ast.Starred):
|
| 757 |
+
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
| 758 |
+
elif isinstance(expression, ast.BoolOp):
|
| 759 |
+
# Boolean operation -> evaluate the operation
|
| 760 |
+
return evaluate_boolop(expression, state, static_tools, custom_tools)
|
| 761 |
+
elif isinstance(expression, ast.Break):
|
| 762 |
+
raise BreakException()
|
| 763 |
+
elif isinstance(expression, ast.Continue):
|
| 764 |
+
raise ContinueException()
|
| 765 |
+
elif isinstance(expression, ast.BinOp):
|
| 766 |
+
# Binary operation -> execute operation
|
| 767 |
+
return evaluate_binop(expression, state, static_tools, custom_tools)
|
| 768 |
+
elif isinstance(expression, ast.Compare):
|
| 769 |
+
# Comparison -> evaluate the comparison
|
| 770 |
+
return evaluate_condition(expression, state, static_tools, custom_tools)
|
| 771 |
+
elif isinstance(expression, ast.Lambda):
|
| 772 |
+
return evaluate_lambda(expression, state, static_tools, custom_tools)
|
| 773 |
+
elif isinstance(expression, ast.FunctionDef):
|
| 774 |
+
return evaluate_function_def(expression, state, static_tools, custom_tools)
|
| 775 |
+
elif isinstance(expression, ast.Dict):
|
| 776 |
+
# Dict -> evaluate all keys and values
|
| 777 |
+
keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys]
|
| 778 |
+
values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values]
|
| 779 |
+
return dict(zip(keys, values))
|
| 780 |
+
elif isinstance(expression, ast.Expr):
|
| 781 |
+
# Expression -> evaluate the content
|
| 782 |
+
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
| 783 |
+
elif isinstance(expression, ast.For):
|
| 784 |
+
# For loop -> execute the loop
|
| 785 |
+
return evaluate_for(expression, state, static_tools, custom_tools)
|
| 786 |
+
elif isinstance(expression, ast.FormattedValue):
|
| 787 |
+
# Formatted value (part of f-string) -> evaluate the content and return
|
| 788 |
+
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
| 789 |
+
elif isinstance(expression, ast.If):
|
| 790 |
+
# If -> execute the right branch
|
| 791 |
+
return evaluate_if(expression, state, static_tools, custom_tools)
|
| 792 |
+
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
| 793 |
+
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
| 794 |
+
elif isinstance(expression, ast.JoinedStr):
|
| 795 |
+
return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values])
|
| 796 |
+
elif isinstance(expression, ast.List):
|
| 797 |
+
# List -> evaluate all elements
|
| 798 |
+
return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts]
|
| 799 |
+
elif isinstance(expression, ast.Name):
|
| 800 |
+
# Name -> pick up the value in the state
|
| 801 |
+
return evaluate_name(expression, state, static_tools, custom_tools)
|
| 802 |
+
elif isinstance(expression, ast.Subscript):
|
| 803 |
+
# Subscript -> return the value of the indexing
|
| 804 |
+
return evaluate_subscript(expression, state, static_tools, custom_tools)
|
| 805 |
+
elif isinstance(expression, ast.IfExp):
|
| 806 |
+
test_val = evaluate_ast(expression.test, state, static_tools, custom_tools)
|
| 807 |
+
if test_val:
|
| 808 |
+
return evaluate_ast(expression.body, state, static_tools, custom_tools)
|
| 809 |
+
else:
|
| 810 |
+
return evaluate_ast(expression.orelse, state, static_tools, custom_tools)
|
| 811 |
+
elif isinstance(expression, ast.Attribute):
|
| 812 |
+
value = evaluate_ast(expression.value, state, static_tools, custom_tools)
|
| 813 |
+
return getattr(value, expression.attr)
|
| 814 |
+
elif isinstance(expression, ast.Slice):
|
| 815 |
+
return slice(
|
| 816 |
+
evaluate_ast(expression.lower, state, static_tools, custom_tools)
|
| 817 |
+
if expression.lower is not None
|
| 818 |
+
else None,
|
| 819 |
+
evaluate_ast(expression.upper, state, static_tools, custom_tools)
|
| 820 |
+
if expression.upper is not None
|
| 821 |
+
else None,
|
| 822 |
+
evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None,
|
| 823 |
+
)
|
| 824 |
+
elif isinstance(expression, ast.DictComp):
|
| 825 |
+
return evaluate_dictcomp(expression, state, static_tools, custom_tools)
|
| 826 |
+
elif isinstance(expression, ast.While):
|
| 827 |
+
return evaluate_while(expression, state, static_tools, custom_tools)
|
| 828 |
+
elif isinstance(expression, (ast.Import, ast.ImportFrom)):
|
| 829 |
+
return import_modules(expression, state, authorized_imports)
|
| 830 |
+
elif isinstance(expression, ast.ClassDef):
|
| 831 |
+
return evaluate_class_def(expression, state, static_tools, custom_tools)
|
| 832 |
+
elif isinstance(expression, ast.Try):
|
| 833 |
+
return evaluate_try(expression, state, static_tools, custom_tools)
|
| 834 |
+
elif isinstance(expression, ast.Raise):
|
| 835 |
+
return evaluate_raise(expression, state, static_tools, custom_tools)
|
| 836 |
+
elif isinstance(expression, ast.Assert):
|
| 837 |
+
return evaluate_assert(expression, state, static_tools, custom_tools)
|
| 838 |
+
elif isinstance(expression, ast.With):
|
| 839 |
+
return evaluate_with(expression, state, static_tools, custom_tools)
|
| 840 |
+
elif isinstance(expression, ast.Set):
|
| 841 |
+
return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts}
|
| 842 |
+
elif isinstance(expression, ast.Return):
|
| 843 |
+
raise ReturnException(
|
| 844 |
+
evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None
|
| 845 |
+
)
|
| 846 |
+
else:
|
| 847 |
+
# For now we refuse anything else. Let's add things as we need them.
|
| 848 |
+
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
def truncate_print_outputs(print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT) -> str:
|
| 852 |
+
if len(print_outputs) < max_len_outputs:
|
| 853 |
+
return print_outputs
|
| 854 |
+
else:
|
| 855 |
+
return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n"
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
def evaluate_python_code(
|
| 859 |
+
code: str,
|
| 860 |
+
static_tools: Optional[Dict[str, Callable]] = None,
|
| 861 |
+
custom_tools: Optional[Dict[str, Callable]] = None,
|
| 862 |
+
state: Optional[Dict[str, Any]] = None,
|
| 863 |
+
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
| 864 |
+
):
|
| 865 |
+
"""
|
| 866 |
+
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
| 867 |
+
of functions.
|
| 868 |
+
|
| 869 |
+
This function will recurse through the nodes of the tree provided.
|
| 870 |
+
|
| 871 |
+
Args:
|
| 872 |
+
code (`str`):
|
| 873 |
+
The code to evaluate.
|
| 874 |
+
static_tools (`Dict[str, Callable]`):
|
| 875 |
+
The functions that may be called during the evaluation.
|
| 876 |
+
These tools cannot be overwritten in the code: any assignment to their name will raise an error.
|
| 877 |
+
custom_tools (`Dict[str, Callable]`):
|
| 878 |
+
The functions that may be called during the evaluation.
|
| 879 |
+
These tools can be overwritten in the code: any assignment to their name will overwrite them.
|
| 880 |
+
state (`Dict[str, Any]`):
|
| 881 |
+
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
|
| 882 |
+
updated by this function to contain all variables as they are evaluated.
|
| 883 |
+
The print outputs will be stored in the state under the key 'print_outputs'.
|
| 884 |
+
"""
|
| 885 |
+
try:
|
| 886 |
+
expression = ast.parse(code)
|
| 887 |
+
except SyntaxError as e:
|
| 888 |
+
raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
|
| 889 |
+
if state is None:
|
| 890 |
+
state = {}
|
| 891 |
+
if static_tools is None:
|
| 892 |
+
static_tools = {}
|
| 893 |
+
if custom_tools is None:
|
| 894 |
+
custom_tools = {}
|
| 895 |
+
result = None
|
| 896 |
+
global PRINT_OUTPUTS
|
| 897 |
+
PRINT_OUTPUTS = ""
|
| 898 |
+
global OPERATIONS_COUNT
|
| 899 |
+
OPERATIONS_COUNT = 0
|
| 900 |
+
try:
|
| 901 |
+
for node in expression.body:
|
| 902 |
+
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
| 903 |
+
state["print_outputs"] = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT)
|
| 904 |
+
return result
|
| 905 |
+
except InterpreterError as e:
|
| 906 |
+
msg = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT)
|
| 907 |
+
msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
|
| 908 |
+
raise InterpreterError(msg)
|
.venv/Lib/site-packages/transformers/agents/search.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
|
| 4 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
import re
|
| 18 |
+
|
| 19 |
+
import requests
|
| 20 |
+
from requests.exceptions import RequestException
|
| 21 |
+
|
| 22 |
+
from .tools import Tool
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DuckDuckGoSearchTool(Tool):
|
| 26 |
+
name = "web_search"
|
| 27 |
+
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
|
| 28 |
+
Each result has keys 'title', 'href' and 'body'."""
|
| 29 |
+
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
|
| 30 |
+
output_type = "any"
|
| 31 |
+
|
| 32 |
+
def forward(self, query: str) -> str:
|
| 33 |
+
try:
|
| 34 |
+
from duckduckgo_search import DDGS
|
| 35 |
+
except ImportError:
|
| 36 |
+
raise ImportError(
|
| 37 |
+
"You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`."
|
| 38 |
+
)
|
| 39 |
+
results = DDGS().text(query, max_results=7)
|
| 40 |
+
return results
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class VisitWebpageTool(Tool):
|
| 44 |
+
name = "visit_webpage"
|
| 45 |
+
description = "Visits a webpage at the given url and returns its content as a markdown string."
|
| 46 |
+
inputs = {
|
| 47 |
+
"url": {
|
| 48 |
+
"type": "string",
|
| 49 |
+
"description": "The url of the webpage to visit.",
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
output_type = "string"
|
| 53 |
+
|
| 54 |
+
def forward(self, url: str) -> str:
|
| 55 |
+
try:
|
| 56 |
+
from markdownify import markdownify
|
| 57 |
+
except ImportError:
|
| 58 |
+
raise ImportError(
|
| 59 |
+
"You must install package `markdownify` to run this tool: for instance run `pip install markdownify`."
|
| 60 |
+
)
|
| 61 |
+
try:
|
| 62 |
+
# Send a GET request to the URL
|
| 63 |
+
response = requests.get(url)
|
| 64 |
+
response.raise_for_status() # Raise an exception for bad status codes
|
| 65 |
+
|
| 66 |
+
# Convert the HTML content to Markdown
|
| 67 |
+
markdown_content = markdownify(response.text).strip()
|
| 68 |
+
|
| 69 |
+
# Remove multiple line breaks
|
| 70 |
+
markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
|
| 71 |
+
|
| 72 |
+
return markdown_content
|
| 73 |
+
|
| 74 |
+
except RequestException as e:
|
| 75 |
+
return f"Error fetching the webpage: {str(e)}"
|
| 76 |
+
except Exception as e:
|
| 77 |
+
return f"An unexpected error occurred: {str(e)}"
|
.venv/Lib/site-packages/transformers/agents/speech_to_text.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
|
| 4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor
|
| 19 |
+
from .tools import PipelineTool
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SpeechToTextTool(PipelineTool):
|
| 23 |
+
default_checkpoint = "distil-whisper/distil-large-v3"
|
| 24 |
+
description = "This is a tool that transcribes an audio into text. It returns the transcribed text."
|
| 25 |
+
name = "transcriber"
|
| 26 |
+
pre_processor_class = WhisperProcessor
|
| 27 |
+
model_class = WhisperForConditionalGeneration
|
| 28 |
+
|
| 29 |
+
inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}}
|
| 30 |
+
output_type = "string"
|
| 31 |
+
|
| 32 |
+
def encode(self, audio):
|
| 33 |
+
return self.pre_processor(audio, return_tensors="pt")
|
| 34 |
+
|
| 35 |
+
def forward(self, inputs):
|
| 36 |
+
return self.model.generate(inputs["input_features"])
|
| 37 |
+
|
| 38 |
+
def decode(self, outputs):
|
| 39 |
+
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
.venv/Lib/site-packages/transformers/agents/text_to_speech.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
|
| 4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
|
| 21 |
+
from ..utils import is_datasets_available
|
| 22 |
+
from .tools import PipelineTool
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if is_datasets_available():
|
| 26 |
+
from datasets import load_dataset
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TextToSpeechTool(PipelineTool):
|
| 30 |
+
default_checkpoint = "microsoft/speecht5_tts"
|
| 31 |
+
description = (
|
| 32 |
+
"This is a tool that reads an English text out loud. It returns a waveform object containing the sound."
|
| 33 |
+
)
|
| 34 |
+
name = "text_to_speech"
|
| 35 |
+
pre_processor_class = SpeechT5Processor
|
| 36 |
+
model_class = SpeechT5ForTextToSpeech
|
| 37 |
+
post_processor_class = SpeechT5HifiGan
|
| 38 |
+
|
| 39 |
+
inputs = {"text": {"type": "string", "description": "The text to read out loud (in English)"}}
|
| 40 |
+
output_type = "audio"
|
| 41 |
+
|
| 42 |
+
def setup(self):
|
| 43 |
+
if self.post_processor is None:
|
| 44 |
+
self.post_processor = "microsoft/speecht5_hifigan"
|
| 45 |
+
super().setup()
|
| 46 |
+
|
| 47 |
+
def encode(self, text, speaker_embeddings=None):
|
| 48 |
+
inputs = self.pre_processor(text=text, return_tensors="pt", truncation=True)
|
| 49 |
+
|
| 50 |
+
if speaker_embeddings is None:
|
| 51 |
+
if not is_datasets_available():
|
| 52 |
+
raise ImportError("Datasets needs to be installed if not passing speaker embeddings.")
|
| 53 |
+
|
| 54 |
+
embeddings_dataset = load_dataset(
|
| 55 |
+
"Matthijs/cmu-arctic-xvectors", split="validation", trust_remote_code=True
|
| 56 |
+
)
|
| 57 |
+
speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0)
|
| 58 |
+
|
| 59 |
+
return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings}
|
| 60 |
+
|
| 61 |
+
def forward(self, inputs):
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
return self.model.generate_speech(**inputs)
|
| 64 |
+
|
| 65 |
+
def decode(self, outputs):
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
return self.post_processor(outputs).cpu().detach()
|
.venv/Lib/site-packages/transformers/agents/tools.py
ADDED
|
@@ -0,0 +1,1003 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
|
| 4 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
import ast
|
| 18 |
+
import base64
|
| 19 |
+
import importlib
|
| 20 |
+
import inspect
|
| 21 |
+
import io
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
import tempfile
|
| 25 |
+
from functools import lru_cache, wraps
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 28 |
+
|
| 29 |
+
from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder
|
| 30 |
+
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
|
| 31 |
+
from packaging import version
|
| 32 |
+
|
| 33 |
+
from ..dynamic_module_utils import (
|
| 34 |
+
custom_object_save,
|
| 35 |
+
get_class_from_dynamic_module,
|
| 36 |
+
get_imports,
|
| 37 |
+
)
|
| 38 |
+
from ..models.auto import AutoProcessor
|
| 39 |
+
from ..utils import (
|
| 40 |
+
CONFIG_NAME,
|
| 41 |
+
TypeHintParsingException,
|
| 42 |
+
cached_file,
|
| 43 |
+
get_json_schema,
|
| 44 |
+
is_accelerate_available,
|
| 45 |
+
is_torch_available,
|
| 46 |
+
is_vision_available,
|
| 47 |
+
logging,
|
| 48 |
+
)
|
| 49 |
+
from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
logger = logging.get_logger(__name__)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if is_torch_available():
|
| 56 |
+
import torch
|
| 57 |
+
|
| 58 |
+
if is_accelerate_available():
|
| 59 |
+
from accelerate import PartialState
|
| 60 |
+
from accelerate.utils import send_to_device
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
TOOL_CONFIG_FILE = "tool_config.json"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
|
| 67 |
+
if repo_type is not None:
|
| 68 |
+
return repo_type
|
| 69 |
+
try:
|
| 70 |
+
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space", **hub_kwargs)
|
| 71 |
+
return "space"
|
| 72 |
+
except RepositoryNotFoundError:
|
| 73 |
+
try:
|
| 74 |
+
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
|
| 75 |
+
return "model"
|
| 76 |
+
except RepositoryNotFoundError:
|
| 77 |
+
raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
|
| 78 |
+
except Exception:
|
| 79 |
+
return "model"
|
| 80 |
+
except Exception:
|
| 81 |
+
return "space"
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# docstyle-ignore
|
| 85 |
+
APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
|
| 86 |
+
from {module_name} import {class_name}
|
| 87 |
+
|
| 88 |
+
launch_gradio_demo({class_name})
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def validate_after_init(cls, do_validate_forward: bool = True):
|
| 93 |
+
original_init = cls.__init__
|
| 94 |
+
|
| 95 |
+
@wraps(original_init)
|
| 96 |
+
def new_init(self, *args, **kwargs):
|
| 97 |
+
original_init(self, *args, **kwargs)
|
| 98 |
+
if not isinstance(self, PipelineTool):
|
| 99 |
+
self.validate_arguments(do_validate_forward=do_validate_forward)
|
| 100 |
+
|
| 101 |
+
cls.__init__ = new_init
|
| 102 |
+
return cls
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class Tool:
|
| 109 |
+
"""
|
| 110 |
+
A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
|
| 111 |
+
following class attributes:
|
| 112 |
+
|
| 113 |
+
- **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it
|
| 114 |
+
will return. For instance 'This is a tool that downloads a file from a `url`. It takes the `url` as input, and
|
| 115 |
+
returns the text contained in the file'.
|
| 116 |
+
- **name** (`str`) -- A performative name that will be used for your tool in the prompt to the agent. For instance
|
| 117 |
+
`"text-classifier"` or `"image_generator"`.
|
| 118 |
+
- **inputs** (`Dict[str, Dict[str, Union[str, type]]]`) -- The dict of modalities expected for the inputs.
|
| 119 |
+
It has one `type`key and a `description`key.
|
| 120 |
+
This is used by `launch_gradio_demo` or to make a nice space from your tool, and also can be used in the generated
|
| 121 |
+
description for your tool.
|
| 122 |
+
- **output_type** (`type`) -- The type of the tool output. This is used by `launch_gradio_demo`
|
| 123 |
+
or to make a nice space from your tool, and also can be used in the generated description for your tool.
|
| 124 |
+
|
| 125 |
+
You can also override the method [`~Tool.setup`] if your tool as an expensive operation to perform before being
|
| 126 |
+
usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at
|
| 127 |
+
instantiation.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
name: str
|
| 131 |
+
description: str
|
| 132 |
+
inputs: Dict[str, Dict[str, Union[str, type]]]
|
| 133 |
+
output_type: type
|
| 134 |
+
|
| 135 |
+
def __init__(self, *args, **kwargs):
|
| 136 |
+
self.is_initialized = False
|
| 137 |
+
|
| 138 |
+
def __init_subclass__(cls, **kwargs):
|
| 139 |
+
super().__init_subclass__(**kwargs)
|
| 140 |
+
validate_after_init(cls, do_validate_forward=False)
|
| 141 |
+
|
| 142 |
+
def validate_arguments(self, do_validate_forward: bool = True):
|
| 143 |
+
required_attributes = {
|
| 144 |
+
"description": str,
|
| 145 |
+
"name": str,
|
| 146 |
+
"inputs": dict,
|
| 147 |
+
"output_type": str,
|
| 148 |
+
}
|
| 149 |
+
authorized_types = ["string", "integer", "number", "image", "audio", "any", "boolean"]
|
| 150 |
+
|
| 151 |
+
for attr, expected_type in required_attributes.items():
|
| 152 |
+
attr_value = getattr(self, attr, None)
|
| 153 |
+
if attr_value is None:
|
| 154 |
+
raise TypeError(f"You must set an attribute {attr}.")
|
| 155 |
+
if not isinstance(attr_value, expected_type):
|
| 156 |
+
raise TypeError(
|
| 157 |
+
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
|
| 158 |
+
)
|
| 159 |
+
for input_name, input_content in self.inputs.items():
|
| 160 |
+
assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
|
| 161 |
+
assert (
|
| 162 |
+
"type" in input_content and "description" in input_content
|
| 163 |
+
), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
|
| 164 |
+
if input_content["type"] not in authorized_types:
|
| 165 |
+
raise Exception(
|
| 166 |
+
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}."
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
assert getattr(self, "output_type", None) in authorized_types
|
| 170 |
+
if do_validate_forward:
|
| 171 |
+
if not isinstance(self, PipelineTool):
|
| 172 |
+
signature = inspect.signature(self.forward)
|
| 173 |
+
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
|
| 174 |
+
raise Exception(
|
| 175 |
+
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def forward(self, *args, **kwargs):
|
| 179 |
+
return NotImplemented("Write this method in your subclass of `Tool`.")
|
| 180 |
+
|
| 181 |
+
def __call__(self, *args, **kwargs):
|
| 182 |
+
args, kwargs = handle_agent_inputs(*args, **kwargs)
|
| 183 |
+
outputs = self.forward(*args, **kwargs)
|
| 184 |
+
return handle_agent_outputs(outputs, self.output_type)
|
| 185 |
+
|
| 186 |
+
def setup(self):
|
| 187 |
+
"""
|
| 188 |
+
Overwrite this method here for any operation that is expensive and needs to be executed before you start using
|
| 189 |
+
your tool. Such as loading a big model.
|
| 190 |
+
"""
|
| 191 |
+
self.is_initialized = True
|
| 192 |
+
|
| 193 |
+
def save(self, output_dir):
|
| 194 |
+
"""
|
| 195 |
+
Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
|
| 196 |
+
tool in `output_dir` as well as autogenerate:
|
| 197 |
+
|
| 198 |
+
- a config file named `tool_config.json`
|
| 199 |
+
- an `app.py` file so that your tool can be converted to a space
|
| 200 |
+
- a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its
|
| 201 |
+
code)
|
| 202 |
+
|
| 203 |
+
You should only use this method to save tools that are defined in a separate module (not `__main__`).
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
output_dir (`str`): The folder in which you want to save your tool.
|
| 207 |
+
"""
|
| 208 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 209 |
+
# Save module file
|
| 210 |
+
if self.__module__ == "__main__":
|
| 211 |
+
raise ValueError(
|
| 212 |
+
f"We can't save the code defining {self} in {output_dir} as it's been defined in __main__. You "
|
| 213 |
+
"have to put this code in a separate module so we can include it in the saved folder."
|
| 214 |
+
)
|
| 215 |
+
module_files = custom_object_save(self, output_dir)
|
| 216 |
+
|
| 217 |
+
module_name = self.__class__.__module__
|
| 218 |
+
last_module = module_name.split(".")[-1]
|
| 219 |
+
full_name = f"{last_module}.{self.__class__.__name__}"
|
| 220 |
+
|
| 221 |
+
# Save config file
|
| 222 |
+
config_file = os.path.join(output_dir, "tool_config.json")
|
| 223 |
+
if os.path.isfile(config_file):
|
| 224 |
+
with open(config_file, "r", encoding="utf-8") as f:
|
| 225 |
+
tool_config = json.load(f)
|
| 226 |
+
else:
|
| 227 |
+
tool_config = {}
|
| 228 |
+
|
| 229 |
+
tool_config = {
|
| 230 |
+
"tool_class": full_name,
|
| 231 |
+
"description": self.description,
|
| 232 |
+
"name": self.name,
|
| 233 |
+
"inputs": self.inputs,
|
| 234 |
+
"output_type": str(self.output_type),
|
| 235 |
+
}
|
| 236 |
+
with open(config_file, "w", encoding="utf-8") as f:
|
| 237 |
+
f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")
|
| 238 |
+
|
| 239 |
+
# Save app file
|
| 240 |
+
app_file = os.path.join(output_dir, "app.py")
|
| 241 |
+
with open(app_file, "w", encoding="utf-8") as f:
|
| 242 |
+
f.write(APP_FILE_TEMPLATE.format(module_name=last_module, class_name=self.__class__.__name__))
|
| 243 |
+
|
| 244 |
+
# Save requirements file
|
| 245 |
+
requirements_file = os.path.join(output_dir, "requirements.txt")
|
| 246 |
+
imports = []
|
| 247 |
+
for module in module_files:
|
| 248 |
+
imports.extend(get_imports(module))
|
| 249 |
+
imports = list(set(imports))
|
| 250 |
+
with open(requirements_file, "w", encoding="utf-8") as f:
|
| 251 |
+
f.write("\n".join(imports) + "\n")
|
| 252 |
+
|
| 253 |
+
@classmethod
|
| 254 |
+
def from_hub(
|
| 255 |
+
cls,
|
| 256 |
+
repo_id: str,
|
| 257 |
+
token: Optional[str] = None,
|
| 258 |
+
**kwargs,
|
| 259 |
+
):
|
| 260 |
+
"""
|
| 261 |
+
Loads a tool defined on the Hub.
|
| 262 |
+
|
| 263 |
+
<Tip warning={true}>
|
| 264 |
+
|
| 265 |
+
Loading a tool from the Hub means that you'll download the tool and execute it locally.
|
| 266 |
+
ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
|
| 267 |
+
installing a package using pip/npm/apt.
|
| 268 |
+
|
| 269 |
+
</Tip>
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
repo_id (`str`):
|
| 273 |
+
The name of the repo on the Hub where your tool is defined.
|
| 274 |
+
token (`str`, *optional*):
|
| 275 |
+
The token to identify you on hf.co. If unset, will use the token generated when running
|
| 276 |
+
`huggingface-cli login` (stored in `~/.huggingface`).
|
| 277 |
+
kwargs (additional keyword arguments, *optional*):
|
| 278 |
+
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
|
| 279 |
+
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
|
| 280 |
+
others will be passed along to its init.
|
| 281 |
+
"""
|
| 282 |
+
hub_kwargs_names = [
|
| 283 |
+
"cache_dir",
|
| 284 |
+
"force_download",
|
| 285 |
+
"resume_download",
|
| 286 |
+
"proxies",
|
| 287 |
+
"revision",
|
| 288 |
+
"repo_type",
|
| 289 |
+
"subfolder",
|
| 290 |
+
"local_files_only",
|
| 291 |
+
]
|
| 292 |
+
hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
|
| 293 |
+
|
| 294 |
+
# Try to get the tool config first.
|
| 295 |
+
hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
|
| 296 |
+
resolved_config_file = cached_file(
|
| 297 |
+
repo_id,
|
| 298 |
+
TOOL_CONFIG_FILE,
|
| 299 |
+
token=token,
|
| 300 |
+
**hub_kwargs,
|
| 301 |
+
_raise_exceptions_for_gated_repo=False,
|
| 302 |
+
_raise_exceptions_for_missing_entries=False,
|
| 303 |
+
_raise_exceptions_for_connection_errors=False,
|
| 304 |
+
)
|
| 305 |
+
is_tool_config = resolved_config_file is not None
|
| 306 |
+
if resolved_config_file is None:
|
| 307 |
+
resolved_config_file = cached_file(
|
| 308 |
+
repo_id,
|
| 309 |
+
CONFIG_NAME,
|
| 310 |
+
token=token,
|
| 311 |
+
**hub_kwargs,
|
| 312 |
+
_raise_exceptions_for_gated_repo=False,
|
| 313 |
+
_raise_exceptions_for_missing_entries=False,
|
| 314 |
+
_raise_exceptions_for_connection_errors=False,
|
| 315 |
+
)
|
| 316 |
+
if resolved_config_file is None:
|
| 317 |
+
raise EnvironmentError(
|
| 318 |
+
f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
with open(resolved_config_file, encoding="utf-8") as reader:
|
| 322 |
+
config = json.load(reader)
|
| 323 |
+
|
| 324 |
+
if not is_tool_config:
|
| 325 |
+
if "custom_tool" not in config:
|
| 326 |
+
raise EnvironmentError(
|
| 327 |
+
f"{repo_id} does not provide a mapping to custom tools in its configuration `config.json`."
|
| 328 |
+
)
|
| 329 |
+
custom_tool = config["custom_tool"]
|
| 330 |
+
else:
|
| 331 |
+
custom_tool = config
|
| 332 |
+
|
| 333 |
+
tool_class = custom_tool["tool_class"]
|
| 334 |
+
tool_class = get_class_from_dynamic_module(tool_class, repo_id, token=token, **hub_kwargs)
|
| 335 |
+
|
| 336 |
+
if len(tool_class.name) == 0:
|
| 337 |
+
tool_class.name = custom_tool["name"]
|
| 338 |
+
if tool_class.name != custom_tool["name"]:
|
| 339 |
+
logger.warning(
|
| 340 |
+
f"{tool_class.__name__} implements a different name in its configuration and class. Using the tool "
|
| 341 |
+
"configuration name."
|
| 342 |
+
)
|
| 343 |
+
tool_class.name = custom_tool["name"]
|
| 344 |
+
|
| 345 |
+
if len(tool_class.description) == 0:
|
| 346 |
+
tool_class.description = custom_tool["description"]
|
| 347 |
+
if tool_class.description != custom_tool["description"]:
|
| 348 |
+
logger.warning(
|
| 349 |
+
f"{tool_class.__name__} implements a different description in its configuration and class. Using the "
|
| 350 |
+
"tool configuration description."
|
| 351 |
+
)
|
| 352 |
+
tool_class.description = custom_tool["description"]
|
| 353 |
+
|
| 354 |
+
if tool_class.inputs != custom_tool["inputs"]:
|
| 355 |
+
tool_class.inputs = custom_tool["inputs"]
|
| 356 |
+
if tool_class.output_type != custom_tool["output_type"]:
|
| 357 |
+
tool_class.output_type = custom_tool["output_type"]
|
| 358 |
+
|
| 359 |
+
if not isinstance(tool_class.inputs, dict):
|
| 360 |
+
tool_class.inputs = ast.literal_eval(tool_class.inputs)
|
| 361 |
+
|
| 362 |
+
return tool_class(**kwargs)
|
| 363 |
+
|
| 364 |
+
def push_to_hub(
|
| 365 |
+
self,
|
| 366 |
+
repo_id: str,
|
| 367 |
+
commit_message: str = "Upload tool",
|
| 368 |
+
private: Optional[bool] = None,
|
| 369 |
+
token: Optional[Union[bool, str]] = None,
|
| 370 |
+
create_pr: bool = False,
|
| 371 |
+
) -> str:
|
| 372 |
+
"""
|
| 373 |
+
Upload the tool to the Hub.
|
| 374 |
+
|
| 375 |
+
For this method to work properly, your tool must have been defined in a separate module (not `__main__`).
|
| 376 |
+
For instance:
|
| 377 |
+
```
|
| 378 |
+
from my_tool_module import MyTool
|
| 379 |
+
my_tool = MyTool()
|
| 380 |
+
my_tool.push_to_hub("my-username/my-space")
|
| 381 |
+
```
|
| 382 |
+
|
| 383 |
+
Parameters:
|
| 384 |
+
repo_id (`str`):
|
| 385 |
+
The name of the repository you want to push your tool to. It should contain your organization name when
|
| 386 |
+
pushing to a given organization.
|
| 387 |
+
commit_message (`str`, *optional*, defaults to `"Upload tool"`):
|
| 388 |
+
Message to commit while pushing.
|
| 389 |
+
private (`bool`, *optional*):
|
| 390 |
+
Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
|
| 391 |
+
token (`bool` or `str`, *optional*):
|
| 392 |
+
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
|
| 393 |
+
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
| 394 |
+
create_pr (`bool`, *optional*, defaults to `False`):
|
| 395 |
+
Whether or not to create a PR with the uploaded files or directly commit.
|
| 396 |
+
"""
|
| 397 |
+
repo_url = create_repo(
|
| 398 |
+
repo_id=repo_id,
|
| 399 |
+
token=token,
|
| 400 |
+
private=private,
|
| 401 |
+
exist_ok=True,
|
| 402 |
+
repo_type="space",
|
| 403 |
+
space_sdk="gradio",
|
| 404 |
+
)
|
| 405 |
+
repo_id = repo_url.repo_id
|
| 406 |
+
metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space")
|
| 407 |
+
|
| 408 |
+
with tempfile.TemporaryDirectory() as work_dir:
|
| 409 |
+
# Save all files.
|
| 410 |
+
self.save(work_dir)
|
| 411 |
+
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
|
| 412 |
+
return upload_folder(
|
| 413 |
+
repo_id=repo_id,
|
| 414 |
+
commit_message=commit_message,
|
| 415 |
+
folder_path=work_dir,
|
| 416 |
+
token=token,
|
| 417 |
+
create_pr=create_pr,
|
| 418 |
+
repo_type="space",
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
@staticmethod
|
| 422 |
+
def from_space(
|
| 423 |
+
space_id: str, name: str, description: str, api_name: Optional[str] = None, token: Optional[str] = None
|
| 424 |
+
):
|
| 425 |
+
"""
|
| 426 |
+
Creates a [`Tool`] from a Space given its id on the Hub.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
space_id (`str`):
|
| 430 |
+
The id of the Space on the Hub.
|
| 431 |
+
name (`str`):
|
| 432 |
+
The name of the tool.
|
| 433 |
+
description (`str`):
|
| 434 |
+
The description of the tool.
|
| 435 |
+
api_name (`str`, *optional*):
|
| 436 |
+
The specific api_name to use, if the space has several tabs. If not precised, will default to the first available api.
|
| 437 |
+
token (`str`, *optional*):
|
| 438 |
+
Add your token to access private spaces or increase your GPU quotas.
|
| 439 |
+
Returns:
|
| 440 |
+
[`Tool`]:
|
| 441 |
+
The Space, as a tool.
|
| 442 |
+
|
| 443 |
+
Examples:
|
| 444 |
+
```
|
| 445 |
+
image_generator = Tool.from_space(
|
| 446 |
+
space_id="black-forest-labs/FLUX.1-schnell",
|
| 447 |
+
name="image-generator",
|
| 448 |
+
description="Generate an image from a prompt"
|
| 449 |
+
)
|
| 450 |
+
image = image_generator("Generate an image of a cool surfer in Tahiti")
|
| 451 |
+
```
|
| 452 |
+
```
|
| 453 |
+
face_swapper = Tool.from_space(
|
| 454 |
+
"tuan2308/face-swap",
|
| 455 |
+
"face_swapper",
|
| 456 |
+
"Tool that puts the face shown on the first image on the second image. You can give it paths to images.",
|
| 457 |
+
)
|
| 458 |
+
image = face_swapper('./aymeric.jpeg', './ruth.jpg')
|
| 459 |
+
```
|
| 460 |
+
"""
|
| 461 |
+
from gradio_client import Client, handle_file
|
| 462 |
+
from gradio_client.utils import is_http_url_like
|
| 463 |
+
|
| 464 |
+
class SpaceToolWrapper(Tool):
|
| 465 |
+
def __init__(
|
| 466 |
+
self,
|
| 467 |
+
space_id: str,
|
| 468 |
+
name: str,
|
| 469 |
+
description: str,
|
| 470 |
+
api_name: Optional[str] = None,
|
| 471 |
+
token: Optional[str] = None,
|
| 472 |
+
):
|
| 473 |
+
self.client = Client(space_id, hf_token=token)
|
| 474 |
+
self.name = name
|
| 475 |
+
self.description = description
|
| 476 |
+
space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"]
|
| 477 |
+
|
| 478 |
+
# If api_name is not defined, take the first of the available APIs for this space
|
| 479 |
+
if api_name is None:
|
| 480 |
+
api_name = list(space_description.keys())[0]
|
| 481 |
+
logger.warning(
|
| 482 |
+
f"Since `api_name` was not defined, it was automatically set to the first avilable API: `{api_name}`."
|
| 483 |
+
)
|
| 484 |
+
self.api_name = api_name
|
| 485 |
+
|
| 486 |
+
try:
|
| 487 |
+
space_description_api = space_description[api_name]
|
| 488 |
+
except KeyError:
|
| 489 |
+
raise KeyError(f"Could not find specified {api_name=} among available api names.")
|
| 490 |
+
|
| 491 |
+
self.inputs = {}
|
| 492 |
+
for parameter in space_description_api["parameters"]:
|
| 493 |
+
if not parameter["parameter_has_default"]:
|
| 494 |
+
parameter_type = parameter["type"]["type"]
|
| 495 |
+
if parameter_type == "object":
|
| 496 |
+
parameter_type = "any"
|
| 497 |
+
self.inputs[parameter["parameter_name"]] = {
|
| 498 |
+
"type": parameter_type,
|
| 499 |
+
"description": parameter["python_type"]["description"],
|
| 500 |
+
}
|
| 501 |
+
output_component = space_description_api["returns"][0]["component"]
|
| 502 |
+
if output_component == "Image":
|
| 503 |
+
self.output_type = "image"
|
| 504 |
+
elif output_component == "Audio":
|
| 505 |
+
self.output_type = "audio"
|
| 506 |
+
else:
|
| 507 |
+
self.output_type = "any"
|
| 508 |
+
|
| 509 |
+
def sanitize_argument_for_prediction(self, arg):
|
| 510 |
+
if isinstance(arg, ImageType):
|
| 511 |
+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
| 512 |
+
arg.save(temp_file.name)
|
| 513 |
+
arg = temp_file.name
|
| 514 |
+
if (isinstance(arg, (str, Path)) and Path(arg).exists() and Path(arg).is_file()) or is_http_url_like(
|
| 515 |
+
arg
|
| 516 |
+
):
|
| 517 |
+
arg = handle_file(arg)
|
| 518 |
+
return arg
|
| 519 |
+
|
| 520 |
+
def forward(self, *args, **kwargs):
|
| 521 |
+
# Preprocess args and kwargs:
|
| 522 |
+
args = list(args)
|
| 523 |
+
for i, arg in enumerate(args):
|
| 524 |
+
args[i] = self.sanitize_argument_for_prediction(arg)
|
| 525 |
+
for arg_name, arg in kwargs.items():
|
| 526 |
+
kwargs[arg_name] = self.sanitize_argument_for_prediction(arg)
|
| 527 |
+
|
| 528 |
+
output = self.client.predict(*args, api_name=self.api_name, **kwargs)
|
| 529 |
+
if isinstance(output, tuple) or isinstance(output, list):
|
| 530 |
+
return output[
|
| 531 |
+
0
|
| 532 |
+
] # Sometime the space also returns the generation seed, in which case the result is at index 0
|
| 533 |
+
return output
|
| 534 |
+
|
| 535 |
+
return SpaceToolWrapper(space_id, name, description, api_name=api_name, token=token)
|
| 536 |
+
|
| 537 |
+
@staticmethod
|
| 538 |
+
def from_gradio(gradio_tool):
|
| 539 |
+
"""
|
| 540 |
+
Creates a [`Tool`] from a gradio tool.
|
| 541 |
+
"""
|
| 542 |
+
import inspect
|
| 543 |
+
|
| 544 |
+
class GradioToolWrapper(Tool):
|
| 545 |
+
def __init__(self, _gradio_tool):
|
| 546 |
+
self.name = _gradio_tool.name
|
| 547 |
+
self.description = _gradio_tool.description
|
| 548 |
+
self.output_type = "string"
|
| 549 |
+
self._gradio_tool = _gradio_tool
|
| 550 |
+
func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
|
| 551 |
+
self.inputs = {
|
| 552 |
+
key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args
|
| 553 |
+
}
|
| 554 |
+
self.forward = self._gradio_tool.run
|
| 555 |
+
|
| 556 |
+
return GradioToolWrapper(gradio_tool)
|
| 557 |
+
|
| 558 |
+
@staticmethod
|
| 559 |
+
def from_langchain(langchain_tool):
|
| 560 |
+
"""
|
| 561 |
+
Creates a [`Tool`] from a langchain tool.
|
| 562 |
+
"""
|
| 563 |
+
|
| 564 |
+
class LangChainToolWrapper(Tool):
|
| 565 |
+
def __init__(self, _langchain_tool):
|
| 566 |
+
self.name = _langchain_tool.name.lower()
|
| 567 |
+
self.description = _langchain_tool.description
|
| 568 |
+
self.inputs = _langchain_tool.args.copy()
|
| 569 |
+
for input_content in self.inputs.values():
|
| 570 |
+
if "title" in input_content:
|
| 571 |
+
input_content.pop("title")
|
| 572 |
+
input_content["description"] = ""
|
| 573 |
+
self.output_type = "string"
|
| 574 |
+
self.langchain_tool = _langchain_tool
|
| 575 |
+
|
| 576 |
+
def forward(self, *args, **kwargs):
|
| 577 |
+
tool_input = kwargs.copy()
|
| 578 |
+
for index, argument in enumerate(args):
|
| 579 |
+
if index < len(self.inputs):
|
| 580 |
+
input_key = next(iter(self.inputs))
|
| 581 |
+
tool_input[input_key] = argument
|
| 582 |
+
return self.langchain_tool.run(tool_input)
|
| 583 |
+
|
| 584 |
+
return LangChainToolWrapper(langchain_tool)
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
|
| 588 |
+
- {{ tool.name }}: {{ tool.description }}
|
| 589 |
+
Takes inputs: {{tool.inputs}}
|
| 590 |
+
Returns an output of type: {{tool.output_type}}
|
| 591 |
+
"""
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def get_tool_description_with_args(tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE) -> str:
|
| 595 |
+
compiled_template = compile_jinja_template(description_template)
|
| 596 |
+
rendered = compiled_template.render(
|
| 597 |
+
tool=tool,
|
| 598 |
+
)
|
| 599 |
+
return rendered
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
@lru_cache
|
| 603 |
+
def compile_jinja_template(template):
|
| 604 |
+
try:
|
| 605 |
+
import jinja2
|
| 606 |
+
from jinja2.exceptions import TemplateError
|
| 607 |
+
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
| 608 |
+
except ImportError:
|
| 609 |
+
raise ImportError("template requires jinja2 to be installed.")
|
| 610 |
+
|
| 611 |
+
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
|
| 612 |
+
raise ImportError("template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}.")
|
| 613 |
+
|
| 614 |
+
def raise_exception(message):
|
| 615 |
+
raise TemplateError(message)
|
| 616 |
+
|
| 617 |
+
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
|
| 618 |
+
jinja_env.globals["raise_exception"] = raise_exception
|
| 619 |
+
return jinja_env.from_string(template)
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
class PipelineTool(Tool):
|
| 623 |
+
"""
|
| 624 |
+
A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
|
| 625 |
+
need to specify:
|
| 626 |
+
|
| 627 |
+
- **model_class** (`type`) -- The class to use to load the model in this tool.
|
| 628 |
+
- **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
|
| 629 |
+
- **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
|
| 630 |
+
pre-processor
|
| 631 |
+
- **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
|
| 632 |
+
post-processor (when different from the pre-processor).
|
| 633 |
+
|
| 634 |
+
Args:
|
| 635 |
+
model (`str` or [`PreTrainedModel`], *optional*):
|
| 636 |
+
The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
|
| 637 |
+
value of the class attribute `default_checkpoint`.
|
| 638 |
+
pre_processor (`str` or `Any`, *optional*):
|
| 639 |
+
The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a
|
| 640 |
+
tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if
|
| 641 |
+
unset.
|
| 642 |
+
post_processor (`str` or `Any`, *optional*):
|
| 643 |
+
The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a
|
| 644 |
+
tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if
|
| 645 |
+
unset.
|
| 646 |
+
device (`int`, `str` or `torch.device`, *optional*):
|
| 647 |
+
The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the
|
| 648 |
+
CPU otherwise.
|
| 649 |
+
device_map (`str` or `dict`, *optional*):
|
| 650 |
+
If passed along, will be used to instantiate the model.
|
| 651 |
+
model_kwargs (`dict`, *optional*):
|
| 652 |
+
Any keyword argument to send to the model instantiation.
|
| 653 |
+
token (`str`, *optional*):
|
| 654 |
+
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
|
| 655 |
+
running `huggingface-cli login` (stored in `~/.huggingface`).
|
| 656 |
+
hub_kwargs (additional keyword arguments, *optional*):
|
| 657 |
+
Any additional keyword argument to send to the methods that will load the data from the Hub.
|
| 658 |
+
"""
|
| 659 |
+
|
| 660 |
+
pre_processor_class = AutoProcessor
|
| 661 |
+
model_class = None
|
| 662 |
+
post_processor_class = AutoProcessor
|
| 663 |
+
default_checkpoint = None
|
| 664 |
+
description = "This is a pipeline tool"
|
| 665 |
+
name = "pipeline"
|
| 666 |
+
inputs = {"prompt": str}
|
| 667 |
+
output_type = str
|
| 668 |
+
|
| 669 |
+
def __init__(
|
| 670 |
+
self,
|
| 671 |
+
model=None,
|
| 672 |
+
pre_processor=None,
|
| 673 |
+
post_processor=None,
|
| 674 |
+
device=None,
|
| 675 |
+
device_map=None,
|
| 676 |
+
model_kwargs=None,
|
| 677 |
+
token=None,
|
| 678 |
+
**hub_kwargs,
|
| 679 |
+
):
|
| 680 |
+
if not is_torch_available():
|
| 681 |
+
raise ImportError("Please install torch in order to use this tool.")
|
| 682 |
+
|
| 683 |
+
if not is_accelerate_available():
|
| 684 |
+
raise ImportError("Please install accelerate in order to use this tool.")
|
| 685 |
+
|
| 686 |
+
if model is None:
|
| 687 |
+
if self.default_checkpoint is None:
|
| 688 |
+
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
|
| 689 |
+
model = self.default_checkpoint
|
| 690 |
+
if pre_processor is None:
|
| 691 |
+
pre_processor = model
|
| 692 |
+
|
| 693 |
+
self.model = model
|
| 694 |
+
self.pre_processor = pre_processor
|
| 695 |
+
self.post_processor = post_processor
|
| 696 |
+
self.device = device
|
| 697 |
+
self.device_map = device_map
|
| 698 |
+
self.model_kwargs = {} if model_kwargs is None else model_kwargs
|
| 699 |
+
if device_map is not None:
|
| 700 |
+
self.model_kwargs["device_map"] = device_map
|
| 701 |
+
self.hub_kwargs = hub_kwargs
|
| 702 |
+
self.hub_kwargs["token"] = token
|
| 703 |
+
|
| 704 |
+
super().__init__()
|
| 705 |
+
|
| 706 |
+
def setup(self):
|
| 707 |
+
"""
|
| 708 |
+
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
|
| 709 |
+
"""
|
| 710 |
+
if isinstance(self.pre_processor, str):
|
| 711 |
+
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
|
| 712 |
+
|
| 713 |
+
if isinstance(self.model, str):
|
| 714 |
+
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
|
| 715 |
+
|
| 716 |
+
if self.post_processor is None:
|
| 717 |
+
self.post_processor = self.pre_processor
|
| 718 |
+
elif isinstance(self.post_processor, str):
|
| 719 |
+
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
|
| 720 |
+
|
| 721 |
+
if self.device is None:
|
| 722 |
+
if self.device_map is not None:
|
| 723 |
+
self.device = list(self.model.hf_device_map.values())[0]
|
| 724 |
+
else:
|
| 725 |
+
self.device = PartialState().default_device
|
| 726 |
+
|
| 727 |
+
if self.device_map is None:
|
| 728 |
+
self.model.to(self.device)
|
| 729 |
+
|
| 730 |
+
super().setup()
|
| 731 |
+
|
| 732 |
+
def encode(self, raw_inputs):
|
| 733 |
+
"""
|
| 734 |
+
Uses the `pre_processor` to prepare the inputs for the `model`.
|
| 735 |
+
"""
|
| 736 |
+
return self.pre_processor(raw_inputs)
|
| 737 |
+
|
| 738 |
+
def forward(self, inputs):
|
| 739 |
+
"""
|
| 740 |
+
Sends the inputs through the `model`.
|
| 741 |
+
"""
|
| 742 |
+
with torch.no_grad():
|
| 743 |
+
return self.model(**inputs)
|
| 744 |
+
|
| 745 |
+
def decode(self, outputs):
|
| 746 |
+
"""
|
| 747 |
+
Uses the `post_processor` to decode the model output.
|
| 748 |
+
"""
|
| 749 |
+
return self.post_processor(outputs)
|
| 750 |
+
|
| 751 |
+
def __call__(self, *args, **kwargs):
|
| 752 |
+
args, kwargs = handle_agent_inputs(*args, **kwargs)
|
| 753 |
+
|
| 754 |
+
if not self.is_initialized:
|
| 755 |
+
self.setup()
|
| 756 |
+
|
| 757 |
+
encoded_inputs = self.encode(*args, **kwargs)
|
| 758 |
+
|
| 759 |
+
tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)}
|
| 760 |
+
non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)}
|
| 761 |
+
|
| 762 |
+
encoded_inputs = send_to_device(tensor_inputs, self.device)
|
| 763 |
+
outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
|
| 764 |
+
outputs = send_to_device(outputs, "cpu")
|
| 765 |
+
decoded_outputs = self.decode(outputs)
|
| 766 |
+
|
| 767 |
+
return handle_agent_outputs(decoded_outputs, self.output_type)
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
def launch_gradio_demo(tool_class: Tool):
|
| 771 |
+
"""
|
| 772 |
+
Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
|
| 773 |
+
`inputs` and `output_type`.
|
| 774 |
+
|
| 775 |
+
Args:
|
| 776 |
+
tool_class (`type`): The class of the tool for which to launch the demo.
|
| 777 |
+
"""
|
| 778 |
+
try:
|
| 779 |
+
import gradio as gr
|
| 780 |
+
except ImportError:
|
| 781 |
+
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
| 782 |
+
|
| 783 |
+
tool = tool_class()
|
| 784 |
+
|
| 785 |
+
def fn(*args, **kwargs):
|
| 786 |
+
return tool(*args, **kwargs)
|
| 787 |
+
|
| 788 |
+
TYPE_TO_COMPONENT_CLASS_MAPPING = {
|
| 789 |
+
"image": gr.Image,
|
| 790 |
+
"audio": gr.Audio,
|
| 791 |
+
"string": gr.Textbox,
|
| 792 |
+
"integer": gr.Textbox,
|
| 793 |
+
"number": gr.Textbox,
|
| 794 |
+
}
|
| 795 |
+
|
| 796 |
+
gradio_inputs = []
|
| 797 |
+
for input_name, input_details in tool_class.inputs.items():
|
| 798 |
+
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]]
|
| 799 |
+
new_component = input_gradio_component_class(label=input_name)
|
| 800 |
+
gradio_inputs.append(new_component)
|
| 801 |
+
|
| 802 |
+
output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[tool_class.output_type]
|
| 803 |
+
gradio_output = output_gradio_componentclass(label=input_name)
|
| 804 |
+
|
| 805 |
+
gr.Interface(
|
| 806 |
+
fn=fn,
|
| 807 |
+
inputs=gradio_inputs,
|
| 808 |
+
outputs=gradio_output,
|
| 809 |
+
title=tool_class.__name__,
|
| 810 |
+
article=tool.description,
|
| 811 |
+
).launch()
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
TOOL_MAPPING = {
|
| 815 |
+
"document_question_answering": "DocumentQuestionAnsweringTool",
|
| 816 |
+
"image_question_answering": "ImageQuestionAnsweringTool",
|
| 817 |
+
"speech_to_text": "SpeechToTextTool",
|
| 818 |
+
"text_to_speech": "TextToSpeechTool",
|
| 819 |
+
"translation": "TranslationTool",
|
| 820 |
+
"python_interpreter": "PythonInterpreterTool",
|
| 821 |
+
"web_search": "DuckDuckGoSearchTool",
|
| 822 |
+
}
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
|
| 826 |
+
"""
|
| 827 |
+
Main function to quickly load a tool, be it on the Hub or in the Transformers library.
|
| 828 |
+
|
| 829 |
+
<Tip warning={true}>
|
| 830 |
+
|
| 831 |
+
Loading a tool means that you'll download the tool and execute it locally.
|
| 832 |
+
ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
|
| 833 |
+
installing a package using pip/npm/apt.
|
| 834 |
+
|
| 835 |
+
</Tip>
|
| 836 |
+
|
| 837 |
+
Args:
|
| 838 |
+
task_or_repo_id (`str`):
|
| 839 |
+
The task for which to load the tool or a repo ID of a tool on the Hub. Tasks implemented in Transformers
|
| 840 |
+
are:
|
| 841 |
+
|
| 842 |
+
- `"document_question_answering"`
|
| 843 |
+
- `"image_question_answering"`
|
| 844 |
+
- `"speech_to_text"`
|
| 845 |
+
- `"text_to_speech"`
|
| 846 |
+
- `"translation"`
|
| 847 |
+
|
| 848 |
+
model_repo_id (`str`, *optional*):
|
| 849 |
+
Use this argument to use a different model than the default one for the tool you selected.
|
| 850 |
+
token (`str`, *optional*):
|
| 851 |
+
The token to identify you on hf.co. If unset, will use the token generated when running `huggingface-cli
|
| 852 |
+
login` (stored in `~/.huggingface`).
|
| 853 |
+
kwargs (additional keyword arguments, *optional*):
|
| 854 |
+
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
|
| 855 |
+
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
|
| 856 |
+
will be passed along to its init.
|
| 857 |
+
"""
|
| 858 |
+
if task_or_repo_id in TOOL_MAPPING:
|
| 859 |
+
tool_class_name = TOOL_MAPPING[task_or_repo_id]
|
| 860 |
+
main_module = importlib.import_module("transformers")
|
| 861 |
+
tools_module = main_module.agents
|
| 862 |
+
tool_class = getattr(tools_module, tool_class_name)
|
| 863 |
+
return tool_class(model_repo_id, token=token, **kwargs)
|
| 864 |
+
else:
|
| 865 |
+
logger.warning_once(
|
| 866 |
+
f"You're loading a tool from the Hub from {model_repo_id}. Please make sure this is a source that you "
|
| 867 |
+
f"trust as the code within that tool will be executed on your machine. Always verify the code of "
|
| 868 |
+
f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the "
|
| 869 |
+
f"code that you have checked."
|
| 870 |
+
)
|
| 871 |
+
return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs)
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
def add_description(description):
|
| 875 |
+
"""
|
| 876 |
+
A decorator that adds a description to a function.
|
| 877 |
+
"""
|
| 878 |
+
|
| 879 |
+
def inner(func):
|
| 880 |
+
func.description = description
|
| 881 |
+
func.name = func.__name__
|
| 882 |
+
return func
|
| 883 |
+
|
| 884 |
+
return inner
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
## Will move to the Hub
|
| 888 |
+
class EndpointClient:
|
| 889 |
+
def __init__(self, endpoint_url: str, token: Optional[str] = None):
|
| 890 |
+
self.headers = {
|
| 891 |
+
**build_hf_headers(token=token),
|
| 892 |
+
"Content-Type": "application/json",
|
| 893 |
+
}
|
| 894 |
+
self.endpoint_url = endpoint_url
|
| 895 |
+
|
| 896 |
+
@staticmethod
|
| 897 |
+
def encode_image(image):
|
| 898 |
+
_bytes = io.BytesIO()
|
| 899 |
+
image.save(_bytes, format="PNG")
|
| 900 |
+
b64 = base64.b64encode(_bytes.getvalue())
|
| 901 |
+
return b64.decode("utf-8")
|
| 902 |
+
|
| 903 |
+
@staticmethod
|
| 904 |
+
def decode_image(raw_image):
|
| 905 |
+
if not is_vision_available():
|
| 906 |
+
raise ImportError(
|
| 907 |
+
"This tool returned an image but Pillow is not installed. Please install it (`pip install Pillow`)."
|
| 908 |
+
)
|
| 909 |
+
|
| 910 |
+
from PIL import Image
|
| 911 |
+
|
| 912 |
+
b64 = base64.b64decode(raw_image)
|
| 913 |
+
_bytes = io.BytesIO(b64)
|
| 914 |
+
return Image.open(_bytes)
|
| 915 |
+
|
| 916 |
+
def __call__(
|
| 917 |
+
self,
|
| 918 |
+
inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
|
| 919 |
+
params: Optional[Dict] = None,
|
| 920 |
+
data: Optional[bytes] = None,
|
| 921 |
+
output_image: bool = False,
|
| 922 |
+
) -> Any:
|
| 923 |
+
# Build payload
|
| 924 |
+
payload = {}
|
| 925 |
+
if inputs:
|
| 926 |
+
payload["inputs"] = inputs
|
| 927 |
+
if params:
|
| 928 |
+
payload["parameters"] = params
|
| 929 |
+
|
| 930 |
+
# Make API call
|
| 931 |
+
response = get_session().post(self.endpoint_url, headers=self.headers, json=payload, data=data)
|
| 932 |
+
|
| 933 |
+
# By default, parse the response for the user.
|
| 934 |
+
if output_image:
|
| 935 |
+
return self.decode_image(response.content)
|
| 936 |
+
else:
|
| 937 |
+
return response.json()
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
class ToolCollection:
|
| 941 |
+
"""
|
| 942 |
+
Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
|
| 943 |
+
|
| 944 |
+
> [!NOTE]
|
| 945 |
+
> Only Spaces will be fetched, so you can feel free to add models and datasets to your collection if you'd
|
| 946 |
+
> like for this collection to showcase them.
|
| 947 |
+
|
| 948 |
+
Args:
|
| 949 |
+
collection_slug (str):
|
| 950 |
+
The collection slug referencing the collection.
|
| 951 |
+
token (str, *optional*):
|
| 952 |
+
The authentication token if the collection is private.
|
| 953 |
+
|
| 954 |
+
Example:
|
| 955 |
+
|
| 956 |
+
```py
|
| 957 |
+
>>> from transformers import ToolCollection, ReactCodeAgent
|
| 958 |
+
|
| 959 |
+
>>> image_tool_collection = ToolCollection(collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f")
|
| 960 |
+
>>> agent = ReactCodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True)
|
| 961 |
+
|
| 962 |
+
>>> agent.run("Please draw me a picture of rivers and lakes.")
|
| 963 |
+
```
|
| 964 |
+
"""
|
| 965 |
+
|
| 966 |
+
def __init__(self, collection_slug: str, token: Optional[str] = None):
|
| 967 |
+
self._collection = get_collection(collection_slug, token=token)
|
| 968 |
+
self._hub_repo_ids = {item.item_id for item in self._collection.items if item.item_type == "space"}
|
| 969 |
+
self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids}
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
def tool(tool_function: Callable) -> Tool:
|
| 973 |
+
"""
|
| 974 |
+
Converts a function into an instance of a Tool subclass.
|
| 975 |
+
|
| 976 |
+
Args:
|
| 977 |
+
tool_function: Your function. Should have type hints for each input and a type hint for the output.
|
| 978 |
+
Should also have a docstring description including an 'Args:' part where each argument is described.
|
| 979 |
+
"""
|
| 980 |
+
parameters = get_json_schema(tool_function)["function"]
|
| 981 |
+
if "return" not in parameters:
|
| 982 |
+
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
|
| 983 |
+
class_name = f"{parameters['name'].capitalize()}Tool"
|
| 984 |
+
|
| 985 |
+
class SpecificTool(Tool):
|
| 986 |
+
name = parameters["name"]
|
| 987 |
+
description = parameters["description"]
|
| 988 |
+
inputs = parameters["parameters"]["properties"]
|
| 989 |
+
output_type = parameters["return"]["type"]
|
| 990 |
+
|
| 991 |
+
@wraps(tool_function)
|
| 992 |
+
def forward(self, *args, **kwargs):
|
| 993 |
+
return tool_function(*args, **kwargs)
|
| 994 |
+
|
| 995 |
+
original_signature = inspect.signature(tool_function)
|
| 996 |
+
new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + list(
|
| 997 |
+
original_signature.parameters.values()
|
| 998 |
+
)
|
| 999 |
+
new_signature = original_signature.replace(parameters=new_parameters)
|
| 1000 |
+
SpecificTool.forward.__signature__ = new_signature
|
| 1001 |
+
|
| 1002 |
+
SpecificTool.__name__ = class_name
|
| 1003 |
+
return SpecificTool()
|
.venv/Lib/site-packages/transformers/agents/translation.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
|
| 4 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 18 |
+
from .tools import PipelineTool
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
LANGUAGE_CODES = {
|
| 22 |
+
"Acehnese Arabic": "ace_Arab",
|
| 23 |
+
"Acehnese Latin": "ace_Latn",
|
| 24 |
+
"Mesopotamian Arabic": "acm_Arab",
|
| 25 |
+
"Ta'izzi-Adeni Arabic": "acq_Arab",
|
| 26 |
+
"Tunisian Arabic": "aeb_Arab",
|
| 27 |
+
"Afrikaans": "afr_Latn",
|
| 28 |
+
"South Levantine Arabic": "ajp_Arab",
|
| 29 |
+
"Akan": "aka_Latn",
|
| 30 |
+
"Amharic": "amh_Ethi",
|
| 31 |
+
"North Levantine Arabic": "apc_Arab",
|
| 32 |
+
"Modern Standard Arabic": "arb_Arab",
|
| 33 |
+
"Modern Standard Arabic Romanized": "arb_Latn",
|
| 34 |
+
"Najdi Arabic": "ars_Arab",
|
| 35 |
+
"Moroccan Arabic": "ary_Arab",
|
| 36 |
+
"Egyptian Arabic": "arz_Arab",
|
| 37 |
+
"Assamese": "asm_Beng",
|
| 38 |
+
"Asturian": "ast_Latn",
|
| 39 |
+
"Awadhi": "awa_Deva",
|
| 40 |
+
"Central Aymara": "ayr_Latn",
|
| 41 |
+
"South Azerbaijani": "azb_Arab",
|
| 42 |
+
"North Azerbaijani": "azj_Latn",
|
| 43 |
+
"Bashkir": "bak_Cyrl",
|
| 44 |
+
"Bambara": "bam_Latn",
|
| 45 |
+
"Balinese": "ban_Latn",
|
| 46 |
+
"Belarusian": "bel_Cyrl",
|
| 47 |
+
"Bemba": "bem_Latn",
|
| 48 |
+
"Bengali": "ben_Beng",
|
| 49 |
+
"Bhojpuri": "bho_Deva",
|
| 50 |
+
"Banjar Arabic": "bjn_Arab",
|
| 51 |
+
"Banjar Latin": "bjn_Latn",
|
| 52 |
+
"Standard Tibetan": "bod_Tibt",
|
| 53 |
+
"Bosnian": "bos_Latn",
|
| 54 |
+
"Buginese": "bug_Latn",
|
| 55 |
+
"Bulgarian": "bul_Cyrl",
|
| 56 |
+
"Catalan": "cat_Latn",
|
| 57 |
+
"Cebuano": "ceb_Latn",
|
| 58 |
+
"Czech": "ces_Latn",
|
| 59 |
+
"Chokwe": "cjk_Latn",
|
| 60 |
+
"Central Kurdish": "ckb_Arab",
|
| 61 |
+
"Crimean Tatar": "crh_Latn",
|
| 62 |
+
"Welsh": "cym_Latn",
|
| 63 |
+
"Danish": "dan_Latn",
|
| 64 |
+
"German": "deu_Latn",
|
| 65 |
+
"Southwestern Dinka": "dik_Latn",
|
| 66 |
+
"Dyula": "dyu_Latn",
|
| 67 |
+
"Dzongkha": "dzo_Tibt",
|
| 68 |
+
"Greek": "ell_Grek",
|
| 69 |
+
"English": "eng_Latn",
|
| 70 |
+
"Esperanto": "epo_Latn",
|
| 71 |
+
"Estonian": "est_Latn",
|
| 72 |
+
"Basque": "eus_Latn",
|
| 73 |
+
"Ewe": "ewe_Latn",
|
| 74 |
+
"Faroese": "fao_Latn",
|
| 75 |
+
"Fijian": "fij_Latn",
|
| 76 |
+
"Finnish": "fin_Latn",
|
| 77 |
+
"Fon": "fon_Latn",
|
| 78 |
+
"French": "fra_Latn",
|
| 79 |
+
"Friulian": "fur_Latn",
|
| 80 |
+
"Nigerian Fulfulde": "fuv_Latn",
|
| 81 |
+
"Scottish Gaelic": "gla_Latn",
|
| 82 |
+
"Irish": "gle_Latn",
|
| 83 |
+
"Galician": "glg_Latn",
|
| 84 |
+
"Guarani": "grn_Latn",
|
| 85 |
+
"Gujarati": "guj_Gujr",
|
| 86 |
+
"Haitian Creole": "hat_Latn",
|
| 87 |
+
"Hausa": "hau_Latn",
|
| 88 |
+
"Hebrew": "heb_Hebr",
|
| 89 |
+
"Hindi": "hin_Deva",
|
| 90 |
+
"Chhattisgarhi": "hne_Deva",
|
| 91 |
+
"Croatian": "hrv_Latn",
|
| 92 |
+
"Hungarian": "hun_Latn",
|
| 93 |
+
"Armenian": "hye_Armn",
|
| 94 |
+
"Igbo": "ibo_Latn",
|
| 95 |
+
"Ilocano": "ilo_Latn",
|
| 96 |
+
"Indonesian": "ind_Latn",
|
| 97 |
+
"Icelandic": "isl_Latn",
|
| 98 |
+
"Italian": "ita_Latn",
|
| 99 |
+
"Javanese": "jav_Latn",
|
| 100 |
+
"Japanese": "jpn_Jpan",
|
| 101 |
+
"Kabyle": "kab_Latn",
|
| 102 |
+
"Jingpho": "kac_Latn",
|
| 103 |
+
"Kamba": "kam_Latn",
|
| 104 |
+
"Kannada": "kan_Knda",
|
| 105 |
+
"Kashmiri Arabic": "kas_Arab",
|
| 106 |
+
"Kashmiri Devanagari": "kas_Deva",
|
| 107 |
+
"Georgian": "kat_Geor",
|
| 108 |
+
"Central Kanuri Arabic": "knc_Arab",
|
| 109 |
+
"Central Kanuri Latin": "knc_Latn",
|
| 110 |
+
"Kazakh": "kaz_Cyrl",
|
| 111 |
+
"Kabiyè": "kbp_Latn",
|
| 112 |
+
"Kabuverdianu": "kea_Latn",
|
| 113 |
+
"Khmer": "khm_Khmr",
|
| 114 |
+
"Kikuyu": "kik_Latn",
|
| 115 |
+
"Kinyarwanda": "kin_Latn",
|
| 116 |
+
"Kyrgyz": "kir_Cyrl",
|
| 117 |
+
"Kimbundu": "kmb_Latn",
|
| 118 |
+
"Northern Kurdish": "kmr_Latn",
|
| 119 |
+
"Kikongo": "kon_Latn",
|
| 120 |
+
"Korean": "kor_Hang",
|
| 121 |
+
"Lao": "lao_Laoo",
|
| 122 |
+
"Ligurian": "lij_Latn",
|
| 123 |
+
"Limburgish": "lim_Latn",
|
| 124 |
+
"Lingala": "lin_Latn",
|
| 125 |
+
"Lithuanian": "lit_Latn",
|
| 126 |
+
"Lombard": "lmo_Latn",
|
| 127 |
+
"Latgalian": "ltg_Latn",
|
| 128 |
+
"Luxembourgish": "ltz_Latn",
|
| 129 |
+
"Luba-Kasai": "lua_Latn",
|
| 130 |
+
"Ganda": "lug_Latn",
|
| 131 |
+
"Luo": "luo_Latn",
|
| 132 |
+
"Mizo": "lus_Latn",
|
| 133 |
+
"Standard Latvian": "lvs_Latn",
|
| 134 |
+
"Magahi": "mag_Deva",
|
| 135 |
+
"Maithili": "mai_Deva",
|
| 136 |
+
"Malayalam": "mal_Mlym",
|
| 137 |
+
"Marathi": "mar_Deva",
|
| 138 |
+
"Minangkabau Arabic ": "min_Arab",
|
| 139 |
+
"Minangkabau Latin": "min_Latn",
|
| 140 |
+
"Macedonian": "mkd_Cyrl",
|
| 141 |
+
"Plateau Malagasy": "plt_Latn",
|
| 142 |
+
"Maltese": "mlt_Latn",
|
| 143 |
+
"Meitei Bengali": "mni_Beng",
|
| 144 |
+
"Halh Mongolian": "khk_Cyrl",
|
| 145 |
+
"Mossi": "mos_Latn",
|
| 146 |
+
"Maori": "mri_Latn",
|
| 147 |
+
"Burmese": "mya_Mymr",
|
| 148 |
+
"Dutch": "nld_Latn",
|
| 149 |
+
"Norwegian Nynorsk": "nno_Latn",
|
| 150 |
+
"Norwegian Bokmål": "nob_Latn",
|
| 151 |
+
"Nepali": "npi_Deva",
|
| 152 |
+
"Northern Sotho": "nso_Latn",
|
| 153 |
+
"Nuer": "nus_Latn",
|
| 154 |
+
"Nyanja": "nya_Latn",
|
| 155 |
+
"Occitan": "oci_Latn",
|
| 156 |
+
"West Central Oromo": "gaz_Latn",
|
| 157 |
+
"Odia": "ory_Orya",
|
| 158 |
+
"Pangasinan": "pag_Latn",
|
| 159 |
+
"Eastern Panjabi": "pan_Guru",
|
| 160 |
+
"Papiamento": "pap_Latn",
|
| 161 |
+
"Western Persian": "pes_Arab",
|
| 162 |
+
"Polish": "pol_Latn",
|
| 163 |
+
"Portuguese": "por_Latn",
|
| 164 |
+
"Dari": "prs_Arab",
|
| 165 |
+
"Southern Pashto": "pbt_Arab",
|
| 166 |
+
"Ayacucho Quechua": "quy_Latn",
|
| 167 |
+
"Romanian": "ron_Latn",
|
| 168 |
+
"Rundi": "run_Latn",
|
| 169 |
+
"Russian": "rus_Cyrl",
|
| 170 |
+
"Sango": "sag_Latn",
|
| 171 |
+
"Sanskrit": "san_Deva",
|
| 172 |
+
"Santali": "sat_Olck",
|
| 173 |
+
"Sicilian": "scn_Latn",
|
| 174 |
+
"Shan": "shn_Mymr",
|
| 175 |
+
"Sinhala": "sin_Sinh",
|
| 176 |
+
"Slovak": "slk_Latn",
|
| 177 |
+
"Slovenian": "slv_Latn",
|
| 178 |
+
"Samoan": "smo_Latn",
|
| 179 |
+
"Shona": "sna_Latn",
|
| 180 |
+
"Sindhi": "snd_Arab",
|
| 181 |
+
"Somali": "som_Latn",
|
| 182 |
+
"Southern Sotho": "sot_Latn",
|
| 183 |
+
"Spanish": "spa_Latn",
|
| 184 |
+
"Tosk Albanian": "als_Latn",
|
| 185 |
+
"Sardinian": "srd_Latn",
|
| 186 |
+
"Serbian": "srp_Cyrl",
|
| 187 |
+
"Swati": "ssw_Latn",
|
| 188 |
+
"Sundanese": "sun_Latn",
|
| 189 |
+
"Swedish": "swe_Latn",
|
| 190 |
+
"Swahili": "swh_Latn",
|
| 191 |
+
"Silesian": "szl_Latn",
|
| 192 |
+
"Tamil": "tam_Taml",
|
| 193 |
+
"Tatar": "tat_Cyrl",
|
| 194 |
+
"Telugu": "tel_Telu",
|
| 195 |
+
"Tajik": "tgk_Cyrl",
|
| 196 |
+
"Tagalog": "tgl_Latn",
|
| 197 |
+
"Thai": "tha_Thai",
|
| 198 |
+
"Tigrinya": "tir_Ethi",
|
| 199 |
+
"Tamasheq Latin": "taq_Latn",
|
| 200 |
+
"Tamasheq Tifinagh": "taq_Tfng",
|
| 201 |
+
"Tok Pisin": "tpi_Latn",
|
| 202 |
+
"Tswana": "tsn_Latn",
|
| 203 |
+
"Tsonga": "tso_Latn",
|
| 204 |
+
"Turkmen": "tuk_Latn",
|
| 205 |
+
"Tumbuka": "tum_Latn",
|
| 206 |
+
"Turkish": "tur_Latn",
|
| 207 |
+
"Twi": "twi_Latn",
|
| 208 |
+
"Central Atlas Tamazight": "tzm_Tfng",
|
| 209 |
+
"Uyghur": "uig_Arab",
|
| 210 |
+
"Ukrainian": "ukr_Cyrl",
|
| 211 |
+
"Umbundu": "umb_Latn",
|
| 212 |
+
"Urdu": "urd_Arab",
|
| 213 |
+
"Northern Uzbek": "uzn_Latn",
|
| 214 |
+
"Venetian": "vec_Latn",
|
| 215 |
+
"Vietnamese": "vie_Latn",
|
| 216 |
+
"Waray": "war_Latn",
|
| 217 |
+
"Wolof": "wol_Latn",
|
| 218 |
+
"Xhosa": "xho_Latn",
|
| 219 |
+
"Eastern Yiddish": "ydd_Hebr",
|
| 220 |
+
"Yoruba": "yor_Latn",
|
| 221 |
+
"Yue Chinese": "yue_Hant",
|
| 222 |
+
"Chinese Simplified": "zho_Hans",
|
| 223 |
+
"Chinese Traditional": "zho_Hant",
|
| 224 |
+
"Standard Malay": "zsm_Latn",
|
| 225 |
+
"Zulu": "zul_Latn",
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class TranslationTool(PipelineTool):
|
| 230 |
+
"""
|
| 231 |
+
Example:
|
| 232 |
+
|
| 233 |
+
```py
|
| 234 |
+
from transformers.agents import TranslationTool
|
| 235 |
+
|
| 236 |
+
translator = TranslationTool()
|
| 237 |
+
translator("This is a super nice API!", src_lang="English", tgt_lang="French")
|
| 238 |
+
```
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
lang_to_code = LANGUAGE_CODES
|
| 242 |
+
default_checkpoint = "facebook/nllb-200-distilled-600M"
|
| 243 |
+
description = (
|
| 244 |
+
"This is a tool that translates text from a language to another."
|
| 245 |
+
f"Both `src_lang`and `tgt_lang` should belong to this list of languages: {list(lang_to_code.keys())}."
|
| 246 |
+
)
|
| 247 |
+
name = "translator"
|
| 248 |
+
pre_processor_class = AutoTokenizer
|
| 249 |
+
model_class = AutoModelForSeq2SeqLM
|
| 250 |
+
|
| 251 |
+
inputs = {
|
| 252 |
+
"text": {"type": "string", "description": "The text to translate"},
|
| 253 |
+
"src_lang": {
|
| 254 |
+
"type": "string",
|
| 255 |
+
"description": "The language of the text to translate. Written in plain English, such as 'Romanian', or 'Albanian'",
|
| 256 |
+
},
|
| 257 |
+
"tgt_lang": {
|
| 258 |
+
"type": "string",
|
| 259 |
+
"description": "The language for the desired ouput language. Written in plain English, such as 'Romanian', or 'Albanian'",
|
| 260 |
+
},
|
| 261 |
+
}
|
| 262 |
+
output_type = "string"
|
| 263 |
+
|
| 264 |
+
def encode(self, text, src_lang, tgt_lang):
|
| 265 |
+
if src_lang not in self.lang_to_code:
|
| 266 |
+
raise ValueError(f"{src_lang} is not a supported language.")
|
| 267 |
+
if tgt_lang not in self.lang_to_code:
|
| 268 |
+
raise ValueError(f"{tgt_lang} is not a supported language.")
|
| 269 |
+
src_lang = self.lang_to_code[src_lang]
|
| 270 |
+
tgt_lang = self.lang_to_code[tgt_lang]
|
| 271 |
+
return self.pre_processor._build_translation_inputs(
|
| 272 |
+
text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
def forward(self, inputs):
|
| 276 |
+
return self.model.generate(**inputs)
|
| 277 |
+
|
| 278 |
+
def decode(self, outputs):
|
| 279 |
+
return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True)
|
.venv/Lib/site-packages/transformers/benchmark/benchmark.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""
|
| 17 |
+
Benchmarking the library on inference and training in PyTorch.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import timeit
|
| 21 |
+
from typing import Callable, Optional
|
| 22 |
+
|
| 23 |
+
from ..configuration_utils import PretrainedConfig
|
| 24 |
+
from ..models.auto.modeling_auto import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING
|
| 25 |
+
from ..utils import is_py3nvml_available, is_torch_available, logging
|
| 26 |
+
from .benchmark_utils import (
|
| 27 |
+
Benchmark,
|
| 28 |
+
Memory,
|
| 29 |
+
MemorySummary,
|
| 30 |
+
measure_peak_memory_cpu,
|
| 31 |
+
start_memory_tracing,
|
| 32 |
+
stop_memory_tracing,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if is_torch_available():
|
| 37 |
+
import torch
|
| 38 |
+
|
| 39 |
+
from .benchmark_args import PyTorchBenchmarkArguments
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if is_py3nvml_available():
|
| 43 |
+
import py3nvml.py3nvml as nvml
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
logger = logging.get_logger(__name__)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class PyTorchBenchmark(Benchmark):
|
| 50 |
+
args: PyTorchBenchmarkArguments
|
| 51 |
+
configs: PretrainedConfig
|
| 52 |
+
framework: str = "PyTorch"
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def framework_version(self):
|
| 56 |
+
return torch.__version__
|
| 57 |
+
|
| 58 |
+
def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
|
| 59 |
+
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
|
| 60 |
+
return self._measure_speed(_inference)
|
| 61 |
+
|
| 62 |
+
def _inference_memory(
|
| 63 |
+
self, model_name: str, batch_size: int, sequence_length: int
|
| 64 |
+
) -> [Memory, Optional[MemorySummary]]:
|
| 65 |
+
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
|
| 66 |
+
return self._measure_memory(_inference)
|
| 67 |
+
|
| 68 |
+
def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
|
| 69 |
+
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
|
| 70 |
+
return self._measure_speed(_train)
|
| 71 |
+
|
| 72 |
+
def _train_memory(
|
| 73 |
+
self, model_name: str, batch_size: int, sequence_length: int
|
| 74 |
+
) -> [Memory, Optional[MemorySummary]]:
|
| 75 |
+
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
|
| 76 |
+
return self._measure_memory(_train)
|
| 77 |
+
|
| 78 |
+
def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
|
| 79 |
+
config = self.config_dict[model_name]
|
| 80 |
+
|
| 81 |
+
if self.args.torchscript:
|
| 82 |
+
config.torchscript = True
|
| 83 |
+
|
| 84 |
+
has_model_class_in_config = (
|
| 85 |
+
hasattr(config, "architectures")
|
| 86 |
+
and isinstance(config.architectures, list)
|
| 87 |
+
and len(config.architectures) > 0
|
| 88 |
+
)
|
| 89 |
+
if not self.args.only_pretrain_model and has_model_class_in_config:
|
| 90 |
+
try:
|
| 91 |
+
model_class = config.architectures[0]
|
| 92 |
+
transformers_module = __import__("transformers", fromlist=[model_class])
|
| 93 |
+
model_cls = getattr(transformers_module, model_class)
|
| 94 |
+
model = model_cls(config)
|
| 95 |
+
except ImportError:
|
| 96 |
+
raise ImportError(
|
| 97 |
+
f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
|
| 98 |
+
" set `--only_pretrain_model` or `args.only_pretrain_model=True`."
|
| 99 |
+
)
|
| 100 |
+
else:
|
| 101 |
+
model = MODEL_MAPPING[config.__class__](config)
|
| 102 |
+
|
| 103 |
+
model.eval()
|
| 104 |
+
model.to(self.args.device)
|
| 105 |
+
|
| 106 |
+
# encoder-decoder has vocab size saved differently
|
| 107 |
+
vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
|
| 108 |
+
input_ids = torch.randint(vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device)
|
| 109 |
+
|
| 110 |
+
if self.args.fp16:
|
| 111 |
+
logger.info("Running training in Mixed Precision...")
|
| 112 |
+
if not self.args.is_gpu:
|
| 113 |
+
raise ValueError("Mixed precision is possible only for GPU.")
|
| 114 |
+
# amp seems to have memory leaks so that memory usage
|
| 115 |
+
# is measured using .half() for now https://github.com/NVIDIA/apex/issues/439
|
| 116 |
+
model.half()
|
| 117 |
+
|
| 118 |
+
if self.args.torchscript:
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
inference_model = torch.jit.trace(model, input_ids)
|
| 121 |
+
else:
|
| 122 |
+
inference_model = model
|
| 123 |
+
|
| 124 |
+
def encoder_decoder_forward():
|
| 125 |
+
with torch.no_grad():
|
| 126 |
+
outputs = inference_model(input_ids, decoder_input_ids=input_ids)
|
| 127 |
+
return outputs
|
| 128 |
+
|
| 129 |
+
def encoder_forward():
|
| 130 |
+
with torch.no_grad():
|
| 131 |
+
outputs = inference_model(input_ids)
|
| 132 |
+
return outputs
|
| 133 |
+
|
| 134 |
+
_forward = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
|
| 135 |
+
return _forward
|
| 136 |
+
|
| 137 |
+
def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
|
| 138 |
+
config = self.config_dict[model_name]
|
| 139 |
+
|
| 140 |
+
has_model_class_in_config = (
|
| 141 |
+
hasattr(config, "architectures")
|
| 142 |
+
and isinstance(config.architectures, list)
|
| 143 |
+
and len(config.architectures) > 0
|
| 144 |
+
)
|
| 145 |
+
if not self.args.only_pretrain_model and has_model_class_in_config:
|
| 146 |
+
try:
|
| 147 |
+
model_class = config.architectures[0]
|
| 148 |
+
transformers_module = __import__("transformers", fromlist=[model_class])
|
| 149 |
+
model_cls = getattr(transformers_module, model_class)
|
| 150 |
+
model = model_cls(config)
|
| 151 |
+
except ImportError:
|
| 152 |
+
raise ImportError(
|
| 153 |
+
f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
|
| 154 |
+
" set `--only_pretrain_model` or `args.only_pretrain_model=True`."
|
| 155 |
+
)
|
| 156 |
+
else:
|
| 157 |
+
model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
|
| 158 |
+
|
| 159 |
+
if self.args.torchscript:
|
| 160 |
+
raise NotImplementedError("Training for torchscript is currently not implemented")
|
| 161 |
+
else:
|
| 162 |
+
train_model = model
|
| 163 |
+
|
| 164 |
+
model.train()
|
| 165 |
+
model.to(self.args.device)
|
| 166 |
+
|
| 167 |
+
# encoder-decoder has vocab size saved differently
|
| 168 |
+
vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
|
| 169 |
+
input_ids = torch.randint(vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device)
|
| 170 |
+
|
| 171 |
+
if self.args.fp16:
|
| 172 |
+
logger.info("Running training in Mixed Precision...")
|
| 173 |
+
if not self.args.is_gpu:
|
| 174 |
+
raise ValueError("Mixed precision is possible only for GPU.")
|
| 175 |
+
|
| 176 |
+
# amp seems to have memory leaks so that memory usage
|
| 177 |
+
# is measured using .half() for now https://github.com/NVIDIA/apex/issues/439
|
| 178 |
+
model.half()
|
| 179 |
+
|
| 180 |
+
def compute_loss_and_backprob_encoder():
|
| 181 |
+
loss = train_model(input_ids, labels=input_ids)[0]
|
| 182 |
+
loss.backward()
|
| 183 |
+
return loss
|
| 184 |
+
|
| 185 |
+
def compute_loss_and_backprob_encoder_decoder():
|
| 186 |
+
loss = train_model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0]
|
| 187 |
+
loss.backward()
|
| 188 |
+
return loss
|
| 189 |
+
|
| 190 |
+
_train = (
|
| 191 |
+
compute_loss_and_backprob_encoder_decoder
|
| 192 |
+
if config.is_encoder_decoder
|
| 193 |
+
else compute_loss_and_backprob_encoder
|
| 194 |
+
)
|
| 195 |
+
return _train
|
| 196 |
+
|
| 197 |
+
def _measure_speed(self, func) -> float:
|
| 198 |
+
try:
|
| 199 |
+
if self.args.is_tpu or self.args.torchscript:
|
| 200 |
+
# run additional 10 times to stabilize compilation for tpu and torchscript
|
| 201 |
+
logger.info("Do inference on TPU or torchscript. Running model 5 times to stabilize compilation")
|
| 202 |
+
timeit.repeat(
|
| 203 |
+
func,
|
| 204 |
+
repeat=1,
|
| 205 |
+
number=5,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
|
| 209 |
+
runtimes = timeit.repeat(
|
| 210 |
+
func,
|
| 211 |
+
repeat=self.args.repeat,
|
| 212 |
+
number=10,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if self.args.is_tpu and self.args.torch_xla_tpu_print_metrics:
|
| 216 |
+
import torch_xla.debug.metrics as met
|
| 217 |
+
|
| 218 |
+
self.print_fn(met.metrics_report())
|
| 219 |
+
|
| 220 |
+
return min(runtimes) / 10.0
|
| 221 |
+
except RuntimeError as e:
|
| 222 |
+
self.print_fn(f"Doesn't fit on GPU. {e}")
|
| 223 |
+
return "N/A"
|
| 224 |
+
|
| 225 |
+
def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
|
| 226 |
+
try:
|
| 227 |
+
if self.args.trace_memory_line_by_line:
|
| 228 |
+
trace = start_memory_tracing("transformers")
|
| 229 |
+
|
| 230 |
+
if self.args.is_tpu:
|
| 231 |
+
# tpu
|
| 232 |
+
raise NotImplementedError(
|
| 233 |
+
"Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with"
|
| 234 |
+
" `--no-memory` or `args.memory=False`"
|
| 235 |
+
)
|
| 236 |
+
elif self.args.is_gpu:
|
| 237 |
+
if not is_py3nvml_available():
|
| 238 |
+
logger.warning(
|
| 239 |
+
"py3nvml not installed, we won't log GPU memory usage. "
|
| 240 |
+
"Install py3nvml (pip install py3nvml) to log information about GPU."
|
| 241 |
+
)
|
| 242 |
+
memory = "N/A"
|
| 243 |
+
else:
|
| 244 |
+
logger.info(
|
| 245 |
+
"Measuring total GPU usage on GPU device. Make sure to not have additional processes running"
|
| 246 |
+
" on the same GPU."
|
| 247 |
+
)
|
| 248 |
+
# init nvml
|
| 249 |
+
nvml.nvmlInit()
|
| 250 |
+
func()
|
| 251 |
+
handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx)
|
| 252 |
+
meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
|
| 253 |
+
max_bytes_in_use = meminfo.used
|
| 254 |
+
memory = Memory(max_bytes_in_use)
|
| 255 |
+
# shutdown nvml
|
| 256 |
+
nvml.nvmlShutdown()
|
| 257 |
+
else:
|
| 258 |
+
# cpu
|
| 259 |
+
memory_bytes = measure_peak_memory_cpu(func)
|
| 260 |
+
memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
|
| 261 |
+
|
| 262 |
+
if self.args.trace_memory_line_by_line:
|
| 263 |
+
summary = stop_memory_tracing(trace)
|
| 264 |
+
else:
|
| 265 |
+
summary = None
|
| 266 |
+
|
| 267 |
+
return memory, summary
|
| 268 |
+
except RuntimeError as e:
|
| 269 |
+
self.print_fn(f"Doesn't fit on GPU. {e}")
|
| 270 |
+
return "N/A", None
|
.venv/Lib/site-packages/transformers/benchmark/benchmark_args.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from typing import Tuple
|
| 19 |
+
|
| 20 |
+
from ..utils import (
|
| 21 |
+
cached_property,
|
| 22 |
+
is_torch_available,
|
| 23 |
+
is_torch_xla_available,
|
| 24 |
+
is_torch_xpu_available,
|
| 25 |
+
logging,
|
| 26 |
+
requires_backends,
|
| 27 |
+
)
|
| 28 |
+
from .benchmark_args_utils import BenchmarkArguments
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if is_torch_available():
|
| 32 |
+
import torch
|
| 33 |
+
|
| 34 |
+
if is_torch_xla_available():
|
| 35 |
+
import torch_xla.core.xla_model as xm
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
logger = logging.get_logger(__name__)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class PyTorchBenchmarkArguments(BenchmarkArguments):
|
| 43 |
+
deprecated_args = [
|
| 44 |
+
"no_inference",
|
| 45 |
+
"no_cuda",
|
| 46 |
+
"no_tpu",
|
| 47 |
+
"no_speed",
|
| 48 |
+
"no_memory",
|
| 49 |
+
"no_env_print",
|
| 50 |
+
"no_multi_process",
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
def __init__(self, **kwargs):
|
| 54 |
+
"""
|
| 55 |
+
This __init__ is there for legacy code. When removing deprecated args completely, the class can simply be
|
| 56 |
+
deleted
|
| 57 |
+
"""
|
| 58 |
+
for deprecated_arg in self.deprecated_args:
|
| 59 |
+
if deprecated_arg in kwargs:
|
| 60 |
+
positive_arg = deprecated_arg[3:]
|
| 61 |
+
setattr(self, positive_arg, not kwargs.pop(deprecated_arg))
|
| 62 |
+
logger.warning(
|
| 63 |
+
f"{deprecated_arg} is depreciated. Please use --no_{positive_arg} or"
|
| 64 |
+
f" {positive_arg}={kwargs[positive_arg]}"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.torchscript = kwargs.pop("torchscript", self.torchscript)
|
| 68 |
+
self.torch_xla_tpu_print_metrics = kwargs.pop("torch_xla_tpu_print_metrics", self.torch_xla_tpu_print_metrics)
|
| 69 |
+
self.fp16_opt_level = kwargs.pop("fp16_opt_level", self.fp16_opt_level)
|
| 70 |
+
super().__init__(**kwargs)
|
| 71 |
+
|
| 72 |
+
torchscript: bool = field(default=False, metadata={"help": "Trace the models using torchscript"})
|
| 73 |
+
torch_xla_tpu_print_metrics: bool = field(default=False, metadata={"help": "Print Xla/PyTorch tpu metrics"})
|
| 74 |
+
fp16_opt_level: str = field(
|
| 75 |
+
default="O1",
|
| 76 |
+
metadata={
|
| 77 |
+
"help": (
|
| 78 |
+
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
|
| 79 |
+
"See details at https://nvidia.github.io/apex/amp.html"
|
| 80 |
+
)
|
| 81 |
+
},
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
@cached_property
|
| 85 |
+
def _setup_devices(self) -> Tuple["torch.device", int]:
|
| 86 |
+
requires_backends(self, ["torch"])
|
| 87 |
+
logger.info("PyTorch: setting up devices")
|
| 88 |
+
if not self.cuda:
|
| 89 |
+
device = torch.device("cpu")
|
| 90 |
+
n_gpu = 0
|
| 91 |
+
elif is_torch_xla_available():
|
| 92 |
+
device = xm.xla_device()
|
| 93 |
+
n_gpu = 0
|
| 94 |
+
elif is_torch_xpu_available():
|
| 95 |
+
device = torch.device("xpu")
|
| 96 |
+
n_gpu = torch.xpu.device_count()
|
| 97 |
+
else:
|
| 98 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 99 |
+
n_gpu = torch.cuda.device_count()
|
| 100 |
+
return device, n_gpu
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def is_tpu(self):
|
| 104 |
+
return is_torch_xla_available() and self.tpu
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def device_idx(self) -> int:
|
| 108 |
+
requires_backends(self, ["torch"])
|
| 109 |
+
# TODO(PVP): currently only single GPU is supported
|
| 110 |
+
return torch.cuda.current_device()
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def device(self) -> "torch.device":
|
| 114 |
+
requires_backends(self, ["torch"])
|
| 115 |
+
return self._setup_devices[0]
|
| 116 |
+
|
| 117 |
+
@property
|
| 118 |
+
def n_gpu(self):
|
| 119 |
+
requires_backends(self, ["torch"])
|
| 120 |
+
return self._setup_devices[1]
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def is_gpu(self):
|
| 124 |
+
return self.n_gpu > 0
|
.venv/Lib/site-packages/transformers/benchmark/benchmark_args_tf.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from typing import Tuple
|
| 19 |
+
|
| 20 |
+
from ..utils import cached_property, is_tf_available, logging, requires_backends
|
| 21 |
+
from .benchmark_args_utils import BenchmarkArguments
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if is_tf_available():
|
| 25 |
+
import tensorflow as tf
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class TensorFlowBenchmarkArguments(BenchmarkArguments):
|
| 33 |
+
deprecated_args = [
|
| 34 |
+
"no_inference",
|
| 35 |
+
"no_cuda",
|
| 36 |
+
"no_tpu",
|
| 37 |
+
"no_speed",
|
| 38 |
+
"no_memory",
|
| 39 |
+
"no_env_print",
|
| 40 |
+
"no_multi_process",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
def __init__(self, **kwargs):
|
| 44 |
+
"""
|
| 45 |
+
This __init__ is there for legacy code. When removing deprecated args completely, the class can simply be
|
| 46 |
+
deleted
|
| 47 |
+
"""
|
| 48 |
+
for deprecated_arg in self.deprecated_args:
|
| 49 |
+
if deprecated_arg in kwargs:
|
| 50 |
+
positive_arg = deprecated_arg[3:]
|
| 51 |
+
kwargs[positive_arg] = not kwargs.pop(deprecated_arg)
|
| 52 |
+
logger.warning(
|
| 53 |
+
f"{deprecated_arg} is depreciated. Please use --no-{positive_arg} or"
|
| 54 |
+
f" {positive_arg}={kwargs[positive_arg]}"
|
| 55 |
+
)
|
| 56 |
+
self.tpu_name = kwargs.pop("tpu_name", self.tpu_name)
|
| 57 |
+
self.device_idx = kwargs.pop("device_idx", self.device_idx)
|
| 58 |
+
self.eager_mode = kwargs.pop("eager_mode", self.eager_mode)
|
| 59 |
+
self.use_xla = kwargs.pop("use_xla", self.use_xla)
|
| 60 |
+
super().__init__(**kwargs)
|
| 61 |
+
|
| 62 |
+
tpu_name: str = field(
|
| 63 |
+
default=None,
|
| 64 |
+
metadata={"help": "Name of TPU"},
|
| 65 |
+
)
|
| 66 |
+
device_idx: int = field(
|
| 67 |
+
default=0,
|
| 68 |
+
metadata={"help": "CPU / GPU device index. Defaults to 0."},
|
| 69 |
+
)
|
| 70 |
+
eager_mode: bool = field(default=False, metadata={"help": "Benchmark models in eager model."})
|
| 71 |
+
use_xla: bool = field(
|
| 72 |
+
default=False,
|
| 73 |
+
metadata={
|
| 74 |
+
"help": "Benchmark models using XLA JIT compilation. Note that `eager_model` has to be set to `False`."
|
| 75 |
+
},
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
@cached_property
|
| 79 |
+
def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]:
|
| 80 |
+
requires_backends(self, ["tf"])
|
| 81 |
+
tpu = None
|
| 82 |
+
if self.tpu:
|
| 83 |
+
try:
|
| 84 |
+
if self.tpu_name:
|
| 85 |
+
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(self.tpu_name)
|
| 86 |
+
else:
|
| 87 |
+
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
|
| 88 |
+
except ValueError:
|
| 89 |
+
tpu = None
|
| 90 |
+
return tpu
|
| 91 |
+
|
| 92 |
+
@cached_property
|
| 93 |
+
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]:
|
| 94 |
+
requires_backends(self, ["tf"])
|
| 95 |
+
if self.is_tpu:
|
| 96 |
+
tf.config.experimental_connect_to_cluster(self._setup_tpu)
|
| 97 |
+
tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)
|
| 98 |
+
|
| 99 |
+
strategy = tf.distribute.TPUStrategy(self._setup_tpu)
|
| 100 |
+
else:
|
| 101 |
+
# currently no multi gpu is allowed
|
| 102 |
+
if self.is_gpu:
|
| 103 |
+
# TODO: Currently only single GPU is supported
|
| 104 |
+
tf.config.set_visible_devices(self.gpu_list[self.device_idx], "GPU")
|
| 105 |
+
strategy = tf.distribute.OneDeviceStrategy(device=f"/gpu:{self.device_idx}")
|
| 106 |
+
else:
|
| 107 |
+
tf.config.set_visible_devices([], "GPU") # disable GPU
|
| 108 |
+
strategy = tf.distribute.OneDeviceStrategy(device=f"/cpu:{self.device_idx}")
|
| 109 |
+
|
| 110 |
+
return strategy
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def is_tpu(self) -> bool:
|
| 114 |
+
requires_backends(self, ["tf"])
|
| 115 |
+
return self._setup_tpu is not None
|
| 116 |
+
|
| 117 |
+
@property
|
| 118 |
+
def strategy(self) -> "tf.distribute.Strategy":
|
| 119 |
+
requires_backends(self, ["tf"])
|
| 120 |
+
return self._setup_strategy
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def gpu_list(self):
|
| 124 |
+
requires_backends(self, ["tf"])
|
| 125 |
+
return tf.config.list_physical_devices("GPU")
|
| 126 |
+
|
| 127 |
+
@property
|
| 128 |
+
def n_gpu(self) -> int:
|
| 129 |
+
requires_backends(self, ["tf"])
|
| 130 |
+
if self.cuda:
|
| 131 |
+
return len(self.gpu_list)
|
| 132 |
+
return 0
|
| 133 |
+
|
| 134 |
+
@property
|
| 135 |
+
def is_gpu(self) -> bool:
|
| 136 |
+
return self.n_gpu > 0
|
.venv/Lib/site-packages/transformers/commands/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from abc import ABC, abstractmethod
|
| 16 |
+
from argparse import ArgumentParser
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class BaseTransformersCLICommand(ABC):
|
| 20 |
+
@staticmethod
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def register_subcommand(parser: ArgumentParser):
|
| 23 |
+
raise NotImplementedError()
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def run(self):
|
| 27 |
+
raise NotImplementedError()
|
.venv/Lib/site-packages/transformers/commands/run.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from argparse import ArgumentParser
|
| 16 |
+
|
| 17 |
+
from ..pipelines import Pipeline, PipelineDataFormat, get_supported_tasks, pipeline
|
| 18 |
+
from ..utils import logging
|
| 19 |
+
from . import BaseTransformersCLICommand
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def try_infer_format_from_ext(path: str):
|
| 26 |
+
if not path:
|
| 27 |
+
return "pipe"
|
| 28 |
+
|
| 29 |
+
for ext in PipelineDataFormat.SUPPORTED_FORMATS:
|
| 30 |
+
if path.endswith(ext):
|
| 31 |
+
return ext
|
| 32 |
+
|
| 33 |
+
raise Exception(
|
| 34 |
+
f"Unable to determine file format from file extension {path}. "
|
| 35 |
+
f"Please provide the format through --format {PipelineDataFormat.SUPPORTED_FORMATS}"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def run_command_factory(args):
|
| 40 |
+
nlp = pipeline(
|
| 41 |
+
task=args.task,
|
| 42 |
+
model=args.model if args.model else None,
|
| 43 |
+
config=args.config,
|
| 44 |
+
tokenizer=args.tokenizer,
|
| 45 |
+
device=args.device,
|
| 46 |
+
)
|
| 47 |
+
format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format
|
| 48 |
+
reader = PipelineDataFormat.from_str(
|
| 49 |
+
format=format,
|
| 50 |
+
output_path=args.output,
|
| 51 |
+
input_path=args.input,
|
| 52 |
+
column=args.column if args.column else nlp.default_input_names,
|
| 53 |
+
overwrite=args.overwrite,
|
| 54 |
+
)
|
| 55 |
+
return RunCommand(nlp, reader)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class RunCommand(BaseTransformersCLICommand):
|
| 59 |
+
def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
|
| 60 |
+
self._nlp = nlp
|
| 61 |
+
self._reader = reader
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def register_subcommand(parser: ArgumentParser):
|
| 65 |
+
run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
|
| 66 |
+
run_parser.add_argument("--task", choices=get_supported_tasks(), help="Task to run")
|
| 67 |
+
run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
|
| 68 |
+
run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
|
| 69 |
+
run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
|
| 70 |
+
run_parser.add_argument("--config", type=str, help="Name or path to the model's config to instantiate.")
|
| 71 |
+
run_parser.add_argument(
|
| 72 |
+
"--tokenizer", type=str, help="Name of the tokenizer to use. (default: same as the model name)"
|
| 73 |
+
)
|
| 74 |
+
run_parser.add_argument(
|
| 75 |
+
"--column",
|
| 76 |
+
type=str,
|
| 77 |
+
help="Name of the column to use as input. (For multi columns input as QA use column1,columns2)",
|
| 78 |
+
)
|
| 79 |
+
run_parser.add_argument(
|
| 80 |
+
"--format",
|
| 81 |
+
type=str,
|
| 82 |
+
default="infer",
|
| 83 |
+
choices=PipelineDataFormat.SUPPORTED_FORMATS,
|
| 84 |
+
help="Input format to read from",
|
| 85 |
+
)
|
| 86 |
+
run_parser.add_argument(
|
| 87 |
+
"--device",
|
| 88 |
+
type=int,
|
| 89 |
+
default=-1,
|
| 90 |
+
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
|
| 91 |
+
)
|
| 92 |
+
run_parser.add_argument("--overwrite", action="store_true", help="Allow overwriting the output file.")
|
| 93 |
+
run_parser.set_defaults(func=run_command_factory)
|
| 94 |
+
|
| 95 |
+
def run(self):
|
| 96 |
+
nlp, outputs = self._nlp, []
|
| 97 |
+
|
| 98 |
+
for entry in self._reader:
|
| 99 |
+
output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry)
|
| 100 |
+
if isinstance(output, dict):
|
| 101 |
+
outputs.append(output)
|
| 102 |
+
else:
|
| 103 |
+
outputs += output
|
| 104 |
+
|
| 105 |
+
# Saving data
|
| 106 |
+
if self._nlp.binary_output:
|
| 107 |
+
binary_path = self._reader.save_binary(outputs)
|
| 108 |
+
logger.warning(f"Current pipeline requires output to be in binary format, saving at {binary_path}")
|
| 109 |
+
else:
|
| 110 |
+
self._reader.save(outputs)
|
.venv/Lib/site-packages/transformers/commands/serving.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from argparse import ArgumentParser, Namespace
|
| 16 |
+
from typing import Any, List, Optional
|
| 17 |
+
|
| 18 |
+
from ..pipelines import Pipeline, get_supported_tasks, pipeline
|
| 19 |
+
from ..utils import logging
|
| 20 |
+
from . import BaseTransformersCLICommand
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
from fastapi import Body, FastAPI, HTTPException
|
| 25 |
+
from fastapi.routing import APIRoute
|
| 26 |
+
from pydantic import BaseModel
|
| 27 |
+
from starlette.responses import JSONResponse
|
| 28 |
+
from uvicorn import run
|
| 29 |
+
|
| 30 |
+
_serve_dependencies_installed = True
|
| 31 |
+
except (ImportError, AttributeError):
|
| 32 |
+
BaseModel = object
|
| 33 |
+
|
| 34 |
+
def Body(*x, **y):
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
_serve_dependencies_installed = False
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
logger = logging.get_logger("transformers-cli/serving")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def serve_command_factory(args: Namespace):
|
| 44 |
+
"""
|
| 45 |
+
Factory function used to instantiate serving server from provided command line arguments.
|
| 46 |
+
|
| 47 |
+
Returns: ServeCommand
|
| 48 |
+
"""
|
| 49 |
+
nlp = pipeline(
|
| 50 |
+
task=args.task,
|
| 51 |
+
model=args.model if args.model else None,
|
| 52 |
+
config=args.config,
|
| 53 |
+
tokenizer=args.tokenizer,
|
| 54 |
+
device=args.device,
|
| 55 |
+
)
|
| 56 |
+
return ServeCommand(nlp, args.host, args.port, args.workers)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ServeModelInfoResult(BaseModel):
|
| 60 |
+
"""
|
| 61 |
+
Expose model information
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
infos: dict
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ServeTokenizeResult(BaseModel):
|
| 68 |
+
"""
|
| 69 |
+
Tokenize result model
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
tokens: List[str]
|
| 73 |
+
tokens_ids: Optional[List[int]]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class ServeDeTokenizeResult(BaseModel):
|
| 77 |
+
"""
|
| 78 |
+
DeTokenize result model
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
text: str
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class ServeForwardResult(BaseModel):
|
| 85 |
+
"""
|
| 86 |
+
Forward result model
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
output: Any
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class ServeCommand(BaseTransformersCLICommand):
|
| 93 |
+
@staticmethod
|
| 94 |
+
def register_subcommand(parser: ArgumentParser):
|
| 95 |
+
"""
|
| 96 |
+
Register this command to argparse so it's available for the transformer-cli
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
parser: Root parser to register command-specific arguments
|
| 100 |
+
"""
|
| 101 |
+
serve_parser = parser.add_parser(
|
| 102 |
+
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
|
| 103 |
+
)
|
| 104 |
+
serve_parser.add_argument(
|
| 105 |
+
"--task",
|
| 106 |
+
type=str,
|
| 107 |
+
choices=get_supported_tasks(),
|
| 108 |
+
help="The task to run the pipeline on",
|
| 109 |
+
)
|
| 110 |
+
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
|
| 111 |
+
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
|
| 112 |
+
serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
|
| 113 |
+
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
|
| 114 |
+
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
|
| 115 |
+
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
|
| 116 |
+
serve_parser.add_argument(
|
| 117 |
+
"--device",
|
| 118 |
+
type=int,
|
| 119 |
+
default=-1,
|
| 120 |
+
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
|
| 121 |
+
)
|
| 122 |
+
serve_parser.set_defaults(func=serve_command_factory)
|
| 123 |
+
|
| 124 |
+
def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
|
| 125 |
+
self._pipeline = pipeline
|
| 126 |
+
|
| 127 |
+
self.host = host
|
| 128 |
+
self.port = port
|
| 129 |
+
self.workers = workers
|
| 130 |
+
|
| 131 |
+
if not _serve_dependencies_installed:
|
| 132 |
+
raise RuntimeError(
|
| 133 |
+
"Using serve command requires FastAPI and uvicorn. "
|
| 134 |
+
'Please install transformers with [serving]: pip install "transformers[serving]". '
|
| 135 |
+
"Or install FastAPI and uvicorn separately."
|
| 136 |
+
)
|
| 137 |
+
else:
|
| 138 |
+
logger.info(f"Serving model over {host}:{port}")
|
| 139 |
+
self._app = FastAPI(
|
| 140 |
+
routes=[
|
| 141 |
+
APIRoute(
|
| 142 |
+
"/",
|
| 143 |
+
self.model_info,
|
| 144 |
+
response_model=ServeModelInfoResult,
|
| 145 |
+
response_class=JSONResponse,
|
| 146 |
+
methods=["GET"],
|
| 147 |
+
),
|
| 148 |
+
APIRoute(
|
| 149 |
+
"/tokenize",
|
| 150 |
+
self.tokenize,
|
| 151 |
+
response_model=ServeTokenizeResult,
|
| 152 |
+
response_class=JSONResponse,
|
| 153 |
+
methods=["POST"],
|
| 154 |
+
),
|
| 155 |
+
APIRoute(
|
| 156 |
+
"/detokenize",
|
| 157 |
+
self.detokenize,
|
| 158 |
+
response_model=ServeDeTokenizeResult,
|
| 159 |
+
response_class=JSONResponse,
|
| 160 |
+
methods=["POST"],
|
| 161 |
+
),
|
| 162 |
+
APIRoute(
|
| 163 |
+
"/forward",
|
| 164 |
+
self.forward,
|
| 165 |
+
response_model=ServeForwardResult,
|
| 166 |
+
response_class=JSONResponse,
|
| 167 |
+
methods=["POST"],
|
| 168 |
+
),
|
| 169 |
+
],
|
| 170 |
+
timeout=600,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def run(self):
|
| 174 |
+
run(self._app, host=self.host, port=self.port, workers=self.workers)
|
| 175 |
+
|
| 176 |
+
def model_info(self):
|
| 177 |
+
return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
|
| 178 |
+
|
| 179 |
+
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
|
| 180 |
+
"""
|
| 181 |
+
Tokenize the provided input and eventually returns corresponding tokens id: - **text_input**: String to
|
| 182 |
+
tokenize - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer
|
| 183 |
+
mapping.
|
| 184 |
+
"""
|
| 185 |
+
try:
|
| 186 |
+
tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
|
| 187 |
+
|
| 188 |
+
if return_ids:
|
| 189 |
+
tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
|
| 190 |
+
return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
|
| 191 |
+
else:
|
| 192 |
+
return ServeTokenizeResult(tokens=tokens_txt)
|
| 193 |
+
|
| 194 |
+
except Exception as e:
|
| 195 |
+
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
| 196 |
+
|
| 197 |
+
def detokenize(
|
| 198 |
+
self,
|
| 199 |
+
tokens_ids: List[int] = Body(None, embed=True),
|
| 200 |
+
skip_special_tokens: bool = Body(False, embed=True),
|
| 201 |
+
cleanup_tokenization_spaces: bool = Body(True, embed=True),
|
| 202 |
+
):
|
| 203 |
+
"""
|
| 204 |
+
Detokenize the provided tokens ids to readable text: - **tokens_ids**: List of tokens ids -
|
| 205 |
+
**skip_special_tokens**: Flag indicating to not try to decode special tokens - **cleanup_tokenization_spaces**:
|
| 206 |
+
Flag indicating to remove all leading/trailing spaces and intermediate ones.
|
| 207 |
+
"""
|
| 208 |
+
try:
|
| 209 |
+
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
|
| 210 |
+
return ServeDeTokenizeResult(model="", text=decoded_str)
|
| 211 |
+
except Exception as e:
|
| 212 |
+
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
| 213 |
+
|
| 214 |
+
async def forward(self, inputs=Body(None, embed=True)):
|
| 215 |
+
"""
|
| 216 |
+
**inputs**: **attention_mask**: **tokens_type_ids**:
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
# Check we don't have empty string
|
| 220 |
+
if len(inputs) == 0:
|
| 221 |
+
return ServeForwardResult(output=[], attention=[])
|
| 222 |
+
|
| 223 |
+
try:
|
| 224 |
+
# Forward through the model
|
| 225 |
+
output = self._pipeline(inputs)
|
| 226 |
+
return ServeForwardResult(output=output)
|
| 227 |
+
except Exception as e:
|
| 228 |
+
raise HTTPException(500, {"error": str(e)})
|
.venv/Lib/site-packages/transformers/commands/train.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from argparse import ArgumentParser, Namespace
|
| 17 |
+
|
| 18 |
+
from ..data import SingleSentenceClassificationProcessor as Processor
|
| 19 |
+
from ..pipelines import TextClassificationPipeline
|
| 20 |
+
from ..utils import is_tf_available, is_torch_available, logging
|
| 21 |
+
from . import BaseTransformersCLICommand
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if not is_tf_available() and not is_torch_available():
|
| 25 |
+
raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
| 26 |
+
|
| 27 |
+
# TF training parameters
|
| 28 |
+
USE_XLA = False
|
| 29 |
+
USE_AMP = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def train_command_factory(args: Namespace):
|
| 33 |
+
"""
|
| 34 |
+
Factory function used to instantiate training command from provided command line arguments.
|
| 35 |
+
|
| 36 |
+
Returns: TrainCommand
|
| 37 |
+
"""
|
| 38 |
+
return TrainCommand(args)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TrainCommand(BaseTransformersCLICommand):
|
| 42 |
+
@staticmethod
|
| 43 |
+
def register_subcommand(parser: ArgumentParser):
|
| 44 |
+
"""
|
| 45 |
+
Register this command to argparse so it's available for the transformer-cli
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
parser: Root parser to register command-specific arguments
|
| 49 |
+
"""
|
| 50 |
+
train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.")
|
| 51 |
+
|
| 52 |
+
train_parser.add_argument(
|
| 53 |
+
"--train_data",
|
| 54 |
+
type=str,
|
| 55 |
+
required=True,
|
| 56 |
+
help="path to train (and optionally evaluation) dataset as a csv with tab separated labels and sentences.",
|
| 57 |
+
)
|
| 58 |
+
train_parser.add_argument(
|
| 59 |
+
"--column_label", type=int, default=0, help="Column of the dataset csv file with example labels."
|
| 60 |
+
)
|
| 61 |
+
train_parser.add_argument(
|
| 62 |
+
"--column_text", type=int, default=1, help="Column of the dataset csv file with example texts."
|
| 63 |
+
)
|
| 64 |
+
train_parser.add_argument(
|
| 65 |
+
"--column_id", type=int, default=2, help="Column of the dataset csv file with example ids."
|
| 66 |
+
)
|
| 67 |
+
train_parser.add_argument(
|
| 68 |
+
"--skip_first_row", action="store_true", help="Skip the first row of the csv file (headers)."
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
train_parser.add_argument("--validation_data", type=str, default="", help="path to validation dataset.")
|
| 72 |
+
train_parser.add_argument(
|
| 73 |
+
"--validation_split",
|
| 74 |
+
type=float,
|
| 75 |
+
default=0.1,
|
| 76 |
+
help="if validation dataset is not provided, fraction of train dataset to use as validation dataset.",
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
train_parser.add_argument("--output", type=str, default="./", help="path to saved the trained model.")
|
| 80 |
+
|
| 81 |
+
train_parser.add_argument(
|
| 82 |
+
"--task", type=str, default="text_classification", help="Task to train the model on."
|
| 83 |
+
)
|
| 84 |
+
train_parser.add_argument(
|
| 85 |
+
"--model", type=str, default="google-bert/bert-base-uncased", help="Model's name or path to stored model."
|
| 86 |
+
)
|
| 87 |
+
train_parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training.")
|
| 88 |
+
train_parser.add_argument("--valid_batch_size", type=int, default=64, help="Batch size for validation.")
|
| 89 |
+
train_parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.")
|
| 90 |
+
train_parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon for Adam optimizer.")
|
| 91 |
+
train_parser.set_defaults(func=train_command_factory)
|
| 92 |
+
|
| 93 |
+
def __init__(self, args: Namespace):
|
| 94 |
+
self.logger = logging.get_logger("transformers-cli/training")
|
| 95 |
+
|
| 96 |
+
self.framework = "tf" if is_tf_available() else "torch"
|
| 97 |
+
|
| 98 |
+
os.makedirs(args.output, exist_ok=True)
|
| 99 |
+
self.output = args.output
|
| 100 |
+
|
| 101 |
+
self.column_label = args.column_label
|
| 102 |
+
self.column_text = args.column_text
|
| 103 |
+
self.column_id = args.column_id
|
| 104 |
+
|
| 105 |
+
self.logger.info(f"Loading {args.task} pipeline for {args.model}")
|
| 106 |
+
if args.task == "text_classification":
|
| 107 |
+
self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
|
| 108 |
+
elif args.task == "token_classification":
|
| 109 |
+
raise NotImplementedError
|
| 110 |
+
elif args.task == "question_answering":
|
| 111 |
+
raise NotImplementedError
|
| 112 |
+
|
| 113 |
+
self.logger.info(f"Loading dataset from {args.train_data}")
|
| 114 |
+
self.train_dataset = Processor.create_from_csv(
|
| 115 |
+
args.train_data,
|
| 116 |
+
column_label=args.column_label,
|
| 117 |
+
column_text=args.column_text,
|
| 118 |
+
column_id=args.column_id,
|
| 119 |
+
skip_first_row=args.skip_first_row,
|
| 120 |
+
)
|
| 121 |
+
self.valid_dataset = None
|
| 122 |
+
if args.validation_data:
|
| 123 |
+
self.logger.info(f"Loading validation dataset from {args.validation_data}")
|
| 124 |
+
self.valid_dataset = Processor.create_from_csv(
|
| 125 |
+
args.validation_data,
|
| 126 |
+
column_label=args.column_label,
|
| 127 |
+
column_text=args.column_text,
|
| 128 |
+
column_id=args.column_id,
|
| 129 |
+
skip_first_row=args.skip_first_row,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.validation_split = args.validation_split
|
| 133 |
+
self.train_batch_size = args.train_batch_size
|
| 134 |
+
self.valid_batch_size = args.valid_batch_size
|
| 135 |
+
self.learning_rate = args.learning_rate
|
| 136 |
+
self.adam_epsilon = args.adam_epsilon
|
| 137 |
+
|
| 138 |
+
def run(self):
|
| 139 |
+
if self.framework == "tf":
|
| 140 |
+
return self.run_tf()
|
| 141 |
+
return self.run_torch()
|
| 142 |
+
|
| 143 |
+
def run_torch(self):
|
| 144 |
+
raise NotImplementedError
|
| 145 |
+
|
| 146 |
+
def run_tf(self):
|
| 147 |
+
self.pipeline.fit(
|
| 148 |
+
self.train_dataset,
|
| 149 |
+
validation_data=self.valid_dataset,
|
| 150 |
+
validation_split=self.validation_split,
|
| 151 |
+
learning_rate=self.learning_rate,
|
| 152 |
+
adam_epsilon=self.adam_epsilon,
|
| 153 |
+
train_batch_size=self.train_batch_size,
|
| 154 |
+
valid_batch_size=self.valid_batch_size,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Save trained pipeline
|
| 158 |
+
self.pipeline.save_pretrained(self.output)
|
.venv/Lib/site-packages/transformers/commands/transformers_cli.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from argparse import ArgumentParser
|
| 17 |
+
|
| 18 |
+
from .add_new_model_like import AddNewModelLikeCommand
|
| 19 |
+
from .convert import ConvertCommand
|
| 20 |
+
from .download import DownloadCommand
|
| 21 |
+
from .env import EnvironmentCommand
|
| 22 |
+
from .lfs import LfsCommands
|
| 23 |
+
from .pt_to_tf import PTtoTFCommand
|
| 24 |
+
from .run import RunCommand
|
| 25 |
+
from .serving import ServeCommand
|
| 26 |
+
from .user import UserCommands
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def main():
|
| 30 |
+
parser = ArgumentParser("Transformers CLI tool", usage="transformers-cli <command> [<args>]")
|
| 31 |
+
commands_parser = parser.add_subparsers(help="transformers-cli command helpers")
|
| 32 |
+
|
| 33 |
+
# Register commands
|
| 34 |
+
ConvertCommand.register_subcommand(commands_parser)
|
| 35 |
+
DownloadCommand.register_subcommand(commands_parser)
|
| 36 |
+
EnvironmentCommand.register_subcommand(commands_parser)
|
| 37 |
+
RunCommand.register_subcommand(commands_parser)
|
| 38 |
+
ServeCommand.register_subcommand(commands_parser)
|
| 39 |
+
UserCommands.register_subcommand(commands_parser)
|
| 40 |
+
AddNewModelLikeCommand.register_subcommand(commands_parser)
|
| 41 |
+
LfsCommands.register_subcommand(commands_parser)
|
| 42 |
+
PTtoTFCommand.register_subcommand(commands_parser)
|
| 43 |
+
|
| 44 |
+
# Let's go
|
| 45 |
+
args = parser.parse_args()
|
| 46 |
+
|
| 47 |
+
if not hasattr(args, "func"):
|
| 48 |
+
parser.print_help()
|
| 49 |
+
exit(1)
|
| 50 |
+
|
| 51 |
+
# Run
|
| 52 |
+
service = args.func(args)
|
| 53 |
+
service.run()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
main()
|
.venv/Lib/site-packages/transformers/commands/user.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import subprocess
|
| 16 |
+
from argparse import ArgumentParser
|
| 17 |
+
from typing import List, Union
|
| 18 |
+
|
| 19 |
+
from huggingface_hub.hf_api import HfFolder, create_repo, whoami
|
| 20 |
+
from requests.exceptions import HTTPError
|
| 21 |
+
|
| 22 |
+
from . import BaseTransformersCLICommand
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class UserCommands(BaseTransformersCLICommand):
|
| 26 |
+
@staticmethod
|
| 27 |
+
def register_subcommand(parser: ArgumentParser):
|
| 28 |
+
login_parser = parser.add_parser("login", help="Log in using the same credentials as on huggingface.co")
|
| 29 |
+
login_parser.set_defaults(func=lambda args: LoginCommand(args))
|
| 30 |
+
whoami_parser = parser.add_parser("whoami", help="Find out which huggingface.co account you are logged in as.")
|
| 31 |
+
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
|
| 32 |
+
logout_parser = parser.add_parser("logout", help="Log out")
|
| 33 |
+
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
|
| 34 |
+
|
| 35 |
+
# new system: git-based repo system
|
| 36 |
+
repo_parser = parser.add_parser(
|
| 37 |
+
"repo",
|
| 38 |
+
help="Deprecated: use `huggingface-cli` instead. Commands to interact with your huggingface.co repos.",
|
| 39 |
+
)
|
| 40 |
+
repo_subparsers = repo_parser.add_subparsers(
|
| 41 |
+
help="Deprecated: use `huggingface-cli` instead. huggingface.co repos related commands"
|
| 42 |
+
)
|
| 43 |
+
repo_create_parser = repo_subparsers.add_parser(
|
| 44 |
+
"create", help="Deprecated: use `huggingface-cli` instead. Create a new repo on huggingface.co"
|
| 45 |
+
)
|
| 46 |
+
repo_create_parser.add_argument(
|
| 47 |
+
"name",
|
| 48 |
+
type=str,
|
| 49 |
+
help="Name for your model's repo. Will be namespaced under your username to build the model id.",
|
| 50 |
+
)
|
| 51 |
+
repo_create_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
| 52 |
+
repo_create_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt")
|
| 53 |
+
repo_create_parser.set_defaults(func=lambda args: RepoCreateCommand(args))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class ANSI:
|
| 57 |
+
"""
|
| 58 |
+
Helper for en.wikipedia.org/wiki/ANSI_escape_code
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
_bold = "\u001b[1m"
|
| 62 |
+
_red = "\u001b[31m"
|
| 63 |
+
_gray = "\u001b[90m"
|
| 64 |
+
_reset = "\u001b[0m"
|
| 65 |
+
|
| 66 |
+
@classmethod
|
| 67 |
+
def bold(cls, s):
|
| 68 |
+
return f"{cls._bold}{s}{cls._reset}"
|
| 69 |
+
|
| 70 |
+
@classmethod
|
| 71 |
+
def red(cls, s):
|
| 72 |
+
return f"{cls._bold}{cls._red}{s}{cls._reset}"
|
| 73 |
+
|
| 74 |
+
@classmethod
|
| 75 |
+
def gray(cls, s):
|
| 76 |
+
return f"{cls._gray}{s}{cls._reset}"
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
|
| 80 |
+
"""
|
| 81 |
+
Inspired by:
|
| 82 |
+
|
| 83 |
+
- stackoverflow.com/a/8356620/593036
|
| 84 |
+
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
|
| 85 |
+
"""
|
| 86 |
+
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
|
| 87 |
+
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
|
| 88 |
+
lines = []
|
| 89 |
+
lines.append(row_format.format(*headers))
|
| 90 |
+
lines.append(row_format.format(*["-" * w for w in col_widths]))
|
| 91 |
+
for row in rows:
|
| 92 |
+
lines.append(row_format.format(*row))
|
| 93 |
+
return "\n".join(lines)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class BaseUserCommand:
|
| 97 |
+
def __init__(self, args):
|
| 98 |
+
self.args = args
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class LoginCommand(BaseUserCommand):
|
| 102 |
+
def run(self):
|
| 103 |
+
print(
|
| 104 |
+
ANSI.red(
|
| 105 |
+
"ERROR! `huggingface-cli login` uses an outdated login mechanism "
|
| 106 |
+
"that is not compatible with the Hugging Face Hub backend anymore. "
|
| 107 |
+
"Please use `huggingface-cli login instead."
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class WhoamiCommand(BaseUserCommand):
|
| 113 |
+
def run(self):
|
| 114 |
+
print(
|
| 115 |
+
ANSI.red(
|
| 116 |
+
"WARNING! `transformers-cli whoami` is deprecated and will be removed in v5. Please use "
|
| 117 |
+
"`huggingface-cli whoami` instead."
|
| 118 |
+
)
|
| 119 |
+
)
|
| 120 |
+
token = HfFolder.get_token()
|
| 121 |
+
if token is None:
|
| 122 |
+
print("Not logged in")
|
| 123 |
+
exit()
|
| 124 |
+
try:
|
| 125 |
+
user, orgs = whoami(token)
|
| 126 |
+
print(user)
|
| 127 |
+
if orgs:
|
| 128 |
+
print(ANSI.bold("orgs: "), ",".join(orgs))
|
| 129 |
+
except HTTPError as e:
|
| 130 |
+
print(e)
|
| 131 |
+
print(ANSI.red(e.response.text))
|
| 132 |
+
exit(1)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class LogoutCommand(BaseUserCommand):
|
| 136 |
+
def run(self):
|
| 137 |
+
print(
|
| 138 |
+
ANSI.red(
|
| 139 |
+
"ERROR! `transformers-cli logout` uses an outdated logout mechanism "
|
| 140 |
+
"that is not compatible with the Hugging Face Hub backend anymore. "
|
| 141 |
+
"Please use `huggingface-cli logout instead."
|
| 142 |
+
)
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class RepoCreateCommand(BaseUserCommand):
|
| 147 |
+
def run(self):
|
| 148 |
+
print(
|
| 149 |
+
ANSI.red(
|
| 150 |
+
"WARNING! Managing repositories through transformers-cli is deprecated. "
|
| 151 |
+
"Please use `huggingface-cli` instead."
|
| 152 |
+
)
|
| 153 |
+
)
|
| 154 |
+
token = HfFolder.get_token()
|
| 155 |
+
if token is None:
|
| 156 |
+
print("Not logged in")
|
| 157 |
+
exit(1)
|
| 158 |
+
try:
|
| 159 |
+
stdout = subprocess.check_output(["git", "--version"]).decode("utf-8")
|
| 160 |
+
print(ANSI.gray(stdout.strip()))
|
| 161 |
+
except FileNotFoundError:
|
| 162 |
+
print("Looks like you do not have git installed, please install.")
|
| 163 |
+
|
| 164 |
+
try:
|
| 165 |
+
stdout = subprocess.check_output(["git-lfs", "--version"]).decode("utf-8")
|
| 166 |
+
print(ANSI.gray(stdout.strip()))
|
| 167 |
+
except FileNotFoundError:
|
| 168 |
+
print(
|
| 169 |
+
ANSI.red(
|
| 170 |
+
"Looks like you do not have git-lfs installed, please install."
|
| 171 |
+
" You can install from https://git-lfs.github.com/."
|
| 172 |
+
" Then run `git lfs install` (you only have to do this once)."
|
| 173 |
+
)
|
| 174 |
+
)
|
| 175 |
+
print("")
|
| 176 |
+
|
| 177 |
+
user, _ = whoami(token)
|
| 178 |
+
namespace = self.args.organization if self.args.organization is not None else user
|
| 179 |
+
full_name = f"{namespace}/{self.args.name}"
|
| 180 |
+
print(f"You are about to create {ANSI.bold(full_name)}")
|
| 181 |
+
|
| 182 |
+
if not self.args.yes:
|
| 183 |
+
choice = input("Proceed? [Y/n] ").lower()
|
| 184 |
+
if not (choice == "" or choice == "y" or choice == "yes"):
|
| 185 |
+
print("Abort")
|
| 186 |
+
exit()
|
| 187 |
+
try:
|
| 188 |
+
url = create_repo(repo_id=full_name, token=token)
|
| 189 |
+
except HTTPError as e:
|
| 190 |
+
print(e)
|
| 191 |
+
print(ANSI.red(e.response.text))
|
| 192 |
+
exit(1)
|
| 193 |
+
print("\nYour repo now lives at:")
|
| 194 |
+
print(f" {ANSI.bold(url)}")
|
| 195 |
+
print("\nYou can clone it locally with the command below, and commit/push as usual.")
|
| 196 |
+
print(f"\n git clone {url}")
|
| 197 |
+
print("")
|
.venv/Lib/site-packages/transformers/data/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .data_collator import (
|
| 16 |
+
DataCollatorForLanguageModeling,
|
| 17 |
+
DataCollatorForPermutationLanguageModeling,
|
| 18 |
+
DataCollatorForSeq2Seq,
|
| 19 |
+
DataCollatorForSOP,
|
| 20 |
+
DataCollatorForTokenClassification,
|
| 21 |
+
DataCollatorForWholeWordMask,
|
| 22 |
+
DataCollatorWithFlattening,
|
| 23 |
+
DataCollatorWithPadding,
|
| 24 |
+
DefaultDataCollator,
|
| 25 |
+
default_data_collator,
|
| 26 |
+
)
|
| 27 |
+
from .metrics import glue_compute_metrics, xnli_compute_metrics
|
| 28 |
+
from .processors import (
|
| 29 |
+
DataProcessor,
|
| 30 |
+
InputExample,
|
| 31 |
+
InputFeatures,
|
| 32 |
+
SingleSentenceClassificationProcessor,
|
| 33 |
+
SquadExample,
|
| 34 |
+
SquadFeatures,
|
| 35 |
+
SquadV1Processor,
|
| 36 |
+
SquadV2Processor,
|
| 37 |
+
glue_convert_examples_to_features,
|
| 38 |
+
glue_output_modes,
|
| 39 |
+
glue_processors,
|
| 40 |
+
glue_tasks_num_labels,
|
| 41 |
+
squad_convert_examples_to_features,
|
| 42 |
+
xnli_output_modes,
|
| 43 |
+
xnli_processors,
|
| 44 |
+
xnli_tasks_num_labels,
|
| 45 |
+
)
|
.venv/Lib/site-packages/transformers/data/data_collator.py
ADDED
|
@@ -0,0 +1,1653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import random
|
| 16 |
+
import warnings
|
| 17 |
+
from collections.abc import Mapping
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from random import randint
|
| 20 |
+
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
from ..models.bert import BertTokenizer, BertTokenizerFast
|
| 25 |
+
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
| 26 |
+
from ..utils import PaddingStrategy
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
InputDataClass = NewType("InputDataClass", Any)
|
| 30 |
+
|
| 31 |
+
"""
|
| 32 |
+
A DataCollator is a function that takes a list of samples from a Dataset and collate them into a batch, as a dictionary
|
| 33 |
+
of PyTorch/TensorFlow tensors or NumPy arrays.
|
| 34 |
+
"""
|
| 35 |
+
DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, Any]])
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class DataCollatorMixin:
|
| 39 |
+
def __call__(self, features, return_tensors=None):
|
| 40 |
+
if return_tensors is None:
|
| 41 |
+
return_tensors = self.return_tensors
|
| 42 |
+
if return_tensors == "tf":
|
| 43 |
+
return self.tf_call(features)
|
| 44 |
+
elif return_tensors == "pt":
|
| 45 |
+
return self.torch_call(features)
|
| 46 |
+
elif return_tensors == "np":
|
| 47 |
+
return self.numpy_call(features)
|
| 48 |
+
else:
|
| 49 |
+
raise ValueError(f"Framework '{return_tensors}' not recognized!")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
|
| 53 |
+
"""
|
| 54 |
+
Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
# To avoid errors when using Feature extractors
|
| 58 |
+
if not hasattr(tokenizer, "deprecation_warnings"):
|
| 59 |
+
return tokenizer.pad(*pad_args, **pad_kwargs)
|
| 60 |
+
|
| 61 |
+
# Save the state of the warning, then disable it
|
| 62 |
+
warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
|
| 63 |
+
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
padded = tokenizer.pad(*pad_args, **pad_kwargs)
|
| 67 |
+
finally:
|
| 68 |
+
# Restore the state of the warning.
|
| 69 |
+
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state
|
| 70 |
+
|
| 71 |
+
return padded
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def default_data_collator(features: List[InputDataClass], return_tensors="pt") -> Dict[str, Any]:
|
| 75 |
+
"""
|
| 76 |
+
Very simple data collator that simply collates batches of dict-like objects and performs special handling for
|
| 77 |
+
potential keys named:
|
| 78 |
+
|
| 79 |
+
- `label`: handles a single value (int or float) per object
|
| 80 |
+
- `label_ids`: handles a list of values per object
|
| 81 |
+
|
| 82 |
+
Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
|
| 83 |
+
to the model. See glue and ner for example of how it's useful.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
# In this function we'll make the assumption that all `features` in the batch
|
| 87 |
+
# have the same attributes.
|
| 88 |
+
# So we will look at the first element as a proxy for what attributes exist
|
| 89 |
+
# on the whole batch.
|
| 90 |
+
|
| 91 |
+
if return_tensors == "pt":
|
| 92 |
+
return torch_default_data_collator(features)
|
| 93 |
+
elif return_tensors == "tf":
|
| 94 |
+
return tf_default_data_collator(features)
|
| 95 |
+
elif return_tensors == "np":
|
| 96 |
+
return numpy_default_data_collator(features)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@dataclass
|
| 100 |
+
class DefaultDataCollator(DataCollatorMixin):
|
| 101 |
+
"""
|
| 102 |
+
Very simple data collator that simply collates batches of dict-like objects and performs special handling for
|
| 103 |
+
potential keys named:
|
| 104 |
+
|
| 105 |
+
- `label`: handles a single value (int or float) per object
|
| 106 |
+
- `label_ids`: handles a list of values per object
|
| 107 |
+
|
| 108 |
+
Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
|
| 109 |
+
to the model. See glue and ner for example of how it's useful.
|
| 110 |
+
|
| 111 |
+
This is an object (like other data collators) rather than a pure function like default_data_collator. This can be
|
| 112 |
+
helpful if you need to set a return_tensors value at initialization.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
| 116 |
+
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
return_tensors: str = "pt"
|
| 120 |
+
|
| 121 |
+
def __call__(self, features: List[Dict[str, Any]], return_tensors=None) -> Dict[str, Any]:
|
| 122 |
+
if return_tensors is None:
|
| 123 |
+
return_tensors = self.return_tensors
|
| 124 |
+
return default_data_collator(features, return_tensors)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
| 128 |
+
import torch
|
| 129 |
+
|
| 130 |
+
if not isinstance(features[0], Mapping):
|
| 131 |
+
features = [vars(f) for f in features]
|
| 132 |
+
first = features[0]
|
| 133 |
+
batch = {}
|
| 134 |
+
|
| 135 |
+
# Special handling for labels.
|
| 136 |
+
# Ensure that tensor is created with the correct type
|
| 137 |
+
# (it should be automatically the case, but let's make sure of it.)
|
| 138 |
+
if "label" in first and first["label"] is not None:
|
| 139 |
+
label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
|
| 140 |
+
dtype = torch.long if isinstance(label, int) else torch.float
|
| 141 |
+
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
| 142 |
+
elif "label_ids" in first and first["label_ids"] is not None:
|
| 143 |
+
if isinstance(first["label_ids"], torch.Tensor):
|
| 144 |
+
batch["labels"] = torch.stack([f["label_ids"] for f in features])
|
| 145 |
+
else:
|
| 146 |
+
dtype = torch.long if isinstance(first["label_ids"][0], int) else torch.float
|
| 147 |
+
batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
|
| 148 |
+
|
| 149 |
+
# Handling of all other possible keys.
|
| 150 |
+
# Again, we will use the first element to figure out which key/values are not None for this model.
|
| 151 |
+
for k, v in first.items():
|
| 152 |
+
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
| 153 |
+
if isinstance(v, torch.Tensor):
|
| 154 |
+
batch[k] = torch.stack([f[k] for f in features])
|
| 155 |
+
elif isinstance(v, np.ndarray):
|
| 156 |
+
batch[k] = torch.from_numpy(np.stack([f[k] for f in features]))
|
| 157 |
+
else:
|
| 158 |
+
batch[k] = torch.tensor([f[k] for f in features])
|
| 159 |
+
|
| 160 |
+
return batch
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
| 164 |
+
import tensorflow as tf
|
| 165 |
+
|
| 166 |
+
if not isinstance(features[0], Mapping):
|
| 167 |
+
features = [vars(f) for f in features]
|
| 168 |
+
first = features[0]
|
| 169 |
+
batch = {}
|
| 170 |
+
|
| 171 |
+
# Special handling for labels.
|
| 172 |
+
# Ensure that tensor is created with the correct type
|
| 173 |
+
# (it should be automatically the case, but let's make sure of it.)
|
| 174 |
+
if "label" in first and first["label"] is not None:
|
| 175 |
+
label_col_name = "label"
|
| 176 |
+
elif "label_ids" in first and first["label_ids"] is not None:
|
| 177 |
+
label_col_name = "label_ids"
|
| 178 |
+
elif "labels" in first and first["labels"] is not None:
|
| 179 |
+
label_col_name = "labels"
|
| 180 |
+
else:
|
| 181 |
+
label_col_name = None
|
| 182 |
+
if label_col_name is not None:
|
| 183 |
+
if isinstance(first[label_col_name], tf.Tensor):
|
| 184 |
+
dtype = tf.int64 if first[label_col_name].dtype.is_integer else tf.float32
|
| 185 |
+
elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic):
|
| 186 |
+
dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32
|
| 187 |
+
elif isinstance(first[label_col_name], (tuple, list)):
|
| 188 |
+
dtype = tf.int64 if isinstance(first[label_col_name][0], int) else tf.float32
|
| 189 |
+
else:
|
| 190 |
+
dtype = tf.int64 if isinstance(first[label_col_name], int) else tf.float32
|
| 191 |
+
batch["labels"] = tf.convert_to_tensor([f[label_col_name] for f in features], dtype=dtype)
|
| 192 |
+
# Handling of all other possible keys.
|
| 193 |
+
# Again, we will use the first element to figure out which key/values are not None for this model.
|
| 194 |
+
for k, v in first.items():
|
| 195 |
+
if k not in ("label", "label_ids", "labels") and v is not None and not isinstance(v, str):
|
| 196 |
+
if isinstance(v, (tf.Tensor, np.ndarray)):
|
| 197 |
+
batch[k] = tf.stack([f[k] for f in features])
|
| 198 |
+
else:
|
| 199 |
+
batch[k] = tf.convert_to_tensor([f[k] for f in features])
|
| 200 |
+
|
| 201 |
+
return batch
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def numpy_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
| 205 |
+
if not isinstance(features[0], Mapping):
|
| 206 |
+
features = [vars(f) for f in features]
|
| 207 |
+
first = features[0]
|
| 208 |
+
batch = {}
|
| 209 |
+
|
| 210 |
+
# Special handling for labels.
|
| 211 |
+
# Ensure that tensor is created with the correct type
|
| 212 |
+
# (it should be automatically the case, but let's make sure of it.)
|
| 213 |
+
if "label" in first and first["label"] is not None:
|
| 214 |
+
label = first["label"].item() if isinstance(first["label"], np.ndarray) else first["label"]
|
| 215 |
+
dtype = np.int64 if isinstance(label, int) else np.float32
|
| 216 |
+
batch["labels"] = np.array([f["label"] for f in features], dtype=dtype)
|
| 217 |
+
elif "label_ids" in first and first["label_ids"] is not None:
|
| 218 |
+
if isinstance(first["label_ids"], np.ndarray):
|
| 219 |
+
batch["labels"] = np.stack([f["label_ids"] for f in features])
|
| 220 |
+
else:
|
| 221 |
+
dtype = np.int64 if isinstance(first["label_ids"][0], int) else np.float32
|
| 222 |
+
batch["labels"] = np.array([f["label_ids"] for f in features], dtype=dtype)
|
| 223 |
+
|
| 224 |
+
# Handling of all other possible keys.
|
| 225 |
+
# Again, we will use the first element to figure out which key/values are not None for this model.
|
| 226 |
+
for k, v in first.items():
|
| 227 |
+
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
| 228 |
+
if isinstance(v, np.ndarray):
|
| 229 |
+
batch[k] = np.stack([f[k] for f in features])
|
| 230 |
+
else:
|
| 231 |
+
batch[k] = np.array([f[k] for f in features])
|
| 232 |
+
|
| 233 |
+
return batch
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@dataclass
|
| 237 |
+
class DataCollatorWithPadding:
|
| 238 |
+
"""
|
| 239 |
+
Data collator that will dynamically pad the inputs received.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
| 243 |
+
The tokenizer used for encoding the data.
|
| 244 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
| 245 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
| 246 |
+
among:
|
| 247 |
+
|
| 248 |
+
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
|
| 249 |
+
sequence is provided).
|
| 250 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
| 251 |
+
acceptable input length for the model if that argument is not provided.
|
| 252 |
+
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
|
| 253 |
+
max_length (`int`, *optional*):
|
| 254 |
+
Maximum length of the returned list and optionally padding length (see above).
|
| 255 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 256 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 257 |
+
|
| 258 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
| 259 |
+
7.5 (Volta).
|
| 260 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
| 261 |
+
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
tokenizer: PreTrainedTokenizerBase
|
| 265 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
| 266 |
+
max_length: Optional[int] = None
|
| 267 |
+
pad_to_multiple_of: Optional[int] = None
|
| 268 |
+
return_tensors: str = "pt"
|
| 269 |
+
|
| 270 |
+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 271 |
+
batch = pad_without_fast_tokenizer_warning(
|
| 272 |
+
self.tokenizer,
|
| 273 |
+
features,
|
| 274 |
+
padding=self.padding,
|
| 275 |
+
max_length=self.max_length,
|
| 276 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
| 277 |
+
return_tensors=self.return_tensors,
|
| 278 |
+
)
|
| 279 |
+
if "label" in batch:
|
| 280 |
+
batch["labels"] = batch["label"]
|
| 281 |
+
del batch["label"]
|
| 282 |
+
if "label_ids" in batch:
|
| 283 |
+
batch["labels"] = batch["label_ids"]
|
| 284 |
+
del batch["label_ids"]
|
| 285 |
+
return batch
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
@dataclass
|
| 289 |
+
class DataCollatorForTokenClassification(DataCollatorMixin):
|
| 290 |
+
"""
|
| 291 |
+
Data collator that will dynamically pad the inputs received, as well as the labels.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
| 295 |
+
The tokenizer used for encoding the data.
|
| 296 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
| 297 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
| 298 |
+
among:
|
| 299 |
+
|
| 300 |
+
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
|
| 301 |
+
sequence is provided).
|
| 302 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
| 303 |
+
acceptable input length for the model if that argument is not provided.
|
| 304 |
+
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
|
| 305 |
+
max_length (`int`, *optional*):
|
| 306 |
+
Maximum length of the returned list and optionally padding length (see above).
|
| 307 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 308 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 309 |
+
|
| 310 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
| 311 |
+
7.5 (Volta).
|
| 312 |
+
label_pad_token_id (`int`, *optional*, defaults to -100):
|
| 313 |
+
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
|
| 314 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
| 315 |
+
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
| 316 |
+
"""
|
| 317 |
+
|
| 318 |
+
tokenizer: PreTrainedTokenizerBase
|
| 319 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
| 320 |
+
max_length: Optional[int] = None
|
| 321 |
+
pad_to_multiple_of: Optional[int] = None
|
| 322 |
+
label_pad_token_id: int = -100
|
| 323 |
+
return_tensors: str = "pt"
|
| 324 |
+
|
| 325 |
+
def torch_call(self, features):
|
| 326 |
+
import torch
|
| 327 |
+
|
| 328 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
| 329 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
| 330 |
+
|
| 331 |
+
no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
|
| 332 |
+
|
| 333 |
+
batch = pad_without_fast_tokenizer_warning(
|
| 334 |
+
self.tokenizer,
|
| 335 |
+
no_labels_features,
|
| 336 |
+
padding=self.padding,
|
| 337 |
+
max_length=self.max_length,
|
| 338 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
| 339 |
+
return_tensors="pt",
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
if labels is None:
|
| 343 |
+
return batch
|
| 344 |
+
|
| 345 |
+
sequence_length = batch["input_ids"].shape[1]
|
| 346 |
+
padding_side = self.tokenizer.padding_side
|
| 347 |
+
|
| 348 |
+
def to_list(tensor_or_iterable):
|
| 349 |
+
if isinstance(tensor_or_iterable, torch.Tensor):
|
| 350 |
+
return tensor_or_iterable.tolist()
|
| 351 |
+
return list(tensor_or_iterable)
|
| 352 |
+
|
| 353 |
+
if padding_side == "right":
|
| 354 |
+
batch[label_name] = [
|
| 355 |
+
to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
|
| 356 |
+
]
|
| 357 |
+
else:
|
| 358 |
+
batch[label_name] = [
|
| 359 |
+
[self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
|
| 360 |
+
]
|
| 361 |
+
|
| 362 |
+
batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
|
| 363 |
+
return batch
|
| 364 |
+
|
| 365 |
+
def tf_call(self, features):
|
| 366 |
+
import tensorflow as tf
|
| 367 |
+
|
| 368 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
| 369 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
| 370 |
+
batch = pad_without_fast_tokenizer_warning(
|
| 371 |
+
self.tokenizer,
|
| 372 |
+
features,
|
| 373 |
+
padding=self.padding,
|
| 374 |
+
max_length=self.max_length,
|
| 375 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
| 376 |
+
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
|
| 377 |
+
return_tensors="tf" if labels is None else None,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
if labels is None:
|
| 381 |
+
return batch
|
| 382 |
+
|
| 383 |
+
sequence_length = tf.convert_to_tensor(batch["input_ids"]).shape[1]
|
| 384 |
+
padding_side = self.tokenizer.padding_side
|
| 385 |
+
if padding_side == "right":
|
| 386 |
+
batch["labels"] = [
|
| 387 |
+
list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
|
| 388 |
+
]
|
| 389 |
+
else:
|
| 390 |
+
batch["labels"] = [
|
| 391 |
+
[self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
|
| 392 |
+
]
|
| 393 |
+
|
| 394 |
+
batch = {k: tf.convert_to_tensor(v, dtype=tf.int64) for k, v in batch.items()}
|
| 395 |
+
return batch
|
| 396 |
+
|
| 397 |
+
def numpy_call(self, features):
|
| 398 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
| 399 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
| 400 |
+
batch = pad_without_fast_tokenizer_warning(
|
| 401 |
+
self.tokenizer,
|
| 402 |
+
features,
|
| 403 |
+
padding=self.padding,
|
| 404 |
+
max_length=self.max_length,
|
| 405 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
| 406 |
+
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
|
| 407 |
+
return_tensors="np" if labels is None else None,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
if labels is None:
|
| 411 |
+
return batch
|
| 412 |
+
|
| 413 |
+
sequence_length = np.array(batch["input_ids"]).shape[1]
|
| 414 |
+
padding_side = self.tokenizer.padding_side
|
| 415 |
+
if padding_side == "right":
|
| 416 |
+
batch["labels"] = [
|
| 417 |
+
list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
|
| 418 |
+
]
|
| 419 |
+
else:
|
| 420 |
+
batch["labels"] = [
|
| 421 |
+
[self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
|
| 422 |
+
]
|
| 423 |
+
|
| 424 |
+
batch = {k: np.array(v, dtype=np.int64) for k, v in batch.items()}
|
| 425 |
+
return batch
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
|
| 429 |
+
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
| 430 |
+
import torch
|
| 431 |
+
|
| 432 |
+
# Tensorize if necessary.
|
| 433 |
+
if isinstance(examples[0], (list, tuple, np.ndarray)):
|
| 434 |
+
examples = [torch.tensor(e, dtype=torch.long) for e in examples]
|
| 435 |
+
|
| 436 |
+
length_of_first = examples[0].size(0)
|
| 437 |
+
|
| 438 |
+
# Check if padding is necessary.
|
| 439 |
+
|
| 440 |
+
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
| 441 |
+
if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
|
| 442 |
+
if not isinstance(examples, torch.Tensor):
|
| 443 |
+
return torch.stack(examples, dim=0)
|
| 444 |
+
|
| 445 |
+
# If yes, check if we have a `pad_token`.
|
| 446 |
+
if tokenizer.pad_token is None:
|
| 447 |
+
raise ValueError(
|
| 448 |
+
"You are attempting to pad samples but the tokenizer you are using"
|
| 449 |
+
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# Creating the full tensor and filling it with our data.
|
| 453 |
+
max_length = max(x.size(0) for x in examples)
|
| 454 |
+
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
| 455 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
| 456 |
+
result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
|
| 457 |
+
for i, example in enumerate(examples):
|
| 458 |
+
if tokenizer.padding_side == "right":
|
| 459 |
+
result[i, : example.shape[0]] = example
|
| 460 |
+
else:
|
| 461 |
+
result[i, -example.shape[0] :] = example
|
| 462 |
+
return result
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def _tf_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
|
| 466 |
+
import tensorflow as tf
|
| 467 |
+
|
| 468 |
+
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
| 469 |
+
# Tensorize if necessary.
|
| 470 |
+
if isinstance(examples[0], (list, tuple)):
|
| 471 |
+
examples = [tf.convert_to_tensor(e, dtype=tf.int64) for e in examples]
|
| 472 |
+
|
| 473 |
+
# Check if padding is necessary.
|
| 474 |
+
length_of_first = len(examples[0])
|
| 475 |
+
are_tensors_same_length = all(len(x) == length_of_first for x in examples)
|
| 476 |
+
if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
|
| 477 |
+
return tf.stack(examples, axis=0)
|
| 478 |
+
|
| 479 |
+
# If yes, check if we have a `pad_token`.
|
| 480 |
+
if tokenizer.pad_token is None:
|
| 481 |
+
raise ValueError(
|
| 482 |
+
"You are attempting to pad samples but the tokenizer you are using"
|
| 483 |
+
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
# Creating the full tensor and filling it with our data.
|
| 487 |
+
max_length = max(len(x) for x in examples)
|
| 488 |
+
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
| 489 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
| 490 |
+
# result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
|
| 491 |
+
result = []
|
| 492 |
+
rank = tf.rank(examples[0])
|
| 493 |
+
paddings = np.zeros((rank, 2), dtype=np.int32)
|
| 494 |
+
for example in examples:
|
| 495 |
+
if tokenizer.padding_side == "right":
|
| 496 |
+
paddings[0, 1] = max_length - len(example)
|
| 497 |
+
else:
|
| 498 |
+
paddings[0, 0] = max_length - len(example)
|
| 499 |
+
result.append(tf.pad(example, paddings, constant_values=tokenizer.pad_token_id))
|
| 500 |
+
return tf.stack(result, axis=0)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def _numpy_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
|
| 504 |
+
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
| 505 |
+
# Tensorize if necessary.
|
| 506 |
+
if isinstance(examples[0], (list, tuple)):
|
| 507 |
+
examples = [np.array(e, dtype=np.int64) for e in examples]
|
| 508 |
+
|
| 509 |
+
# Check if padding is necessary.
|
| 510 |
+
length_of_first = len(examples[0])
|
| 511 |
+
are_tensors_same_length = all(len(x) == length_of_first for x in examples)
|
| 512 |
+
if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
|
| 513 |
+
return np.stack(examples, axis=0)
|
| 514 |
+
|
| 515 |
+
# If yes, check if we have a `pad_token`.
|
| 516 |
+
if tokenizer.pad_token is None:
|
| 517 |
+
raise ValueError(
|
| 518 |
+
"You are attempting to pad samples but the tokenizer you are using"
|
| 519 |
+
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
# Creating the full tensor and filling it with our data.
|
| 523 |
+
max_length = max(len(x) for x in examples)
|
| 524 |
+
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
| 525 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
| 526 |
+
result = np.full(shape=(len(examples), max_length), fill_value=tokenizer.pad_token_id, dtype=examples[0].dtype)
|
| 527 |
+
for i, example in enumerate(examples):
|
| 528 |
+
if tokenizer.padding_side == "right":
|
| 529 |
+
result[i, : example.shape[0]] = example
|
| 530 |
+
else:
|
| 531 |
+
result[i, -example.shape[0] :] = example
|
| 532 |
+
return result
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def tolist(x):
|
| 536 |
+
if isinstance(x, list):
|
| 537 |
+
return x
|
| 538 |
+
elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import
|
| 539 |
+
x = x.numpy()
|
| 540 |
+
return x.tolist()
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
@dataclass
|
| 544 |
+
class DataCollatorForSeq2Seq:
|
| 545 |
+
"""
|
| 546 |
+
Data collator that will dynamically pad the inputs received, as well as the labels.
|
| 547 |
+
|
| 548 |
+
Args:
|
| 549 |
+
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
| 550 |
+
The tokenizer used for encoding the data.
|
| 551 |
+
model ([`PreTrainedModel`], *optional*):
|
| 552 |
+
The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
|
| 553 |
+
prepare the *decoder_input_ids*
|
| 554 |
+
|
| 555 |
+
This is useful when using *label_smoothing* to avoid calculating loss twice.
|
| 556 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
| 557 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
| 558 |
+
among:
|
| 559 |
+
|
| 560 |
+
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
|
| 561 |
+
sequence is provided).
|
| 562 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
| 563 |
+
acceptable input length for the model if that argument is not provided.
|
| 564 |
+
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
|
| 565 |
+
max_length (`int`, *optional*):
|
| 566 |
+
Maximum length of the returned list and optionally padding length (see above).
|
| 567 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 568 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 569 |
+
|
| 570 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
| 571 |
+
7.5 (Volta).
|
| 572 |
+
label_pad_token_id (`int`, *optional*, defaults to -100):
|
| 573 |
+
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
| 574 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
| 575 |
+
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
| 576 |
+
"""
|
| 577 |
+
|
| 578 |
+
tokenizer: PreTrainedTokenizerBase
|
| 579 |
+
model: Optional[Any] = None
|
| 580 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
| 581 |
+
max_length: Optional[int] = None
|
| 582 |
+
pad_to_multiple_of: Optional[int] = None
|
| 583 |
+
label_pad_token_id: int = -100
|
| 584 |
+
return_tensors: str = "pt"
|
| 585 |
+
|
| 586 |
+
def __call__(self, features, return_tensors=None):
|
| 587 |
+
if return_tensors is None:
|
| 588 |
+
return_tensors = self.return_tensors
|
| 589 |
+
|
| 590 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
| 591 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
| 592 |
+
# reconvert list[None] to None if necessary
|
| 593 |
+
# this might occur when we pass {..., "labels": None}
|
| 594 |
+
if labels is not None and all(label is None for label in labels):
|
| 595 |
+
labels = None
|
| 596 |
+
non_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
|
| 597 |
+
|
| 598 |
+
# run through tokenizer without labels to ensure no side effects
|
| 599 |
+
batch = pad_without_fast_tokenizer_warning(
|
| 600 |
+
self.tokenizer,
|
| 601 |
+
non_labels_features,
|
| 602 |
+
padding=self.padding,
|
| 603 |
+
max_length=self.max_length,
|
| 604 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
| 605 |
+
return_tensors=return_tensors,
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
# we have to pad the labels manually as we cannot rely on `tokenizer.pad` and we need them to be of the same length to return tensors
|
| 609 |
+
no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD
|
| 610 |
+
if labels is not None:
|
| 611 |
+
if no_padding:
|
| 612 |
+
if isinstance(features[0][label_name], list):
|
| 613 |
+
batch["labels"] = list(labels)
|
| 614 |
+
else:
|
| 615 |
+
batch["labels"] = [np.concatenate([label, []]) for label in labels]
|
| 616 |
+
else:
|
| 617 |
+
max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None
|
| 618 |
+
max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length
|
| 619 |
+
if self.pad_to_multiple_of is not None:
|
| 620 |
+
max_label_length = (
|
| 621 |
+
(max_label_length + self.pad_to_multiple_of - 1)
|
| 622 |
+
// self.pad_to_multiple_of
|
| 623 |
+
* self.pad_to_multiple_of
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
padding_side = self.tokenizer.padding_side
|
| 627 |
+
if isinstance(features[0][label_name], list):
|
| 628 |
+
batch["labels"] = [
|
| 629 |
+
label + [self.label_pad_token_id] * (max_label_length - len(label))
|
| 630 |
+
if padding_side == "right"
|
| 631 |
+
else [self.label_pad_token_id] * (max_label_length - len(label)) + label
|
| 632 |
+
for label in labels
|
| 633 |
+
]
|
| 634 |
+
else:
|
| 635 |
+
batch["labels"] = [
|
| 636 |
+
np.concatenate(
|
| 637 |
+
[
|
| 638 |
+
label,
|
| 639 |
+
np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64),
|
| 640 |
+
]
|
| 641 |
+
)
|
| 642 |
+
if padding_side == "right"
|
| 643 |
+
else np.concatenate(
|
| 644 |
+
[
|
| 645 |
+
np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64),
|
| 646 |
+
label,
|
| 647 |
+
]
|
| 648 |
+
)
|
| 649 |
+
for label in labels
|
| 650 |
+
]
|
| 651 |
+
|
| 652 |
+
# reintroduce side effects via tokenizer that return respective datatypes for the `return_tensors` argument
|
| 653 |
+
if batch.get("labels", None) is not None:
|
| 654 |
+
if return_tensors == "pt":
|
| 655 |
+
import torch
|
| 656 |
+
|
| 657 |
+
batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
|
| 658 |
+
elif return_tensors == "tf":
|
| 659 |
+
import tensorflow as tf
|
| 660 |
+
|
| 661 |
+
batch["labels"] = tf.constant(batch["labels"], dtype=tf.int64)
|
| 662 |
+
else:
|
| 663 |
+
batch["labels"] = np.array(batch["labels"], dtype=np.int64)
|
| 664 |
+
else:
|
| 665 |
+
batch["labels"] = None
|
| 666 |
+
|
| 667 |
+
# prepare decoder_input_ids
|
| 668 |
+
if (
|
| 669 |
+
labels is not None
|
| 670 |
+
and self.model is not None
|
| 671 |
+
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
|
| 672 |
+
):
|
| 673 |
+
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"])
|
| 674 |
+
batch["decoder_input_ids"] = decoder_input_ids
|
| 675 |
+
|
| 676 |
+
return batch
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
@dataclass
|
| 680 |
+
class DataCollatorForLanguageModeling(DataCollatorMixin):
|
| 681 |
+
"""
|
| 682 |
+
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
|
| 683 |
+
are not all of the same length.
|
| 684 |
+
|
| 685 |
+
Args:
|
| 686 |
+
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
| 687 |
+
The tokenizer used for encoding the data.
|
| 688 |
+
mlm (`bool`, *optional*, defaults to `True`):
|
| 689 |
+
Whether or not to use masked language modeling. If set to `False`, the labels are the same as the inputs
|
| 690 |
+
with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for non-masked
|
| 691 |
+
tokens and the value to predict for the masked token.
|
| 692 |
+
mlm_probability (`float`, *optional*, defaults to 0.15):
|
| 693 |
+
The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
|
| 694 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 695 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 696 |
+
return_tensors (`str`):
|
| 697 |
+
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
| 698 |
+
|
| 699 |
+
<Tip>
|
| 700 |
+
|
| 701 |
+
For best performance, this data collator should be used with a dataset having items that are dictionaries or
|
| 702 |
+
BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a
|
| 703 |
+
[`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`.
|
| 704 |
+
|
| 705 |
+
</Tip>"""
|
| 706 |
+
|
| 707 |
+
tokenizer: PreTrainedTokenizerBase
|
| 708 |
+
mlm: bool = True
|
| 709 |
+
mlm_probability: float = 0.15
|
| 710 |
+
pad_to_multiple_of: Optional[int] = None
|
| 711 |
+
tf_experimental_compile: bool = False
|
| 712 |
+
return_tensors: str = "pt"
|
| 713 |
+
|
| 714 |
+
def __post_init__(self):
|
| 715 |
+
if self.mlm and self.tokenizer.mask_token is None:
|
| 716 |
+
raise ValueError(
|
| 717 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
| 718 |
+
"You should pass `mlm=False` to train on causal language modeling instead."
|
| 719 |
+
)
|
| 720 |
+
if self.tf_experimental_compile:
|
| 721 |
+
import tensorflow as tf
|
| 722 |
+
|
| 723 |
+
self.tf_mask_tokens = tf.function(self.tf_mask_tokens, jit_compile=True)
|
| 724 |
+
|
| 725 |
+
@staticmethod
|
| 726 |
+
def tf_bernoulli(shape, probability):
|
| 727 |
+
import tensorflow as tf
|
| 728 |
+
|
| 729 |
+
prob_matrix = tf.fill(shape, probability)
|
| 730 |
+
return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool)
|
| 731 |
+
|
| 732 |
+
def tf_mask_tokens(
|
| 733 |
+
self, inputs: Any, vocab_size, mask_token_id, special_tokens_mask: Optional[Any] = None
|
| 734 |
+
) -> Tuple[Any, Any]:
|
| 735 |
+
"""
|
| 736 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
| 737 |
+
"""
|
| 738 |
+
import tensorflow as tf
|
| 739 |
+
|
| 740 |
+
mask_token_id = tf.cast(mask_token_id, inputs.dtype)
|
| 741 |
+
|
| 742 |
+
input_shape = tf.shape(inputs)
|
| 743 |
+
# 1 for a special token, 0 for a normal token in the special tokens mask
|
| 744 |
+
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
| 745 |
+
masked_indices = self.tf_bernoulli(input_shape, self.mlm_probability) & ~special_tokens_mask
|
| 746 |
+
# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
|
| 747 |
+
labels = tf.where(masked_indices, inputs, -100)
|
| 748 |
+
|
| 749 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 750 |
+
indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
|
| 751 |
+
|
| 752 |
+
inputs = tf.where(indices_replaced, mask_token_id, inputs)
|
| 753 |
+
|
| 754 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 755 |
+
indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
|
| 756 |
+
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)
|
| 757 |
+
|
| 758 |
+
inputs = tf.where(indices_random, random_words, inputs)
|
| 759 |
+
|
| 760 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
| 761 |
+
return inputs, labels
|
| 762 |
+
|
| 763 |
+
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 764 |
+
import tensorflow as tf
|
| 765 |
+
|
| 766 |
+
# Handle dict or lists with proper padding and conversion to tensor.
|
| 767 |
+
if isinstance(examples[0], Mapping):
|
| 768 |
+
batch = pad_without_fast_tokenizer_warning(
|
| 769 |
+
self.tokenizer, examples, return_tensors="tf", pad_to_multiple_of=self.pad_to_multiple_of
|
| 770 |
+
)
|
| 771 |
+
else:
|
| 772 |
+
batch = {
|
| 773 |
+
"input_ids": _tf_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 774 |
+
}
|
| 775 |
+
|
| 776 |
+
# If special token mask has been preprocessed, pop it from the dict.
|
| 777 |
+
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
| 778 |
+
if self.mlm:
|
| 779 |
+
if special_tokens_mask is None:
|
| 780 |
+
special_tokens_mask = [
|
| 781 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
|
| 782 |
+
for val in batch["input_ids"].numpy().tolist()
|
| 783 |
+
]
|
| 784 |
+
# Cannot directly create as bool
|
| 785 |
+
special_tokens_mask = tf.cast(tf.convert_to_tensor(special_tokens_mask, dtype=tf.int64), tf.bool)
|
| 786 |
+
else:
|
| 787 |
+
special_tokens_mask = tf.cast(special_tokens_mask, tf.bool)
|
| 788 |
+
batch["input_ids"], batch["labels"] = self.tf_mask_tokens(
|
| 789 |
+
tf.cast(batch["input_ids"], tf.int64),
|
| 790 |
+
special_tokens_mask=special_tokens_mask,
|
| 791 |
+
mask_token_id=self.tokenizer.mask_token_id,
|
| 792 |
+
vocab_size=len(self.tokenizer),
|
| 793 |
+
)
|
| 794 |
+
else:
|
| 795 |
+
labels = batch["input_ids"]
|
| 796 |
+
if self.tokenizer.pad_token_id is not None:
|
| 797 |
+
# Replace self.tokenizer.pad_token_id with -100
|
| 798 |
+
labels = tf.where(labels == self.tokenizer.pad_token_id, -100, labels)
|
| 799 |
+
else:
|
| 800 |
+
labels = tf.identity(labels) # Makes a copy, just in case
|
| 801 |
+
batch["labels"] = labels
|
| 802 |
+
return batch
|
| 803 |
+
|
| 804 |
+
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 805 |
+
# Handle dict or lists with proper padding and conversion to tensor.
|
| 806 |
+
if isinstance(examples[0], Mapping):
|
| 807 |
+
batch = pad_without_fast_tokenizer_warning(
|
| 808 |
+
self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of
|
| 809 |
+
)
|
| 810 |
+
else:
|
| 811 |
+
batch = {
|
| 812 |
+
"input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 813 |
+
}
|
| 814 |
+
|
| 815 |
+
# If special token mask has been preprocessed, pop it from the dict.
|
| 816 |
+
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
| 817 |
+
if self.mlm:
|
| 818 |
+
batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
|
| 819 |
+
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
| 820 |
+
)
|
| 821 |
+
else:
|
| 822 |
+
labels = batch["input_ids"].clone()
|
| 823 |
+
if self.tokenizer.pad_token_id is not None:
|
| 824 |
+
labels[labels == self.tokenizer.pad_token_id] = -100
|
| 825 |
+
batch["labels"] = labels
|
| 826 |
+
return batch
|
| 827 |
+
|
| 828 |
+
def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
|
| 829 |
+
"""
|
| 830 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
| 831 |
+
"""
|
| 832 |
+
import torch
|
| 833 |
+
|
| 834 |
+
labels = inputs.clone()
|
| 835 |
+
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
| 836 |
+
probability_matrix = torch.full(labels.shape, self.mlm_probability)
|
| 837 |
+
if special_tokens_mask is None:
|
| 838 |
+
special_tokens_mask = [
|
| 839 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
| 840 |
+
]
|
| 841 |
+
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
|
| 842 |
+
else:
|
| 843 |
+
special_tokens_mask = special_tokens_mask.bool()
|
| 844 |
+
|
| 845 |
+
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
| 846 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
| 847 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 848 |
+
|
| 849 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 850 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
| 851 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
| 852 |
+
|
| 853 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 854 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
| 855 |
+
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
| 856 |
+
inputs[indices_random] = random_words[indices_random]
|
| 857 |
+
|
| 858 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
| 859 |
+
return inputs, labels
|
| 860 |
+
|
| 861 |
+
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 862 |
+
# Handle dict or lists with proper padding and conversion to tensor.
|
| 863 |
+
if isinstance(examples[0], Mapping):
|
| 864 |
+
batch = pad_without_fast_tokenizer_warning(
|
| 865 |
+
self.tokenizer, examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of
|
| 866 |
+
)
|
| 867 |
+
else:
|
| 868 |
+
batch = {
|
| 869 |
+
"input_ids": _numpy_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 870 |
+
}
|
| 871 |
+
|
| 872 |
+
# If special token mask has been preprocessed, pop it from the dict.
|
| 873 |
+
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
| 874 |
+
if self.mlm:
|
| 875 |
+
batch["input_ids"], batch["labels"] = self.numpy_mask_tokens(
|
| 876 |
+
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
| 877 |
+
)
|
| 878 |
+
else:
|
| 879 |
+
labels = np.copy(batch["input_ids"])
|
| 880 |
+
if self.tokenizer.pad_token_id is not None:
|
| 881 |
+
labels[labels == self.tokenizer.pad_token_id] = -100
|
| 882 |
+
batch["labels"] = labels
|
| 883 |
+
return batch
|
| 884 |
+
|
| 885 |
+
def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
|
| 886 |
+
"""
|
| 887 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
| 888 |
+
"""
|
| 889 |
+
labels = np.copy(inputs)
|
| 890 |
+
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
| 891 |
+
probability_matrix = np.full(labels.shape, self.mlm_probability)
|
| 892 |
+
if special_tokens_mask is None:
|
| 893 |
+
special_tokens_mask = [
|
| 894 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
| 895 |
+
]
|
| 896 |
+
special_tokens_mask = np.array(special_tokens_mask, dtype=bool)
|
| 897 |
+
else:
|
| 898 |
+
special_tokens_mask = special_tokens_mask.astype(bool)
|
| 899 |
+
|
| 900 |
+
probability_matrix[special_tokens_mask] = 0
|
| 901 |
+
# Numpy doesn't have bernoulli, so we use a binomial with 1 trial
|
| 902 |
+
masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
|
| 903 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 904 |
+
|
| 905 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 906 |
+
indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
|
| 907 |
+
inputs[indices_replaced] = self.tokenizer.mask_token_id
|
| 908 |
+
|
| 909 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 910 |
+
# indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
| 911 |
+
indices_random = (
|
| 912 |
+
np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
|
| 913 |
+
)
|
| 914 |
+
random_words = np.random.randint(
|
| 915 |
+
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
|
| 916 |
+
)
|
| 917 |
+
inputs[indices_random] = random_words
|
| 918 |
+
|
| 919 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
| 920 |
+
return inputs, labels
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
@dataclass
|
| 924 |
+
class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
| 925 |
+
"""
|
| 926 |
+
Data collator used for language modeling that masks entire words.
|
| 927 |
+
|
| 928 |
+
- collates batches of tensors, honoring their tokenizer's pad_token
|
| 929 |
+
- preprocesses batches for masked language modeling
|
| 930 |
+
|
| 931 |
+
<Tip>
|
| 932 |
+
|
| 933 |
+
This collator relies on details of the implementation of subword tokenization by [`BertTokenizer`], specifically
|
| 934 |
+
that subword tokens are prefixed with *##*. For tokenizers that do not adhere to this scheme, this collator will
|
| 935 |
+
produce an output that is roughly equivalent to [`.DataCollatorForLanguageModeling`].
|
| 936 |
+
|
| 937 |
+
</Tip>"""
|
| 938 |
+
|
| 939 |
+
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 940 |
+
if isinstance(examples[0], Mapping):
|
| 941 |
+
input_ids = [e["input_ids"] for e in examples]
|
| 942 |
+
else:
|
| 943 |
+
input_ids = examples
|
| 944 |
+
examples = [{"input_ids": e} for e in examples]
|
| 945 |
+
|
| 946 |
+
batch_input = _torch_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 947 |
+
|
| 948 |
+
mask_labels = []
|
| 949 |
+
for e in examples:
|
| 950 |
+
ref_tokens = []
|
| 951 |
+
for id in tolist(e["input_ids"]):
|
| 952 |
+
token = self.tokenizer._convert_id_to_token(id)
|
| 953 |
+
ref_tokens.append(token)
|
| 954 |
+
|
| 955 |
+
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
|
| 956 |
+
if "chinese_ref" in e:
|
| 957 |
+
ref_pos = tolist(e["chinese_ref"])
|
| 958 |
+
len_seq = len(e["input_ids"])
|
| 959 |
+
for i in range(len_seq):
|
| 960 |
+
if i in ref_pos:
|
| 961 |
+
ref_tokens[i] = "##" + ref_tokens[i]
|
| 962 |
+
mask_labels.append(self._whole_word_mask(ref_tokens))
|
| 963 |
+
batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 964 |
+
inputs, labels = self.torch_mask_tokens(batch_input, batch_mask)
|
| 965 |
+
return {"input_ids": inputs, "labels": labels}
|
| 966 |
+
|
| 967 |
+
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 968 |
+
import tensorflow as tf
|
| 969 |
+
|
| 970 |
+
if isinstance(examples[0], Mapping):
|
| 971 |
+
input_ids = [e["input_ids"] for e in examples]
|
| 972 |
+
else:
|
| 973 |
+
input_ids = examples
|
| 974 |
+
examples = [{"input_ids": e} for e in examples]
|
| 975 |
+
|
| 976 |
+
batch_input = _tf_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 977 |
+
|
| 978 |
+
mask_labels = []
|
| 979 |
+
for e in examples:
|
| 980 |
+
ref_tokens = []
|
| 981 |
+
for id in tolist(e["input_ids"]):
|
| 982 |
+
token = self.tokenizer._convert_id_to_token(id)
|
| 983 |
+
ref_tokens.append(token)
|
| 984 |
+
|
| 985 |
+
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
|
| 986 |
+
if "chinese_ref" in e:
|
| 987 |
+
ref_pos = tolist(e["chinese_ref"])
|
| 988 |
+
len_seq = len(e["input_ids"])
|
| 989 |
+
for i in range(len_seq):
|
| 990 |
+
if i in ref_pos:
|
| 991 |
+
ref_tokens[i] = "##" + ref_tokens[i]
|
| 992 |
+
mask_labels.append(self._whole_word_mask(ref_tokens))
|
| 993 |
+
batch_mask = _tf_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 994 |
+
inputs, labels = self.tf_mask_tokens(tf.cast(batch_input, tf.int64), batch_mask)
|
| 995 |
+
return {"input_ids": inputs, "labels": labels}
|
| 996 |
+
|
| 997 |
+
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 998 |
+
if isinstance(examples[0], Mapping):
|
| 999 |
+
input_ids = [e["input_ids"] for e in examples]
|
| 1000 |
+
else:
|
| 1001 |
+
input_ids = examples
|
| 1002 |
+
examples = [{"input_ids": e} for e in examples]
|
| 1003 |
+
|
| 1004 |
+
batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 1005 |
+
|
| 1006 |
+
mask_labels = []
|
| 1007 |
+
for e in examples:
|
| 1008 |
+
ref_tokens = []
|
| 1009 |
+
for id in tolist(e["input_ids"]):
|
| 1010 |
+
token = self.tokenizer._convert_id_to_token(id)
|
| 1011 |
+
ref_tokens.append(token)
|
| 1012 |
+
|
| 1013 |
+
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
|
| 1014 |
+
if "chinese_ref" in e:
|
| 1015 |
+
ref_pos = tolist(e["chinese_ref"])
|
| 1016 |
+
len_seq = len(e["input_ids"])
|
| 1017 |
+
for i in range(len_seq):
|
| 1018 |
+
if i in ref_pos:
|
| 1019 |
+
ref_tokens[i] = "##" + ref_tokens[i]
|
| 1020 |
+
mask_labels.append(self._whole_word_mask(ref_tokens))
|
| 1021 |
+
batch_mask = _numpy_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
| 1022 |
+
inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask)
|
| 1023 |
+
return {"input_ids": inputs, "labels": labels}
|
| 1024 |
+
|
| 1025 |
+
def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
|
| 1026 |
+
"""
|
| 1027 |
+
Get 0/1 labels for masked tokens with whole word mask proxy
|
| 1028 |
+
"""
|
| 1029 |
+
if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
|
| 1030 |
+
warnings.warn(
|
| 1031 |
+
"DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
|
| 1032 |
+
"Please refer to the documentation for more information."
|
| 1033 |
+
)
|
| 1034 |
+
|
| 1035 |
+
cand_indexes = []
|
| 1036 |
+
for i, token in enumerate(input_tokens):
|
| 1037 |
+
if token == "[CLS]" or token == "[SEP]":
|
| 1038 |
+
continue
|
| 1039 |
+
|
| 1040 |
+
if len(cand_indexes) >= 1 and token.startswith("##"):
|
| 1041 |
+
cand_indexes[-1].append(i)
|
| 1042 |
+
else:
|
| 1043 |
+
cand_indexes.append([i])
|
| 1044 |
+
|
| 1045 |
+
random.shuffle(cand_indexes)
|
| 1046 |
+
num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
|
| 1047 |
+
masked_lms = []
|
| 1048 |
+
covered_indexes = set()
|
| 1049 |
+
for index_set in cand_indexes:
|
| 1050 |
+
if len(masked_lms) >= num_to_predict:
|
| 1051 |
+
break
|
| 1052 |
+
# If adding a whole-word mask would exceed the maximum number of
|
| 1053 |
+
# predictions, then just skip this candidate.
|
| 1054 |
+
if len(masked_lms) + len(index_set) > num_to_predict:
|
| 1055 |
+
continue
|
| 1056 |
+
is_any_index_covered = False
|
| 1057 |
+
for index in index_set:
|
| 1058 |
+
if index in covered_indexes:
|
| 1059 |
+
is_any_index_covered = True
|
| 1060 |
+
break
|
| 1061 |
+
if is_any_index_covered:
|
| 1062 |
+
continue
|
| 1063 |
+
for index in index_set:
|
| 1064 |
+
covered_indexes.add(index)
|
| 1065 |
+
masked_lms.append(index)
|
| 1066 |
+
|
| 1067 |
+
if len(covered_indexes) != len(masked_lms):
|
| 1068 |
+
raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
|
| 1069 |
+
mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
|
| 1070 |
+
return mask_labels
|
| 1071 |
+
|
| 1072 |
+
def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
|
| 1073 |
+
"""
|
| 1074 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
|
| 1075 |
+
'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
|
| 1076 |
+
"""
|
| 1077 |
+
import torch
|
| 1078 |
+
|
| 1079 |
+
if self.tokenizer.mask_token is None:
|
| 1080 |
+
raise ValueError(
|
| 1081 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
|
| 1082 |
+
" --mlm flag if you want to use this tokenizer."
|
| 1083 |
+
)
|
| 1084 |
+
labels = inputs.clone()
|
| 1085 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
| 1086 |
+
|
| 1087 |
+
probability_matrix = mask_labels
|
| 1088 |
+
|
| 1089 |
+
special_tokens_mask = [
|
| 1090 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
| 1091 |
+
]
|
| 1092 |
+
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
| 1093 |
+
if self.tokenizer.pad_token is not None:
|
| 1094 |
+
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
| 1095 |
+
probability_matrix.masked_fill_(padding_mask, value=0.0)
|
| 1096 |
+
|
| 1097 |
+
masked_indices = probability_matrix.bool()
|
| 1098 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 1099 |
+
|
| 1100 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 1101 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
| 1102 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
| 1103 |
+
|
| 1104 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 1105 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
| 1106 |
+
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
| 1107 |
+
inputs[indices_random] = random_words[indices_random]
|
| 1108 |
+
|
| 1109 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
| 1110 |
+
return inputs, labels
|
| 1111 |
+
|
| 1112 |
+
def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
|
| 1113 |
+
"""
|
| 1114 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
|
| 1115 |
+
'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
|
| 1116 |
+
"""
|
| 1117 |
+
import tensorflow as tf
|
| 1118 |
+
|
| 1119 |
+
input_shape = tf.shape(inputs)
|
| 1120 |
+
if self.tokenizer.mask_token is None:
|
| 1121 |
+
raise ValueError(
|
| 1122 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
|
| 1123 |
+
" --mlm flag if you want to use this tokenizer."
|
| 1124 |
+
)
|
| 1125 |
+
labels = tf.identity(inputs)
|
| 1126 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
| 1127 |
+
|
| 1128 |
+
masked_indices = tf.cast(mask_labels, tf.bool)
|
| 1129 |
+
|
| 1130 |
+
special_tokens_mask = [
|
| 1131 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels
|
| 1132 |
+
]
|
| 1133 |
+
masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool)
|
| 1134 |
+
if self.tokenizer.pad_token is not None:
|
| 1135 |
+
padding_mask = inputs == self.tokenizer.pad_token_id
|
| 1136 |
+
masked_indices = masked_indices & ~padding_mask
|
| 1137 |
+
|
| 1138 |
+
# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
|
| 1139 |
+
labels = tf.where(masked_indices, inputs, -100)
|
| 1140 |
+
|
| 1141 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 1142 |
+
indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
|
| 1143 |
+
|
| 1144 |
+
inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs)
|
| 1145 |
+
|
| 1146 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 1147 |
+
indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
|
| 1148 |
+
random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
|
| 1149 |
+
inputs = tf.where(indices_random, random_words, inputs)
|
| 1150 |
+
|
| 1151 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
| 1152 |
+
return inputs, labels
|
| 1153 |
+
|
| 1154 |
+
def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
|
| 1155 |
+
"""
|
| 1156 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
|
| 1157 |
+
'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
|
| 1158 |
+
"""
|
| 1159 |
+
if self.tokenizer.mask_token is None:
|
| 1160 |
+
raise ValueError(
|
| 1161 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
|
| 1162 |
+
" --mlm flag if you want to use this tokenizer."
|
| 1163 |
+
)
|
| 1164 |
+
labels = np.copy(inputs)
|
| 1165 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
| 1166 |
+
|
| 1167 |
+
masked_indices = mask_labels.astype(bool)
|
| 1168 |
+
|
| 1169 |
+
special_tokens_mask = [
|
| 1170 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
| 1171 |
+
]
|
| 1172 |
+
masked_indices[np.array(special_tokens_mask, dtype=bool)] = 0
|
| 1173 |
+
if self.tokenizer.pad_token is not None:
|
| 1174 |
+
padding_mask = labels == self.tokenizer.pad_token_id
|
| 1175 |
+
masked_indices[padding_mask] = 0
|
| 1176 |
+
|
| 1177 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 1178 |
+
|
| 1179 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 1180 |
+
indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
|
| 1181 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
| 1182 |
+
|
| 1183 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 1184 |
+
# indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
| 1185 |
+
indices_random = (
|
| 1186 |
+
np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
|
| 1187 |
+
)
|
| 1188 |
+
random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
|
| 1189 |
+
inputs[indices_random] = random_words[indices_random]
|
| 1190 |
+
|
| 1191 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
| 1192 |
+
return inputs, labels
|
| 1193 |
+
|
| 1194 |
+
|
| 1195 |
+
@dataclass
|
| 1196 |
+
class DataCollatorForSOP(DataCollatorForLanguageModeling):
|
| 1197 |
+
"""
|
| 1198 |
+
Data collator used for sentence order prediction task.
|
| 1199 |
+
|
| 1200 |
+
- collates batches of tensors, honoring their tokenizer's pad_token
|
| 1201 |
+
- preprocesses batches for both masked language modeling and sentence order prediction
|
| 1202 |
+
"""
|
| 1203 |
+
|
| 1204 |
+
def __init__(self, *args, **kwargs):
|
| 1205 |
+
warnings.warn(
|
| 1206 |
+
"DataCollatorForSOP is deprecated and will be removed in a future version, you can now use "
|
| 1207 |
+
"DataCollatorForLanguageModeling instead.",
|
| 1208 |
+
FutureWarning,
|
| 1209 |
+
)
|
| 1210 |
+
|
| 1211 |
+
def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 1212 |
+
import torch
|
| 1213 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 1214 |
+
|
| 1215 |
+
input_ids = [example["input_ids"] for example in examples]
|
| 1216 |
+
input_ids = _torch_collate_batch(input_ids, self.tokenizer)
|
| 1217 |
+
input_ids, labels, attention_mask = self.mask_tokens(input_ids)
|
| 1218 |
+
|
| 1219 |
+
token_type_ids = [example["token_type_ids"] for example in examples]
|
| 1220 |
+
# size of segment_ids varied because randomness, padding zero to the end as the original implementation
|
| 1221 |
+
token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
| 1222 |
+
|
| 1223 |
+
sop_label_list = [example["sentence_order_label"] for example in examples]
|
| 1224 |
+
sentence_order_label = torch.stack(sop_label_list)
|
| 1225 |
+
|
| 1226 |
+
return {
|
| 1227 |
+
"input_ids": input_ids,
|
| 1228 |
+
"labels": labels,
|
| 1229 |
+
"attention_mask": attention_mask,
|
| 1230 |
+
"token_type_ids": token_type_ids,
|
| 1231 |
+
"sentence_order_label": sentence_order_label,
|
| 1232 |
+
}
|
| 1233 |
+
|
| 1234 |
+
def mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any]:
|
| 1235 |
+
"""
|
| 1236 |
+
Prepare masked tokens inputs/labels/attention_mask for masked language modeling: 80% MASK, 10% random, 10%
|
| 1237 |
+
original. N-gram not applied yet.
|
| 1238 |
+
"""
|
| 1239 |
+
import torch
|
| 1240 |
+
|
| 1241 |
+
if self.tokenizer.mask_token is None:
|
| 1242 |
+
raise ValueError(
|
| 1243 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
|
| 1244 |
+
" --mlm flag if you want to use this tokenizer."
|
| 1245 |
+
)
|
| 1246 |
+
|
| 1247 |
+
labels = inputs.clone()
|
| 1248 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
| 1249 |
+
probability_matrix = torch.full(labels.shape, self.mlm_probability)
|
| 1250 |
+
special_tokens_mask = [
|
| 1251 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
| 1252 |
+
]
|
| 1253 |
+
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
| 1254 |
+
if self.tokenizer.pad_token is not None:
|
| 1255 |
+
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
| 1256 |
+
probability_matrix.masked_fill_(padding_mask, value=0.0)
|
| 1257 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
| 1258 |
+
# probability be `1` (masked), however in albert model attention mask `0` means masked, revert the value
|
| 1259 |
+
attention_mask = (~masked_indices).float()
|
| 1260 |
+
if self.tokenizer.pad_token is not None:
|
| 1261 |
+
attention_padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
| 1262 |
+
attention_mask.masked_fill_(attention_padding_mask, value=1.0)
|
| 1263 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens, -100 is default for CE compute
|
| 1264 |
+
|
| 1265 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 1266 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
| 1267 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
| 1268 |
+
|
| 1269 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 1270 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
| 1271 |
+
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
| 1272 |
+
inputs[indices_random] = random_words[indices_random]
|
| 1273 |
+
|
| 1274 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
| 1275 |
+
return inputs, labels, attention_mask
|
| 1276 |
+
|
| 1277 |
+
|
| 1278 |
+
@dataclass
|
| 1279 |
+
class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
|
| 1280 |
+
"""
|
| 1281 |
+
Data collator used for permutation language modeling.
|
| 1282 |
+
|
| 1283 |
+
- collates batches of tensors, honoring their tokenizer's pad_token
|
| 1284 |
+
- preprocesses batches for permutation language modeling with procedures specific to XLNet
|
| 1285 |
+
"""
|
| 1286 |
+
|
| 1287 |
+
tokenizer: PreTrainedTokenizerBase
|
| 1288 |
+
plm_probability: float = 1 / 6
|
| 1289 |
+
max_span_length: int = 5 # maximum length of a span of masked tokens
|
| 1290 |
+
return_tensors: str = "pt"
|
| 1291 |
+
|
| 1292 |
+
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 1293 |
+
if isinstance(examples[0], Mapping):
|
| 1294 |
+
examples = [e["input_ids"] for e in examples]
|
| 1295 |
+
batch = _torch_collate_batch(examples, self.tokenizer)
|
| 1296 |
+
inputs, perm_mask, target_mapping, labels = self.torch_mask_tokens(batch)
|
| 1297 |
+
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
| 1298 |
+
|
| 1299 |
+
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 1300 |
+
if isinstance(examples[0], Mapping):
|
| 1301 |
+
examples = [e["input_ids"] for e in examples]
|
| 1302 |
+
batch = _tf_collate_batch(examples, self.tokenizer)
|
| 1303 |
+
inputs, perm_mask, target_mapping, labels = self.tf_mask_tokens(batch)
|
| 1304 |
+
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
| 1305 |
+
|
| 1306 |
+
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
| 1307 |
+
if isinstance(examples[0], Mapping):
|
| 1308 |
+
examples = [e["input_ids"] for e in examples]
|
| 1309 |
+
batch = _numpy_collate_batch(examples, self.tokenizer)
|
| 1310 |
+
inputs, perm_mask, target_mapping, labels = self.numpy_mask_tokens(batch)
|
| 1311 |
+
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
| 1312 |
+
|
| 1313 |
+
def torch_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
|
| 1314 |
+
"""
|
| 1315 |
+
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
|
| 1316 |
+
|
| 1317 |
+
0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
| 1318 |
+
1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
| 1319 |
+
2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
|
| 1320 |
+
masked
|
| 1321 |
+
3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
|
| 1322 |
+
span_length]` and mask tokens `start_index:start_index + span_length`
|
| 1323 |
+
4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
|
| 1324 |
+
sequence to be processed), repeat from Step 1.
|
| 1325 |
+
"""
|
| 1326 |
+
import torch
|
| 1327 |
+
|
| 1328 |
+
if self.tokenizer.mask_token is None:
|
| 1329 |
+
raise ValueError(
|
| 1330 |
+
"This tokenizer does not have a mask token which is necessary for permutation language modeling."
|
| 1331 |
+
" Please add a mask token if you want to use this tokenizer."
|
| 1332 |
+
)
|
| 1333 |
+
|
| 1334 |
+
if inputs.size(1) % 2 != 0:
|
| 1335 |
+
raise ValueError(
|
| 1336 |
+
"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
|
| 1337 |
+
" relevant comments in source code for details."
|
| 1338 |
+
)
|
| 1339 |
+
|
| 1340 |
+
labels = inputs.clone()
|
| 1341 |
+
# Creating the mask and target_mapping tensors
|
| 1342 |
+
masked_indices = torch.full(labels.shape, 0, dtype=torch.bool)
|
| 1343 |
+
target_mapping = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
|
| 1344 |
+
|
| 1345 |
+
for i in range(labels.size(0)):
|
| 1346 |
+
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
| 1347 |
+
cur_len = 0
|
| 1348 |
+
max_len = labels.size(1)
|
| 1349 |
+
|
| 1350 |
+
while cur_len < max_len:
|
| 1351 |
+
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
| 1352 |
+
span_length = torch.randint(1, self.max_span_length + 1, (1,)).item()
|
| 1353 |
+
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
|
| 1354 |
+
context_length = int(span_length / self.plm_probability)
|
| 1355 |
+
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
|
| 1356 |
+
start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item()
|
| 1357 |
+
masked_indices[i, start_index : start_index + span_length] = 1
|
| 1358 |
+
# Set `cur_len = cur_len + context_length`
|
| 1359 |
+
cur_len += context_length
|
| 1360 |
+
|
| 1361 |
+
# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
|
| 1362 |
+
# the i-th predict corresponds to the i-th token.
|
| 1363 |
+
target_mapping[i] = torch.eye(labels.size(1))
|
| 1364 |
+
|
| 1365 |
+
special_tokens_mask = torch.tensor(
|
| 1366 |
+
[self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
|
| 1367 |
+
dtype=torch.bool,
|
| 1368 |
+
)
|
| 1369 |
+
masked_indices.masked_fill_(special_tokens_mask, value=0.0)
|
| 1370 |
+
if self.tokenizer.pad_token is not None:
|
| 1371 |
+
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
| 1372 |
+
masked_indices.masked_fill_(padding_mask, value=0.0)
|
| 1373 |
+
|
| 1374 |
+
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
|
| 1375 |
+
non_func_mask = ~(padding_mask | special_tokens_mask)
|
| 1376 |
+
|
| 1377 |
+
inputs[masked_indices] = self.tokenizer.mask_token_id
|
| 1378 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 1379 |
+
|
| 1380 |
+
perm_mask = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
|
| 1381 |
+
|
| 1382 |
+
for i in range(labels.size(0)):
|
| 1383 |
+
# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
|
| 1384 |
+
# determine which tokens a given token can attend to (encoded in `perm_mask`).
|
| 1385 |
+
# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
|
| 1386 |
+
# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
|
| 1387 |
+
# we assume that reused length is half of sequence length and permutation length is equal to reused length.
|
| 1388 |
+
# This requires that the sequence length be even.
|
| 1389 |
+
|
| 1390 |
+
# Create a linear factorisation order
|
| 1391 |
+
perm_index = torch.arange(labels.size(1))
|
| 1392 |
+
# Split this into two halves, assuming that half the sequence is reused each time
|
| 1393 |
+
perm_index = perm_index.reshape((-1, labels.size(1) // 2)).transpose(0, 1)
|
| 1394 |
+
# Permute the two halves such that they do not cross over
|
| 1395 |
+
perm_index = perm_index[torch.randperm(labels.size(1) // 2)]
|
| 1396 |
+
# Flatten this out into the desired permuted factorisation order
|
| 1397 |
+
perm_index = torch.flatten(perm_index.transpose(0, 1))
|
| 1398 |
+
# Set the permutation indices of non-masked (non-functional) tokens to the
|
| 1399 |
+
# smallest index (-1) so that:
|
| 1400 |
+
# (1) They can be seen by all other positions
|
| 1401 |
+
# (2) They cannot see masked positions, so there won't be information leak
|
| 1402 |
+
perm_index.masked_fill_(~masked_indices[i] & non_func_mask[i], -1)
|
| 1403 |
+
# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
|
| 1404 |
+
# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
|
| 1405 |
+
# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
|
| 1406 |
+
perm_mask[i] = (
|
| 1407 |
+
perm_index.reshape((labels.size(1), 1)) <= perm_index.reshape((1, labels.size(1)))
|
| 1408 |
+
) & masked_indices[i]
|
| 1409 |
+
|
| 1410 |
+
return inputs.long(), perm_mask, target_mapping, labels.long()
|
| 1411 |
+
|
| 1412 |
+
def tf_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
|
| 1413 |
+
"""
|
| 1414 |
+
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
|
| 1415 |
+
|
| 1416 |
+
0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
| 1417 |
+
1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
| 1418 |
+
2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
|
| 1419 |
+
masked
|
| 1420 |
+
3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
|
| 1421 |
+
span_length]` and mask tokens `start_index:start_index + span_length`
|
| 1422 |
+
4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
|
| 1423 |
+
sequence to be processed), repeat from Step 1.
|
| 1424 |
+
"""
|
| 1425 |
+
import tensorflow as tf
|
| 1426 |
+
|
| 1427 |
+
if self.tokenizer.mask_token is None:
|
| 1428 |
+
raise ValueError(
|
| 1429 |
+
"This tokenizer does not have a mask token which is necessary for permutation language modeling."
|
| 1430 |
+
" Please add a mask token if you want to use this tokenizer."
|
| 1431 |
+
)
|
| 1432 |
+
|
| 1433 |
+
if tf.shape(inputs)[1] % 2 != 0:
|
| 1434 |
+
raise ValueError(
|
| 1435 |
+
"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
|
| 1436 |
+
" relevant comments in source code for details."
|
| 1437 |
+
)
|
| 1438 |
+
|
| 1439 |
+
labels = tf.identity(inputs)
|
| 1440 |
+
# Creating the mask and target_mapping tensors
|
| 1441 |
+
masked_indices = np.full(labels.shape.as_list(), 0, dtype=bool)
|
| 1442 |
+
labels_shape = tf.shape(labels)
|
| 1443 |
+
target_mapping = np.zeros((labels_shape[0], labels_shape[1], labels_shape[1]), dtype=np.float32)
|
| 1444 |
+
|
| 1445 |
+
for i in range(len(labels)):
|
| 1446 |
+
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
| 1447 |
+
cur_len = 0
|
| 1448 |
+
max_len = tf.shape(labels)[1]
|
| 1449 |
+
|
| 1450 |
+
while cur_len < max_len:
|
| 1451 |
+
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
| 1452 |
+
span_length = randint(1, self.max_span_length + 1)
|
| 1453 |
+
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
|
| 1454 |
+
context_length = int(span_length / self.plm_probability)
|
| 1455 |
+
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
|
| 1456 |
+
start_index = cur_len + randint(0, context_length - span_length + 1)
|
| 1457 |
+
masked_indices[i, start_index : start_index + span_length] = 1
|
| 1458 |
+
# Set `cur_len = cur_len + context_length`
|
| 1459 |
+
cur_len += context_length
|
| 1460 |
+
|
| 1461 |
+
# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
|
| 1462 |
+
# the i-th predict corresponds to the i-th token.
|
| 1463 |
+
target_mapping[i] = np.eye(labels_shape[1])
|
| 1464 |
+
masked_indices = tf.cast(tf.convert_to_tensor(masked_indices), dtype=tf.bool)
|
| 1465 |
+
target_mapping = tf.convert_to_tensor(target_mapping)
|
| 1466 |
+
special_tokens_mask = tf.convert_to_tensor(
|
| 1467 |
+
[
|
| 1468 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
|
| 1469 |
+
for val in labels.numpy().tolist()
|
| 1470 |
+
],
|
| 1471 |
+
)
|
| 1472 |
+
special_tokens_mask = tf.cast(special_tokens_mask, dtype=tf.bool)
|
| 1473 |
+
masked_indices = masked_indices & ~special_tokens_mask
|
| 1474 |
+
if self.tokenizer.pad_token is not None:
|
| 1475 |
+
padding_mask = labels == self.tokenizer.pad_token_id
|
| 1476 |
+
masked_indices = masked_indices & ~padding_mask
|
| 1477 |
+
|
| 1478 |
+
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
|
| 1479 |
+
non_func_mask = ~(padding_mask | special_tokens_mask)
|
| 1480 |
+
|
| 1481 |
+
inputs = tf.where(masked_indices, self.tokenizer.mask_token_id, inputs)
|
| 1482 |
+
labels = tf.where(masked_indices, labels, -100) # We only compute loss on masked tokens
|
| 1483 |
+
|
| 1484 |
+
perm_mask = []
|
| 1485 |
+
|
| 1486 |
+
for i in range(len(labels)):
|
| 1487 |
+
# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
|
| 1488 |
+
# determine which tokens a given token can attend to (encoded in `perm_mask`).
|
| 1489 |
+
# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
|
| 1490 |
+
# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
|
| 1491 |
+
# we assume that reused length is half of sequence length and permutation length is equal to reused length.
|
| 1492 |
+
# This requires that the sequence length be even.
|
| 1493 |
+
|
| 1494 |
+
# Create a linear factorisation order
|
| 1495 |
+
# tf.range is the equivalent of torch.arange
|
| 1496 |
+
perm_index = tf.range(labels_shape[1])
|
| 1497 |
+
# Split this into two halves, assuming that half the sequence is reused each time
|
| 1498 |
+
perm_index = tf.transpose(tf.reshape(perm_index, (-1, labels_shape[1] // 2)))
|
| 1499 |
+
# Permute the two halves such that they do not cross over
|
| 1500 |
+
perm_index = tf.random.shuffle(perm_index) # Shuffles along the first dimension
|
| 1501 |
+
# Flatten this out into the desired permuted factorisation order
|
| 1502 |
+
perm_index = tf.reshape(tf.transpose(perm_index), (-1,))
|
| 1503 |
+
# Set the permutation indices of non-masked (non-functional) tokens to the
|
| 1504 |
+
# smallest index (-1) so that:
|
| 1505 |
+
# (1) They can be seen by all other positions
|
| 1506 |
+
# (2) They cannot see masked positions, so there won't be information leak
|
| 1507 |
+
perm_index = tf.where(~masked_indices[i] & non_func_mask[i], -1, perm_index)
|
| 1508 |
+
# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
|
| 1509 |
+
# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
|
| 1510 |
+
# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
|
| 1511 |
+
perm_mask.append(
|
| 1512 |
+
(tf.reshape(perm_index, (labels_shape[1], 1)) <= tf.reshape(perm_index, (1, labels_shape[1])))
|
| 1513 |
+
& masked_indices[i]
|
| 1514 |
+
)
|
| 1515 |
+
perm_mask = tf.stack(perm_mask, axis=0)
|
| 1516 |
+
|
| 1517 |
+
return tf.cast(inputs, tf.int64), tf.cast(perm_mask, tf.float32), target_mapping, tf.cast(labels, tf.int64)
|
| 1518 |
+
|
| 1519 |
+
def numpy_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
|
| 1520 |
+
"""
|
| 1521 |
+
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
|
| 1522 |
+
|
| 1523 |
+
0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
| 1524 |
+
1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
| 1525 |
+
2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
|
| 1526 |
+
masked
|
| 1527 |
+
3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
|
| 1528 |
+
span_length]` and mask tokens `start_index:start_index + span_length`
|
| 1529 |
+
4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
|
| 1530 |
+
sequence to be processed), repeat from Step 1.
|
| 1531 |
+
"""
|
| 1532 |
+
if self.tokenizer.mask_token is None:
|
| 1533 |
+
raise ValueError(
|
| 1534 |
+
"This tokenizer does not have a mask token which is necessary for permutation language modeling."
|
| 1535 |
+
" Please add a mask token if you want to use this tokenizer."
|
| 1536 |
+
)
|
| 1537 |
+
|
| 1538 |
+
if inputs.shape[1] % 2 != 0:
|
| 1539 |
+
raise ValueError(
|
| 1540 |
+
"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
|
| 1541 |
+
" relevant comments in source code for details."
|
| 1542 |
+
)
|
| 1543 |
+
|
| 1544 |
+
labels = np.copy(inputs)
|
| 1545 |
+
# Creating the mask and target_mapping tensors
|
| 1546 |
+
masked_indices = np.full(labels.shape, 0, dtype=bool)
|
| 1547 |
+
target_mapping = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)
|
| 1548 |
+
|
| 1549 |
+
for i in range(labels.shape[0]):
|
| 1550 |
+
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
| 1551 |
+
cur_len = 0
|
| 1552 |
+
max_len = labels.shape[1]
|
| 1553 |
+
|
| 1554 |
+
while cur_len < max_len:
|
| 1555 |
+
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
| 1556 |
+
span_length = randint(1, self.max_span_length + 1)
|
| 1557 |
+
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
|
| 1558 |
+
context_length = int(span_length / self.plm_probability)
|
| 1559 |
+
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
|
| 1560 |
+
start_index = cur_len + randint(0, context_length - span_length + 1)
|
| 1561 |
+
masked_indices[i, start_index : start_index + span_length] = 1
|
| 1562 |
+
# Set `cur_len = cur_len + context_length`
|
| 1563 |
+
cur_len += context_length
|
| 1564 |
+
|
| 1565 |
+
# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
|
| 1566 |
+
# the i-th predict corresponds to the i-th token.
|
| 1567 |
+
target_mapping[i] = np.eye(labels.shape[1])
|
| 1568 |
+
|
| 1569 |
+
special_tokens_mask = np.array(
|
| 1570 |
+
[self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
|
| 1571 |
+
dtype=bool,
|
| 1572 |
+
)
|
| 1573 |
+
masked_indices[special_tokens_mask] = 0
|
| 1574 |
+
if self.tokenizer.pad_token is not None:
|
| 1575 |
+
padding_mask = labels == self.tokenizer.pad_token_id
|
| 1576 |
+
masked_indices[padding_mask] = 0.0
|
| 1577 |
+
|
| 1578 |
+
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
|
| 1579 |
+
non_func_mask = ~(padding_mask | special_tokens_mask)
|
| 1580 |
+
|
| 1581 |
+
inputs[masked_indices] = self.tokenizer.mask_token_id
|
| 1582 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 1583 |
+
|
| 1584 |
+
perm_mask = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)
|
| 1585 |
+
|
| 1586 |
+
for i in range(labels.shape[0]):
|
| 1587 |
+
# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
|
| 1588 |
+
# determine which tokens a given token can attend to (encoded in `perm_mask`).
|
| 1589 |
+
# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
|
| 1590 |
+
# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
|
| 1591 |
+
# we assume that reused length is half of sequence length and permutation length is equal to reused length.
|
| 1592 |
+
# This requires that the sequence length be even.
|
| 1593 |
+
|
| 1594 |
+
# Create a linear factorisation order
|
| 1595 |
+
perm_index = np.arange(labels.shape[1])
|
| 1596 |
+
# Split this into two halves, assuming that half the sequence is reused each time
|
| 1597 |
+
perm_index = perm_index.reshape((-1, labels.shape[1] // 2)).T
|
| 1598 |
+
# Permute the two halves such that they do not cross over
|
| 1599 |
+
np.random.shuffle(perm_index)
|
| 1600 |
+
# Flatten this out into the desired permuted factorisation order
|
| 1601 |
+
perm_index = perm_index.T.flatten()
|
| 1602 |
+
# Set the permutation indices of non-masked (non-functional) tokens to the
|
| 1603 |
+
# smallest index (-1) so that:
|
| 1604 |
+
# (1) They can be seen by all other positions
|
| 1605 |
+
# (2) They cannot see masked positions, so there won't be information leak
|
| 1606 |
+
perm_index[~masked_indices[i] & non_func_mask[i]] = -1
|
| 1607 |
+
# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
|
| 1608 |
+
# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
|
| 1609 |
+
# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
|
| 1610 |
+
perm_mask[i] = (
|
| 1611 |
+
perm_index.reshape((labels.shape[1], 1)) <= perm_index.reshape((1, labels.shape[1]))
|
| 1612 |
+
) & masked_indices[i]
|
| 1613 |
+
|
| 1614 |
+
return inputs.astype(np.int64), perm_mask, target_mapping, labels.astype(np.int64)
|
| 1615 |
+
|
| 1616 |
+
|
| 1617 |
+
@dataclass
|
| 1618 |
+
class DataCollatorWithFlattening(DefaultDataCollator):
|
| 1619 |
+
"""
|
| 1620 |
+
Data collator used for padding free approach. Does the following:
|
| 1621 |
+
|
| 1622 |
+
- concatate the entire mini batch into single long sequence [1, total_tokens]
|
| 1623 |
+
- uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100
|
| 1624 |
+
- no padding will be added, returns `input_ids`, `labels` and `position_ids`
|
| 1625 |
+
"""
|
| 1626 |
+
|
| 1627 |
+
def __init__(self, *args, return_position_ids=True, separator_id=-100, **kwargs):
|
| 1628 |
+
super().__init__(*args, **kwargs)
|
| 1629 |
+
self.return_position_ids = return_position_ids
|
| 1630 |
+
self.separator_id = separator_id
|
| 1631 |
+
warnings.warn(
|
| 1632 |
+
"Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence."
|
| 1633 |
+
"Make sure your attention computation is able to handle it!"
|
| 1634 |
+
)
|
| 1635 |
+
|
| 1636 |
+
def __call__(self, features, return_tensors=None, separator_id=None):
|
| 1637 |
+
if return_tensors is None:
|
| 1638 |
+
return_tensors = self.return_tensors
|
| 1639 |
+
if separator_id is None:
|
| 1640 |
+
separator_id = self.separator_id
|
| 1641 |
+
is_labels_provided = "labels" in features[0]
|
| 1642 |
+
ret = {"input_ids": [], "labels": []}
|
| 1643 |
+
if self.return_position_ids:
|
| 1644 |
+
ret.update({"position_ids": []})
|
| 1645 |
+
for idx in range(0, len(features)):
|
| 1646 |
+
ret["input_ids"] += features[idx]["input_ids"]
|
| 1647 |
+
if is_labels_provided:
|
| 1648 |
+
ret["labels"] += [separator_id] + features[idx]["labels"][1:]
|
| 1649 |
+
else:
|
| 1650 |
+
ret["labels"] += [separator_id] + features[idx]["input_ids"][1:]
|
| 1651 |
+
if self.return_position_ids:
|
| 1652 |
+
ret["position_ids"] += list(range(len(features[idx]["input_ids"])))
|
| 1653 |
+
return default_data_collator([ret], return_tensors)
|
.venv/Lib/site-packages/transformers/data/datasets/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .glue import GlueDataset, GlueDataTrainingArguments
|
| 16 |
+
from .language_modeling import (
|
| 17 |
+
LineByLineTextDataset,
|
| 18 |
+
LineByLineWithRefDataset,
|
| 19 |
+
LineByLineWithSOPTextDataset,
|
| 20 |
+
TextDataset,
|
| 21 |
+
TextDatasetForNextSentencePrediction,
|
| 22 |
+
)
|
| 23 |
+
from .squad import SquadDataset, SquadDataTrainingArguments
|
.venv/Lib/site-packages/transformers/data/datasets/glue.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import time
|
| 17 |
+
import warnings
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
from enum import Enum
|
| 20 |
+
from typing import List, Optional, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from filelock import FileLock
|
| 24 |
+
from torch.utils.data import Dataset
|
| 25 |
+
|
| 26 |
+
from ...tokenization_utils_base import PreTrainedTokenizerBase
|
| 27 |
+
from ...utils import logging
|
| 28 |
+
from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
|
| 29 |
+
from ..processors.utils import InputFeatures
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class GlueDataTrainingArguments:
|
| 37 |
+
"""
|
| 38 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 39 |
+
|
| 40 |
+
Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command
|
| 41 |
+
line.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
|
| 45 |
+
data_dir: str = field(
|
| 46 |
+
metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
|
| 47 |
+
)
|
| 48 |
+
max_seq_length: int = field(
|
| 49 |
+
default=128,
|
| 50 |
+
metadata={
|
| 51 |
+
"help": (
|
| 52 |
+
"The maximum total input sequence length after tokenization. Sequences longer "
|
| 53 |
+
"than this will be truncated, sequences shorter will be padded."
|
| 54 |
+
)
|
| 55 |
+
},
|
| 56 |
+
)
|
| 57 |
+
overwrite_cache: bool = field(
|
| 58 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def __post_init__(self):
|
| 62 |
+
self.task_name = self.task_name.lower()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Split(Enum):
|
| 66 |
+
train = "train"
|
| 67 |
+
dev = "dev"
|
| 68 |
+
test = "test"
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class GlueDataset(Dataset):
|
| 72 |
+
"""
|
| 73 |
+
This will be superseded by a framework-agnostic approach soon.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
args: GlueDataTrainingArguments
|
| 77 |
+
output_mode: str
|
| 78 |
+
features: List[InputFeatures]
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
args: GlueDataTrainingArguments,
|
| 83 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 84 |
+
limit_length: Optional[int] = None,
|
| 85 |
+
mode: Union[str, Split] = Split.train,
|
| 86 |
+
cache_dir: Optional[str] = None,
|
| 87 |
+
):
|
| 88 |
+
warnings.warn(
|
| 89 |
+
"This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
|
| 90 |
+
"library. You can have a look at this example script for pointers: "
|
| 91 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py",
|
| 92 |
+
FutureWarning,
|
| 93 |
+
)
|
| 94 |
+
self.args = args
|
| 95 |
+
self.processor = glue_processors[args.task_name]()
|
| 96 |
+
self.output_mode = glue_output_modes[args.task_name]
|
| 97 |
+
if isinstance(mode, str):
|
| 98 |
+
try:
|
| 99 |
+
mode = Split[mode]
|
| 100 |
+
except KeyError:
|
| 101 |
+
raise KeyError("mode is not a valid split name")
|
| 102 |
+
# Load data features from cache or dataset file
|
| 103 |
+
cached_features_file = os.path.join(
|
| 104 |
+
cache_dir if cache_dir is not None else args.data_dir,
|
| 105 |
+
f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{args.task_name}",
|
| 106 |
+
)
|
| 107 |
+
label_list = self.processor.get_labels()
|
| 108 |
+
if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__.__name__ in (
|
| 109 |
+
"RobertaTokenizer",
|
| 110 |
+
"RobertaTokenizerFast",
|
| 111 |
+
"XLMRobertaTokenizer",
|
| 112 |
+
"BartTokenizer",
|
| 113 |
+
"BartTokenizerFast",
|
| 114 |
+
):
|
| 115 |
+
# HACK(label indices are swapped in RoBERTa pretrained model)
|
| 116 |
+
label_list[1], label_list[2] = label_list[2], label_list[1]
|
| 117 |
+
self.label_list = label_list
|
| 118 |
+
|
| 119 |
+
# Make sure only the first process in distributed training processes the dataset,
|
| 120 |
+
# and the others will use the cache.
|
| 121 |
+
lock_path = cached_features_file + ".lock"
|
| 122 |
+
with FileLock(lock_path):
|
| 123 |
+
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
| 124 |
+
start = time.time()
|
| 125 |
+
self.features = torch.load(cached_features_file)
|
| 126 |
+
logger.info(
|
| 127 |
+
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
logger.info(f"Creating features from dataset file at {args.data_dir}")
|
| 131 |
+
|
| 132 |
+
if mode == Split.dev:
|
| 133 |
+
examples = self.processor.get_dev_examples(args.data_dir)
|
| 134 |
+
elif mode == Split.test:
|
| 135 |
+
examples = self.processor.get_test_examples(args.data_dir)
|
| 136 |
+
else:
|
| 137 |
+
examples = self.processor.get_train_examples(args.data_dir)
|
| 138 |
+
if limit_length is not None:
|
| 139 |
+
examples = examples[:limit_length]
|
| 140 |
+
self.features = glue_convert_examples_to_features(
|
| 141 |
+
examples,
|
| 142 |
+
tokenizer,
|
| 143 |
+
max_length=args.max_seq_length,
|
| 144 |
+
label_list=label_list,
|
| 145 |
+
output_mode=self.output_mode,
|
| 146 |
+
)
|
| 147 |
+
start = time.time()
|
| 148 |
+
torch.save(self.features, cached_features_file)
|
| 149 |
+
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
|
| 150 |
+
logger.info(
|
| 151 |
+
f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def __len__(self):
|
| 155 |
+
return len(self.features)
|
| 156 |
+
|
| 157 |
+
def __getitem__(self, i) -> InputFeatures:
|
| 158 |
+
return self.features[i]
|
| 159 |
+
|
| 160 |
+
def get_labels(self):
|
| 161 |
+
return self.label_list
|
.venv/Lib/site-packages/transformers/data/datasets/language_modeling.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import pickle
|
| 18 |
+
import random
|
| 19 |
+
import time
|
| 20 |
+
import warnings
|
| 21 |
+
from typing import Dict, List, Optional
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from filelock import FileLock
|
| 25 |
+
from torch.utils.data import Dataset
|
| 26 |
+
|
| 27 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
| 28 |
+
from ...utils import logging
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
DEPRECATION_WARNING = (
|
| 35 |
+
"This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
|
| 36 |
+
"library. You can have a look at this example script for pointers: {0}"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class TextDataset(Dataset):
|
| 41 |
+
"""
|
| 42 |
+
This will be superseded by a framework-agnostic approach soon.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
tokenizer: PreTrainedTokenizer,
|
| 48 |
+
file_path: str,
|
| 49 |
+
block_size: int,
|
| 50 |
+
overwrite_cache=False,
|
| 51 |
+
cache_dir: Optional[str] = None,
|
| 52 |
+
):
|
| 53 |
+
warnings.warn(
|
| 54 |
+
DEPRECATION_WARNING.format(
|
| 55 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
|
| 56 |
+
),
|
| 57 |
+
FutureWarning,
|
| 58 |
+
)
|
| 59 |
+
if os.path.isfile(file_path) is False:
|
| 60 |
+
raise ValueError(f"Input file path {file_path} not found")
|
| 61 |
+
|
| 62 |
+
block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
|
| 63 |
+
|
| 64 |
+
directory, filename = os.path.split(file_path)
|
| 65 |
+
cached_features_file = os.path.join(
|
| 66 |
+
cache_dir if cache_dir is not None else directory,
|
| 67 |
+
f"cached_lm_{tokenizer.__class__.__name__}_{block_size}_{filename}",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Make sure only the first process in distributed training processes the dataset,
|
| 71 |
+
# and the others will use the cache.
|
| 72 |
+
lock_path = cached_features_file + ".lock"
|
| 73 |
+
with FileLock(lock_path):
|
| 74 |
+
if os.path.exists(cached_features_file) and not overwrite_cache:
|
| 75 |
+
start = time.time()
|
| 76 |
+
with open(cached_features_file, "rb") as handle:
|
| 77 |
+
self.examples = pickle.load(handle)
|
| 78 |
+
logger.info(
|
| 79 |
+
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
else:
|
| 83 |
+
logger.info(f"Creating features from dataset file at {directory}")
|
| 84 |
+
|
| 85 |
+
self.examples = []
|
| 86 |
+
with open(file_path, encoding="utf-8") as f:
|
| 87 |
+
text = f.read()
|
| 88 |
+
|
| 89 |
+
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
| 90 |
+
|
| 91 |
+
for i in range(0, len(tokenized_text) - block_size + 1, block_size): # Truncate in block of block_size
|
| 92 |
+
self.examples.append(
|
| 93 |
+
tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size])
|
| 94 |
+
)
|
| 95 |
+
# Note that we are losing the last truncated example here for the sake of simplicity (no padding)
|
| 96 |
+
# If your dataset is small, first you should look for a bigger one :-) and second you
|
| 97 |
+
# can change this behavior by adding (model specific) padding.
|
| 98 |
+
|
| 99 |
+
start = time.time()
|
| 100 |
+
with open(cached_features_file, "wb") as handle:
|
| 101 |
+
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
| 102 |
+
logger.info(
|
| 103 |
+
f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def __len__(self):
|
| 107 |
+
return len(self.examples)
|
| 108 |
+
|
| 109 |
+
def __getitem__(self, i) -> torch.Tensor:
|
| 110 |
+
return torch.tensor(self.examples[i], dtype=torch.long)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class LineByLineTextDataset(Dataset):
|
| 114 |
+
"""
|
| 115 |
+
This will be superseded by a framework-agnostic approach soon.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
|
| 119 |
+
warnings.warn(
|
| 120 |
+
DEPRECATION_WARNING.format(
|
| 121 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
|
| 122 |
+
),
|
| 123 |
+
FutureWarning,
|
| 124 |
+
)
|
| 125 |
+
if os.path.isfile(file_path) is False:
|
| 126 |
+
raise ValueError(f"Input file path {file_path} not found")
|
| 127 |
+
# Here, we do not cache the features, operating under the assumption
|
| 128 |
+
# that we will soon use fast multithreaded tokenizers from the
|
| 129 |
+
# `tokenizers` repo everywhere =)
|
| 130 |
+
logger.info(f"Creating features from dataset file at {file_path}")
|
| 131 |
+
|
| 132 |
+
with open(file_path, encoding="utf-8") as f:
|
| 133 |
+
lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
|
| 134 |
+
|
| 135 |
+
batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size)
|
| 136 |
+
self.examples = batch_encoding["input_ids"]
|
| 137 |
+
self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
|
| 138 |
+
|
| 139 |
+
def __len__(self):
|
| 140 |
+
return len(self.examples)
|
| 141 |
+
|
| 142 |
+
def __getitem__(self, i) -> Dict[str, torch.tensor]:
|
| 143 |
+
return self.examples[i]
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class LineByLineWithRefDataset(Dataset):
|
| 147 |
+
"""
|
| 148 |
+
This will be superseded by a framework-agnostic approach soon.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, ref_path: str):
|
| 152 |
+
warnings.warn(
|
| 153 |
+
DEPRECATION_WARNING.format(
|
| 154 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm_wwm.py"
|
| 155 |
+
),
|
| 156 |
+
FutureWarning,
|
| 157 |
+
)
|
| 158 |
+
if os.path.isfile(file_path) is False:
|
| 159 |
+
raise ValueError(f"Input file path {file_path} not found")
|
| 160 |
+
if os.path.isfile(ref_path) is False:
|
| 161 |
+
raise ValueError(f"Ref file path {file_path} not found")
|
| 162 |
+
# Here, we do not cache the features, operating under the assumption
|
| 163 |
+
# that we will soon use fast multithreaded tokenizers from the
|
| 164 |
+
# `tokenizers` repo everywhere =)
|
| 165 |
+
logger.info(f"Creating features from dataset file at {file_path}")
|
| 166 |
+
logger.info(f"Use ref segment results at {ref_path}")
|
| 167 |
+
with open(file_path, encoding="utf-8") as f:
|
| 168 |
+
data = f.readlines() # use this method to avoid delimiter '\u2029' to split a line
|
| 169 |
+
data = [line.strip() for line in data if len(line) > 0 and not line.isspace()]
|
| 170 |
+
# Get ref inf from file
|
| 171 |
+
with open(ref_path, encoding="utf-8") as f:
|
| 172 |
+
ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
|
| 173 |
+
if len(data) != len(ref):
|
| 174 |
+
raise ValueError(
|
| 175 |
+
f"Length of Input file should be equal to Ref file. But the length of {file_path} is {len(data)} "
|
| 176 |
+
f"while length of {ref_path} is {len(ref)}"
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
|
| 180 |
+
self.examples = batch_encoding["input_ids"]
|
| 181 |
+
self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
|
| 182 |
+
|
| 183 |
+
n = len(self.examples)
|
| 184 |
+
for i in range(n):
|
| 185 |
+
self.examples[i]["chinese_ref"] = torch.tensor(ref[i], dtype=torch.long)
|
| 186 |
+
|
| 187 |
+
def __len__(self):
|
| 188 |
+
return len(self.examples)
|
| 189 |
+
|
| 190 |
+
def __getitem__(self, i) -> Dict[str, torch.tensor]:
|
| 191 |
+
return self.examples[i]
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class LineByLineWithSOPTextDataset(Dataset):
|
| 195 |
+
"""
|
| 196 |
+
Dataset for sentence order prediction task, prepare sentence pairs for SOP task
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
def __init__(self, tokenizer: PreTrainedTokenizer, file_dir: str, block_size: int):
|
| 200 |
+
warnings.warn(
|
| 201 |
+
DEPRECATION_WARNING.format(
|
| 202 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
|
| 203 |
+
),
|
| 204 |
+
FutureWarning,
|
| 205 |
+
)
|
| 206 |
+
if os.path.isdir(file_dir) is False:
|
| 207 |
+
raise ValueError(f"{file_dir} is not a directory")
|
| 208 |
+
logger.info(f"Creating features from dataset file folder at {file_dir}")
|
| 209 |
+
self.examples = []
|
| 210 |
+
# TODO: randomness could apply a random seed, ex. rng = random.Random(random_seed)
|
| 211 |
+
# file path looks like ./dataset/wiki_1, ./dataset/wiki_2
|
| 212 |
+
for file_name in os.listdir(file_dir):
|
| 213 |
+
file_path = os.path.join(file_dir, file_name)
|
| 214 |
+
if os.path.isfile(file_path) is False:
|
| 215 |
+
raise ValueError(f"{file_path} is not a file")
|
| 216 |
+
article_open = False
|
| 217 |
+
with open(file_path, encoding="utf-8") as f:
|
| 218 |
+
original_lines = f.readlines()
|
| 219 |
+
article_lines = []
|
| 220 |
+
for line in original_lines:
|
| 221 |
+
if "<doc id=" in line:
|
| 222 |
+
article_open = True
|
| 223 |
+
elif "</doc>" in line:
|
| 224 |
+
article_open = False
|
| 225 |
+
document = [
|
| 226 |
+
tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line))
|
| 227 |
+
for line in article_lines[1:]
|
| 228 |
+
if (len(line) > 0 and not line.isspace())
|
| 229 |
+
]
|
| 230 |
+
|
| 231 |
+
examples = self.create_examples_from_document(document, block_size, tokenizer)
|
| 232 |
+
self.examples.extend(examples)
|
| 233 |
+
article_lines = []
|
| 234 |
+
else:
|
| 235 |
+
if article_open:
|
| 236 |
+
article_lines.append(line)
|
| 237 |
+
|
| 238 |
+
logger.info("Dataset parse finished.")
|
| 239 |
+
|
| 240 |
+
def create_examples_from_document(self, document, block_size, tokenizer, short_seq_prob=0.1):
|
| 241 |
+
"""Creates examples for a single document."""
|
| 242 |
+
|
| 243 |
+
# Account for special tokens
|
| 244 |
+
max_num_tokens = block_size - tokenizer.num_special_tokens_to_add(pair=True)
|
| 245 |
+
|
| 246 |
+
# We *usually* want to fill up the entire sequence since we are padding
|
| 247 |
+
# to `block_size` anyways, so short sequences are generally wasted
|
| 248 |
+
# computation. However, we *sometimes*
|
| 249 |
+
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
| 250 |
+
# sequences to minimize the mismatch between pretraining and fine-tuning.
|
| 251 |
+
# The `target_seq_length` is just a rough target however, whereas
|
| 252 |
+
# `block_size` is a hard limit.
|
| 253 |
+
target_seq_length = max_num_tokens
|
| 254 |
+
if random.random() < short_seq_prob:
|
| 255 |
+
target_seq_length = random.randint(2, max_num_tokens)
|
| 256 |
+
|
| 257 |
+
# We DON'T just concatenate all of the tokens from a document into a long
|
| 258 |
+
# sequence and choose an arbitrary split point because this would make the
|
| 259 |
+
# next sentence prediction task too easy. Instead, we split the input into
|
| 260 |
+
# segments "A" and "B" based on the actual "sentences" provided by the user
|
| 261 |
+
# input.
|
| 262 |
+
examples = []
|
| 263 |
+
current_chunk = [] # a buffer stored current working segments
|
| 264 |
+
current_length = 0
|
| 265 |
+
i = 0
|
| 266 |
+
while i < len(document):
|
| 267 |
+
segment = document[i] # get a segment
|
| 268 |
+
if not segment:
|
| 269 |
+
i += 1
|
| 270 |
+
continue
|
| 271 |
+
current_chunk.append(segment) # add a segment to current chunk
|
| 272 |
+
current_length += len(segment) # overall token length
|
| 273 |
+
# if current length goes to the target length or reaches the end of file, start building token a and b
|
| 274 |
+
if i == len(document) - 1 or current_length >= target_seq_length:
|
| 275 |
+
if current_chunk:
|
| 276 |
+
# `a_end` is how many segments from `current_chunk` go into the `A` (first) sentence.
|
| 277 |
+
a_end = 1
|
| 278 |
+
# if current chunk has more than 2 sentences, pick part of it `A` (first) sentence
|
| 279 |
+
if len(current_chunk) >= 2:
|
| 280 |
+
a_end = random.randint(1, len(current_chunk) - 1)
|
| 281 |
+
# token a
|
| 282 |
+
tokens_a = []
|
| 283 |
+
for j in range(a_end):
|
| 284 |
+
tokens_a.extend(current_chunk[j])
|
| 285 |
+
|
| 286 |
+
# token b
|
| 287 |
+
tokens_b = []
|
| 288 |
+
for j in range(a_end, len(current_chunk)):
|
| 289 |
+
tokens_b.extend(current_chunk[j])
|
| 290 |
+
|
| 291 |
+
if len(tokens_a) == 0 or len(tokens_b) == 0:
|
| 292 |
+
continue
|
| 293 |
+
|
| 294 |
+
# switch tokens_a and tokens_b randomly
|
| 295 |
+
if random.random() < 0.5:
|
| 296 |
+
is_next = False
|
| 297 |
+
tokens_a, tokens_b = tokens_b, tokens_a
|
| 298 |
+
else:
|
| 299 |
+
is_next = True
|
| 300 |
+
|
| 301 |
+
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
|
| 302 |
+
"""Truncates a pair of sequences to a maximum sequence length."""
|
| 303 |
+
while True:
|
| 304 |
+
total_length = len(tokens_a) + len(tokens_b)
|
| 305 |
+
if total_length <= max_num_tokens:
|
| 306 |
+
break
|
| 307 |
+
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
| 308 |
+
if not (len(trunc_tokens) >= 1):
|
| 309 |
+
raise ValueError("Sequence length to be truncated must be no less than one")
|
| 310 |
+
# We want to sometimes truncate from the front and sometimes from the
|
| 311 |
+
# back to add more randomness and avoid biases.
|
| 312 |
+
if random.random() < 0.5:
|
| 313 |
+
del trunc_tokens[0]
|
| 314 |
+
else:
|
| 315 |
+
trunc_tokens.pop()
|
| 316 |
+
|
| 317 |
+
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)
|
| 318 |
+
if not (len(tokens_a) >= 1):
|
| 319 |
+
raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
|
| 320 |
+
if not (len(tokens_b) >= 1):
|
| 321 |
+
raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
|
| 322 |
+
|
| 323 |
+
# add special tokens
|
| 324 |
+
input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
|
| 325 |
+
# add token type ids, 0 for sentence a, 1 for sentence b
|
| 326 |
+
token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
|
| 327 |
+
|
| 328 |
+
example = {
|
| 329 |
+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
| 330 |
+
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
|
| 331 |
+
"sentence_order_label": torch.tensor(0 if is_next else 1, dtype=torch.long),
|
| 332 |
+
}
|
| 333 |
+
examples.append(example)
|
| 334 |
+
current_chunk = [] # clear current chunk
|
| 335 |
+
current_length = 0 # reset current text length
|
| 336 |
+
i += 1 # go to next line
|
| 337 |
+
return examples
|
| 338 |
+
|
| 339 |
+
def __len__(self):
|
| 340 |
+
return len(self.examples)
|
| 341 |
+
|
| 342 |
+
def __getitem__(self, i) -> Dict[str, torch.tensor]:
|
| 343 |
+
return self.examples[i]
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class TextDatasetForNextSentencePrediction(Dataset):
|
| 347 |
+
"""
|
| 348 |
+
This will be superseded by a framework-agnostic approach soon.
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
def __init__(
|
| 352 |
+
self,
|
| 353 |
+
tokenizer: PreTrainedTokenizer,
|
| 354 |
+
file_path: str,
|
| 355 |
+
block_size: int,
|
| 356 |
+
overwrite_cache=False,
|
| 357 |
+
short_seq_probability=0.1,
|
| 358 |
+
nsp_probability=0.5,
|
| 359 |
+
):
|
| 360 |
+
warnings.warn(
|
| 361 |
+
DEPRECATION_WARNING.format(
|
| 362 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
|
| 363 |
+
),
|
| 364 |
+
FutureWarning,
|
| 365 |
+
)
|
| 366 |
+
if not os.path.isfile(file_path):
|
| 367 |
+
raise ValueError(f"Input file path {file_path} not found")
|
| 368 |
+
|
| 369 |
+
self.short_seq_probability = short_seq_probability
|
| 370 |
+
self.nsp_probability = nsp_probability
|
| 371 |
+
|
| 372 |
+
directory, filename = os.path.split(file_path)
|
| 373 |
+
cached_features_file = os.path.join(
|
| 374 |
+
directory,
|
| 375 |
+
f"cached_nsp_{tokenizer.__class__.__name__}_{block_size}_{filename}",
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
self.tokenizer = tokenizer
|
| 379 |
+
|
| 380 |
+
# Make sure only the first process in distributed training processes the dataset,
|
| 381 |
+
# and the others will use the cache.
|
| 382 |
+
lock_path = cached_features_file + ".lock"
|
| 383 |
+
|
| 384 |
+
# Input file format:
|
| 385 |
+
# (1) One sentence per line. These should ideally be actual sentences, not
|
| 386 |
+
# entire paragraphs or arbitrary spans of text. (Because we use the
|
| 387 |
+
# sentence boundaries for the "next sentence prediction" task).
|
| 388 |
+
# (2) Blank lines between documents. Document boundaries are needed so
|
| 389 |
+
# that the "next sentence prediction" task doesn't span between documents.
|
| 390 |
+
#
|
| 391 |
+
# Example:
|
| 392 |
+
# I am very happy.
|
| 393 |
+
# Here is the second sentence.
|
| 394 |
+
#
|
| 395 |
+
# A new document.
|
| 396 |
+
|
| 397 |
+
with FileLock(lock_path):
|
| 398 |
+
if os.path.exists(cached_features_file) and not overwrite_cache:
|
| 399 |
+
start = time.time()
|
| 400 |
+
with open(cached_features_file, "rb") as handle:
|
| 401 |
+
self.examples = pickle.load(handle)
|
| 402 |
+
logger.info(
|
| 403 |
+
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
|
| 404 |
+
)
|
| 405 |
+
else:
|
| 406 |
+
logger.info(f"Creating features from dataset file at {directory}")
|
| 407 |
+
|
| 408 |
+
self.documents = [[]]
|
| 409 |
+
with open(file_path, encoding="utf-8") as f:
|
| 410 |
+
while True:
|
| 411 |
+
line = f.readline()
|
| 412 |
+
if not line:
|
| 413 |
+
break
|
| 414 |
+
line = line.strip()
|
| 415 |
+
|
| 416 |
+
# Empty lines are used as document delimiters
|
| 417 |
+
if not line and len(self.documents[-1]) != 0:
|
| 418 |
+
self.documents.append([])
|
| 419 |
+
tokens = tokenizer.tokenize(line)
|
| 420 |
+
tokens = tokenizer.convert_tokens_to_ids(tokens)
|
| 421 |
+
if tokens:
|
| 422 |
+
self.documents[-1].append(tokens)
|
| 423 |
+
|
| 424 |
+
logger.info(f"Creating examples from {len(self.documents)} documents.")
|
| 425 |
+
self.examples = []
|
| 426 |
+
for doc_index, document in enumerate(self.documents):
|
| 427 |
+
self.create_examples_from_document(document, doc_index, block_size)
|
| 428 |
+
|
| 429 |
+
start = time.time()
|
| 430 |
+
with open(cached_features_file, "wb") as handle:
|
| 431 |
+
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
| 432 |
+
logger.info(
|
| 433 |
+
f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
def create_examples_from_document(self, document: List[List[int]], doc_index: int, block_size: int):
|
| 437 |
+
"""Creates examples for a single document."""
|
| 438 |
+
|
| 439 |
+
max_num_tokens = block_size - self.tokenizer.num_special_tokens_to_add(pair=True)
|
| 440 |
+
|
| 441 |
+
# We *usually* want to fill up the entire sequence since we are padding
|
| 442 |
+
# to `block_size` anyways, so short sequences are generally wasted
|
| 443 |
+
# computation. However, we *sometimes*
|
| 444 |
+
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
| 445 |
+
# sequences to minimize the mismatch between pretraining and fine-tuning.
|
| 446 |
+
# The `target_seq_length` is just a rough target however, whereas
|
| 447 |
+
# `block_size` is a hard limit.
|
| 448 |
+
target_seq_length = max_num_tokens
|
| 449 |
+
if random.random() < self.short_seq_probability:
|
| 450 |
+
target_seq_length = random.randint(2, max_num_tokens)
|
| 451 |
+
|
| 452 |
+
current_chunk = [] # a buffer stored current working segments
|
| 453 |
+
current_length = 0
|
| 454 |
+
i = 0
|
| 455 |
+
|
| 456 |
+
while i < len(document):
|
| 457 |
+
segment = document[i]
|
| 458 |
+
current_chunk.append(segment)
|
| 459 |
+
current_length += len(segment)
|
| 460 |
+
if i == len(document) - 1 or current_length >= target_seq_length:
|
| 461 |
+
if current_chunk:
|
| 462 |
+
# `a_end` is how many segments from `current_chunk` go into the `A`
|
| 463 |
+
# (first) sentence.
|
| 464 |
+
a_end = 1
|
| 465 |
+
if len(current_chunk) >= 2:
|
| 466 |
+
a_end = random.randint(1, len(current_chunk) - 1)
|
| 467 |
+
|
| 468 |
+
tokens_a = []
|
| 469 |
+
for j in range(a_end):
|
| 470 |
+
tokens_a.extend(current_chunk[j])
|
| 471 |
+
|
| 472 |
+
tokens_b = []
|
| 473 |
+
|
| 474 |
+
if len(current_chunk) == 1 or random.random() < self.nsp_probability:
|
| 475 |
+
is_random_next = True
|
| 476 |
+
target_b_length = target_seq_length - len(tokens_a)
|
| 477 |
+
|
| 478 |
+
# This should rarely go for more than one iteration for large
|
| 479 |
+
# corpora. However, just to be careful, we try to make sure that
|
| 480 |
+
# the random document is not the same as the document
|
| 481 |
+
# we're processing.
|
| 482 |
+
for _ in range(10):
|
| 483 |
+
random_document_index = random.randint(0, len(self.documents) - 1)
|
| 484 |
+
if random_document_index != doc_index:
|
| 485 |
+
break
|
| 486 |
+
|
| 487 |
+
random_document = self.documents[random_document_index]
|
| 488 |
+
random_start = random.randint(0, len(random_document) - 1)
|
| 489 |
+
for j in range(random_start, len(random_document)):
|
| 490 |
+
tokens_b.extend(random_document[j])
|
| 491 |
+
if len(tokens_b) >= target_b_length:
|
| 492 |
+
break
|
| 493 |
+
# We didn't actually use these segments so we "put them back" so
|
| 494 |
+
# they don't go to waste.
|
| 495 |
+
num_unused_segments = len(current_chunk) - a_end
|
| 496 |
+
i -= num_unused_segments
|
| 497 |
+
# Actual next
|
| 498 |
+
else:
|
| 499 |
+
is_random_next = False
|
| 500 |
+
for j in range(a_end, len(current_chunk)):
|
| 501 |
+
tokens_b.extend(current_chunk[j])
|
| 502 |
+
|
| 503 |
+
if not (len(tokens_a) >= 1):
|
| 504 |
+
raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
|
| 505 |
+
if not (len(tokens_b) >= 1):
|
| 506 |
+
raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
|
| 507 |
+
|
| 508 |
+
# add special tokens
|
| 509 |
+
input_ids = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
|
| 510 |
+
# add token type ids, 0 for sentence a, 1 for sentence b
|
| 511 |
+
token_type_ids = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
|
| 512 |
+
|
| 513 |
+
example = {
|
| 514 |
+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
| 515 |
+
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
|
| 516 |
+
"next_sentence_label": torch.tensor(1 if is_random_next else 0, dtype=torch.long),
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
self.examples.append(example)
|
| 520 |
+
|
| 521 |
+
current_chunk = []
|
| 522 |
+
current_length = 0
|
| 523 |
+
|
| 524 |
+
i += 1
|
| 525 |
+
|
| 526 |
+
def __len__(self):
|
| 527 |
+
return len(self.examples)
|
| 528 |
+
|
| 529 |
+
def __getitem__(self, i):
|
| 530 |
+
return self.examples[i]
|
.venv/Lib/site-packages/transformers/data/datasets/squad.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import time
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from enum import Enum
|
| 19 |
+
from typing import Dict, List, Optional, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from filelock import FileLock
|
| 23 |
+
from torch.utils.data import Dataset
|
| 24 |
+
|
| 25 |
+
from ...models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
| 26 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
| 27 |
+
from ...utils import logging
|
| 28 |
+
from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
|
| 34 |
+
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class SquadDataTrainingArguments:
|
| 39 |
+
"""
|
| 40 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
model_type: str = field(
|
| 44 |
+
default=None, metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_TYPES)}
|
| 45 |
+
)
|
| 46 |
+
data_dir: str = field(
|
| 47 |
+
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
|
| 48 |
+
)
|
| 49 |
+
max_seq_length: int = field(
|
| 50 |
+
default=128,
|
| 51 |
+
metadata={
|
| 52 |
+
"help": (
|
| 53 |
+
"The maximum total input sequence length after tokenization. Sequences longer "
|
| 54 |
+
"than this will be truncated, sequences shorter will be padded."
|
| 55 |
+
)
|
| 56 |
+
},
|
| 57 |
+
)
|
| 58 |
+
doc_stride: int = field(
|
| 59 |
+
default=128,
|
| 60 |
+
metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
|
| 61 |
+
)
|
| 62 |
+
max_query_length: int = field(
|
| 63 |
+
default=64,
|
| 64 |
+
metadata={
|
| 65 |
+
"help": (
|
| 66 |
+
"The maximum number of tokens for the question. Questions longer than this will "
|
| 67 |
+
"be truncated to this length."
|
| 68 |
+
)
|
| 69 |
+
},
|
| 70 |
+
)
|
| 71 |
+
max_answer_length: int = field(
|
| 72 |
+
default=30,
|
| 73 |
+
metadata={
|
| 74 |
+
"help": (
|
| 75 |
+
"The maximum length of an answer that can be generated. This is needed because the start "
|
| 76 |
+
"and end predictions are not conditioned on one another."
|
| 77 |
+
)
|
| 78 |
+
},
|
| 79 |
+
)
|
| 80 |
+
overwrite_cache: bool = field(
|
| 81 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 82 |
+
)
|
| 83 |
+
version_2_with_negative: bool = field(
|
| 84 |
+
default=False, metadata={"help": "If true, the SQuAD examples contain some that do not have an answer."}
|
| 85 |
+
)
|
| 86 |
+
null_score_diff_threshold: float = field(
|
| 87 |
+
default=0.0, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
|
| 88 |
+
)
|
| 89 |
+
n_best_size: int = field(
|
| 90 |
+
default=20, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
|
| 91 |
+
)
|
| 92 |
+
lang_id: int = field(
|
| 93 |
+
default=0,
|
| 94 |
+
metadata={
|
| 95 |
+
"help": (
|
| 96 |
+
"language id of input for language-specific xlm models (see"
|
| 97 |
+
" tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
|
| 98 |
+
)
|
| 99 |
+
},
|
| 100 |
+
)
|
| 101 |
+
threads: int = field(default=1, metadata={"help": "multiple threads for converting example to features"})
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class Split(Enum):
|
| 105 |
+
train = "train"
|
| 106 |
+
dev = "dev"
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class SquadDataset(Dataset):
|
| 110 |
+
"""
|
| 111 |
+
This will be superseded by a framework-agnostic approach soon.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
args: SquadDataTrainingArguments
|
| 115 |
+
features: List[SquadFeatures]
|
| 116 |
+
mode: Split
|
| 117 |
+
is_language_sensitive: bool
|
| 118 |
+
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
args: SquadDataTrainingArguments,
|
| 122 |
+
tokenizer: PreTrainedTokenizer,
|
| 123 |
+
limit_length: Optional[int] = None,
|
| 124 |
+
mode: Union[str, Split] = Split.train,
|
| 125 |
+
is_language_sensitive: Optional[bool] = False,
|
| 126 |
+
cache_dir: Optional[str] = None,
|
| 127 |
+
dataset_format: Optional[str] = "pt",
|
| 128 |
+
):
|
| 129 |
+
self.args = args
|
| 130 |
+
self.is_language_sensitive = is_language_sensitive
|
| 131 |
+
self.processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
| 132 |
+
if isinstance(mode, str):
|
| 133 |
+
try:
|
| 134 |
+
mode = Split[mode]
|
| 135 |
+
except KeyError:
|
| 136 |
+
raise KeyError("mode is not a valid split name")
|
| 137 |
+
self.mode = mode
|
| 138 |
+
# Load data features from cache or dataset file
|
| 139 |
+
version_tag = "v2" if args.version_2_with_negative else "v1"
|
| 140 |
+
cached_features_file = os.path.join(
|
| 141 |
+
cache_dir if cache_dir is not None else args.data_dir,
|
| 142 |
+
f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{version_tag}",
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Make sure only the first process in distributed training processes the dataset,
|
| 146 |
+
# and the others will use the cache.
|
| 147 |
+
lock_path = cached_features_file + ".lock"
|
| 148 |
+
with FileLock(lock_path):
|
| 149 |
+
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
| 150 |
+
start = time.time()
|
| 151 |
+
self.old_features = torch.load(cached_features_file)
|
| 152 |
+
|
| 153 |
+
# Legacy cache files have only features, while new cache files
|
| 154 |
+
# will have dataset and examples also.
|
| 155 |
+
self.features = self.old_features["features"]
|
| 156 |
+
self.dataset = self.old_features.get("dataset", None)
|
| 157 |
+
self.examples = self.old_features.get("examples", None)
|
| 158 |
+
logger.info(
|
| 159 |
+
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
if self.dataset is None or self.examples is None:
|
| 163 |
+
logger.warning(
|
| 164 |
+
f"Deleting cached file {cached_features_file} will allow dataset and examples to be cached in"
|
| 165 |
+
" future run"
|
| 166 |
+
)
|
| 167 |
+
else:
|
| 168 |
+
if mode == Split.dev:
|
| 169 |
+
self.examples = self.processor.get_dev_examples(args.data_dir)
|
| 170 |
+
else:
|
| 171 |
+
self.examples = self.processor.get_train_examples(args.data_dir)
|
| 172 |
+
|
| 173 |
+
self.features, self.dataset = squad_convert_examples_to_features(
|
| 174 |
+
examples=self.examples,
|
| 175 |
+
tokenizer=tokenizer,
|
| 176 |
+
max_seq_length=args.max_seq_length,
|
| 177 |
+
doc_stride=args.doc_stride,
|
| 178 |
+
max_query_length=args.max_query_length,
|
| 179 |
+
is_training=mode == Split.train,
|
| 180 |
+
threads=args.threads,
|
| 181 |
+
return_dataset=dataset_format,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
start = time.time()
|
| 185 |
+
torch.save(
|
| 186 |
+
{"features": self.features, "dataset": self.dataset, "examples": self.examples},
|
| 187 |
+
cached_features_file,
|
| 188 |
+
)
|
| 189 |
+
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
|
| 190 |
+
logger.info(
|
| 191 |
+
f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
def __len__(self):
|
| 195 |
+
return len(self.features)
|
| 196 |
+
|
| 197 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
| 198 |
+
# Convert to Tensors and build dataset
|
| 199 |
+
feature = self.features[i]
|
| 200 |
+
|
| 201 |
+
input_ids = torch.tensor(feature.input_ids, dtype=torch.long)
|
| 202 |
+
attention_mask = torch.tensor(feature.attention_mask, dtype=torch.long)
|
| 203 |
+
token_type_ids = torch.tensor(feature.token_type_ids, dtype=torch.long)
|
| 204 |
+
cls_index = torch.tensor(feature.cls_index, dtype=torch.long)
|
| 205 |
+
p_mask = torch.tensor(feature.p_mask, dtype=torch.float)
|
| 206 |
+
is_impossible = torch.tensor(feature.is_impossible, dtype=torch.float)
|
| 207 |
+
|
| 208 |
+
inputs = {
|
| 209 |
+
"input_ids": input_ids,
|
| 210 |
+
"attention_mask": attention_mask,
|
| 211 |
+
"token_type_ids": token_type_ids,
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
if self.args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
|
| 215 |
+
del inputs["token_type_ids"]
|
| 216 |
+
|
| 217 |
+
if self.args.model_type in ["xlnet", "xlm"]:
|
| 218 |
+
inputs.update({"cls_index": cls_index, "p_mask": p_mask})
|
| 219 |
+
if self.args.version_2_with_negative:
|
| 220 |
+
inputs.update({"is_impossible": is_impossible})
|
| 221 |
+
if self.is_language_sensitive:
|
| 222 |
+
inputs.update({"langs": (torch.ones(input_ids.shape, dtype=torch.int64) * self.args.lang_id)})
|
| 223 |
+
|
| 224 |
+
if self.mode == Split.train:
|
| 225 |
+
start_positions = torch.tensor(feature.start_position, dtype=torch.long)
|
| 226 |
+
end_positions = torch.tensor(feature.end_position, dtype=torch.long)
|
| 227 |
+
inputs.update({"start_positions": start_positions, "end_positions": end_positions})
|
| 228 |
+
|
| 229 |
+
return inputs
|
.venv/Lib/site-packages/transformers/data/metrics/__init__.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 2 |
+
# you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at
|
| 4 |
+
#
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
#
|
| 7 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 10 |
+
# See the License for the specific language governing permissions and
|
| 11 |
+
# limitations under the License.
|
| 12 |
+
|
| 13 |
+
import warnings
|
| 14 |
+
|
| 15 |
+
from ...utils import is_sklearn_available, requires_backends
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if is_sklearn_available():
|
| 19 |
+
from scipy.stats import pearsonr, spearmanr
|
| 20 |
+
from sklearn.metrics import f1_score, matthews_corrcoef
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
DEPRECATION_WARNING = (
|
| 24 |
+
"This metric will be removed from the library soon, metrics should be handled with the 🤗 Evaluate "
|
| 25 |
+
"library. You can have a look at this example script for pointers: "
|
| 26 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def simple_accuracy(preds, labels):
|
| 31 |
+
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
| 32 |
+
requires_backends(simple_accuracy, "sklearn")
|
| 33 |
+
return (preds == labels).mean()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def acc_and_f1(preds, labels):
|
| 37 |
+
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
| 38 |
+
requires_backends(acc_and_f1, "sklearn")
|
| 39 |
+
acc = simple_accuracy(preds, labels)
|
| 40 |
+
f1 = f1_score(y_true=labels, y_pred=preds)
|
| 41 |
+
return {
|
| 42 |
+
"acc": acc,
|
| 43 |
+
"f1": f1,
|
| 44 |
+
"acc_and_f1": (acc + f1) / 2,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def pearson_and_spearman(preds, labels):
|
| 49 |
+
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
| 50 |
+
requires_backends(pearson_and_spearman, "sklearn")
|
| 51 |
+
pearson_corr = pearsonr(preds, labels)[0]
|
| 52 |
+
spearman_corr = spearmanr(preds, labels)[0]
|
| 53 |
+
return {
|
| 54 |
+
"pearson": pearson_corr,
|
| 55 |
+
"spearmanr": spearman_corr,
|
| 56 |
+
"corr": (pearson_corr + spearman_corr) / 2,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def glue_compute_metrics(task_name, preds, labels):
|
| 61 |
+
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
| 62 |
+
requires_backends(glue_compute_metrics, "sklearn")
|
| 63 |
+
assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
|
| 64 |
+
if task_name == "cola":
|
| 65 |
+
return {"mcc": matthews_corrcoef(labels, preds)}
|
| 66 |
+
elif task_name == "sst-2":
|
| 67 |
+
return {"acc": simple_accuracy(preds, labels)}
|
| 68 |
+
elif task_name == "mrpc":
|
| 69 |
+
return acc_and_f1(preds, labels)
|
| 70 |
+
elif task_name == "sts-b":
|
| 71 |
+
return pearson_and_spearman(preds, labels)
|
| 72 |
+
elif task_name == "qqp":
|
| 73 |
+
return acc_and_f1(preds, labels)
|
| 74 |
+
elif task_name == "mnli":
|
| 75 |
+
return {"mnli/acc": simple_accuracy(preds, labels)}
|
| 76 |
+
elif task_name == "mnli-mm":
|
| 77 |
+
return {"mnli-mm/acc": simple_accuracy(preds, labels)}
|
| 78 |
+
elif task_name == "qnli":
|
| 79 |
+
return {"acc": simple_accuracy(preds, labels)}
|
| 80 |
+
elif task_name == "rte":
|
| 81 |
+
return {"acc": simple_accuracy(preds, labels)}
|
| 82 |
+
elif task_name == "wnli":
|
| 83 |
+
return {"acc": simple_accuracy(preds, labels)}
|
| 84 |
+
elif task_name == "hans":
|
| 85 |
+
return {"acc": simple_accuracy(preds, labels)}
|
| 86 |
+
else:
|
| 87 |
+
raise KeyError(task_name)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def xnli_compute_metrics(task_name, preds, labels):
|
| 91 |
+
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
| 92 |
+
requires_backends(xnli_compute_metrics, "sklearn")
|
| 93 |
+
if len(preds) != len(labels):
|
| 94 |
+
raise ValueError(f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}")
|
| 95 |
+
if task_name == "xnli":
|
| 96 |
+
return {"acc": simple_accuracy(preds, labels)}
|
| 97 |
+
else:
|
| 98 |
+
raise KeyError(task_name)
|
.venv/Lib/site-packages/transformers/data/metrics/squad_metrics.py
ADDED
|
@@ -0,0 +1,779 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Very heavily inspired by the official evaluation script for SQuAD version 2.0 which was modified by XLNet authors to
|
| 16 |
+
update `find_best_threshold` scripts for SQuAD V2.0
|
| 17 |
+
|
| 18 |
+
In addition to basic functionality, we also compute additional statistics and plot precision-recall curves if an
|
| 19 |
+
additional na_prob.json file is provided. This file is expected to map question ID's to the model's predicted
|
| 20 |
+
probability that a question is unanswerable.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import collections
|
| 24 |
+
import json
|
| 25 |
+
import math
|
| 26 |
+
import re
|
| 27 |
+
import string
|
| 28 |
+
|
| 29 |
+
from ...models.bert import BasicTokenizer
|
| 30 |
+
from ...utils import logging
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logger = logging.get_logger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def normalize_answer(s):
|
| 37 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
| 38 |
+
|
| 39 |
+
def remove_articles(text):
|
| 40 |
+
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
| 41 |
+
return re.sub(regex, " ", text)
|
| 42 |
+
|
| 43 |
+
def white_space_fix(text):
|
| 44 |
+
return " ".join(text.split())
|
| 45 |
+
|
| 46 |
+
def remove_punc(text):
|
| 47 |
+
exclude = set(string.punctuation)
|
| 48 |
+
return "".join(ch for ch in text if ch not in exclude)
|
| 49 |
+
|
| 50 |
+
def lower(text):
|
| 51 |
+
return text.lower()
|
| 52 |
+
|
| 53 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_tokens(s):
|
| 57 |
+
if not s:
|
| 58 |
+
return []
|
| 59 |
+
return normalize_answer(s).split()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def compute_exact(a_gold, a_pred):
|
| 63 |
+
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def compute_f1(a_gold, a_pred):
|
| 67 |
+
gold_toks = get_tokens(a_gold)
|
| 68 |
+
pred_toks = get_tokens(a_pred)
|
| 69 |
+
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
|
| 70 |
+
num_same = sum(common.values())
|
| 71 |
+
if len(gold_toks) == 0 or len(pred_toks) == 0:
|
| 72 |
+
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
|
| 73 |
+
return int(gold_toks == pred_toks)
|
| 74 |
+
if num_same == 0:
|
| 75 |
+
return 0
|
| 76 |
+
precision = 1.0 * num_same / len(pred_toks)
|
| 77 |
+
recall = 1.0 * num_same / len(gold_toks)
|
| 78 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
| 79 |
+
return f1
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_raw_scores(examples, preds):
|
| 83 |
+
"""
|
| 84 |
+
Computes the exact and f1 scores from the examples and the model predictions
|
| 85 |
+
"""
|
| 86 |
+
exact_scores = {}
|
| 87 |
+
f1_scores = {}
|
| 88 |
+
|
| 89 |
+
for example in examples:
|
| 90 |
+
qas_id = example.qas_id
|
| 91 |
+
gold_answers = [answer["text"] for answer in example.answers if normalize_answer(answer["text"])]
|
| 92 |
+
|
| 93 |
+
if not gold_answers:
|
| 94 |
+
# For unanswerable questions, only correct answer is empty string
|
| 95 |
+
gold_answers = [""]
|
| 96 |
+
|
| 97 |
+
if qas_id not in preds:
|
| 98 |
+
print(f"Missing prediction for {qas_id}")
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
prediction = preds[qas_id]
|
| 102 |
+
exact_scores[qas_id] = max(compute_exact(a, prediction) for a in gold_answers)
|
| 103 |
+
f1_scores[qas_id] = max(compute_f1(a, prediction) for a in gold_answers)
|
| 104 |
+
|
| 105 |
+
return exact_scores, f1_scores
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
|
| 109 |
+
new_scores = {}
|
| 110 |
+
for qid, s in scores.items():
|
| 111 |
+
pred_na = na_probs[qid] > na_prob_thresh
|
| 112 |
+
if pred_na:
|
| 113 |
+
new_scores[qid] = float(not qid_to_has_ans[qid])
|
| 114 |
+
else:
|
| 115 |
+
new_scores[qid] = s
|
| 116 |
+
return new_scores
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
|
| 120 |
+
if not qid_list:
|
| 121 |
+
total = len(exact_scores)
|
| 122 |
+
return collections.OrderedDict(
|
| 123 |
+
[
|
| 124 |
+
("exact", 100.0 * sum(exact_scores.values()) / total),
|
| 125 |
+
("f1", 100.0 * sum(f1_scores.values()) / total),
|
| 126 |
+
("total", total),
|
| 127 |
+
]
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
total = len(qid_list)
|
| 131 |
+
return collections.OrderedDict(
|
| 132 |
+
[
|
| 133 |
+
("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
|
| 134 |
+
("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
|
| 135 |
+
("total", total),
|
| 136 |
+
]
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def merge_eval(main_eval, new_eval, prefix):
|
| 141 |
+
for k in new_eval:
|
| 142 |
+
main_eval[f"{prefix}_{k}"] = new_eval[k]
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
|
| 146 |
+
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
| 147 |
+
cur_score = num_no_ans
|
| 148 |
+
best_score = cur_score
|
| 149 |
+
best_thresh = 0.0
|
| 150 |
+
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
| 151 |
+
for i, qid in enumerate(qid_list):
|
| 152 |
+
if qid not in scores:
|
| 153 |
+
continue
|
| 154 |
+
if qid_to_has_ans[qid]:
|
| 155 |
+
diff = scores[qid]
|
| 156 |
+
else:
|
| 157 |
+
if preds[qid]:
|
| 158 |
+
diff = -1
|
| 159 |
+
else:
|
| 160 |
+
diff = 0
|
| 161 |
+
cur_score += diff
|
| 162 |
+
if cur_score > best_score:
|
| 163 |
+
best_score = cur_score
|
| 164 |
+
best_thresh = na_probs[qid]
|
| 165 |
+
|
| 166 |
+
has_ans_score, has_ans_cnt = 0, 0
|
| 167 |
+
for qid in qid_list:
|
| 168 |
+
if not qid_to_has_ans[qid]:
|
| 169 |
+
continue
|
| 170 |
+
has_ans_cnt += 1
|
| 171 |
+
|
| 172 |
+
if qid not in scores:
|
| 173 |
+
continue
|
| 174 |
+
has_ans_score += scores[qid]
|
| 175 |
+
|
| 176 |
+
return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
| 180 |
+
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
|
| 181 |
+
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
|
| 182 |
+
main_eval["best_exact"] = best_exact
|
| 183 |
+
main_eval["best_exact_thresh"] = exact_thresh
|
| 184 |
+
main_eval["best_f1"] = best_f1
|
| 185 |
+
main_eval["best_f1_thresh"] = f1_thresh
|
| 186 |
+
main_eval["has_ans_exact"] = has_ans_exact
|
| 187 |
+
main_eval["has_ans_f1"] = has_ans_f1
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
|
| 191 |
+
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
| 192 |
+
cur_score = num_no_ans
|
| 193 |
+
best_score = cur_score
|
| 194 |
+
best_thresh = 0.0
|
| 195 |
+
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
| 196 |
+
for _, qid in enumerate(qid_list):
|
| 197 |
+
if qid not in scores:
|
| 198 |
+
continue
|
| 199 |
+
if qid_to_has_ans[qid]:
|
| 200 |
+
diff = scores[qid]
|
| 201 |
+
else:
|
| 202 |
+
if preds[qid]:
|
| 203 |
+
diff = -1
|
| 204 |
+
else:
|
| 205 |
+
diff = 0
|
| 206 |
+
cur_score += diff
|
| 207 |
+
if cur_score > best_score:
|
| 208 |
+
best_score = cur_score
|
| 209 |
+
best_thresh = na_probs[qid]
|
| 210 |
+
return 100.0 * best_score / len(scores), best_thresh
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
| 214 |
+
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
|
| 215 |
+
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
|
| 216 |
+
|
| 217 |
+
main_eval["best_exact"] = best_exact
|
| 218 |
+
main_eval["best_exact_thresh"] = exact_thresh
|
| 219 |
+
main_eval["best_f1"] = best_f1
|
| 220 |
+
main_eval["best_f1_thresh"] = f1_thresh
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0):
|
| 224 |
+
qas_id_to_has_answer = {example.qas_id: bool(example.answers) for example in examples}
|
| 225 |
+
has_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if has_answer]
|
| 226 |
+
no_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if not has_answer]
|
| 227 |
+
|
| 228 |
+
if no_answer_probs is None:
|
| 229 |
+
no_answer_probs = {k: 0.0 for k in preds}
|
| 230 |
+
|
| 231 |
+
exact, f1 = get_raw_scores(examples, preds)
|
| 232 |
+
|
| 233 |
+
exact_threshold = apply_no_ans_threshold(
|
| 234 |
+
exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold
|
| 235 |
+
)
|
| 236 |
+
f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
|
| 237 |
+
|
| 238 |
+
evaluation = make_eval_dict(exact_threshold, f1_threshold)
|
| 239 |
+
|
| 240 |
+
if has_answer_qids:
|
| 241 |
+
has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids)
|
| 242 |
+
merge_eval(evaluation, has_ans_eval, "HasAns")
|
| 243 |
+
|
| 244 |
+
if no_answer_qids:
|
| 245 |
+
no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids)
|
| 246 |
+
merge_eval(evaluation, no_ans_eval, "NoAns")
|
| 247 |
+
|
| 248 |
+
if no_answer_probs:
|
| 249 |
+
find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer)
|
| 250 |
+
|
| 251 |
+
return evaluation
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
| 255 |
+
"""Project the tokenized prediction back to the original text."""
|
| 256 |
+
|
| 257 |
+
# When we created the data, we kept track of the alignment between original
|
| 258 |
+
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
|
| 259 |
+
# now `orig_text` contains the span of our original text corresponding to the
|
| 260 |
+
# span that we predicted.
|
| 261 |
+
#
|
| 262 |
+
# However, `orig_text` may contain extra characters that we don't want in
|
| 263 |
+
# our prediction.
|
| 264 |
+
#
|
| 265 |
+
# For example, let's say:
|
| 266 |
+
# pred_text = steve smith
|
| 267 |
+
# orig_text = Steve Smith's
|
| 268 |
+
#
|
| 269 |
+
# We don't want to return `orig_text` because it contains the extra "'s".
|
| 270 |
+
#
|
| 271 |
+
# We don't want to return `pred_text` because it's already been normalized
|
| 272 |
+
# (the SQuAD eval script also does punctuation stripping/lower casing but
|
| 273 |
+
# our tokenizer does additional normalization like stripping accent
|
| 274 |
+
# characters).
|
| 275 |
+
#
|
| 276 |
+
# What we really want to return is "Steve Smith".
|
| 277 |
+
#
|
| 278 |
+
# Therefore, we have to apply a semi-complicated alignment heuristic between
|
| 279 |
+
# `pred_text` and `orig_text` to get a character-to-character alignment. This
|
| 280 |
+
# can fail in certain cases in which case we just return `orig_text`.
|
| 281 |
+
|
| 282 |
+
def _strip_spaces(text):
|
| 283 |
+
ns_chars = []
|
| 284 |
+
ns_to_s_map = collections.OrderedDict()
|
| 285 |
+
for i, c in enumerate(text):
|
| 286 |
+
if c == " ":
|
| 287 |
+
continue
|
| 288 |
+
ns_to_s_map[len(ns_chars)] = i
|
| 289 |
+
ns_chars.append(c)
|
| 290 |
+
ns_text = "".join(ns_chars)
|
| 291 |
+
return (ns_text, ns_to_s_map)
|
| 292 |
+
|
| 293 |
+
# We first tokenize `orig_text`, strip whitespace from the result
|
| 294 |
+
# and `pred_text`, and check if they are the same length. If they are
|
| 295 |
+
# NOT the same length, the heuristic has failed. If they are the same
|
| 296 |
+
# length, we assume the characters are one-to-one aligned.
|
| 297 |
+
tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
| 298 |
+
|
| 299 |
+
tok_text = " ".join(tokenizer.tokenize(orig_text))
|
| 300 |
+
|
| 301 |
+
start_position = tok_text.find(pred_text)
|
| 302 |
+
if start_position == -1:
|
| 303 |
+
if verbose_logging:
|
| 304 |
+
logger.info(f"Unable to find text: '{pred_text}' in '{orig_text}'")
|
| 305 |
+
return orig_text
|
| 306 |
+
end_position = start_position + len(pred_text) - 1
|
| 307 |
+
|
| 308 |
+
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
|
| 309 |
+
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
|
| 310 |
+
|
| 311 |
+
if len(orig_ns_text) != len(tok_ns_text):
|
| 312 |
+
if verbose_logging:
|
| 313 |
+
logger.info(f"Length not equal after stripping spaces: '{orig_ns_text}' vs '{tok_ns_text}'")
|
| 314 |
+
return orig_text
|
| 315 |
+
|
| 316 |
+
# We then project the characters in `pred_text` back to `orig_text` using
|
| 317 |
+
# the character-to-character alignment.
|
| 318 |
+
tok_s_to_ns_map = {}
|
| 319 |
+
for i, tok_index in tok_ns_to_s_map.items():
|
| 320 |
+
tok_s_to_ns_map[tok_index] = i
|
| 321 |
+
|
| 322 |
+
orig_start_position = None
|
| 323 |
+
if start_position in tok_s_to_ns_map:
|
| 324 |
+
ns_start_position = tok_s_to_ns_map[start_position]
|
| 325 |
+
if ns_start_position in orig_ns_to_s_map:
|
| 326 |
+
orig_start_position = orig_ns_to_s_map[ns_start_position]
|
| 327 |
+
|
| 328 |
+
if orig_start_position is None:
|
| 329 |
+
if verbose_logging:
|
| 330 |
+
logger.info("Couldn't map start position")
|
| 331 |
+
return orig_text
|
| 332 |
+
|
| 333 |
+
orig_end_position = None
|
| 334 |
+
if end_position in tok_s_to_ns_map:
|
| 335 |
+
ns_end_position = tok_s_to_ns_map[end_position]
|
| 336 |
+
if ns_end_position in orig_ns_to_s_map:
|
| 337 |
+
orig_end_position = orig_ns_to_s_map[ns_end_position]
|
| 338 |
+
|
| 339 |
+
if orig_end_position is None:
|
| 340 |
+
if verbose_logging:
|
| 341 |
+
logger.info("Couldn't map end position")
|
| 342 |
+
return orig_text
|
| 343 |
+
|
| 344 |
+
output_text = orig_text[orig_start_position : (orig_end_position + 1)]
|
| 345 |
+
return output_text
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def _get_best_indexes(logits, n_best_size):
|
| 349 |
+
"""Get the n-best logits from a list."""
|
| 350 |
+
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
| 351 |
+
|
| 352 |
+
best_indexes = []
|
| 353 |
+
for i in range(len(index_and_score)):
|
| 354 |
+
if i >= n_best_size:
|
| 355 |
+
break
|
| 356 |
+
best_indexes.append(index_and_score[i][0])
|
| 357 |
+
return best_indexes
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def _compute_softmax(scores):
|
| 361 |
+
"""Compute softmax probability over raw logits."""
|
| 362 |
+
if not scores:
|
| 363 |
+
return []
|
| 364 |
+
|
| 365 |
+
max_score = None
|
| 366 |
+
for score in scores:
|
| 367 |
+
if max_score is None or score > max_score:
|
| 368 |
+
max_score = score
|
| 369 |
+
|
| 370 |
+
exp_scores = []
|
| 371 |
+
total_sum = 0.0
|
| 372 |
+
for score in scores:
|
| 373 |
+
x = math.exp(score - max_score)
|
| 374 |
+
exp_scores.append(x)
|
| 375 |
+
total_sum += x
|
| 376 |
+
|
| 377 |
+
probs = []
|
| 378 |
+
for score in exp_scores:
|
| 379 |
+
probs.append(score / total_sum)
|
| 380 |
+
return probs
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def compute_predictions_logits(
|
| 384 |
+
all_examples,
|
| 385 |
+
all_features,
|
| 386 |
+
all_results,
|
| 387 |
+
n_best_size,
|
| 388 |
+
max_answer_length,
|
| 389 |
+
do_lower_case,
|
| 390 |
+
output_prediction_file,
|
| 391 |
+
output_nbest_file,
|
| 392 |
+
output_null_log_odds_file,
|
| 393 |
+
verbose_logging,
|
| 394 |
+
version_2_with_negative,
|
| 395 |
+
null_score_diff_threshold,
|
| 396 |
+
tokenizer,
|
| 397 |
+
):
|
| 398 |
+
"""Write final predictions to the json file and log-odds of null if needed."""
|
| 399 |
+
if output_prediction_file:
|
| 400 |
+
logger.info(f"Writing predictions to: {output_prediction_file}")
|
| 401 |
+
if output_nbest_file:
|
| 402 |
+
logger.info(f"Writing nbest to: {output_nbest_file}")
|
| 403 |
+
if output_null_log_odds_file and version_2_with_negative:
|
| 404 |
+
logger.info(f"Writing null_log_odds to: {output_null_log_odds_file}")
|
| 405 |
+
|
| 406 |
+
example_index_to_features = collections.defaultdict(list)
|
| 407 |
+
for feature in all_features:
|
| 408 |
+
example_index_to_features[feature.example_index].append(feature)
|
| 409 |
+
|
| 410 |
+
unique_id_to_result = {}
|
| 411 |
+
for result in all_results:
|
| 412 |
+
unique_id_to_result[result.unique_id] = result
|
| 413 |
+
|
| 414 |
+
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
| 415 |
+
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
all_predictions = collections.OrderedDict()
|
| 419 |
+
all_nbest_json = collections.OrderedDict()
|
| 420 |
+
scores_diff_json = collections.OrderedDict()
|
| 421 |
+
|
| 422 |
+
for example_index, example in enumerate(all_examples):
|
| 423 |
+
features = example_index_to_features[example_index]
|
| 424 |
+
|
| 425 |
+
prelim_predictions = []
|
| 426 |
+
# keep track of the minimum score of null start+end of position 0
|
| 427 |
+
score_null = 1000000 # large and positive
|
| 428 |
+
min_null_feature_index = 0 # the paragraph slice with min null score
|
| 429 |
+
null_start_logit = 0 # the start logit at the slice with min null score
|
| 430 |
+
null_end_logit = 0 # the end logit at the slice with min null score
|
| 431 |
+
for feature_index, feature in enumerate(features):
|
| 432 |
+
result = unique_id_to_result[feature.unique_id]
|
| 433 |
+
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
|
| 434 |
+
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
|
| 435 |
+
# if we could have irrelevant answers, get the min score of irrelevant
|
| 436 |
+
if version_2_with_negative:
|
| 437 |
+
feature_null_score = result.start_logits[0] + result.end_logits[0]
|
| 438 |
+
if feature_null_score < score_null:
|
| 439 |
+
score_null = feature_null_score
|
| 440 |
+
min_null_feature_index = feature_index
|
| 441 |
+
null_start_logit = result.start_logits[0]
|
| 442 |
+
null_end_logit = result.end_logits[0]
|
| 443 |
+
for start_index in start_indexes:
|
| 444 |
+
for end_index in end_indexes:
|
| 445 |
+
# We could hypothetically create invalid predictions, e.g., predict
|
| 446 |
+
# that the start of the span is in the question. We throw out all
|
| 447 |
+
# invalid predictions.
|
| 448 |
+
if start_index >= len(feature.tokens):
|
| 449 |
+
continue
|
| 450 |
+
if end_index >= len(feature.tokens):
|
| 451 |
+
continue
|
| 452 |
+
if start_index not in feature.token_to_orig_map:
|
| 453 |
+
continue
|
| 454 |
+
if end_index not in feature.token_to_orig_map:
|
| 455 |
+
continue
|
| 456 |
+
if not feature.token_is_max_context.get(start_index, False):
|
| 457 |
+
continue
|
| 458 |
+
if end_index < start_index:
|
| 459 |
+
continue
|
| 460 |
+
length = end_index - start_index + 1
|
| 461 |
+
if length > max_answer_length:
|
| 462 |
+
continue
|
| 463 |
+
prelim_predictions.append(
|
| 464 |
+
_PrelimPrediction(
|
| 465 |
+
feature_index=feature_index,
|
| 466 |
+
start_index=start_index,
|
| 467 |
+
end_index=end_index,
|
| 468 |
+
start_logit=result.start_logits[start_index],
|
| 469 |
+
end_logit=result.end_logits[end_index],
|
| 470 |
+
)
|
| 471 |
+
)
|
| 472 |
+
if version_2_with_negative:
|
| 473 |
+
prelim_predictions.append(
|
| 474 |
+
_PrelimPrediction(
|
| 475 |
+
feature_index=min_null_feature_index,
|
| 476 |
+
start_index=0,
|
| 477 |
+
end_index=0,
|
| 478 |
+
start_logit=null_start_logit,
|
| 479 |
+
end_logit=null_end_logit,
|
| 480 |
+
)
|
| 481 |
+
)
|
| 482 |
+
prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
|
| 483 |
+
|
| 484 |
+
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
| 485 |
+
"NbestPrediction", ["text", "start_logit", "end_logit"]
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
seen_predictions = {}
|
| 489 |
+
nbest = []
|
| 490 |
+
for pred in prelim_predictions:
|
| 491 |
+
if len(nbest) >= n_best_size:
|
| 492 |
+
break
|
| 493 |
+
feature = features[pred.feature_index]
|
| 494 |
+
if pred.start_index > 0: # this is a non-null prediction
|
| 495 |
+
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
|
| 496 |
+
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
| 497 |
+
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
| 498 |
+
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
|
| 499 |
+
|
| 500 |
+
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
|
| 501 |
+
|
| 502 |
+
# tok_text = " ".join(tok_tokens)
|
| 503 |
+
#
|
| 504 |
+
# # De-tokenize WordPieces that have been split off.
|
| 505 |
+
# tok_text = tok_text.replace(" ##", "")
|
| 506 |
+
# tok_text = tok_text.replace("##", "")
|
| 507 |
+
|
| 508 |
+
# Clean whitespace
|
| 509 |
+
tok_text = tok_text.strip()
|
| 510 |
+
tok_text = " ".join(tok_text.split())
|
| 511 |
+
orig_text = " ".join(orig_tokens)
|
| 512 |
+
|
| 513 |
+
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
|
| 514 |
+
if final_text in seen_predictions:
|
| 515 |
+
continue
|
| 516 |
+
|
| 517 |
+
seen_predictions[final_text] = True
|
| 518 |
+
else:
|
| 519 |
+
final_text = ""
|
| 520 |
+
seen_predictions[final_text] = True
|
| 521 |
+
|
| 522 |
+
nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit))
|
| 523 |
+
# if we didn't include the empty option in the n-best, include it
|
| 524 |
+
if version_2_with_negative:
|
| 525 |
+
if "" not in seen_predictions:
|
| 526 |
+
nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit))
|
| 527 |
+
|
| 528 |
+
# In very rare edge cases we could only have single null prediction.
|
| 529 |
+
# So we just create a nonce prediction in this case to avoid failure.
|
| 530 |
+
if len(nbest) == 1:
|
| 531 |
+
nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
| 532 |
+
|
| 533 |
+
# In very rare edge cases we could have no valid predictions. So we
|
| 534 |
+
# just create a nonce prediction in this case to avoid failure.
|
| 535 |
+
if not nbest:
|
| 536 |
+
nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
| 537 |
+
|
| 538 |
+
if len(nbest) < 1:
|
| 539 |
+
raise ValueError("No valid predictions")
|
| 540 |
+
|
| 541 |
+
total_scores = []
|
| 542 |
+
best_non_null_entry = None
|
| 543 |
+
for entry in nbest:
|
| 544 |
+
total_scores.append(entry.start_logit + entry.end_logit)
|
| 545 |
+
if not best_non_null_entry:
|
| 546 |
+
if entry.text:
|
| 547 |
+
best_non_null_entry = entry
|
| 548 |
+
|
| 549 |
+
probs = _compute_softmax(total_scores)
|
| 550 |
+
|
| 551 |
+
nbest_json = []
|
| 552 |
+
for i, entry in enumerate(nbest):
|
| 553 |
+
output = collections.OrderedDict()
|
| 554 |
+
output["text"] = entry.text
|
| 555 |
+
output["probability"] = probs[i]
|
| 556 |
+
output["start_logit"] = entry.start_logit
|
| 557 |
+
output["end_logit"] = entry.end_logit
|
| 558 |
+
nbest_json.append(output)
|
| 559 |
+
|
| 560 |
+
if len(nbest_json) < 1:
|
| 561 |
+
raise ValueError("No valid predictions")
|
| 562 |
+
|
| 563 |
+
if not version_2_with_negative:
|
| 564 |
+
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
| 565 |
+
else:
|
| 566 |
+
# predict "" iff the null score - the score of best non-null > threshold
|
| 567 |
+
score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit)
|
| 568 |
+
scores_diff_json[example.qas_id] = score_diff
|
| 569 |
+
if score_diff > null_score_diff_threshold:
|
| 570 |
+
all_predictions[example.qas_id] = ""
|
| 571 |
+
else:
|
| 572 |
+
all_predictions[example.qas_id] = best_non_null_entry.text
|
| 573 |
+
all_nbest_json[example.qas_id] = nbest_json
|
| 574 |
+
|
| 575 |
+
if output_prediction_file:
|
| 576 |
+
with open(output_prediction_file, "w") as writer:
|
| 577 |
+
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
| 578 |
+
|
| 579 |
+
if output_nbest_file:
|
| 580 |
+
with open(output_nbest_file, "w") as writer:
|
| 581 |
+
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
| 582 |
+
|
| 583 |
+
if output_null_log_odds_file and version_2_with_negative:
|
| 584 |
+
with open(output_null_log_odds_file, "w") as writer:
|
| 585 |
+
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
| 586 |
+
|
| 587 |
+
return all_predictions
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def compute_predictions_log_probs(
|
| 591 |
+
all_examples,
|
| 592 |
+
all_features,
|
| 593 |
+
all_results,
|
| 594 |
+
n_best_size,
|
| 595 |
+
max_answer_length,
|
| 596 |
+
output_prediction_file,
|
| 597 |
+
output_nbest_file,
|
| 598 |
+
output_null_log_odds_file,
|
| 599 |
+
start_n_top,
|
| 600 |
+
end_n_top,
|
| 601 |
+
version_2_with_negative,
|
| 602 |
+
tokenizer,
|
| 603 |
+
verbose_logging,
|
| 604 |
+
):
|
| 605 |
+
"""
|
| 606 |
+
XLNet write prediction logic (more complex than Bert's). Write final predictions to the json file and log-odds of
|
| 607 |
+
null if needed.
|
| 608 |
+
|
| 609 |
+
Requires utils_squad_evaluate.py
|
| 610 |
+
"""
|
| 611 |
+
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
| 612 |
+
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"]
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
| 616 |
+
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"]
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
logger.info(f"Writing predictions to: {output_prediction_file}")
|
| 620 |
+
|
| 621 |
+
example_index_to_features = collections.defaultdict(list)
|
| 622 |
+
for feature in all_features:
|
| 623 |
+
example_index_to_features[feature.example_index].append(feature)
|
| 624 |
+
|
| 625 |
+
unique_id_to_result = {}
|
| 626 |
+
for result in all_results:
|
| 627 |
+
unique_id_to_result[result.unique_id] = result
|
| 628 |
+
|
| 629 |
+
all_predictions = collections.OrderedDict()
|
| 630 |
+
all_nbest_json = collections.OrderedDict()
|
| 631 |
+
scores_diff_json = collections.OrderedDict()
|
| 632 |
+
|
| 633 |
+
for example_index, example in enumerate(all_examples):
|
| 634 |
+
features = example_index_to_features[example_index]
|
| 635 |
+
|
| 636 |
+
prelim_predictions = []
|
| 637 |
+
# keep track of the minimum score of null start+end of position 0
|
| 638 |
+
score_null = 1000000 # large and positive
|
| 639 |
+
|
| 640 |
+
for feature_index, feature in enumerate(features):
|
| 641 |
+
result = unique_id_to_result[feature.unique_id]
|
| 642 |
+
|
| 643 |
+
cur_null_score = result.cls_logits
|
| 644 |
+
|
| 645 |
+
# if we could have irrelevant answers, get the min score of irrelevant
|
| 646 |
+
score_null = min(score_null, cur_null_score)
|
| 647 |
+
|
| 648 |
+
for i in range(start_n_top):
|
| 649 |
+
for j in range(end_n_top):
|
| 650 |
+
start_log_prob = result.start_logits[i]
|
| 651 |
+
start_index = result.start_top_index[i]
|
| 652 |
+
|
| 653 |
+
j_index = i * end_n_top + j
|
| 654 |
+
|
| 655 |
+
end_log_prob = result.end_logits[j_index]
|
| 656 |
+
end_index = result.end_top_index[j_index]
|
| 657 |
+
|
| 658 |
+
# We could hypothetically create invalid predictions, e.g., predict
|
| 659 |
+
# that the start of the span is in the question. We throw out all
|
| 660 |
+
# invalid predictions.
|
| 661 |
+
if start_index >= feature.paragraph_len - 1:
|
| 662 |
+
continue
|
| 663 |
+
if end_index >= feature.paragraph_len - 1:
|
| 664 |
+
continue
|
| 665 |
+
|
| 666 |
+
if not feature.token_is_max_context.get(start_index, False):
|
| 667 |
+
continue
|
| 668 |
+
if end_index < start_index:
|
| 669 |
+
continue
|
| 670 |
+
length = end_index - start_index + 1
|
| 671 |
+
if length > max_answer_length:
|
| 672 |
+
continue
|
| 673 |
+
|
| 674 |
+
prelim_predictions.append(
|
| 675 |
+
_PrelimPrediction(
|
| 676 |
+
feature_index=feature_index,
|
| 677 |
+
start_index=start_index,
|
| 678 |
+
end_index=end_index,
|
| 679 |
+
start_log_prob=start_log_prob,
|
| 680 |
+
end_log_prob=end_log_prob,
|
| 681 |
+
)
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
prelim_predictions = sorted(
|
| 685 |
+
prelim_predictions, key=lambda x: (x.start_log_prob + x.end_log_prob), reverse=True
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
seen_predictions = {}
|
| 689 |
+
nbest = []
|
| 690 |
+
for pred in prelim_predictions:
|
| 691 |
+
if len(nbest) >= n_best_size:
|
| 692 |
+
break
|
| 693 |
+
feature = features[pred.feature_index]
|
| 694 |
+
|
| 695 |
+
# XLNet un-tokenizer
|
| 696 |
+
# Let's keep it simple for now and see if we need all this later.
|
| 697 |
+
#
|
| 698 |
+
# tok_start_to_orig_index = feature.tok_start_to_orig_index
|
| 699 |
+
# tok_end_to_orig_index = feature.tok_end_to_orig_index
|
| 700 |
+
# start_orig_pos = tok_start_to_orig_index[pred.start_index]
|
| 701 |
+
# end_orig_pos = tok_end_to_orig_index[pred.end_index]
|
| 702 |
+
# paragraph_text = example.paragraph_text
|
| 703 |
+
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
|
| 704 |
+
|
| 705 |
+
# Previously used Bert untokenizer
|
| 706 |
+
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
|
| 707 |
+
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
| 708 |
+
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
| 709 |
+
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
|
| 710 |
+
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
|
| 711 |
+
|
| 712 |
+
# Clean whitespace
|
| 713 |
+
tok_text = tok_text.strip()
|
| 714 |
+
tok_text = " ".join(tok_text.split())
|
| 715 |
+
orig_text = " ".join(orig_tokens)
|
| 716 |
+
|
| 717 |
+
if hasattr(tokenizer, "do_lower_case"):
|
| 718 |
+
do_lower_case = tokenizer.do_lower_case
|
| 719 |
+
else:
|
| 720 |
+
do_lower_case = tokenizer.do_lowercase_and_remove_accent
|
| 721 |
+
|
| 722 |
+
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
|
| 723 |
+
|
| 724 |
+
if final_text in seen_predictions:
|
| 725 |
+
continue
|
| 726 |
+
|
| 727 |
+
seen_predictions[final_text] = True
|
| 728 |
+
|
| 729 |
+
nbest.append(
|
| 730 |
+
_NbestPrediction(text=final_text, start_log_prob=pred.start_log_prob, end_log_prob=pred.end_log_prob)
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
# In very rare edge cases we could have no valid predictions. So we
|
| 734 |
+
# just create a nonce prediction in this case to avoid failure.
|
| 735 |
+
if not nbest:
|
| 736 |
+
nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
|
| 737 |
+
|
| 738 |
+
total_scores = []
|
| 739 |
+
best_non_null_entry = None
|
| 740 |
+
for entry in nbest:
|
| 741 |
+
total_scores.append(entry.start_log_prob + entry.end_log_prob)
|
| 742 |
+
if not best_non_null_entry:
|
| 743 |
+
best_non_null_entry = entry
|
| 744 |
+
|
| 745 |
+
probs = _compute_softmax(total_scores)
|
| 746 |
+
|
| 747 |
+
nbest_json = []
|
| 748 |
+
for i, entry in enumerate(nbest):
|
| 749 |
+
output = collections.OrderedDict()
|
| 750 |
+
output["text"] = entry.text
|
| 751 |
+
output["probability"] = probs[i]
|
| 752 |
+
output["start_log_prob"] = entry.start_log_prob
|
| 753 |
+
output["end_log_prob"] = entry.end_log_prob
|
| 754 |
+
nbest_json.append(output)
|
| 755 |
+
|
| 756 |
+
if len(nbest_json) < 1:
|
| 757 |
+
raise ValueError("No valid predictions")
|
| 758 |
+
if best_non_null_entry is None:
|
| 759 |
+
raise ValueError("No valid predictions")
|
| 760 |
+
|
| 761 |
+
score_diff = score_null
|
| 762 |
+
scores_diff_json[example.qas_id] = score_diff
|
| 763 |
+
# note(zhiliny): always predict best_non_null_entry
|
| 764 |
+
# and the evaluation script will search for the best threshold
|
| 765 |
+
all_predictions[example.qas_id] = best_non_null_entry.text
|
| 766 |
+
|
| 767 |
+
all_nbest_json[example.qas_id] = nbest_json
|
| 768 |
+
|
| 769 |
+
with open(output_prediction_file, "w") as writer:
|
| 770 |
+
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
| 771 |
+
|
| 772 |
+
with open(output_nbest_file, "w") as writer:
|
| 773 |
+
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
| 774 |
+
|
| 775 |
+
if version_2_with_negative:
|
| 776 |
+
with open(output_null_log_odds_file, "w") as writer:
|
| 777 |
+
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
| 778 |
+
|
| 779 |
+
return all_predictions
|
.venv/Lib/site-packages/transformers/data/processors/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .glue import glue_convert_examples_to_features, glue_output_modes, glue_processors, glue_tasks_num_labels
|
| 16 |
+
from .squad import SquadExample, SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
|
| 17 |
+
from .utils import DataProcessor, InputExample, InputFeatures, SingleSentenceClassificationProcessor
|
| 18 |
+
from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
|
.venv/Lib/site-packages/transformers/data/processors/glue.py
ADDED
|
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""GLUE processors and helpers"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import warnings
|
| 20 |
+
from dataclasses import asdict
|
| 21 |
+
from enum import Enum
|
| 22 |
+
from typing import List, Optional, Union
|
| 23 |
+
|
| 24 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
| 25 |
+
from ...utils import is_tf_available, logging
|
| 26 |
+
from .utils import DataProcessor, InputExample, InputFeatures
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if is_tf_available():
|
| 30 |
+
import tensorflow as tf
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
DEPRECATION_WARNING = (
|
| 35 |
+
"This {0} will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
|
| 36 |
+
"library. You can have a look at this example script for pointers: "
|
| 37 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def glue_convert_examples_to_features(
|
| 42 |
+
examples: Union[List[InputExample], "tf.data.Dataset"],
|
| 43 |
+
tokenizer: PreTrainedTokenizer,
|
| 44 |
+
max_length: Optional[int] = None,
|
| 45 |
+
task=None,
|
| 46 |
+
label_list=None,
|
| 47 |
+
output_mode=None,
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Loads a data file into a list of `InputFeatures`
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
examples: List of `InputExamples` or `tf.data.Dataset` containing the examples.
|
| 54 |
+
tokenizer: Instance of a tokenizer that will tokenize the examples
|
| 55 |
+
max_length: Maximum example length. Defaults to the tokenizer's max_len
|
| 56 |
+
task: GLUE task
|
| 57 |
+
label_list: List of labels. Can be obtained from the processor using the `processor.get_labels()` method
|
| 58 |
+
output_mode: String indicating the output mode. Either `regression` or `classification`
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
If the `examples` input is a `tf.data.Dataset`, will return a `tf.data.Dataset` containing the task-specific
|
| 62 |
+
features. If the input is a list of `InputExamples`, will return a list of task-specific `InputFeatures` which
|
| 63 |
+
can be fed to the model.
|
| 64 |
+
|
| 65 |
+
"""
|
| 66 |
+
warnings.warn(DEPRECATION_WARNING.format("function"), FutureWarning)
|
| 67 |
+
if is_tf_available() and isinstance(examples, tf.data.Dataset):
|
| 68 |
+
if task is None:
|
| 69 |
+
raise ValueError("When calling glue_convert_examples_to_features from TF, the task parameter is required.")
|
| 70 |
+
return _tf_glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
|
| 71 |
+
return _glue_convert_examples_to_features(
|
| 72 |
+
examples, tokenizer, max_length=max_length, task=task, label_list=label_list, output_mode=output_mode
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if is_tf_available():
|
| 77 |
+
|
| 78 |
+
def _tf_glue_convert_examples_to_features(
|
| 79 |
+
examples: tf.data.Dataset,
|
| 80 |
+
tokenizer: PreTrainedTokenizer,
|
| 81 |
+
task=str,
|
| 82 |
+
max_length: Optional[int] = None,
|
| 83 |
+
) -> tf.data.Dataset:
|
| 84 |
+
"""
|
| 85 |
+
Returns:
|
| 86 |
+
A `tf.data.Dataset` containing the task-specific features.
|
| 87 |
+
|
| 88 |
+
"""
|
| 89 |
+
processor = glue_processors[task]()
|
| 90 |
+
examples = [processor.tfds_map(processor.get_example_from_tensor_dict(example)) for example in examples]
|
| 91 |
+
features = glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
|
| 92 |
+
label_type = tf.float32 if task == "sts-b" else tf.int64
|
| 93 |
+
|
| 94 |
+
def gen():
|
| 95 |
+
for ex in features:
|
| 96 |
+
d = {k: v for k, v in asdict(ex).items() if v is not None}
|
| 97 |
+
label = d.pop("label")
|
| 98 |
+
yield (d, label)
|
| 99 |
+
|
| 100 |
+
input_names = tokenizer.model_input_names
|
| 101 |
+
|
| 102 |
+
return tf.data.Dataset.from_generator(
|
| 103 |
+
gen,
|
| 104 |
+
({k: tf.int32 for k in input_names}, label_type),
|
| 105 |
+
({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _glue_convert_examples_to_features(
|
| 110 |
+
examples: List[InputExample],
|
| 111 |
+
tokenizer: PreTrainedTokenizer,
|
| 112 |
+
max_length: Optional[int] = None,
|
| 113 |
+
task=None,
|
| 114 |
+
label_list=None,
|
| 115 |
+
output_mode=None,
|
| 116 |
+
):
|
| 117 |
+
if max_length is None:
|
| 118 |
+
max_length = tokenizer.model_max_length
|
| 119 |
+
|
| 120 |
+
if task is not None:
|
| 121 |
+
processor = glue_processors[task]()
|
| 122 |
+
if label_list is None:
|
| 123 |
+
label_list = processor.get_labels()
|
| 124 |
+
logger.info(f"Using label list {label_list} for task {task}")
|
| 125 |
+
if output_mode is None:
|
| 126 |
+
output_mode = glue_output_modes[task]
|
| 127 |
+
logger.info(f"Using output mode {output_mode} for task {task}")
|
| 128 |
+
|
| 129 |
+
label_map = {label: i for i, label in enumerate(label_list)}
|
| 130 |
+
|
| 131 |
+
def label_from_example(example: InputExample) -> Union[int, float, None]:
|
| 132 |
+
if example.label is None:
|
| 133 |
+
return None
|
| 134 |
+
if output_mode == "classification":
|
| 135 |
+
return label_map[example.label]
|
| 136 |
+
elif output_mode == "regression":
|
| 137 |
+
return float(example.label)
|
| 138 |
+
raise KeyError(output_mode)
|
| 139 |
+
|
| 140 |
+
labels = [label_from_example(example) for example in examples]
|
| 141 |
+
|
| 142 |
+
batch_encoding = tokenizer(
|
| 143 |
+
[(example.text_a, example.text_b) for example in examples],
|
| 144 |
+
max_length=max_length,
|
| 145 |
+
padding="max_length",
|
| 146 |
+
truncation=True,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
features = []
|
| 150 |
+
for i in range(len(examples)):
|
| 151 |
+
inputs = {k: batch_encoding[k][i] for k in batch_encoding}
|
| 152 |
+
|
| 153 |
+
feature = InputFeatures(**inputs, label=labels[i])
|
| 154 |
+
features.append(feature)
|
| 155 |
+
|
| 156 |
+
for i, example in enumerate(examples[:5]):
|
| 157 |
+
logger.info("*** Example ***")
|
| 158 |
+
logger.info(f"guid: {example.guid}")
|
| 159 |
+
logger.info(f"features: {features[i]}")
|
| 160 |
+
|
| 161 |
+
return features
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class OutputMode(Enum):
|
| 165 |
+
classification = "classification"
|
| 166 |
+
regression = "regression"
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class MrpcProcessor(DataProcessor):
|
| 170 |
+
"""Processor for the MRPC data set (GLUE version)."""
|
| 171 |
+
|
| 172 |
+
def __init__(self, *args, **kwargs):
|
| 173 |
+
super().__init__(*args, **kwargs)
|
| 174 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 175 |
+
|
| 176 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 177 |
+
"""See base class."""
|
| 178 |
+
return InputExample(
|
| 179 |
+
tensor_dict["idx"].numpy(),
|
| 180 |
+
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
| 181 |
+
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
| 182 |
+
str(tensor_dict["label"].numpy()),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
def get_train_examples(self, data_dir):
|
| 186 |
+
"""See base class."""
|
| 187 |
+
logger.info(f"LOOKING AT {os.path.join(data_dir, 'train.tsv')}")
|
| 188 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 189 |
+
|
| 190 |
+
def get_dev_examples(self, data_dir):
|
| 191 |
+
"""See base class."""
|
| 192 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 193 |
+
|
| 194 |
+
def get_test_examples(self, data_dir):
|
| 195 |
+
"""See base class."""
|
| 196 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 197 |
+
|
| 198 |
+
def get_labels(self):
|
| 199 |
+
"""See base class."""
|
| 200 |
+
return ["0", "1"]
|
| 201 |
+
|
| 202 |
+
def _create_examples(self, lines, set_type):
|
| 203 |
+
"""Creates examples for the training, dev and test sets."""
|
| 204 |
+
examples = []
|
| 205 |
+
for i, line in enumerate(lines):
|
| 206 |
+
if i == 0:
|
| 207 |
+
continue
|
| 208 |
+
guid = f"{set_type}-{i}"
|
| 209 |
+
text_a = line[3]
|
| 210 |
+
text_b = line[4]
|
| 211 |
+
label = None if set_type == "test" else line[0]
|
| 212 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 213 |
+
return examples
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class MnliProcessor(DataProcessor):
|
| 217 |
+
"""Processor for the MultiNLI data set (GLUE version)."""
|
| 218 |
+
|
| 219 |
+
def __init__(self, *args, **kwargs):
|
| 220 |
+
super().__init__(*args, **kwargs)
|
| 221 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 222 |
+
|
| 223 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 224 |
+
"""See base class."""
|
| 225 |
+
return InputExample(
|
| 226 |
+
tensor_dict["idx"].numpy(),
|
| 227 |
+
tensor_dict["premise"].numpy().decode("utf-8"),
|
| 228 |
+
tensor_dict["hypothesis"].numpy().decode("utf-8"),
|
| 229 |
+
str(tensor_dict["label"].numpy()),
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
def get_train_examples(self, data_dir):
|
| 233 |
+
"""See base class."""
|
| 234 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 235 |
+
|
| 236 |
+
def get_dev_examples(self, data_dir):
|
| 237 |
+
"""See base class."""
|
| 238 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
|
| 239 |
+
|
| 240 |
+
def get_test_examples(self, data_dir):
|
| 241 |
+
"""See base class."""
|
| 242 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched")
|
| 243 |
+
|
| 244 |
+
def get_labels(self):
|
| 245 |
+
"""See base class."""
|
| 246 |
+
return ["contradiction", "entailment", "neutral"]
|
| 247 |
+
|
| 248 |
+
def _create_examples(self, lines, set_type):
|
| 249 |
+
"""Creates examples for the training, dev and test sets."""
|
| 250 |
+
examples = []
|
| 251 |
+
for i, line in enumerate(lines):
|
| 252 |
+
if i == 0:
|
| 253 |
+
continue
|
| 254 |
+
guid = f"{set_type}-{line[0]}"
|
| 255 |
+
text_a = line[8]
|
| 256 |
+
text_b = line[9]
|
| 257 |
+
label = None if set_type.startswith("test") else line[-1]
|
| 258 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 259 |
+
return examples
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class MnliMismatchedProcessor(MnliProcessor):
|
| 263 |
+
"""Processor for the MultiNLI Mismatched data set (GLUE version)."""
|
| 264 |
+
|
| 265 |
+
def __init__(self, *args, **kwargs):
|
| 266 |
+
super().__init__(*args, **kwargs)
|
| 267 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 268 |
+
|
| 269 |
+
def get_dev_examples(self, data_dir):
|
| 270 |
+
"""See base class."""
|
| 271 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
|
| 272 |
+
|
| 273 |
+
def get_test_examples(self, data_dir):
|
| 274 |
+
"""See base class."""
|
| 275 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched")
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class ColaProcessor(DataProcessor):
|
| 279 |
+
"""Processor for the CoLA data set (GLUE version)."""
|
| 280 |
+
|
| 281 |
+
def __init__(self, *args, **kwargs):
|
| 282 |
+
super().__init__(*args, **kwargs)
|
| 283 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 284 |
+
|
| 285 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 286 |
+
"""See base class."""
|
| 287 |
+
return InputExample(
|
| 288 |
+
tensor_dict["idx"].numpy(),
|
| 289 |
+
tensor_dict["sentence"].numpy().decode("utf-8"),
|
| 290 |
+
None,
|
| 291 |
+
str(tensor_dict["label"].numpy()),
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
def get_train_examples(self, data_dir):
|
| 295 |
+
"""See base class."""
|
| 296 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 297 |
+
|
| 298 |
+
def get_dev_examples(self, data_dir):
|
| 299 |
+
"""See base class."""
|
| 300 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 301 |
+
|
| 302 |
+
def get_test_examples(self, data_dir):
|
| 303 |
+
"""See base class."""
|
| 304 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 305 |
+
|
| 306 |
+
def get_labels(self):
|
| 307 |
+
"""See base class."""
|
| 308 |
+
return ["0", "1"]
|
| 309 |
+
|
| 310 |
+
def _create_examples(self, lines, set_type):
|
| 311 |
+
"""Creates examples for the training, dev and test sets."""
|
| 312 |
+
test_mode = set_type == "test"
|
| 313 |
+
if test_mode:
|
| 314 |
+
lines = lines[1:]
|
| 315 |
+
text_index = 1 if test_mode else 3
|
| 316 |
+
examples = []
|
| 317 |
+
for i, line in enumerate(lines):
|
| 318 |
+
guid = f"{set_type}-{i}"
|
| 319 |
+
text_a = line[text_index]
|
| 320 |
+
label = None if test_mode else line[1]
|
| 321 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
| 322 |
+
return examples
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class Sst2Processor(DataProcessor):
|
| 326 |
+
"""Processor for the SST-2 data set (GLUE version)."""
|
| 327 |
+
|
| 328 |
+
def __init__(self, *args, **kwargs):
|
| 329 |
+
super().__init__(*args, **kwargs)
|
| 330 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 331 |
+
|
| 332 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 333 |
+
"""See base class."""
|
| 334 |
+
return InputExample(
|
| 335 |
+
tensor_dict["idx"].numpy(),
|
| 336 |
+
tensor_dict["sentence"].numpy().decode("utf-8"),
|
| 337 |
+
None,
|
| 338 |
+
str(tensor_dict["label"].numpy()),
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
def get_train_examples(self, data_dir):
|
| 342 |
+
"""See base class."""
|
| 343 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 344 |
+
|
| 345 |
+
def get_dev_examples(self, data_dir):
|
| 346 |
+
"""See base class."""
|
| 347 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 348 |
+
|
| 349 |
+
def get_test_examples(self, data_dir):
|
| 350 |
+
"""See base class."""
|
| 351 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 352 |
+
|
| 353 |
+
def get_labels(self):
|
| 354 |
+
"""See base class."""
|
| 355 |
+
return ["0", "1"]
|
| 356 |
+
|
| 357 |
+
def _create_examples(self, lines, set_type):
|
| 358 |
+
"""Creates examples for the training, dev and test sets."""
|
| 359 |
+
examples = []
|
| 360 |
+
text_index = 1 if set_type == "test" else 0
|
| 361 |
+
for i, line in enumerate(lines):
|
| 362 |
+
if i == 0:
|
| 363 |
+
continue
|
| 364 |
+
guid = f"{set_type}-{i}"
|
| 365 |
+
text_a = line[text_index]
|
| 366 |
+
label = None if set_type == "test" else line[1]
|
| 367 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
| 368 |
+
return examples
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class StsbProcessor(DataProcessor):
|
| 372 |
+
"""Processor for the STS-B data set (GLUE version)."""
|
| 373 |
+
|
| 374 |
+
def __init__(self, *args, **kwargs):
|
| 375 |
+
super().__init__(*args, **kwargs)
|
| 376 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 377 |
+
|
| 378 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 379 |
+
"""See base class."""
|
| 380 |
+
return InputExample(
|
| 381 |
+
tensor_dict["idx"].numpy(),
|
| 382 |
+
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
| 383 |
+
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
| 384 |
+
str(tensor_dict["label"].numpy()),
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
def get_train_examples(self, data_dir):
|
| 388 |
+
"""See base class."""
|
| 389 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 390 |
+
|
| 391 |
+
def get_dev_examples(self, data_dir):
|
| 392 |
+
"""See base class."""
|
| 393 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 394 |
+
|
| 395 |
+
def get_test_examples(self, data_dir):
|
| 396 |
+
"""See base class."""
|
| 397 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 398 |
+
|
| 399 |
+
def get_labels(self):
|
| 400 |
+
"""See base class."""
|
| 401 |
+
return [None]
|
| 402 |
+
|
| 403 |
+
def _create_examples(self, lines, set_type):
|
| 404 |
+
"""Creates examples for the training, dev and test sets."""
|
| 405 |
+
examples = []
|
| 406 |
+
for i, line in enumerate(lines):
|
| 407 |
+
if i == 0:
|
| 408 |
+
continue
|
| 409 |
+
guid = f"{set_type}-{line[0]}"
|
| 410 |
+
text_a = line[7]
|
| 411 |
+
text_b = line[8]
|
| 412 |
+
label = None if set_type == "test" else line[-1]
|
| 413 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 414 |
+
return examples
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
class QqpProcessor(DataProcessor):
|
| 418 |
+
"""Processor for the QQP data set (GLUE version)."""
|
| 419 |
+
|
| 420 |
+
def __init__(self, *args, **kwargs):
|
| 421 |
+
super().__init__(*args, **kwargs)
|
| 422 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 423 |
+
|
| 424 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 425 |
+
"""See base class."""
|
| 426 |
+
return InputExample(
|
| 427 |
+
tensor_dict["idx"].numpy(),
|
| 428 |
+
tensor_dict["question1"].numpy().decode("utf-8"),
|
| 429 |
+
tensor_dict["question2"].numpy().decode("utf-8"),
|
| 430 |
+
str(tensor_dict["label"].numpy()),
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
def get_train_examples(self, data_dir):
|
| 434 |
+
"""See base class."""
|
| 435 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 436 |
+
|
| 437 |
+
def get_dev_examples(self, data_dir):
|
| 438 |
+
"""See base class."""
|
| 439 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 440 |
+
|
| 441 |
+
def get_test_examples(self, data_dir):
|
| 442 |
+
"""See base class."""
|
| 443 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 444 |
+
|
| 445 |
+
def get_labels(self):
|
| 446 |
+
"""See base class."""
|
| 447 |
+
return ["0", "1"]
|
| 448 |
+
|
| 449 |
+
def _create_examples(self, lines, set_type):
|
| 450 |
+
"""Creates examples for the training, dev and test sets."""
|
| 451 |
+
test_mode = set_type == "test"
|
| 452 |
+
q1_index = 1 if test_mode else 3
|
| 453 |
+
q2_index = 2 if test_mode else 4
|
| 454 |
+
examples = []
|
| 455 |
+
for i, line in enumerate(lines):
|
| 456 |
+
if i == 0:
|
| 457 |
+
continue
|
| 458 |
+
guid = f"{set_type}-{line[0]}"
|
| 459 |
+
try:
|
| 460 |
+
text_a = line[q1_index]
|
| 461 |
+
text_b = line[q2_index]
|
| 462 |
+
label = None if test_mode else line[5]
|
| 463 |
+
except IndexError:
|
| 464 |
+
continue
|
| 465 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 466 |
+
return examples
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
class QnliProcessor(DataProcessor):
|
| 470 |
+
"""Processor for the QNLI data set (GLUE version)."""
|
| 471 |
+
|
| 472 |
+
def __init__(self, *args, **kwargs):
|
| 473 |
+
super().__init__(*args, **kwargs)
|
| 474 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 475 |
+
|
| 476 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 477 |
+
"""See base class."""
|
| 478 |
+
return InputExample(
|
| 479 |
+
tensor_dict["idx"].numpy(),
|
| 480 |
+
tensor_dict["question"].numpy().decode("utf-8"),
|
| 481 |
+
tensor_dict["sentence"].numpy().decode("utf-8"),
|
| 482 |
+
str(tensor_dict["label"].numpy()),
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
def get_train_examples(self, data_dir):
|
| 486 |
+
"""See base class."""
|
| 487 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 488 |
+
|
| 489 |
+
def get_dev_examples(self, data_dir):
|
| 490 |
+
"""See base class."""
|
| 491 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 492 |
+
|
| 493 |
+
def get_test_examples(self, data_dir):
|
| 494 |
+
"""See base class."""
|
| 495 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 496 |
+
|
| 497 |
+
def get_labels(self):
|
| 498 |
+
"""See base class."""
|
| 499 |
+
return ["entailment", "not_entailment"]
|
| 500 |
+
|
| 501 |
+
def _create_examples(self, lines, set_type):
|
| 502 |
+
"""Creates examples for the training, dev and test sets."""
|
| 503 |
+
examples = []
|
| 504 |
+
for i, line in enumerate(lines):
|
| 505 |
+
if i == 0:
|
| 506 |
+
continue
|
| 507 |
+
guid = f"{set_type}-{line[0]}"
|
| 508 |
+
text_a = line[1]
|
| 509 |
+
text_b = line[2]
|
| 510 |
+
label = None if set_type == "test" else line[-1]
|
| 511 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 512 |
+
return examples
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
class RteProcessor(DataProcessor):
|
| 516 |
+
"""Processor for the RTE data set (GLUE version)."""
|
| 517 |
+
|
| 518 |
+
def __init__(self, *args, **kwargs):
|
| 519 |
+
super().__init__(*args, **kwargs)
|
| 520 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 521 |
+
|
| 522 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 523 |
+
"""See base class."""
|
| 524 |
+
return InputExample(
|
| 525 |
+
tensor_dict["idx"].numpy(),
|
| 526 |
+
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
| 527 |
+
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
| 528 |
+
str(tensor_dict["label"].numpy()),
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
def get_train_examples(self, data_dir):
|
| 532 |
+
"""See base class."""
|
| 533 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 534 |
+
|
| 535 |
+
def get_dev_examples(self, data_dir):
|
| 536 |
+
"""See base class."""
|
| 537 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 538 |
+
|
| 539 |
+
def get_test_examples(self, data_dir):
|
| 540 |
+
"""See base class."""
|
| 541 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 542 |
+
|
| 543 |
+
def get_labels(self):
|
| 544 |
+
"""See base class."""
|
| 545 |
+
return ["entailment", "not_entailment"]
|
| 546 |
+
|
| 547 |
+
def _create_examples(self, lines, set_type):
|
| 548 |
+
"""Creates examples for the training, dev and test sets."""
|
| 549 |
+
examples = []
|
| 550 |
+
for i, line in enumerate(lines):
|
| 551 |
+
if i == 0:
|
| 552 |
+
continue
|
| 553 |
+
guid = f"{set_type}-{line[0]}"
|
| 554 |
+
text_a = line[1]
|
| 555 |
+
text_b = line[2]
|
| 556 |
+
label = None if set_type == "test" else line[-1]
|
| 557 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 558 |
+
return examples
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
class WnliProcessor(DataProcessor):
|
| 562 |
+
"""Processor for the WNLI data set (GLUE version)."""
|
| 563 |
+
|
| 564 |
+
def __init__(self, *args, **kwargs):
|
| 565 |
+
super().__init__(*args, **kwargs)
|
| 566 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
| 567 |
+
|
| 568 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 569 |
+
"""See base class."""
|
| 570 |
+
return InputExample(
|
| 571 |
+
tensor_dict["idx"].numpy(),
|
| 572 |
+
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
| 573 |
+
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
| 574 |
+
str(tensor_dict["label"].numpy()),
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
def get_train_examples(self, data_dir):
|
| 578 |
+
"""See base class."""
|
| 579 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
| 580 |
+
|
| 581 |
+
def get_dev_examples(self, data_dir):
|
| 582 |
+
"""See base class."""
|
| 583 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
| 584 |
+
|
| 585 |
+
def get_test_examples(self, data_dir):
|
| 586 |
+
"""See base class."""
|
| 587 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
| 588 |
+
|
| 589 |
+
def get_labels(self):
|
| 590 |
+
"""See base class."""
|
| 591 |
+
return ["0", "1"]
|
| 592 |
+
|
| 593 |
+
def _create_examples(self, lines, set_type):
|
| 594 |
+
"""Creates examples for the training, dev and test sets."""
|
| 595 |
+
examples = []
|
| 596 |
+
for i, line in enumerate(lines):
|
| 597 |
+
if i == 0:
|
| 598 |
+
continue
|
| 599 |
+
guid = f"{set_type}-{line[0]}"
|
| 600 |
+
text_a = line[1]
|
| 601 |
+
text_b = line[2]
|
| 602 |
+
label = None if set_type == "test" else line[-1]
|
| 603 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 604 |
+
return examples
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
glue_tasks_num_labels = {
|
| 608 |
+
"cola": 2,
|
| 609 |
+
"mnli": 3,
|
| 610 |
+
"mrpc": 2,
|
| 611 |
+
"sst-2": 2,
|
| 612 |
+
"sts-b": 1,
|
| 613 |
+
"qqp": 2,
|
| 614 |
+
"qnli": 2,
|
| 615 |
+
"rte": 2,
|
| 616 |
+
"wnli": 2,
|
| 617 |
+
}
|
| 618 |
+
|
| 619 |
+
glue_processors = {
|
| 620 |
+
"cola": ColaProcessor,
|
| 621 |
+
"mnli": MnliProcessor,
|
| 622 |
+
"mnli-mm": MnliMismatchedProcessor,
|
| 623 |
+
"mrpc": MrpcProcessor,
|
| 624 |
+
"sst-2": Sst2Processor,
|
| 625 |
+
"sts-b": StsbProcessor,
|
| 626 |
+
"qqp": QqpProcessor,
|
| 627 |
+
"qnli": QnliProcessor,
|
| 628 |
+
"rte": RteProcessor,
|
| 629 |
+
"wnli": WnliProcessor,
|
| 630 |
+
}
|
| 631 |
+
|
| 632 |
+
glue_output_modes = {
|
| 633 |
+
"cola": "classification",
|
| 634 |
+
"mnli": "classification",
|
| 635 |
+
"mnli-mm": "classification",
|
| 636 |
+
"mrpc": "classification",
|
| 637 |
+
"sst-2": "classification",
|
| 638 |
+
"sts-b": "regression",
|
| 639 |
+
"qqp": "classification",
|
| 640 |
+
"qnli": "classification",
|
| 641 |
+
"rte": "classification",
|
| 642 |
+
"wnli": "classification",
|
| 643 |
+
}
|
.venv/Lib/site-packages/transformers/data/processors/squad.py
ADDED
|
@@ -0,0 +1,845 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
from functools import partial
|
| 18 |
+
from multiprocessing import Pool, cpu_count
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
from ...models.bert.tokenization_bert import whitespace_tokenize
|
| 24 |
+
from ...tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TruncationStrategy
|
| 25 |
+
from ...utils import is_tf_available, is_torch_available, logging
|
| 26 |
+
from .utils import DataProcessor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Store the tokenizers which insert 2 separators tokens
|
| 30 |
+
MULTI_SEP_TOKENS_TOKENIZERS_SET = {"roberta", "camembert", "bart", "mpnet"}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if is_torch_available():
|
| 34 |
+
import torch
|
| 35 |
+
from torch.utils.data import TensorDataset
|
| 36 |
+
|
| 37 |
+
if is_tf_available():
|
| 38 |
+
import tensorflow as tf
|
| 39 |
+
|
| 40 |
+
logger = logging.get_logger(__name__)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
|
| 44 |
+
"""Returns tokenized answer spans that better match the annotated answer."""
|
| 45 |
+
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
|
| 46 |
+
|
| 47 |
+
for new_start in range(input_start, input_end + 1):
|
| 48 |
+
for new_end in range(input_end, new_start - 1, -1):
|
| 49 |
+
text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
|
| 50 |
+
if text_span == tok_answer_text:
|
| 51 |
+
return (new_start, new_end)
|
| 52 |
+
|
| 53 |
+
return (input_start, input_end)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _check_is_max_context(doc_spans, cur_span_index, position):
|
| 57 |
+
"""Check if this is the 'max context' doc span for the token."""
|
| 58 |
+
best_score = None
|
| 59 |
+
best_span_index = None
|
| 60 |
+
for span_index, doc_span in enumerate(doc_spans):
|
| 61 |
+
end = doc_span.start + doc_span.length - 1
|
| 62 |
+
if position < doc_span.start:
|
| 63 |
+
continue
|
| 64 |
+
if position > end:
|
| 65 |
+
continue
|
| 66 |
+
num_left_context = position - doc_span.start
|
| 67 |
+
num_right_context = end - position
|
| 68 |
+
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
|
| 69 |
+
if best_score is None or score > best_score:
|
| 70 |
+
best_score = score
|
| 71 |
+
best_span_index = span_index
|
| 72 |
+
|
| 73 |
+
return cur_span_index == best_span_index
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _new_check_is_max_context(doc_spans, cur_span_index, position):
|
| 77 |
+
"""Check if this is the 'max context' doc span for the token."""
|
| 78 |
+
# if len(doc_spans) == 1:
|
| 79 |
+
# return True
|
| 80 |
+
best_score = None
|
| 81 |
+
best_span_index = None
|
| 82 |
+
for span_index, doc_span in enumerate(doc_spans):
|
| 83 |
+
end = doc_span["start"] + doc_span["length"] - 1
|
| 84 |
+
if position < doc_span["start"]:
|
| 85 |
+
continue
|
| 86 |
+
if position > end:
|
| 87 |
+
continue
|
| 88 |
+
num_left_context = position - doc_span["start"]
|
| 89 |
+
num_right_context = end - position
|
| 90 |
+
score = min(num_left_context, num_right_context) + 0.01 * doc_span["length"]
|
| 91 |
+
if best_score is None or score > best_score:
|
| 92 |
+
best_score = score
|
| 93 |
+
best_span_index = span_index
|
| 94 |
+
|
| 95 |
+
return cur_span_index == best_span_index
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _is_whitespace(c):
|
| 99 |
+
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
|
| 100 |
+
return True
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def squad_convert_example_to_features(
|
| 105 |
+
example, max_seq_length, doc_stride, max_query_length, padding_strategy, is_training
|
| 106 |
+
):
|
| 107 |
+
features = []
|
| 108 |
+
if is_training and not example.is_impossible:
|
| 109 |
+
# Get start and end position
|
| 110 |
+
start_position = example.start_position
|
| 111 |
+
end_position = example.end_position
|
| 112 |
+
|
| 113 |
+
# If the answer cannot be found in the text, then skip this example.
|
| 114 |
+
actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
|
| 115 |
+
cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
|
| 116 |
+
if actual_text.find(cleaned_answer_text) == -1:
|
| 117 |
+
logger.warning(f"Could not find answer: '{actual_text}' vs. '{cleaned_answer_text}'")
|
| 118 |
+
return []
|
| 119 |
+
|
| 120 |
+
tok_to_orig_index = []
|
| 121 |
+
orig_to_tok_index = []
|
| 122 |
+
all_doc_tokens = []
|
| 123 |
+
for i, token in enumerate(example.doc_tokens):
|
| 124 |
+
orig_to_tok_index.append(len(all_doc_tokens))
|
| 125 |
+
if tokenizer.__class__.__name__ in [
|
| 126 |
+
"RobertaTokenizer",
|
| 127 |
+
"LongformerTokenizer",
|
| 128 |
+
"BartTokenizer",
|
| 129 |
+
"RobertaTokenizerFast",
|
| 130 |
+
"LongformerTokenizerFast",
|
| 131 |
+
"BartTokenizerFast",
|
| 132 |
+
]:
|
| 133 |
+
sub_tokens = tokenizer.tokenize(token, add_prefix_space=True)
|
| 134 |
+
else:
|
| 135 |
+
sub_tokens = tokenizer.tokenize(token)
|
| 136 |
+
for sub_token in sub_tokens:
|
| 137 |
+
tok_to_orig_index.append(i)
|
| 138 |
+
all_doc_tokens.append(sub_token)
|
| 139 |
+
|
| 140 |
+
if is_training and not example.is_impossible:
|
| 141 |
+
tok_start_position = orig_to_tok_index[example.start_position]
|
| 142 |
+
if example.end_position < len(example.doc_tokens) - 1:
|
| 143 |
+
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
|
| 144 |
+
else:
|
| 145 |
+
tok_end_position = len(all_doc_tokens) - 1
|
| 146 |
+
|
| 147 |
+
(tok_start_position, tok_end_position) = _improve_answer_span(
|
| 148 |
+
all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
spans = []
|
| 152 |
+
|
| 153 |
+
truncated_query = tokenizer.encode(
|
| 154 |
+
example.question_text, add_special_tokens=False, truncation=True, max_length=max_query_length
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Tokenizers who insert 2 SEP tokens in-between <context> & <question> need to have special handling
|
| 158 |
+
# in the way they compute mask of added tokens.
|
| 159 |
+
tokenizer_type = type(tokenizer).__name__.replace("Tokenizer", "").lower()
|
| 160 |
+
sequence_added_tokens = (
|
| 161 |
+
tokenizer.model_max_length - tokenizer.max_len_single_sentence + 1
|
| 162 |
+
if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET
|
| 163 |
+
else tokenizer.model_max_length - tokenizer.max_len_single_sentence
|
| 164 |
+
)
|
| 165 |
+
sequence_pair_added_tokens = tokenizer.model_max_length - tokenizer.max_len_sentences_pair
|
| 166 |
+
|
| 167 |
+
span_doc_tokens = all_doc_tokens
|
| 168 |
+
while len(spans) * doc_stride < len(all_doc_tokens):
|
| 169 |
+
# Define the side we want to truncate / pad and the text/pair sorting
|
| 170 |
+
if tokenizer.padding_side == "right":
|
| 171 |
+
texts = truncated_query
|
| 172 |
+
pairs = span_doc_tokens
|
| 173 |
+
truncation = TruncationStrategy.ONLY_SECOND.value
|
| 174 |
+
else:
|
| 175 |
+
texts = span_doc_tokens
|
| 176 |
+
pairs = truncated_query
|
| 177 |
+
truncation = TruncationStrategy.ONLY_FIRST.value
|
| 178 |
+
|
| 179 |
+
encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic
|
| 180 |
+
texts,
|
| 181 |
+
pairs,
|
| 182 |
+
truncation=truncation,
|
| 183 |
+
padding=padding_strategy,
|
| 184 |
+
max_length=max_seq_length,
|
| 185 |
+
return_overflowing_tokens=True,
|
| 186 |
+
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
|
| 187 |
+
return_token_type_ids=True,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
paragraph_len = min(
|
| 191 |
+
len(all_doc_tokens) - len(spans) * doc_stride,
|
| 192 |
+
max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if tokenizer.pad_token_id in encoded_dict["input_ids"]:
|
| 196 |
+
if tokenizer.padding_side == "right":
|
| 197 |
+
non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
|
| 198 |
+
else:
|
| 199 |
+
last_padding_id_position = (
|
| 200 |
+
len(encoded_dict["input_ids"]) - 1 - encoded_dict["input_ids"][::-1].index(tokenizer.pad_token_id)
|
| 201 |
+
)
|
| 202 |
+
non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1 :]
|
| 203 |
+
|
| 204 |
+
else:
|
| 205 |
+
non_padded_ids = encoded_dict["input_ids"]
|
| 206 |
+
|
| 207 |
+
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
|
| 208 |
+
|
| 209 |
+
token_to_orig_map = {}
|
| 210 |
+
for i in range(paragraph_len):
|
| 211 |
+
index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i
|
| 212 |
+
token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
|
| 213 |
+
|
| 214 |
+
encoded_dict["paragraph_len"] = paragraph_len
|
| 215 |
+
encoded_dict["tokens"] = tokens
|
| 216 |
+
encoded_dict["token_to_orig_map"] = token_to_orig_map
|
| 217 |
+
encoded_dict["truncated_query_with_special_tokens_length"] = len(truncated_query) + sequence_added_tokens
|
| 218 |
+
encoded_dict["token_is_max_context"] = {}
|
| 219 |
+
encoded_dict["start"] = len(spans) * doc_stride
|
| 220 |
+
encoded_dict["length"] = paragraph_len
|
| 221 |
+
|
| 222 |
+
spans.append(encoded_dict)
|
| 223 |
+
|
| 224 |
+
if "overflowing_tokens" not in encoded_dict or (
|
| 225 |
+
"overflowing_tokens" in encoded_dict and len(encoded_dict["overflowing_tokens"]) == 0
|
| 226 |
+
):
|
| 227 |
+
break
|
| 228 |
+
span_doc_tokens = encoded_dict["overflowing_tokens"]
|
| 229 |
+
|
| 230 |
+
for doc_span_index in range(len(spans)):
|
| 231 |
+
for j in range(spans[doc_span_index]["paragraph_len"]):
|
| 232 |
+
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
|
| 233 |
+
index = (
|
| 234 |
+
j
|
| 235 |
+
if tokenizer.padding_side == "left"
|
| 236 |
+
else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
|
| 237 |
+
)
|
| 238 |
+
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
|
| 239 |
+
|
| 240 |
+
for span in spans:
|
| 241 |
+
# Identify the position of the CLS token
|
| 242 |
+
cls_index = span["input_ids"].index(tokenizer.cls_token_id)
|
| 243 |
+
|
| 244 |
+
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
|
| 245 |
+
# Original TF implementation also keep the classification token (set to 0)
|
| 246 |
+
p_mask = np.ones_like(span["token_type_ids"])
|
| 247 |
+
if tokenizer.padding_side == "right":
|
| 248 |
+
p_mask[len(truncated_query) + sequence_added_tokens :] = 0
|
| 249 |
+
else:
|
| 250 |
+
p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0
|
| 251 |
+
|
| 252 |
+
pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id)
|
| 253 |
+
special_token_indices = np.asarray(
|
| 254 |
+
tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True)
|
| 255 |
+
).nonzero()
|
| 256 |
+
|
| 257 |
+
p_mask[pad_token_indices] = 1
|
| 258 |
+
p_mask[special_token_indices] = 1
|
| 259 |
+
|
| 260 |
+
# Set the cls index to 0: the CLS index can be used for impossible answers
|
| 261 |
+
p_mask[cls_index] = 0
|
| 262 |
+
|
| 263 |
+
span_is_impossible = example.is_impossible
|
| 264 |
+
start_position = 0
|
| 265 |
+
end_position = 0
|
| 266 |
+
if is_training and not span_is_impossible:
|
| 267 |
+
# For training, if our document chunk does not contain an annotation
|
| 268 |
+
# we throw it out, since there is nothing to predict.
|
| 269 |
+
doc_start = span["start"]
|
| 270 |
+
doc_end = span["start"] + span["length"] - 1
|
| 271 |
+
out_of_span = False
|
| 272 |
+
|
| 273 |
+
if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
|
| 274 |
+
out_of_span = True
|
| 275 |
+
|
| 276 |
+
if out_of_span:
|
| 277 |
+
start_position = cls_index
|
| 278 |
+
end_position = cls_index
|
| 279 |
+
span_is_impossible = True
|
| 280 |
+
else:
|
| 281 |
+
if tokenizer.padding_side == "left":
|
| 282 |
+
doc_offset = 0
|
| 283 |
+
else:
|
| 284 |
+
doc_offset = len(truncated_query) + sequence_added_tokens
|
| 285 |
+
|
| 286 |
+
start_position = tok_start_position - doc_start + doc_offset
|
| 287 |
+
end_position = tok_end_position - doc_start + doc_offset
|
| 288 |
+
|
| 289 |
+
features.append(
|
| 290 |
+
SquadFeatures(
|
| 291 |
+
span["input_ids"],
|
| 292 |
+
span["attention_mask"],
|
| 293 |
+
span["token_type_ids"],
|
| 294 |
+
cls_index,
|
| 295 |
+
p_mask.tolist(),
|
| 296 |
+
example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing.
|
| 297 |
+
unique_id=0,
|
| 298 |
+
paragraph_len=span["paragraph_len"],
|
| 299 |
+
token_is_max_context=span["token_is_max_context"],
|
| 300 |
+
tokens=span["tokens"],
|
| 301 |
+
token_to_orig_map=span["token_to_orig_map"],
|
| 302 |
+
start_position=start_position,
|
| 303 |
+
end_position=end_position,
|
| 304 |
+
is_impossible=span_is_impossible,
|
| 305 |
+
qas_id=example.qas_id,
|
| 306 |
+
)
|
| 307 |
+
)
|
| 308 |
+
return features
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def squad_convert_example_to_features_init(tokenizer_for_convert: PreTrainedTokenizerBase):
|
| 312 |
+
global tokenizer
|
| 313 |
+
tokenizer = tokenizer_for_convert
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def squad_convert_examples_to_features(
|
| 317 |
+
examples,
|
| 318 |
+
tokenizer,
|
| 319 |
+
max_seq_length,
|
| 320 |
+
doc_stride,
|
| 321 |
+
max_query_length,
|
| 322 |
+
is_training,
|
| 323 |
+
padding_strategy="max_length",
|
| 324 |
+
return_dataset=False,
|
| 325 |
+
threads=1,
|
| 326 |
+
tqdm_enabled=True,
|
| 327 |
+
):
|
| 328 |
+
"""
|
| 329 |
+
Converts a list of examples into a list of features that can be directly given as input to a model. It is
|
| 330 |
+
model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
examples: list of [`~data.processors.squad.SquadExample`]
|
| 334 |
+
tokenizer: an instance of a child of [`PreTrainedTokenizer`]
|
| 335 |
+
max_seq_length: The maximum sequence length of the inputs.
|
| 336 |
+
doc_stride: The stride used when the context is too large and is split across several features.
|
| 337 |
+
max_query_length: The maximum length of the query.
|
| 338 |
+
is_training: whether to create features for model evaluation or model training.
|
| 339 |
+
padding_strategy: Default to "max_length". Which padding strategy to use
|
| 340 |
+
return_dataset: Default False. Either 'pt' or 'tf'.
|
| 341 |
+
if 'pt': returns a torch.data.TensorDataset, if 'tf': returns a tf.data.Dataset
|
| 342 |
+
threads: multiple processing threads.
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
list of [`~data.processors.squad.SquadFeatures`]
|
| 347 |
+
|
| 348 |
+
Example:
|
| 349 |
+
|
| 350 |
+
```python
|
| 351 |
+
processor = SquadV2Processor()
|
| 352 |
+
examples = processor.get_dev_examples(data_dir)
|
| 353 |
+
|
| 354 |
+
features = squad_convert_examples_to_features(
|
| 355 |
+
examples=examples,
|
| 356 |
+
tokenizer=tokenizer,
|
| 357 |
+
max_seq_length=args.max_seq_length,
|
| 358 |
+
doc_stride=args.doc_stride,
|
| 359 |
+
max_query_length=args.max_query_length,
|
| 360 |
+
is_training=not evaluate,
|
| 361 |
+
)
|
| 362 |
+
```"""
|
| 363 |
+
# Defining helper methods
|
| 364 |
+
features = []
|
| 365 |
+
|
| 366 |
+
threads = min(threads, cpu_count())
|
| 367 |
+
with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
|
| 368 |
+
annotate_ = partial(
|
| 369 |
+
squad_convert_example_to_features,
|
| 370 |
+
max_seq_length=max_seq_length,
|
| 371 |
+
doc_stride=doc_stride,
|
| 372 |
+
max_query_length=max_query_length,
|
| 373 |
+
padding_strategy=padding_strategy,
|
| 374 |
+
is_training=is_training,
|
| 375 |
+
)
|
| 376 |
+
features = list(
|
| 377 |
+
tqdm(
|
| 378 |
+
p.imap(annotate_, examples, chunksize=32),
|
| 379 |
+
total=len(examples),
|
| 380 |
+
desc="convert squad examples to features",
|
| 381 |
+
disable=not tqdm_enabled,
|
| 382 |
+
)
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
new_features = []
|
| 386 |
+
unique_id = 1000000000
|
| 387 |
+
example_index = 0
|
| 388 |
+
for example_features in tqdm(
|
| 389 |
+
features, total=len(features), desc="add example index and unique id", disable=not tqdm_enabled
|
| 390 |
+
):
|
| 391 |
+
if not example_features:
|
| 392 |
+
continue
|
| 393 |
+
for example_feature in example_features:
|
| 394 |
+
example_feature.example_index = example_index
|
| 395 |
+
example_feature.unique_id = unique_id
|
| 396 |
+
new_features.append(example_feature)
|
| 397 |
+
unique_id += 1
|
| 398 |
+
example_index += 1
|
| 399 |
+
features = new_features
|
| 400 |
+
del new_features
|
| 401 |
+
if return_dataset == "pt":
|
| 402 |
+
if not is_torch_available():
|
| 403 |
+
raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.")
|
| 404 |
+
|
| 405 |
+
# Convert to Tensors and build dataset
|
| 406 |
+
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
| 407 |
+
all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
| 408 |
+
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
| 409 |
+
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
|
| 410 |
+
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
| 411 |
+
all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)
|
| 412 |
+
|
| 413 |
+
if not is_training:
|
| 414 |
+
all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
| 415 |
+
dataset = TensorDataset(
|
| 416 |
+
all_input_ids, all_attention_masks, all_token_type_ids, all_feature_index, all_cls_index, all_p_mask
|
| 417 |
+
)
|
| 418 |
+
else:
|
| 419 |
+
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
| 420 |
+
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
| 421 |
+
dataset = TensorDataset(
|
| 422 |
+
all_input_ids,
|
| 423 |
+
all_attention_masks,
|
| 424 |
+
all_token_type_ids,
|
| 425 |
+
all_start_positions,
|
| 426 |
+
all_end_positions,
|
| 427 |
+
all_cls_index,
|
| 428 |
+
all_p_mask,
|
| 429 |
+
all_is_impossible,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
return features, dataset
|
| 433 |
+
elif return_dataset == "tf":
|
| 434 |
+
if not is_tf_available():
|
| 435 |
+
raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.")
|
| 436 |
+
|
| 437 |
+
def gen():
|
| 438 |
+
for i, ex in enumerate(features):
|
| 439 |
+
if ex.token_type_ids is None:
|
| 440 |
+
yield (
|
| 441 |
+
{
|
| 442 |
+
"input_ids": ex.input_ids,
|
| 443 |
+
"attention_mask": ex.attention_mask,
|
| 444 |
+
"feature_index": i,
|
| 445 |
+
"qas_id": ex.qas_id,
|
| 446 |
+
},
|
| 447 |
+
{
|
| 448 |
+
"start_positions": ex.start_position,
|
| 449 |
+
"end_positions": ex.end_position,
|
| 450 |
+
"cls_index": ex.cls_index,
|
| 451 |
+
"p_mask": ex.p_mask,
|
| 452 |
+
"is_impossible": ex.is_impossible,
|
| 453 |
+
},
|
| 454 |
+
)
|
| 455 |
+
else:
|
| 456 |
+
yield (
|
| 457 |
+
{
|
| 458 |
+
"input_ids": ex.input_ids,
|
| 459 |
+
"attention_mask": ex.attention_mask,
|
| 460 |
+
"token_type_ids": ex.token_type_ids,
|
| 461 |
+
"feature_index": i,
|
| 462 |
+
"qas_id": ex.qas_id,
|
| 463 |
+
},
|
| 464 |
+
{
|
| 465 |
+
"start_positions": ex.start_position,
|
| 466 |
+
"end_positions": ex.end_position,
|
| 467 |
+
"cls_index": ex.cls_index,
|
| 468 |
+
"p_mask": ex.p_mask,
|
| 469 |
+
"is_impossible": ex.is_impossible,
|
| 470 |
+
},
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
# Why have we split the batch into a tuple? PyTorch just has a list of tensors.
|
| 474 |
+
if "token_type_ids" in tokenizer.model_input_names:
|
| 475 |
+
train_types = (
|
| 476 |
+
{
|
| 477 |
+
"input_ids": tf.int32,
|
| 478 |
+
"attention_mask": tf.int32,
|
| 479 |
+
"token_type_ids": tf.int32,
|
| 480 |
+
"feature_index": tf.int64,
|
| 481 |
+
"qas_id": tf.string,
|
| 482 |
+
},
|
| 483 |
+
{
|
| 484 |
+
"start_positions": tf.int64,
|
| 485 |
+
"end_positions": tf.int64,
|
| 486 |
+
"cls_index": tf.int64,
|
| 487 |
+
"p_mask": tf.int32,
|
| 488 |
+
"is_impossible": tf.int32,
|
| 489 |
+
},
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
train_shapes = (
|
| 493 |
+
{
|
| 494 |
+
"input_ids": tf.TensorShape([None]),
|
| 495 |
+
"attention_mask": tf.TensorShape([None]),
|
| 496 |
+
"token_type_ids": tf.TensorShape([None]),
|
| 497 |
+
"feature_index": tf.TensorShape([]),
|
| 498 |
+
"qas_id": tf.TensorShape([]),
|
| 499 |
+
},
|
| 500 |
+
{
|
| 501 |
+
"start_positions": tf.TensorShape([]),
|
| 502 |
+
"end_positions": tf.TensorShape([]),
|
| 503 |
+
"cls_index": tf.TensorShape([]),
|
| 504 |
+
"p_mask": tf.TensorShape([None]),
|
| 505 |
+
"is_impossible": tf.TensorShape([]),
|
| 506 |
+
},
|
| 507 |
+
)
|
| 508 |
+
else:
|
| 509 |
+
train_types = (
|
| 510 |
+
{"input_ids": tf.int32, "attention_mask": tf.int32, "feature_index": tf.int64, "qas_id": tf.string},
|
| 511 |
+
{
|
| 512 |
+
"start_positions": tf.int64,
|
| 513 |
+
"end_positions": tf.int64,
|
| 514 |
+
"cls_index": tf.int64,
|
| 515 |
+
"p_mask": tf.int32,
|
| 516 |
+
"is_impossible": tf.int32,
|
| 517 |
+
},
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
train_shapes = (
|
| 521 |
+
{
|
| 522 |
+
"input_ids": tf.TensorShape([None]),
|
| 523 |
+
"attention_mask": tf.TensorShape([None]),
|
| 524 |
+
"feature_index": tf.TensorShape([]),
|
| 525 |
+
"qas_id": tf.TensorShape([]),
|
| 526 |
+
},
|
| 527 |
+
{
|
| 528 |
+
"start_positions": tf.TensorShape([]),
|
| 529 |
+
"end_positions": tf.TensorShape([]),
|
| 530 |
+
"cls_index": tf.TensorShape([]),
|
| 531 |
+
"p_mask": tf.TensorShape([None]),
|
| 532 |
+
"is_impossible": tf.TensorShape([]),
|
| 533 |
+
},
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
return tf.data.Dataset.from_generator(gen, train_types, train_shapes)
|
| 537 |
+
else:
|
| 538 |
+
return features
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
class SquadProcessor(DataProcessor):
|
| 542 |
+
"""
|
| 543 |
+
Processor for the SQuAD data set. overridden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and
|
| 544 |
+
version 2.0 of SQuAD, respectively.
|
| 545 |
+
"""
|
| 546 |
+
|
| 547 |
+
train_file = None
|
| 548 |
+
dev_file = None
|
| 549 |
+
|
| 550 |
+
def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False):
|
| 551 |
+
if not evaluate:
|
| 552 |
+
answer = tensor_dict["answers"]["text"][0].numpy().decode("utf-8")
|
| 553 |
+
answer_start = tensor_dict["answers"]["answer_start"][0].numpy()
|
| 554 |
+
answers = []
|
| 555 |
+
else:
|
| 556 |
+
answers = [
|
| 557 |
+
{"answer_start": start.numpy(), "text": text.numpy().decode("utf-8")}
|
| 558 |
+
for start, text in zip(tensor_dict["answers"]["answer_start"], tensor_dict["answers"]["text"])
|
| 559 |
+
]
|
| 560 |
+
|
| 561 |
+
answer = None
|
| 562 |
+
answer_start = None
|
| 563 |
+
|
| 564 |
+
return SquadExample(
|
| 565 |
+
qas_id=tensor_dict["id"].numpy().decode("utf-8"),
|
| 566 |
+
question_text=tensor_dict["question"].numpy().decode("utf-8"),
|
| 567 |
+
context_text=tensor_dict["context"].numpy().decode("utf-8"),
|
| 568 |
+
answer_text=answer,
|
| 569 |
+
start_position_character=answer_start,
|
| 570 |
+
title=tensor_dict["title"].numpy().decode("utf-8"),
|
| 571 |
+
answers=answers,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
def get_examples_from_dataset(self, dataset, evaluate=False):
|
| 575 |
+
"""
|
| 576 |
+
Creates a list of [`~data.processors.squad.SquadExample`] using a TFDS dataset.
|
| 577 |
+
|
| 578 |
+
Args:
|
| 579 |
+
dataset: The tfds dataset loaded from *tensorflow_datasets.load("squad")*
|
| 580 |
+
evaluate: Boolean specifying if in evaluation mode or in training mode
|
| 581 |
+
|
| 582 |
+
Returns:
|
| 583 |
+
List of SquadExample
|
| 584 |
+
|
| 585 |
+
Examples:
|
| 586 |
+
|
| 587 |
+
```python
|
| 588 |
+
>>> import tensorflow_datasets as tfds
|
| 589 |
+
|
| 590 |
+
>>> dataset = tfds.load("squad")
|
| 591 |
+
|
| 592 |
+
>>> training_examples = get_examples_from_dataset(dataset, evaluate=False)
|
| 593 |
+
>>> evaluation_examples = get_examples_from_dataset(dataset, evaluate=True)
|
| 594 |
+
```"""
|
| 595 |
+
|
| 596 |
+
if evaluate:
|
| 597 |
+
dataset = dataset["validation"]
|
| 598 |
+
else:
|
| 599 |
+
dataset = dataset["train"]
|
| 600 |
+
|
| 601 |
+
examples = []
|
| 602 |
+
for tensor_dict in tqdm(dataset):
|
| 603 |
+
examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate))
|
| 604 |
+
|
| 605 |
+
return examples
|
| 606 |
+
|
| 607 |
+
def get_train_examples(self, data_dir, filename=None):
|
| 608 |
+
"""
|
| 609 |
+
Returns the training examples from the data directory.
|
| 610 |
+
|
| 611 |
+
Args:
|
| 612 |
+
data_dir: Directory containing the data files used for training and evaluating.
|
| 613 |
+
filename: None by default, specify this if the training file has a different name than the original one
|
| 614 |
+
which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
|
| 615 |
+
|
| 616 |
+
"""
|
| 617 |
+
if data_dir is None:
|
| 618 |
+
data_dir = ""
|
| 619 |
+
|
| 620 |
+
if self.train_file is None:
|
| 621 |
+
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
|
| 622 |
+
|
| 623 |
+
with open(
|
| 624 |
+
os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8"
|
| 625 |
+
) as reader:
|
| 626 |
+
input_data = json.load(reader)["data"]
|
| 627 |
+
return self._create_examples(input_data, "train")
|
| 628 |
+
|
| 629 |
+
def get_dev_examples(self, data_dir, filename=None):
|
| 630 |
+
"""
|
| 631 |
+
Returns the evaluation example from the data directory.
|
| 632 |
+
|
| 633 |
+
Args:
|
| 634 |
+
data_dir: Directory containing the data files used for training and evaluating.
|
| 635 |
+
filename: None by default, specify this if the evaluation file has a different name than the original one
|
| 636 |
+
which is `dev-v1.1.json` and `dev-v2.0.json` for squad versions 1.1 and 2.0 respectively.
|
| 637 |
+
"""
|
| 638 |
+
if data_dir is None:
|
| 639 |
+
data_dir = ""
|
| 640 |
+
|
| 641 |
+
if self.dev_file is None:
|
| 642 |
+
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
|
| 643 |
+
|
| 644 |
+
with open(
|
| 645 |
+
os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8"
|
| 646 |
+
) as reader:
|
| 647 |
+
input_data = json.load(reader)["data"]
|
| 648 |
+
return self._create_examples(input_data, "dev")
|
| 649 |
+
|
| 650 |
+
def _create_examples(self, input_data, set_type):
|
| 651 |
+
is_training = set_type == "train"
|
| 652 |
+
examples = []
|
| 653 |
+
for entry in tqdm(input_data):
|
| 654 |
+
title = entry["title"]
|
| 655 |
+
for paragraph in entry["paragraphs"]:
|
| 656 |
+
context_text = paragraph["context"]
|
| 657 |
+
for qa in paragraph["qas"]:
|
| 658 |
+
qas_id = qa["id"]
|
| 659 |
+
question_text = qa["question"]
|
| 660 |
+
start_position_character = None
|
| 661 |
+
answer_text = None
|
| 662 |
+
answers = []
|
| 663 |
+
|
| 664 |
+
is_impossible = qa.get("is_impossible", False)
|
| 665 |
+
if not is_impossible:
|
| 666 |
+
if is_training:
|
| 667 |
+
answer = qa["answers"][0]
|
| 668 |
+
answer_text = answer["text"]
|
| 669 |
+
start_position_character = answer["answer_start"]
|
| 670 |
+
else:
|
| 671 |
+
answers = qa["answers"]
|
| 672 |
+
|
| 673 |
+
example = SquadExample(
|
| 674 |
+
qas_id=qas_id,
|
| 675 |
+
question_text=question_text,
|
| 676 |
+
context_text=context_text,
|
| 677 |
+
answer_text=answer_text,
|
| 678 |
+
start_position_character=start_position_character,
|
| 679 |
+
title=title,
|
| 680 |
+
is_impossible=is_impossible,
|
| 681 |
+
answers=answers,
|
| 682 |
+
)
|
| 683 |
+
examples.append(example)
|
| 684 |
+
return examples
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
class SquadV1Processor(SquadProcessor):
|
| 688 |
+
train_file = "train-v1.1.json"
|
| 689 |
+
dev_file = "dev-v1.1.json"
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
class SquadV2Processor(SquadProcessor):
|
| 693 |
+
train_file = "train-v2.0.json"
|
| 694 |
+
dev_file = "dev-v2.0.json"
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
class SquadExample:
|
| 698 |
+
"""
|
| 699 |
+
A single training/test example for the Squad dataset, as loaded from disk.
|
| 700 |
+
|
| 701 |
+
Args:
|
| 702 |
+
qas_id: The example's unique identifier
|
| 703 |
+
question_text: The question string
|
| 704 |
+
context_text: The context string
|
| 705 |
+
answer_text: The answer string
|
| 706 |
+
start_position_character: The character position of the start of the answer
|
| 707 |
+
title: The title of the example
|
| 708 |
+
answers: None by default, this is used during evaluation. Holds answers as well as their start positions.
|
| 709 |
+
is_impossible: False by default, set to True if the example has no possible answer.
|
| 710 |
+
"""
|
| 711 |
+
|
| 712 |
+
def __init__(
|
| 713 |
+
self,
|
| 714 |
+
qas_id,
|
| 715 |
+
question_text,
|
| 716 |
+
context_text,
|
| 717 |
+
answer_text,
|
| 718 |
+
start_position_character,
|
| 719 |
+
title,
|
| 720 |
+
answers=[],
|
| 721 |
+
is_impossible=False,
|
| 722 |
+
):
|
| 723 |
+
self.qas_id = qas_id
|
| 724 |
+
self.question_text = question_text
|
| 725 |
+
self.context_text = context_text
|
| 726 |
+
self.answer_text = answer_text
|
| 727 |
+
self.title = title
|
| 728 |
+
self.is_impossible = is_impossible
|
| 729 |
+
self.answers = answers
|
| 730 |
+
|
| 731 |
+
self.start_position, self.end_position = 0, 0
|
| 732 |
+
|
| 733 |
+
doc_tokens = []
|
| 734 |
+
char_to_word_offset = []
|
| 735 |
+
prev_is_whitespace = True
|
| 736 |
+
|
| 737 |
+
# Split on whitespace so that different tokens may be attributed to their original position.
|
| 738 |
+
for c in self.context_text:
|
| 739 |
+
if _is_whitespace(c):
|
| 740 |
+
prev_is_whitespace = True
|
| 741 |
+
else:
|
| 742 |
+
if prev_is_whitespace:
|
| 743 |
+
doc_tokens.append(c)
|
| 744 |
+
else:
|
| 745 |
+
doc_tokens[-1] += c
|
| 746 |
+
prev_is_whitespace = False
|
| 747 |
+
char_to_word_offset.append(len(doc_tokens) - 1)
|
| 748 |
+
|
| 749 |
+
self.doc_tokens = doc_tokens
|
| 750 |
+
self.char_to_word_offset = char_to_word_offset
|
| 751 |
+
|
| 752 |
+
# Start and end positions only has a value during evaluation.
|
| 753 |
+
if start_position_character is not None and not is_impossible:
|
| 754 |
+
self.start_position = char_to_word_offset[start_position_character]
|
| 755 |
+
self.end_position = char_to_word_offset[
|
| 756 |
+
min(start_position_character + len(answer_text) - 1, len(char_to_word_offset) - 1)
|
| 757 |
+
]
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
class SquadFeatures:
|
| 761 |
+
"""
|
| 762 |
+
Single squad example features to be fed to a model. Those features are model-specific and can be crafted from
|
| 763 |
+
[`~data.processors.squad.SquadExample`] using the
|
| 764 |
+
:method:*~transformers.data.processors.squad.squad_convert_examples_to_features* method.
|
| 765 |
+
|
| 766 |
+
Args:
|
| 767 |
+
input_ids: Indices of input sequence tokens in the vocabulary.
|
| 768 |
+
attention_mask: Mask to avoid performing attention on padding token indices.
|
| 769 |
+
token_type_ids: Segment token indices to indicate first and second portions of the inputs.
|
| 770 |
+
cls_index: the index of the CLS token.
|
| 771 |
+
p_mask: Mask identifying tokens that can be answers vs. tokens that cannot.
|
| 772 |
+
Mask with 1 for tokens than cannot be in the answer and 0 for token that can be in an answer
|
| 773 |
+
example_index: the index of the example
|
| 774 |
+
unique_id: The unique Feature identifier
|
| 775 |
+
paragraph_len: The length of the context
|
| 776 |
+
token_is_max_context:
|
| 777 |
+
List of booleans identifying which tokens have their maximum context in this feature object. If a token
|
| 778 |
+
does not have their maximum context in this feature object, it means that another feature object has more
|
| 779 |
+
information related to that token and should be prioritized over this feature for that token.
|
| 780 |
+
tokens: list of tokens corresponding to the input ids
|
| 781 |
+
token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer.
|
| 782 |
+
start_position: start of the answer token index
|
| 783 |
+
end_position: end of the answer token index
|
| 784 |
+
encoding: optionally store the BatchEncoding with the fast-tokenizer alignment methods.
|
| 785 |
+
"""
|
| 786 |
+
|
| 787 |
+
def __init__(
|
| 788 |
+
self,
|
| 789 |
+
input_ids,
|
| 790 |
+
attention_mask,
|
| 791 |
+
token_type_ids,
|
| 792 |
+
cls_index,
|
| 793 |
+
p_mask,
|
| 794 |
+
example_index,
|
| 795 |
+
unique_id,
|
| 796 |
+
paragraph_len,
|
| 797 |
+
token_is_max_context,
|
| 798 |
+
tokens,
|
| 799 |
+
token_to_orig_map,
|
| 800 |
+
start_position,
|
| 801 |
+
end_position,
|
| 802 |
+
is_impossible,
|
| 803 |
+
qas_id: str = None,
|
| 804 |
+
encoding: BatchEncoding = None,
|
| 805 |
+
):
|
| 806 |
+
self.input_ids = input_ids
|
| 807 |
+
self.attention_mask = attention_mask
|
| 808 |
+
self.token_type_ids = token_type_ids
|
| 809 |
+
self.cls_index = cls_index
|
| 810 |
+
self.p_mask = p_mask
|
| 811 |
+
|
| 812 |
+
self.example_index = example_index
|
| 813 |
+
self.unique_id = unique_id
|
| 814 |
+
self.paragraph_len = paragraph_len
|
| 815 |
+
self.token_is_max_context = token_is_max_context
|
| 816 |
+
self.tokens = tokens
|
| 817 |
+
self.token_to_orig_map = token_to_orig_map
|
| 818 |
+
|
| 819 |
+
self.start_position = start_position
|
| 820 |
+
self.end_position = end_position
|
| 821 |
+
self.is_impossible = is_impossible
|
| 822 |
+
self.qas_id = qas_id
|
| 823 |
+
|
| 824 |
+
self.encoding = encoding
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
class SquadResult:
|
| 828 |
+
"""
|
| 829 |
+
Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.
|
| 830 |
+
|
| 831 |
+
Args:
|
| 832 |
+
unique_id: The unique identifier corresponding to that example.
|
| 833 |
+
start_logits: The logits corresponding to the start of the answer
|
| 834 |
+
end_logits: The logits corresponding to the end of the answer
|
| 835 |
+
"""
|
| 836 |
+
|
| 837 |
+
def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None):
|
| 838 |
+
self.start_logits = start_logits
|
| 839 |
+
self.end_logits = end_logits
|
| 840 |
+
self.unique_id = unique_id
|
| 841 |
+
|
| 842 |
+
if start_top_index:
|
| 843 |
+
self.start_top_index = start_top_index
|
| 844 |
+
self.end_top_index = end_top_index
|
| 845 |
+
self.cls_logits = cls_logits
|
.venv/Lib/site-packages/transformers/data/processors/utils.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import csv
|
| 18 |
+
import dataclasses
|
| 19 |
+
import json
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import List, Optional, Union
|
| 22 |
+
|
| 23 |
+
from ...utils import is_tf_available, is_torch_available, logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class InputExample:
|
| 31 |
+
"""
|
| 32 |
+
A single training/test example for simple sequence classification.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
guid: Unique id for the example.
|
| 36 |
+
text_a: string. The untokenized text of the first sequence. For single
|
| 37 |
+
sequence tasks, only this sequence must be specified.
|
| 38 |
+
text_b: (Optional) string. The untokenized text of the second sequence.
|
| 39 |
+
Only must be specified for sequence pair tasks.
|
| 40 |
+
label: (Optional) string. The label of the example. This should be
|
| 41 |
+
specified for train and dev examples, but not for test examples.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
guid: str
|
| 45 |
+
text_a: str
|
| 46 |
+
text_b: Optional[str] = None
|
| 47 |
+
label: Optional[str] = None
|
| 48 |
+
|
| 49 |
+
def to_json_string(self):
|
| 50 |
+
"""Serializes this instance to a JSON string."""
|
| 51 |
+
return json.dumps(dataclasses.asdict(self), indent=2) + "\n"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass(frozen=True)
|
| 55 |
+
class InputFeatures:
|
| 56 |
+
"""
|
| 57 |
+
A single set of features of data. Property names are the same names as the corresponding inputs to a model.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
input_ids: Indices of input sequence tokens in the vocabulary.
|
| 61 |
+
attention_mask: Mask to avoid performing attention on padding token indices.
|
| 62 |
+
Mask values selected in `[0, 1]`: Usually `1` for tokens that are NOT MASKED, `0` for MASKED (padded)
|
| 63 |
+
tokens.
|
| 64 |
+
token_type_ids: (Optional) Segment token indices to indicate first and second
|
| 65 |
+
portions of the inputs. Only some models use them.
|
| 66 |
+
label: (Optional) Label corresponding to the input. Int for classification problems,
|
| 67 |
+
float for regression problems.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
input_ids: List[int]
|
| 71 |
+
attention_mask: Optional[List[int]] = None
|
| 72 |
+
token_type_ids: Optional[List[int]] = None
|
| 73 |
+
label: Optional[Union[int, float]] = None
|
| 74 |
+
|
| 75 |
+
def to_json_string(self):
|
| 76 |
+
"""Serializes this instance to a JSON string."""
|
| 77 |
+
return json.dumps(dataclasses.asdict(self)) + "\n"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class DataProcessor:
|
| 81 |
+
"""Base class for data converters for sequence classification data sets."""
|
| 82 |
+
|
| 83 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
| 84 |
+
"""
|
| 85 |
+
Gets an example from a dict with tensorflow tensors.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
tensor_dict: Keys and values should match the corresponding Glue
|
| 89 |
+
tensorflow_dataset examples.
|
| 90 |
+
"""
|
| 91 |
+
raise NotImplementedError()
|
| 92 |
+
|
| 93 |
+
def get_train_examples(self, data_dir):
|
| 94 |
+
"""Gets a collection of [`InputExample`] for the train set."""
|
| 95 |
+
raise NotImplementedError()
|
| 96 |
+
|
| 97 |
+
def get_dev_examples(self, data_dir):
|
| 98 |
+
"""Gets a collection of [`InputExample`] for the dev set."""
|
| 99 |
+
raise NotImplementedError()
|
| 100 |
+
|
| 101 |
+
def get_test_examples(self, data_dir):
|
| 102 |
+
"""Gets a collection of [`InputExample`] for the test set."""
|
| 103 |
+
raise NotImplementedError()
|
| 104 |
+
|
| 105 |
+
def get_labels(self):
|
| 106 |
+
"""Gets the list of labels for this data set."""
|
| 107 |
+
raise NotImplementedError()
|
| 108 |
+
|
| 109 |
+
def tfds_map(self, example):
|
| 110 |
+
"""
|
| 111 |
+
Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are. This method converts
|
| 112 |
+
examples to the correct format.
|
| 113 |
+
"""
|
| 114 |
+
if len(self.get_labels()) > 1:
|
| 115 |
+
example.label = self.get_labels()[int(example.label)]
|
| 116 |
+
return example
|
| 117 |
+
|
| 118 |
+
@classmethod
|
| 119 |
+
def _read_tsv(cls, input_file, quotechar=None):
|
| 120 |
+
"""Reads a tab separated value file."""
|
| 121 |
+
with open(input_file, "r", encoding="utf-8-sig") as f:
|
| 122 |
+
return list(csv.reader(f, delimiter="\t", quotechar=quotechar))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class SingleSentenceClassificationProcessor(DataProcessor):
|
| 126 |
+
"""Generic processor for a single sentence classification data set."""
|
| 127 |
+
|
| 128 |
+
def __init__(self, labels=None, examples=None, mode="classification", verbose=False):
|
| 129 |
+
self.labels = [] if labels is None else labels
|
| 130 |
+
self.examples = [] if examples is None else examples
|
| 131 |
+
self.mode = mode
|
| 132 |
+
self.verbose = verbose
|
| 133 |
+
|
| 134 |
+
def __len__(self):
|
| 135 |
+
return len(self.examples)
|
| 136 |
+
|
| 137 |
+
def __getitem__(self, idx):
|
| 138 |
+
if isinstance(idx, slice):
|
| 139 |
+
return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx])
|
| 140 |
+
return self.examples[idx]
|
| 141 |
+
|
| 142 |
+
@classmethod
|
| 143 |
+
def create_from_csv(
|
| 144 |
+
cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs
|
| 145 |
+
):
|
| 146 |
+
processor = cls(**kwargs)
|
| 147 |
+
processor.add_examples_from_csv(
|
| 148 |
+
file_name,
|
| 149 |
+
split_name=split_name,
|
| 150 |
+
column_label=column_label,
|
| 151 |
+
column_text=column_text,
|
| 152 |
+
column_id=column_id,
|
| 153 |
+
skip_first_row=skip_first_row,
|
| 154 |
+
overwrite_labels=True,
|
| 155 |
+
overwrite_examples=True,
|
| 156 |
+
)
|
| 157 |
+
return processor
|
| 158 |
+
|
| 159 |
+
@classmethod
|
| 160 |
+
def create_from_examples(cls, texts_or_text_and_labels, labels=None, **kwargs):
|
| 161 |
+
processor = cls(**kwargs)
|
| 162 |
+
processor.add_examples(texts_or_text_and_labels, labels=labels)
|
| 163 |
+
return processor
|
| 164 |
+
|
| 165 |
+
def add_examples_from_csv(
|
| 166 |
+
self,
|
| 167 |
+
file_name,
|
| 168 |
+
split_name="",
|
| 169 |
+
column_label=0,
|
| 170 |
+
column_text=1,
|
| 171 |
+
column_id=None,
|
| 172 |
+
skip_first_row=False,
|
| 173 |
+
overwrite_labels=False,
|
| 174 |
+
overwrite_examples=False,
|
| 175 |
+
):
|
| 176 |
+
lines = self._read_tsv(file_name)
|
| 177 |
+
if skip_first_row:
|
| 178 |
+
lines = lines[1:]
|
| 179 |
+
texts = []
|
| 180 |
+
labels = []
|
| 181 |
+
ids = []
|
| 182 |
+
for i, line in enumerate(lines):
|
| 183 |
+
texts.append(line[column_text])
|
| 184 |
+
labels.append(line[column_label])
|
| 185 |
+
if column_id is not None:
|
| 186 |
+
ids.append(line[column_id])
|
| 187 |
+
else:
|
| 188 |
+
guid = f"{split_name}-{i}" if split_name else str(i)
|
| 189 |
+
ids.append(guid)
|
| 190 |
+
|
| 191 |
+
return self.add_examples(
|
| 192 |
+
texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def add_examples(
|
| 196 |
+
self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
|
| 197 |
+
):
|
| 198 |
+
if labels is not None and len(texts_or_text_and_labels) != len(labels):
|
| 199 |
+
raise ValueError(
|
| 200 |
+
f"Text and labels have mismatched lengths {len(texts_or_text_and_labels)} and {len(labels)}"
|
| 201 |
+
)
|
| 202 |
+
if ids is not None and len(texts_or_text_and_labels) != len(ids):
|
| 203 |
+
raise ValueError(f"Text and ids have mismatched lengths {len(texts_or_text_and_labels)} and {len(ids)}")
|
| 204 |
+
if ids is None:
|
| 205 |
+
ids = [None] * len(texts_or_text_and_labels)
|
| 206 |
+
if labels is None:
|
| 207 |
+
labels = [None] * len(texts_or_text_and_labels)
|
| 208 |
+
examples = []
|
| 209 |
+
added_labels = set()
|
| 210 |
+
for text_or_text_and_label, label, guid in zip(texts_or_text_and_labels, labels, ids):
|
| 211 |
+
if isinstance(text_or_text_and_label, (tuple, list)) and label is None:
|
| 212 |
+
text, label = text_or_text_and_label
|
| 213 |
+
else:
|
| 214 |
+
text = text_or_text_and_label
|
| 215 |
+
added_labels.add(label)
|
| 216 |
+
examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label))
|
| 217 |
+
|
| 218 |
+
# Update examples
|
| 219 |
+
if overwrite_examples:
|
| 220 |
+
self.examples = examples
|
| 221 |
+
else:
|
| 222 |
+
self.examples.extend(examples)
|
| 223 |
+
|
| 224 |
+
# Update labels
|
| 225 |
+
if overwrite_labels:
|
| 226 |
+
self.labels = list(added_labels)
|
| 227 |
+
else:
|
| 228 |
+
self.labels = list(set(self.labels).union(added_labels))
|
| 229 |
+
|
| 230 |
+
return self.examples
|
| 231 |
+
|
| 232 |
+
def get_features(
|
| 233 |
+
self,
|
| 234 |
+
tokenizer,
|
| 235 |
+
max_length=None,
|
| 236 |
+
pad_on_left=False,
|
| 237 |
+
pad_token=0,
|
| 238 |
+
mask_padding_with_zero=True,
|
| 239 |
+
return_tensors=None,
|
| 240 |
+
):
|
| 241 |
+
"""
|
| 242 |
+
Convert examples in a list of `InputFeatures`
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
tokenizer: Instance of a tokenizer that will tokenize the examples
|
| 246 |
+
max_length: Maximum example length
|
| 247 |
+
pad_on_left: If set to `True`, the examples will be padded on the left rather than on the right (default)
|
| 248 |
+
pad_token: Padding token
|
| 249 |
+
mask_padding_with_zero: If set to `True`, the attention mask will be filled by `1` for actual values
|
| 250 |
+
and by `0` for padded values. If set to `False`, inverts it (`1` for padded values, `0` for actual
|
| 251 |
+
values)
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
If the `examples` input is a `tf.data.Dataset`, will return a `tf.data.Dataset` containing the
|
| 255 |
+
task-specific features. If the input is a list of `InputExamples`, will return a list of task-specific
|
| 256 |
+
`InputFeatures` which can be fed to the model.
|
| 257 |
+
|
| 258 |
+
"""
|
| 259 |
+
if max_length is None:
|
| 260 |
+
max_length = tokenizer.max_len
|
| 261 |
+
|
| 262 |
+
label_map = {label: i for i, label in enumerate(self.labels)}
|
| 263 |
+
|
| 264 |
+
all_input_ids = []
|
| 265 |
+
for ex_index, example in enumerate(self.examples):
|
| 266 |
+
if ex_index % 10000 == 0:
|
| 267 |
+
logger.info(f"Tokenizing example {ex_index}")
|
| 268 |
+
|
| 269 |
+
input_ids = tokenizer.encode(
|
| 270 |
+
example.text_a,
|
| 271 |
+
add_special_tokens=True,
|
| 272 |
+
max_length=min(max_length, tokenizer.max_len),
|
| 273 |
+
)
|
| 274 |
+
all_input_ids.append(input_ids)
|
| 275 |
+
|
| 276 |
+
batch_length = max(len(input_ids) for input_ids in all_input_ids)
|
| 277 |
+
|
| 278 |
+
features = []
|
| 279 |
+
for ex_index, (input_ids, example) in enumerate(zip(all_input_ids, self.examples)):
|
| 280 |
+
if ex_index % 10000 == 0:
|
| 281 |
+
logger.info(f"Writing example {ex_index}/{len(self.examples)}")
|
| 282 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
| 283 |
+
# tokens are attended to.
|
| 284 |
+
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
| 285 |
+
|
| 286 |
+
# Zero-pad up to the sequence length.
|
| 287 |
+
padding_length = batch_length - len(input_ids)
|
| 288 |
+
if pad_on_left:
|
| 289 |
+
input_ids = ([pad_token] * padding_length) + input_ids
|
| 290 |
+
attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
|
| 291 |
+
else:
|
| 292 |
+
input_ids = input_ids + ([pad_token] * padding_length)
|
| 293 |
+
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
|
| 294 |
+
|
| 295 |
+
if len(input_ids) != batch_length:
|
| 296 |
+
raise ValueError(f"Error with input length {len(input_ids)} vs {batch_length}")
|
| 297 |
+
if len(attention_mask) != batch_length:
|
| 298 |
+
raise ValueError(f"Error with input length {len(attention_mask)} vs {batch_length}")
|
| 299 |
+
|
| 300 |
+
if self.mode == "classification":
|
| 301 |
+
label = label_map[example.label]
|
| 302 |
+
elif self.mode == "regression":
|
| 303 |
+
label = float(example.label)
|
| 304 |
+
else:
|
| 305 |
+
raise ValueError(self.mode)
|
| 306 |
+
|
| 307 |
+
if ex_index < 5 and self.verbose:
|
| 308 |
+
logger.info("*** Example ***")
|
| 309 |
+
logger.info(f"guid: {example.guid}")
|
| 310 |
+
logger.info(f"input_ids: {' '.join([str(x) for x in input_ids])}")
|
| 311 |
+
logger.info(f"attention_mask: {' '.join([str(x) for x in attention_mask])}")
|
| 312 |
+
logger.info(f"label: {example.label} (id = {label})")
|
| 313 |
+
|
| 314 |
+
features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, label=label))
|
| 315 |
+
|
| 316 |
+
if return_tensors is None:
|
| 317 |
+
return features
|
| 318 |
+
elif return_tensors == "tf":
|
| 319 |
+
if not is_tf_available():
|
| 320 |
+
raise RuntimeError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
|
| 321 |
+
import tensorflow as tf
|
| 322 |
+
|
| 323 |
+
def gen():
|
| 324 |
+
for ex in features:
|
| 325 |
+
yield ({"input_ids": ex.input_ids, "attention_mask": ex.attention_mask}, ex.label)
|
| 326 |
+
|
| 327 |
+
dataset = tf.data.Dataset.from_generator(
|
| 328 |
+
gen,
|
| 329 |
+
({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64),
|
| 330 |
+
({"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])}, tf.TensorShape([])),
|
| 331 |
+
)
|
| 332 |
+
return dataset
|
| 333 |
+
elif return_tensors == "pt":
|
| 334 |
+
if not is_torch_available():
|
| 335 |
+
raise RuntimeError("return_tensors set to 'pt' but PyTorch can't be imported")
|
| 336 |
+
import torch
|
| 337 |
+
from torch.utils.data import TensorDataset
|
| 338 |
+
|
| 339 |
+
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
| 340 |
+
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
| 341 |
+
if self.mode == "classification":
|
| 342 |
+
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
| 343 |
+
elif self.mode == "regression":
|
| 344 |
+
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
| 345 |
+
|
| 346 |
+
dataset = TensorDataset(all_input_ids, all_attention_mask, all_labels)
|
| 347 |
+
return dataset
|
| 348 |
+
else:
|
| 349 |
+
raise ValueError("return_tensors should be one of 'tf' or 'pt'")
|
.venv/Lib/site-packages/transformers/data/processors/xnli.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""XNLI utils (dataset loading and evaluation)"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
from .utils import DataProcessor, InputExample
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class XnliProcessor(DataProcessor):
|
| 28 |
+
"""
|
| 29 |
+
Processor for the XNLI dataset. Adapted from
|
| 30 |
+
https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, language, train_language=None):
|
| 34 |
+
self.language = language
|
| 35 |
+
self.train_language = train_language
|
| 36 |
+
|
| 37 |
+
def get_train_examples(self, data_dir):
|
| 38 |
+
"""See base class."""
|
| 39 |
+
lg = self.language if self.train_language is None else self.train_language
|
| 40 |
+
lines = self._read_tsv(os.path.join(data_dir, f"XNLI-MT-1.0/multinli/multinli.train.{lg}.tsv"))
|
| 41 |
+
examples = []
|
| 42 |
+
for i, line in enumerate(lines):
|
| 43 |
+
if i == 0:
|
| 44 |
+
continue
|
| 45 |
+
guid = f"train-{i}"
|
| 46 |
+
text_a = line[0]
|
| 47 |
+
text_b = line[1]
|
| 48 |
+
label = "contradiction" if line[2] == "contradictory" else line[2]
|
| 49 |
+
if not isinstance(text_a, str):
|
| 50 |
+
raise TypeError(f"Training input {text_a} is not a string")
|
| 51 |
+
if not isinstance(text_b, str):
|
| 52 |
+
raise TypeError(f"Training input {text_b} is not a string")
|
| 53 |
+
if not isinstance(label, str):
|
| 54 |
+
raise TypeError(f"Training label {label} is not a string")
|
| 55 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 56 |
+
return examples
|
| 57 |
+
|
| 58 |
+
def get_test_examples(self, data_dir):
|
| 59 |
+
"""See base class."""
|
| 60 |
+
lines = self._read_tsv(os.path.join(data_dir, "XNLI-1.0/xnli.test.tsv"))
|
| 61 |
+
examples = []
|
| 62 |
+
for i, line in enumerate(lines):
|
| 63 |
+
if i == 0:
|
| 64 |
+
continue
|
| 65 |
+
language = line[0]
|
| 66 |
+
if language != self.language:
|
| 67 |
+
continue
|
| 68 |
+
guid = f"test-{i}"
|
| 69 |
+
text_a = line[6]
|
| 70 |
+
text_b = line[7]
|
| 71 |
+
label = line[1]
|
| 72 |
+
if not isinstance(text_a, str):
|
| 73 |
+
raise TypeError(f"Training input {text_a} is not a string")
|
| 74 |
+
if not isinstance(text_b, str):
|
| 75 |
+
raise TypeError(f"Training input {text_b} is not a string")
|
| 76 |
+
if not isinstance(label, str):
|
| 77 |
+
raise TypeError(f"Training label {label} is not a string")
|
| 78 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
| 79 |
+
return examples
|
| 80 |
+
|
| 81 |
+
def get_labels(self):
|
| 82 |
+
"""See base class."""
|
| 83 |
+
return ["contradiction", "entailment", "neutral"]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
xnli_processors = {
|
| 87 |
+
"xnli": XnliProcessor,
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
xnli_output_modes = {
|
| 91 |
+
"xnli": "classification",
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
xnli_tasks_num_labels = {
|
| 95 |
+
"xnli": 3,
|
| 96 |
+
}
|
.venv/Lib/site-packages/transformers/generation/__init__.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import TYPE_CHECKING
|
| 16 |
+
|
| 17 |
+
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
_import_structure = {
|
| 21 |
+
"configuration_utils": [
|
| 22 |
+
"BaseWatermarkingConfig",
|
| 23 |
+
"CompileConfig",
|
| 24 |
+
"GenerationConfig",
|
| 25 |
+
"GenerationMode",
|
| 26 |
+
"SynthIDTextWatermarkingConfig",
|
| 27 |
+
"WatermarkingConfig",
|
| 28 |
+
],
|
| 29 |
+
"streamers": ["TextIteratorStreamer", "TextStreamer"],
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
if not is_torch_available():
|
| 34 |
+
raise OptionalDependencyNotAvailable()
|
| 35 |
+
except OptionalDependencyNotAvailable:
|
| 36 |
+
pass
|
| 37 |
+
else:
|
| 38 |
+
_import_structure["beam_constraints"] = [
|
| 39 |
+
"Constraint",
|
| 40 |
+
"ConstraintListState",
|
| 41 |
+
"DisjunctiveConstraint",
|
| 42 |
+
"PhrasalConstraint",
|
| 43 |
+
]
|
| 44 |
+
_import_structure["beam_search"] = [
|
| 45 |
+
"BeamHypotheses",
|
| 46 |
+
"BeamScorer",
|
| 47 |
+
"BeamSearchScorer",
|
| 48 |
+
"ConstrainedBeamSearchScorer",
|
| 49 |
+
]
|
| 50 |
+
_import_structure["candidate_generator"] = [
|
| 51 |
+
"AssistedCandidateGenerator",
|
| 52 |
+
"CandidateGenerator",
|
| 53 |
+
"EarlyExitCandidateGenerator",
|
| 54 |
+
"PromptLookupCandidateGenerator",
|
| 55 |
+
]
|
| 56 |
+
_import_structure["logits_process"] = [
|
| 57 |
+
"AlternatingCodebooksLogitsProcessor",
|
| 58 |
+
"ClassifierFreeGuidanceLogitsProcessor",
|
| 59 |
+
"EncoderNoRepeatNGramLogitsProcessor",
|
| 60 |
+
"EncoderRepetitionPenaltyLogitsProcessor",
|
| 61 |
+
"EpsilonLogitsWarper",
|
| 62 |
+
"EtaLogitsWarper",
|
| 63 |
+
"ExponentialDecayLengthPenalty",
|
| 64 |
+
"ForcedBOSTokenLogitsProcessor",
|
| 65 |
+
"ForcedEOSTokenLogitsProcessor",
|
| 66 |
+
"HammingDiversityLogitsProcessor",
|
| 67 |
+
"InfNanRemoveLogitsProcessor",
|
| 68 |
+
"LogitNormalization",
|
| 69 |
+
"LogitsProcessor",
|
| 70 |
+
"LogitsProcessorList",
|
| 71 |
+
"LogitsWarper",
|
| 72 |
+
"MinLengthLogitsProcessor",
|
| 73 |
+
"MinNewTokensLengthLogitsProcessor",
|
| 74 |
+
"MinPLogitsWarper",
|
| 75 |
+
"NoBadWordsLogitsProcessor",
|
| 76 |
+
"NoRepeatNGramLogitsProcessor",
|
| 77 |
+
"PrefixConstrainedLogitsProcessor",
|
| 78 |
+
"RepetitionPenaltyLogitsProcessor",
|
| 79 |
+
"SequenceBiasLogitsProcessor",
|
| 80 |
+
"SuppressTokensLogitsProcessor",
|
| 81 |
+
"SuppressTokensAtBeginLogitsProcessor",
|
| 82 |
+
"SynthIDTextWatermarkLogitsProcessor",
|
| 83 |
+
"TemperatureLogitsWarper",
|
| 84 |
+
"TopKLogitsWarper",
|
| 85 |
+
"TopPLogitsWarper",
|
| 86 |
+
"TypicalLogitsWarper",
|
| 87 |
+
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
| 88 |
+
"WhisperTimeStampLogitsProcessor",
|
| 89 |
+
"WatermarkLogitsProcessor",
|
| 90 |
+
]
|
| 91 |
+
_import_structure["stopping_criteria"] = [
|
| 92 |
+
"MaxNewTokensCriteria",
|
| 93 |
+
"MaxLengthCriteria",
|
| 94 |
+
"MaxTimeCriteria",
|
| 95 |
+
"ConfidenceCriteria",
|
| 96 |
+
"EosTokenCriteria",
|
| 97 |
+
"StoppingCriteria",
|
| 98 |
+
"StoppingCriteriaList",
|
| 99 |
+
"validate_stopping_criteria",
|
| 100 |
+
"StopStringCriteria",
|
| 101 |
+
]
|
| 102 |
+
_import_structure["utils"] = [
|
| 103 |
+
"GenerationMixin",
|
| 104 |
+
"GreedySearchEncoderDecoderOutput",
|
| 105 |
+
"GreedySearchDecoderOnlyOutput",
|
| 106 |
+
"SampleEncoderDecoderOutput",
|
| 107 |
+
"SampleDecoderOnlyOutput",
|
| 108 |
+
"BeamSearchEncoderDecoderOutput",
|
| 109 |
+
"BeamSearchDecoderOnlyOutput",
|
| 110 |
+
"BeamSampleEncoderDecoderOutput",
|
| 111 |
+
"BeamSampleDecoderOnlyOutput",
|
| 112 |
+
"ContrastiveSearchEncoderDecoderOutput",
|
| 113 |
+
"ContrastiveSearchDecoderOnlyOutput",
|
| 114 |
+
"GenerateBeamDecoderOnlyOutput",
|
| 115 |
+
"GenerateBeamEncoderDecoderOutput",
|
| 116 |
+
"GenerateDecoderOnlyOutput",
|
| 117 |
+
"GenerateEncoderDecoderOutput",
|
| 118 |
+
]
|
| 119 |
+
_import_structure["watermarking"] = [
|
| 120 |
+
"WatermarkDetector",
|
| 121 |
+
"WatermarkDetectorOutput",
|
| 122 |
+
"BayesianDetectorModel",
|
| 123 |
+
"BayesianDetectorConfig",
|
| 124 |
+
"SynthIDTextWatermarkDetector",
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
try:
|
| 128 |
+
if not is_tf_available():
|
| 129 |
+
raise OptionalDependencyNotAvailable()
|
| 130 |
+
except OptionalDependencyNotAvailable:
|
| 131 |
+
pass
|
| 132 |
+
else:
|
| 133 |
+
_import_structure["tf_logits_process"] = [
|
| 134 |
+
"TFForcedBOSTokenLogitsProcessor",
|
| 135 |
+
"TFForcedEOSTokenLogitsProcessor",
|
| 136 |
+
"TFForceTokensLogitsProcessor",
|
| 137 |
+
"TFLogitsProcessor",
|
| 138 |
+
"TFLogitsProcessorList",
|
| 139 |
+
"TFLogitsWarper",
|
| 140 |
+
"TFMinLengthLogitsProcessor",
|
| 141 |
+
"TFNoBadWordsLogitsProcessor",
|
| 142 |
+
"TFNoRepeatNGramLogitsProcessor",
|
| 143 |
+
"TFRepetitionPenaltyLogitsProcessor",
|
| 144 |
+
"TFSuppressTokensAtBeginLogitsProcessor",
|
| 145 |
+
"TFSuppressTokensLogitsProcessor",
|
| 146 |
+
"TFTemperatureLogitsWarper",
|
| 147 |
+
"TFTopKLogitsWarper",
|
| 148 |
+
"TFTopPLogitsWarper",
|
| 149 |
+
]
|
| 150 |
+
_import_structure["tf_utils"] = [
|
| 151 |
+
"TFGenerationMixin",
|
| 152 |
+
"TFGreedySearchDecoderOnlyOutput",
|
| 153 |
+
"TFGreedySearchEncoderDecoderOutput",
|
| 154 |
+
"TFSampleEncoderDecoderOutput",
|
| 155 |
+
"TFSampleDecoderOnlyOutput",
|
| 156 |
+
"TFBeamSearchEncoderDecoderOutput",
|
| 157 |
+
"TFBeamSearchDecoderOnlyOutput",
|
| 158 |
+
"TFBeamSampleEncoderDecoderOutput",
|
| 159 |
+
"TFBeamSampleDecoderOnlyOutput",
|
| 160 |
+
"TFContrastiveSearchEncoderDecoderOutput",
|
| 161 |
+
"TFContrastiveSearchDecoderOnlyOutput",
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
try:
|
| 165 |
+
if not is_flax_available():
|
| 166 |
+
raise OptionalDependencyNotAvailable()
|
| 167 |
+
except OptionalDependencyNotAvailable:
|
| 168 |
+
pass
|
| 169 |
+
else:
|
| 170 |
+
_import_structure["flax_logits_process"] = [
|
| 171 |
+
"FlaxForcedBOSTokenLogitsProcessor",
|
| 172 |
+
"FlaxForcedEOSTokenLogitsProcessor",
|
| 173 |
+
"FlaxForceTokensLogitsProcessor",
|
| 174 |
+
"FlaxLogitsProcessor",
|
| 175 |
+
"FlaxLogitsProcessorList",
|
| 176 |
+
"FlaxLogitsWarper",
|
| 177 |
+
"FlaxMinLengthLogitsProcessor",
|
| 178 |
+
"FlaxSuppressTokensAtBeginLogitsProcessor",
|
| 179 |
+
"FlaxSuppressTokensLogitsProcessor",
|
| 180 |
+
"FlaxTemperatureLogitsWarper",
|
| 181 |
+
"FlaxTopKLogitsWarper",
|
| 182 |
+
"FlaxTopPLogitsWarper",
|
| 183 |
+
"FlaxWhisperTimeStampLogitsProcessor",
|
| 184 |
+
"FlaxNoRepeatNGramLogitsProcessor",
|
| 185 |
+
]
|
| 186 |
+
_import_structure["flax_utils"] = [
|
| 187 |
+
"FlaxGenerationMixin",
|
| 188 |
+
"FlaxGreedySearchOutput",
|
| 189 |
+
"FlaxSampleOutput",
|
| 190 |
+
"FlaxBeamSearchOutput",
|
| 191 |
+
]
|
| 192 |
+
|
| 193 |
+
if TYPE_CHECKING:
|
| 194 |
+
from .configuration_utils import (
|
| 195 |
+
BaseWatermarkingConfig,
|
| 196 |
+
CompileConfig,
|
| 197 |
+
GenerationConfig,
|
| 198 |
+
GenerationMode,
|
| 199 |
+
SynthIDTextWatermarkingConfig,
|
| 200 |
+
WatermarkingConfig,
|
| 201 |
+
)
|
| 202 |
+
from .streamers import TextIteratorStreamer, TextStreamer
|
| 203 |
+
|
| 204 |
+
try:
|
| 205 |
+
if not is_torch_available():
|
| 206 |
+
raise OptionalDependencyNotAvailable()
|
| 207 |
+
except OptionalDependencyNotAvailable:
|
| 208 |
+
pass
|
| 209 |
+
else:
|
| 210 |
+
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
|
| 211 |
+
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
| 212 |
+
from .candidate_generator import (
|
| 213 |
+
AssistedCandidateGenerator,
|
| 214 |
+
CandidateGenerator,
|
| 215 |
+
EarlyExitCandidateGenerator,
|
| 216 |
+
PromptLookupCandidateGenerator,
|
| 217 |
+
)
|
| 218 |
+
from .logits_process import (
|
| 219 |
+
AlternatingCodebooksLogitsProcessor,
|
| 220 |
+
ClassifierFreeGuidanceLogitsProcessor,
|
| 221 |
+
EncoderNoRepeatNGramLogitsProcessor,
|
| 222 |
+
EncoderRepetitionPenaltyLogitsProcessor,
|
| 223 |
+
EpsilonLogitsWarper,
|
| 224 |
+
EtaLogitsWarper,
|
| 225 |
+
ExponentialDecayLengthPenalty,
|
| 226 |
+
ForcedBOSTokenLogitsProcessor,
|
| 227 |
+
ForcedEOSTokenLogitsProcessor,
|
| 228 |
+
HammingDiversityLogitsProcessor,
|
| 229 |
+
InfNanRemoveLogitsProcessor,
|
| 230 |
+
LogitNormalization,
|
| 231 |
+
LogitsProcessor,
|
| 232 |
+
LogitsProcessorList,
|
| 233 |
+
LogitsWarper,
|
| 234 |
+
MinLengthLogitsProcessor,
|
| 235 |
+
MinNewTokensLengthLogitsProcessor,
|
| 236 |
+
MinPLogitsWarper,
|
| 237 |
+
NoBadWordsLogitsProcessor,
|
| 238 |
+
NoRepeatNGramLogitsProcessor,
|
| 239 |
+
PrefixConstrainedLogitsProcessor,
|
| 240 |
+
RepetitionPenaltyLogitsProcessor,
|
| 241 |
+
SequenceBiasLogitsProcessor,
|
| 242 |
+
SuppressTokensAtBeginLogitsProcessor,
|
| 243 |
+
SuppressTokensLogitsProcessor,
|
| 244 |
+
SynthIDTextWatermarkLogitsProcessor,
|
| 245 |
+
TemperatureLogitsWarper,
|
| 246 |
+
TopKLogitsWarper,
|
| 247 |
+
TopPLogitsWarper,
|
| 248 |
+
TypicalLogitsWarper,
|
| 249 |
+
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
| 250 |
+
WatermarkLogitsProcessor,
|
| 251 |
+
WhisperTimeStampLogitsProcessor,
|
| 252 |
+
)
|
| 253 |
+
from .stopping_criteria import (
|
| 254 |
+
ConfidenceCriteria,
|
| 255 |
+
EosTokenCriteria,
|
| 256 |
+
MaxLengthCriteria,
|
| 257 |
+
MaxNewTokensCriteria,
|
| 258 |
+
MaxTimeCriteria,
|
| 259 |
+
StoppingCriteria,
|
| 260 |
+
StoppingCriteriaList,
|
| 261 |
+
StopStringCriteria,
|
| 262 |
+
validate_stopping_criteria,
|
| 263 |
+
)
|
| 264 |
+
from .utils import (
|
| 265 |
+
BeamSampleDecoderOnlyOutput,
|
| 266 |
+
BeamSampleEncoderDecoderOutput,
|
| 267 |
+
BeamSearchDecoderOnlyOutput,
|
| 268 |
+
BeamSearchEncoderDecoderOutput,
|
| 269 |
+
ContrastiveSearchDecoderOnlyOutput,
|
| 270 |
+
ContrastiveSearchEncoderDecoderOutput,
|
| 271 |
+
GenerateBeamDecoderOnlyOutput,
|
| 272 |
+
GenerateBeamEncoderDecoderOutput,
|
| 273 |
+
GenerateDecoderOnlyOutput,
|
| 274 |
+
GenerateEncoderDecoderOutput,
|
| 275 |
+
GenerationMixin,
|
| 276 |
+
GreedySearchDecoderOnlyOutput,
|
| 277 |
+
GreedySearchEncoderDecoderOutput,
|
| 278 |
+
SampleDecoderOnlyOutput,
|
| 279 |
+
SampleEncoderDecoderOutput,
|
| 280 |
+
)
|
| 281 |
+
from .watermarking import (
|
| 282 |
+
BayesianDetectorConfig,
|
| 283 |
+
BayesianDetectorModel,
|
| 284 |
+
SynthIDTextWatermarkDetector,
|
| 285 |
+
WatermarkDetector,
|
| 286 |
+
WatermarkDetectorOutput,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
try:
|
| 290 |
+
if not is_tf_available():
|
| 291 |
+
raise OptionalDependencyNotAvailable()
|
| 292 |
+
except OptionalDependencyNotAvailable:
|
| 293 |
+
pass
|
| 294 |
+
else:
|
| 295 |
+
from .tf_logits_process import (
|
| 296 |
+
TFForcedBOSTokenLogitsProcessor,
|
| 297 |
+
TFForcedEOSTokenLogitsProcessor,
|
| 298 |
+
TFForceTokensLogitsProcessor,
|
| 299 |
+
TFLogitsProcessor,
|
| 300 |
+
TFLogitsProcessorList,
|
| 301 |
+
TFLogitsWarper,
|
| 302 |
+
TFMinLengthLogitsProcessor,
|
| 303 |
+
TFNoBadWordsLogitsProcessor,
|
| 304 |
+
TFNoRepeatNGramLogitsProcessor,
|
| 305 |
+
TFRepetitionPenaltyLogitsProcessor,
|
| 306 |
+
TFSuppressTokensAtBeginLogitsProcessor,
|
| 307 |
+
TFSuppressTokensLogitsProcessor,
|
| 308 |
+
TFTemperatureLogitsWarper,
|
| 309 |
+
TFTopKLogitsWarper,
|
| 310 |
+
TFTopPLogitsWarper,
|
| 311 |
+
)
|
| 312 |
+
from .tf_utils import (
|
| 313 |
+
TFBeamSampleDecoderOnlyOutput,
|
| 314 |
+
TFBeamSampleEncoderDecoderOutput,
|
| 315 |
+
TFBeamSearchDecoderOnlyOutput,
|
| 316 |
+
TFBeamSearchEncoderDecoderOutput,
|
| 317 |
+
TFContrastiveSearchDecoderOnlyOutput,
|
| 318 |
+
TFContrastiveSearchEncoderDecoderOutput,
|
| 319 |
+
TFGenerationMixin,
|
| 320 |
+
TFGreedySearchDecoderOnlyOutput,
|
| 321 |
+
TFGreedySearchEncoderDecoderOutput,
|
| 322 |
+
TFSampleDecoderOnlyOutput,
|
| 323 |
+
TFSampleEncoderDecoderOutput,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
try:
|
| 327 |
+
if not is_flax_available():
|
| 328 |
+
raise OptionalDependencyNotAvailable()
|
| 329 |
+
except OptionalDependencyNotAvailable:
|
| 330 |
+
pass
|
| 331 |
+
else:
|
| 332 |
+
from .flax_logits_process import (
|
| 333 |
+
FlaxForcedBOSTokenLogitsProcessor,
|
| 334 |
+
FlaxForcedEOSTokenLogitsProcessor,
|
| 335 |
+
FlaxForceTokensLogitsProcessor,
|
| 336 |
+
FlaxLogitsProcessor,
|
| 337 |
+
FlaxLogitsProcessorList,
|
| 338 |
+
FlaxLogitsWarper,
|
| 339 |
+
FlaxMinLengthLogitsProcessor,
|
| 340 |
+
FlaxNoRepeatNGramLogitsProcessor,
|
| 341 |
+
FlaxSuppressTokensAtBeginLogitsProcessor,
|
| 342 |
+
FlaxSuppressTokensLogitsProcessor,
|
| 343 |
+
FlaxTemperatureLogitsWarper,
|
| 344 |
+
FlaxTopKLogitsWarper,
|
| 345 |
+
FlaxTopPLogitsWarper,
|
| 346 |
+
FlaxWhisperTimeStampLogitsProcessor,
|
| 347 |
+
)
|
| 348 |
+
from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput
|
| 349 |
+
else:
|
| 350 |
+
import sys
|
| 351 |
+
|
| 352 |
+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
.venv/Lib/site-packages/transformers/generation/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (6.45 kB). View file
|
|
|
.venv/Lib/site-packages/transformers/generation/__pycache__/beam_constraints.cpython-39.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
.venv/Lib/site-packages/transformers/generation/__pycache__/beam_search.cpython-39.pyc
ADDED
|
Binary file (28.8 kB). View file
|
|
|
.venv/Lib/site-packages/transformers/generation/__pycache__/candidate_generator.cpython-39.pyc
ADDED
|
Binary file (25.4 kB). View file
|
|
|
.venv/Lib/site-packages/transformers/generation/__pycache__/configuration_utils.cpython-39.pyc
ADDED
|
Binary file (65.9 kB). View file
|
|
|
.venv/Lib/site-packages/transformers/generation/__pycache__/logits_process.cpython-39.pyc
ADDED
|
Binary file (122 kB). View file
|
|
|
.venv/Lib/site-packages/transformers/generation/__pycache__/stopping_criteria.cpython-39.pyc
ADDED
|
Binary file (23.7 kB). View file
|
|
|
.venv/Lib/site-packages/transformers/generation/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (130 kB). View file
|
|
|