ayousanz commited on
Commit
865b15f
·
verified ·
1 Parent(s): 4ca3bd3

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/Lib/site-packages/torio/lib/libtorio_ffmpeg5.pyd +3 -0
  3. .venv/Lib/site-packages/transformers/__init__.py +0 -0
  4. .venv/Lib/site-packages/transformers/agents/__init__.py +69 -0
  5. .venv/Lib/site-packages/transformers/agents/agent_types.py +260 -0
  6. .venv/Lib/site-packages/transformers/agents/agents.py +1278 -0
  7. .venv/Lib/site-packages/transformers/agents/default_tools.py +187 -0
  8. .venv/Lib/site-packages/transformers/agents/document_question_answering.py +89 -0
  9. .venv/Lib/site-packages/transformers/agents/evaluate_agent.py +414 -0
  10. .venv/Lib/site-packages/transformers/agents/image_question_answering.py +58 -0
  11. .venv/Lib/site-packages/transformers/agents/llm_engine.py +238 -0
  12. .venv/Lib/site-packages/transformers/agents/monitoring.py +117 -0
  13. .venv/Lib/site-packages/transformers/agents/prompts.py +789 -0
  14. .venv/Lib/site-packages/transformers/agents/python_interpreter.py +908 -0
  15. .venv/Lib/site-packages/transformers/agents/search.py +77 -0
  16. .venv/Lib/site-packages/transformers/agents/speech_to_text.py +39 -0
  17. .venv/Lib/site-packages/transformers/agents/text_to_speech.py +67 -0
  18. .venv/Lib/site-packages/transformers/agents/tools.py +1003 -0
  19. .venv/Lib/site-packages/transformers/agents/translation.py +279 -0
  20. .venv/Lib/site-packages/transformers/benchmark/benchmark.py +270 -0
  21. .venv/Lib/site-packages/transformers/benchmark/benchmark_args.py +124 -0
  22. .venv/Lib/site-packages/transformers/benchmark/benchmark_args_tf.py +136 -0
  23. .venv/Lib/site-packages/transformers/commands/__init__.py +27 -0
  24. .venv/Lib/site-packages/transformers/commands/run.py +110 -0
  25. .venv/Lib/site-packages/transformers/commands/serving.py +228 -0
  26. .venv/Lib/site-packages/transformers/commands/train.py +158 -0
  27. .venv/Lib/site-packages/transformers/commands/transformers_cli.py +57 -0
  28. .venv/Lib/site-packages/transformers/commands/user.py +197 -0
  29. .venv/Lib/site-packages/transformers/data/__init__.py +45 -0
  30. .venv/Lib/site-packages/transformers/data/data_collator.py +1653 -0
  31. .venv/Lib/site-packages/transformers/data/datasets/__init__.py +23 -0
  32. .venv/Lib/site-packages/transformers/data/datasets/glue.py +161 -0
  33. .venv/Lib/site-packages/transformers/data/datasets/language_modeling.py +530 -0
  34. .venv/Lib/site-packages/transformers/data/datasets/squad.py +229 -0
  35. .venv/Lib/site-packages/transformers/data/metrics/__init__.py +98 -0
  36. .venv/Lib/site-packages/transformers/data/metrics/squad_metrics.py +779 -0
  37. .venv/Lib/site-packages/transformers/data/processors/__init__.py +18 -0
  38. .venv/Lib/site-packages/transformers/data/processors/glue.py +643 -0
  39. .venv/Lib/site-packages/transformers/data/processors/squad.py +845 -0
  40. .venv/Lib/site-packages/transformers/data/processors/utils.py +349 -0
  41. .venv/Lib/site-packages/transformers/data/processors/xnli.py +96 -0
  42. .venv/Lib/site-packages/transformers/generation/__init__.py +352 -0
  43. .venv/Lib/site-packages/transformers/generation/__pycache__/__init__.cpython-39.pyc +0 -0
  44. .venv/Lib/site-packages/transformers/generation/__pycache__/beam_constraints.cpython-39.pyc +0 -0
  45. .venv/Lib/site-packages/transformers/generation/__pycache__/beam_search.cpython-39.pyc +0 -0
  46. .venv/Lib/site-packages/transformers/generation/__pycache__/candidate_generator.cpython-39.pyc +0 -0
  47. .venv/Lib/site-packages/transformers/generation/__pycache__/configuration_utils.cpython-39.pyc +0 -0
  48. .venv/Lib/site-packages/transformers/generation/__pycache__/logits_process.cpython-39.pyc +0 -0
  49. .venv/Lib/site-packages/transformers/generation/__pycache__/stopping_criteria.cpython-39.pyc +0 -0
  50. .venv/Lib/site-packages/transformers/generation/__pycache__/utils.cpython-39.pyc +0 -0
.gitattributes CHANGED
@@ -83,3 +83,4 @@ reference_sample_wavs/syuukovoice_200918_3_01.wav filter=lfs diff=lfs merge=lfs
83
  .venv/Lib/site-packages/torchaudio/lib/libtorchaudio.pyd filter=lfs diff=lfs merge=lfs -text
84
  .venv/Lib/site-packages/torchvision/nvjpeg64_12.dll filter=lfs diff=lfs merge=lfs -text
85
  .venv/Lib/site-packages/torchvision/_C.pyd filter=lfs diff=lfs merge=lfs -text
 
 
83
  .venv/Lib/site-packages/torchaudio/lib/libtorchaudio.pyd filter=lfs diff=lfs merge=lfs -text
84
  .venv/Lib/site-packages/torchvision/nvjpeg64_12.dll filter=lfs diff=lfs merge=lfs -text
85
  .venv/Lib/site-packages/torchvision/_C.pyd filter=lfs diff=lfs merge=lfs -text
86
+ .venv/Lib/site-packages/torio/lib/libtorio_ffmpeg5.pyd filter=lfs diff=lfs merge=lfs -text
.venv/Lib/site-packages/torio/lib/libtorio_ffmpeg5.pyd ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7abc7280260cda768b24d26ab52f7f1409d073b921bb57b52ffde627d2200bb5
3
+ size 1094656
.venv/Lib/site-packages/transformers/__init__.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/Lib/site-packages/transformers/agents/__init__.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ from typing import TYPE_CHECKING
18
+
19
+ from ..utils import (
20
+ OptionalDependencyNotAvailable,
21
+ _LazyModule,
22
+ is_torch_available,
23
+ )
24
+
25
+
26
+ _import_structure = {
27
+ "agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
28
+ "llm_engine": ["HfApiEngine", "TransformersEngine"],
29
+ "monitoring": ["stream_to_gradio"],
30
+ "tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool", "tool"],
31
+ }
32
+
33
+ try:
34
+ if not is_torch_available():
35
+ raise OptionalDependencyNotAvailable()
36
+ except OptionalDependencyNotAvailable:
37
+ pass
38
+ else:
39
+ _import_structure["default_tools"] = ["FinalAnswerTool", "PythonInterpreterTool"]
40
+ _import_structure["document_question_answering"] = ["DocumentQuestionAnsweringTool"]
41
+ _import_structure["image_question_answering"] = ["ImageQuestionAnsweringTool"]
42
+ _import_structure["search"] = ["DuckDuckGoSearchTool", "VisitWebpageTool"]
43
+ _import_structure["speech_to_text"] = ["SpeechToTextTool"]
44
+ _import_structure["text_to_speech"] = ["TextToSpeechTool"]
45
+ _import_structure["translation"] = ["TranslationTool"]
46
+
47
+ if TYPE_CHECKING:
48
+ from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
49
+ from .llm_engine import HfApiEngine, TransformersEngine
50
+ from .monitoring import stream_to_gradio
51
+ from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool, tool
52
+
53
+ try:
54
+ if not is_torch_available():
55
+ raise OptionalDependencyNotAvailable()
56
+ except OptionalDependencyNotAvailable:
57
+ pass
58
+ else:
59
+ from .default_tools import FinalAnswerTool, PythonInterpreterTool
60
+ from .document_question_answering import DocumentQuestionAnsweringTool
61
+ from .image_question_answering import ImageQuestionAnsweringTool
62
+ from .search import DuckDuckGoSearchTool, VisitWebpageTool
63
+ from .speech_to_text import SpeechToTextTool
64
+ from .text_to_speech import TextToSpeechTool
65
+ from .translation import TranslationTool
66
+ else:
67
+ import sys
68
+
69
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.venv/Lib/site-packages/transformers/agents/agent_types.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import os
16
+ import pathlib
17
+ import tempfile
18
+ import uuid
19
+
20
+ import numpy as np
21
+
22
+ from ..utils import is_soundfile_availble, is_torch_available, is_vision_available, logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ if is_vision_available():
28
+ from PIL import Image
29
+ from PIL.Image import Image as ImageType
30
+ else:
31
+ ImageType = object
32
+
33
+ if is_torch_available():
34
+ import torch
35
+ from torch import Tensor
36
+ else:
37
+ Tensor = object
38
+
39
+ if is_soundfile_availble():
40
+ import soundfile as sf
41
+
42
+
43
+ class AgentType:
44
+ """
45
+ Abstract class to be reimplemented to define types that can be returned by agents.
46
+
47
+ These objects serve three purposes:
48
+
49
+ - They behave as they were the type they're meant to be, e.g., a string for text, a PIL.Image for images
50
+ - They can be stringified: str(object) in order to return a string defining the object
51
+ - They should be displayed correctly in ipython notebooks/colab/jupyter
52
+ """
53
+
54
+ def __init__(self, value):
55
+ self._value = value
56
+
57
+ def __str__(self):
58
+ return self.to_string()
59
+
60
+ def to_raw(self):
61
+ logger.error(
62
+ "This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
63
+ )
64
+ return self._value
65
+
66
+ def to_string(self) -> str:
67
+ logger.error(
68
+ "This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
69
+ )
70
+ return str(self._value)
71
+
72
+
73
+ class AgentText(AgentType, str):
74
+ """
75
+ Text type returned by the agent. Behaves as a string.
76
+ """
77
+
78
+ def to_raw(self):
79
+ return self._value
80
+
81
+ def to_string(self):
82
+ return str(self._value)
83
+
84
+
85
+ class AgentImage(AgentType, ImageType):
86
+ """
87
+ Image type returned by the agent. Behaves as a PIL.Image.
88
+ """
89
+
90
+ def __init__(self, value):
91
+ AgentType.__init__(self, value)
92
+ ImageType.__init__(self)
93
+
94
+ if not is_vision_available():
95
+ raise ImportError("PIL must be installed in order to handle images.")
96
+
97
+ self._path = None
98
+ self._raw = None
99
+ self._tensor = None
100
+
101
+ if isinstance(value, ImageType):
102
+ self._raw = value
103
+ elif isinstance(value, (str, pathlib.Path)):
104
+ self._path = value
105
+ elif isinstance(value, torch.Tensor):
106
+ self._tensor = value
107
+ elif isinstance(value, np.ndarray):
108
+ self._tensor = torch.from_numpy(value)
109
+ else:
110
+ raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
111
+
112
+ def _ipython_display_(self, include=None, exclude=None):
113
+ """
114
+ Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
115
+ """
116
+ from IPython.display import Image, display
117
+
118
+ display(Image(self.to_string()))
119
+
120
+ def to_raw(self):
121
+ """
122
+ Returns the "raw" version of that object. In the case of an AgentImage, it is a PIL.Image.
123
+ """
124
+ if self._raw is not None:
125
+ return self._raw
126
+
127
+ if self._path is not None:
128
+ self._raw = Image.open(self._path)
129
+ return self._raw
130
+
131
+ if self._tensor is not None:
132
+ array = self._tensor.cpu().detach().numpy()
133
+ return Image.fromarray((255 - array * 255).astype(np.uint8))
134
+
135
+ def to_string(self):
136
+ """
137
+ Returns the stringified version of that object. In the case of an AgentImage, it is a path to the serialized
138
+ version of the image.
139
+ """
140
+ if self._path is not None:
141
+ return self._path
142
+
143
+ if self._raw is not None:
144
+ directory = tempfile.mkdtemp()
145
+ self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
146
+ self._raw.save(self._path)
147
+ return self._path
148
+
149
+ if self._tensor is not None:
150
+ array = self._tensor.cpu().detach().numpy()
151
+
152
+ # There is likely simpler than load into image into save
153
+ img = Image.fromarray((255 - array * 255).astype(np.uint8))
154
+
155
+ directory = tempfile.mkdtemp()
156
+ self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
157
+
158
+ img.save(self._path)
159
+
160
+ return self._path
161
+
162
+ def save(self, output_bytes, format, **params):
163
+ """
164
+ Saves the image to a file.
165
+ Args:
166
+ output_bytes (bytes): The output bytes to save the image to.
167
+ format (str): The format to use for the output image. The format is the same as in PIL.Image.save.
168
+ **params: Additional parameters to pass to PIL.Image.save.
169
+ """
170
+ img = self.to_raw()
171
+ img.save(output_bytes, format, **params)
172
+
173
+
174
+ class AgentAudio(AgentType, str):
175
+ """
176
+ Audio type returned by the agent.
177
+ """
178
+
179
+ def __init__(self, value, samplerate=16_000):
180
+ super().__init__(value)
181
+
182
+ if not is_soundfile_availble():
183
+ raise ImportError("soundfile must be installed in order to handle audio.")
184
+
185
+ self._path = None
186
+ self._tensor = None
187
+
188
+ self.samplerate = samplerate
189
+ if isinstance(value, (str, pathlib.Path)):
190
+ self._path = value
191
+ elif is_torch_available() and isinstance(value, torch.Tensor):
192
+ self._tensor = value
193
+ elif isinstance(value, tuple):
194
+ self.samplerate = value[0]
195
+ if isinstance(value[1], np.ndarray):
196
+ self._tensor = torch.from_numpy(value[1])
197
+ else:
198
+ self._tensor = torch.tensor(value[1])
199
+ else:
200
+ raise ValueError(f"Unsupported audio type: {type(value)}")
201
+
202
+ def _ipython_display_(self, include=None, exclude=None):
203
+ """
204
+ Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
205
+ """
206
+ from IPython.display import Audio, display
207
+
208
+ display(Audio(self.to_string(), rate=self.samplerate))
209
+
210
+ def to_raw(self):
211
+ """
212
+ Returns the "raw" version of that object. It is a `torch.Tensor` object.
213
+ """
214
+ if self._tensor is not None:
215
+ return self._tensor
216
+
217
+ if self._path is not None:
218
+ tensor, self.samplerate = sf.read(self._path)
219
+ self._tensor = torch.tensor(tensor)
220
+ return self._tensor
221
+
222
+ def to_string(self):
223
+ """
224
+ Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
225
+ version of the audio.
226
+ """
227
+ if self._path is not None:
228
+ return self._path
229
+
230
+ if self._tensor is not None:
231
+ directory = tempfile.mkdtemp()
232
+ self._path = os.path.join(directory, str(uuid.uuid4()) + ".wav")
233
+ sf.write(self._path, self._tensor, samplerate=self.samplerate)
234
+ return self._path
235
+
236
+
237
+ AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
238
+ INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage}
239
+
240
+ if is_torch_available():
241
+ INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio
242
+
243
+
244
+ def handle_agent_inputs(*args, **kwargs):
245
+ args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
246
+ kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()}
247
+ return args, kwargs
248
+
249
+
250
+ def handle_agent_outputs(output, output_type=None):
251
+ if output_type in AGENT_TYPE_MAPPING:
252
+ # If the class has defined outputs, we can map directly according to the class definition
253
+ decoded_outputs = AGENT_TYPE_MAPPING[output_type](output)
254
+ return decoded_outputs
255
+ else:
256
+ # If the class does not have defined output, then we map according to the type
257
+ for _k, _v in INSTANCE_TYPE_MAPPING.items():
258
+ if isinstance(output, _k):
259
+ return _v(output)
260
+ return output
.venv/Lib/site-packages/transformers/agents/agents.py ADDED
@@ -0,0 +1,1278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import json
18
+ import logging
19
+ import re
20
+ import time
21
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
22
+
23
+ from .. import is_torch_available
24
+ from ..utils import logging as transformers_logging
25
+ from ..utils.import_utils import is_pygments_available
26
+ from .agent_types import AgentAudio, AgentImage
27
+ from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
28
+ from .llm_engine import HfApiEngine, MessageRole
29
+ from .monitoring import Monitor
30
+ from .prompts import (
31
+ DEFAULT_CODE_SYSTEM_PROMPT,
32
+ DEFAULT_REACT_CODE_SYSTEM_PROMPT,
33
+ DEFAULT_REACT_JSON_SYSTEM_PROMPT,
34
+ PLAN_UPDATE_FINAL_PLAN_REDACTION,
35
+ PROMPTS_FOR_INITIAL_PLAN,
36
+ PROMPTS_FOR_PLAN_UPDATE,
37
+ SUPPORTED_PLAN_TYPES,
38
+ SYSTEM_PROMPT_FACTS,
39
+ SYSTEM_PROMPT_FACTS_UPDATE,
40
+ USER_PROMPT_FACTS_UPDATE,
41
+ )
42
+ from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
43
+ from .tools import (
44
+ DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
45
+ Tool,
46
+ get_tool_description_with_args,
47
+ load_tool,
48
+ )
49
+
50
+
51
+ if is_pygments_available():
52
+ from pygments import highlight
53
+ from pygments.formatters import Terminal256Formatter
54
+ from pygments.lexers import PythonLexer
55
+
56
+
57
+ class CustomFormatter(logging.Formatter):
58
+ grey = "\x1b[38;20m"
59
+ bold_yellow = "\x1b[33;1m"
60
+ red = "\x1b[31;20m"
61
+ green = "\x1b[32;20m"
62
+ bold_green = "\x1b[32;20;1m"
63
+ bold_red = "\x1b[31;1m"
64
+ bold_white = "\x1b[37;1m"
65
+ orange = "\x1b[38;5;214m"
66
+ bold_orange = "\x1b[38;5;214;1m"
67
+ reset = "\x1b[0m"
68
+ format = "%(message)s"
69
+
70
+ FORMATS = {
71
+ logging.DEBUG: grey + format + reset,
72
+ logging.INFO: format,
73
+ logging.WARNING: bold_yellow + format + reset,
74
+ logging.ERROR: red + format + reset,
75
+ logging.CRITICAL: bold_red + format + reset,
76
+ 31: reset + format + reset,
77
+ 32: green + format + reset,
78
+ 33: bold_green + format + reset,
79
+ 34: bold_white + format + reset,
80
+ 35: orange + format + reset,
81
+ 36: bold_orange + format + reset,
82
+ }
83
+
84
+ def format(self, record):
85
+ log_fmt = self.FORMATS.get(record.levelno)
86
+ formatter = logging.Formatter(log_fmt)
87
+ return formatter.format(record)
88
+
89
+
90
+ logger = transformers_logging.get_logger(__name__)
91
+ logger.propagate = False
92
+ ch = logging.StreamHandler()
93
+ ch.setFormatter(CustomFormatter())
94
+ logger.addHandler(ch)
95
+
96
+
97
+ def parse_json_blob(json_blob: str) -> Dict[str, str]:
98
+ try:
99
+ first_accolade_index = json_blob.find("{")
100
+ last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
101
+ json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'")
102
+ json_data = json.loads(json_blob, strict=False)
103
+ return json_data
104
+ except json.JSONDecodeError as e:
105
+ place = e.pos
106
+ if json_blob[place - 1 : place + 2] == "},\n":
107
+ raise ValueError(
108
+ "JSON is invalid: you probably tried to provide multiple tool calls in one action. PROVIDE ONLY ONE TOOL CALL."
109
+ )
110
+ raise ValueError(
111
+ f"The JSON blob you used is invalid due to the following error: {e}.\n"
112
+ f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n"
113
+ f"'{json_blob[place-4:place+5]}'."
114
+ )
115
+ except Exception as e:
116
+ raise ValueError(f"Error in parsing the JSON blob: {e}")
117
+
118
+
119
+ def parse_code_blob(code_blob: str) -> str:
120
+ try:
121
+ pattern = r"```(?:py|python)?\n(.*?)\n```"
122
+ match = re.search(pattern, code_blob, re.DOTALL)
123
+ return match.group(1).strip()
124
+ except Exception as e:
125
+ raise ValueError(
126
+ f"""
127
+ The code blob you used is invalid: due to the following error: {e}
128
+ This means that the regex pattern {pattern} was not respected: make sure to include code with the correct pattern, for instance:
129
+ Thoughts: Your thoughts
130
+ Code:
131
+ ```py
132
+ # Your python code here
133
+ ```<end_action>"""
134
+ )
135
+
136
+
137
+ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
138
+ json_blob = json_blob.replace("```json", "").replace("```", "")
139
+ tool_call = parse_json_blob(json_blob)
140
+ if "action" in tool_call and "action_input" in tool_call:
141
+ return tool_call["action"], tool_call["action_input"]
142
+ elif "action" in tool_call:
143
+ return tool_call["action"], None
144
+ else:
145
+ raise ValueError(
146
+ f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}"
147
+ )
148
+
149
+
150
+ def parse_text_tool_call(text: str) -> Tuple[str, Union[str, Dict[str, str]]]:
151
+ """
152
+ Expects a text in the format: 'Action:', 'Action input:', 'Observation:'. 'Action input:' contains a json string with input arguments.
153
+ """
154
+ try:
155
+ if "Observation:" in text:
156
+ text = text.split("Observation:")[0]
157
+ if "Action:" in text:
158
+ text = text.split("Action:")[1]
159
+ tool_name, tool_input = text.split("Action input:")
160
+ if "{" in tool_input:
161
+ tool_input = parse_json_blob(tool_input)
162
+ else:
163
+ tool_input = tool_input.strip().replace('"', "")
164
+ return tool_name.strip().replace('"', "").replace("\\", ""), tool_input
165
+ except Exception as e:
166
+ raise ValueError(
167
+ f"Error in parsing the text tool call: {e}. Be sure to provide the correct format. DO NOT repeat your previous incorrect tool call."
168
+ )
169
+
170
+
171
+ def to_text(input: Union[List[Dict[str, str]], Dict[str, str], str]) -> str:
172
+ if isinstance(input, list):
173
+ return "\n".join([m["content"] for m in input])
174
+ elif isinstance(input, dict):
175
+ return input["content"]
176
+ else:
177
+ return input
178
+
179
+
180
+ HUGGINGFACE_DEFAULT_TOOLS = {}
181
+ _tools_are_initialized = False
182
+
183
+
184
+ class Toolbox:
185
+ """
186
+ The toolbox contains all tools that the agent can perform operations with, as well as a few methods to
187
+ manage them.
188
+
189
+ Args:
190
+ tools (`List[Tool]`):
191
+ The list of tools to instantiate the toolbox with
192
+ add_base_tools (`bool`, defaults to `False`, *optional*, defaults to `False`):
193
+ Whether to add the tools available within `transformers` to the toolbox.
194
+ """
195
+
196
+ def __init__(self, tools: List[Tool], add_base_tools: bool = False):
197
+ self._tools = {tool.name: tool for tool in tools}
198
+ if add_base_tools:
199
+ self.add_base_tools()
200
+ self._load_tools_if_needed()
201
+
202
+ def add_base_tools(self, add_python_interpreter: bool = False):
203
+ global _tools_are_initialized
204
+ global HUGGINGFACE_DEFAULT_TOOLS
205
+ if not _tools_are_initialized:
206
+ HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools(logger)
207
+ _tools_are_initialized = True
208
+ for tool in HUGGINGFACE_DEFAULT_TOOLS.values():
209
+ if tool.name != "python_interpreter" or add_python_interpreter:
210
+ self.add_tool(tool)
211
+ self._load_tools_if_needed()
212
+
213
+ @property
214
+ def tools(self) -> Dict[str, Tool]:
215
+ """Get all tools currently in the toolbox"""
216
+ return self._tools
217
+
218
+ def show_tool_descriptions(self, tool_description_template: str = None) -> str:
219
+ """
220
+ Returns the description of all tools in the toolbox
221
+
222
+ Args:
223
+ tool_description_template (`str`, *optional*):
224
+ The template to use to describe the tools. If not provided, the default template will be used.
225
+ """
226
+ return "\n".join(
227
+ [get_tool_description_with_args(tool, tool_description_template) for tool in self._tools.values()]
228
+ )
229
+
230
+ def add_tool(self, tool: Tool):
231
+ """
232
+ Adds a tool to the toolbox
233
+
234
+ Args:
235
+ tool (`Tool`):
236
+ The tool to add to the toolbox.
237
+ """
238
+ if tool.name in self._tools:
239
+ raise KeyError(f"Error: tool '{tool.name}' already exists in the toolbox.")
240
+ self._tools[tool.name] = tool
241
+
242
+ def remove_tool(self, tool_name: str):
243
+ """
244
+ Removes a tool from the toolbox
245
+
246
+ Args:
247
+ tool_name (`str`):
248
+ The tool to remove from the toolbox.
249
+ """
250
+ if tool_name not in self._tools:
251
+ raise KeyError(
252
+ f"Error: tool {tool_name} not found in toolbox for removal, should be instead one of {list(self._tools.keys())}."
253
+ )
254
+ del self._tools[tool_name]
255
+
256
+ def update_tool(self, tool: Tool):
257
+ """
258
+ Updates a tool in the toolbox according to its name.
259
+
260
+ Args:
261
+ tool (`Tool`):
262
+ The tool to update to the toolbox.
263
+ """
264
+ if tool.name not in self._tools:
265
+ raise KeyError(
266
+ f"Error: tool {tool.name} not found in toolbox for update, should be instead one of {list(self._tools.keys())}."
267
+ )
268
+ self._tools[tool.name] = tool
269
+
270
+ def clear_toolbox(self):
271
+ """Clears the toolbox"""
272
+ self._tools = {}
273
+
274
+ def _load_tools_if_needed(self):
275
+ for name, tool in self._tools.items():
276
+ if not isinstance(tool, Tool):
277
+ task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
278
+ self._tools[name] = load_tool(task_or_repo_id)
279
+
280
+ def __repr__(self):
281
+ toolbox_description = "Toolbox contents:\n"
282
+ for tool in self._tools.values():
283
+ toolbox_description += f"\t{tool.name}: {tool.description}\n"
284
+ return toolbox_description
285
+
286
+
287
+ class AgentError(Exception):
288
+ """Base class for other agent-related exceptions"""
289
+
290
+ def __init__(self, message):
291
+ super().__init__(message)
292
+ self.message = message
293
+
294
+
295
+ class AgentParsingError(AgentError):
296
+ """Exception raised for errors in parsing in the agent"""
297
+
298
+ pass
299
+
300
+
301
+ class AgentExecutionError(AgentError):
302
+ """Exception raised for errors in execution in the agent"""
303
+
304
+ pass
305
+
306
+
307
+ class AgentMaxIterationsError(AgentError):
308
+ """Exception raised for errors in execution in the agent"""
309
+
310
+ pass
311
+
312
+
313
+ class AgentGenerationError(AgentError):
314
+ """Exception raised for errors in generation in the agent"""
315
+
316
+ pass
317
+
318
+
319
+ def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
320
+ tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
321
+ prompt = prompt_template.replace("<<tool_descriptions>>", tool_descriptions)
322
+
323
+ if "<<tool_names>>" in prompt:
324
+ tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()]
325
+ prompt = prompt.replace("<<tool_names>>", ", ".join(tool_names))
326
+
327
+ return prompt
328
+
329
+
330
+ def show_agents_descriptions(managed_agents: list):
331
+ managed_agents_descriptions = """
332
+ You can also give requests to team members.
333
+ Calling a team member works the same as for calling a tool: simply, the only argument you can give in the call is 'request', a long string explaning your request.
334
+ Given that this team member is a real human, you should be very verbose in your request.
335
+ Here is a list of the team members that you can call:"""
336
+ for agent in managed_agents.values():
337
+ managed_agents_descriptions += f"\n- {agent.name}: {agent.description}"
338
+ return managed_agents_descriptions
339
+
340
+
341
+ def format_prompt_with_managed_agents_descriptions(prompt_template, managed_agents=None) -> str:
342
+ if managed_agents is not None:
343
+ return prompt_template.replace("<<managed_agents_descriptions>>", show_agents_descriptions(managed_agents))
344
+ else:
345
+ return prompt_template.replace("<<managed_agents_descriptions>>", "")
346
+
347
+
348
+ def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str:
349
+ if "<<authorized_imports>>" not in prompt_template:
350
+ raise AgentError("Tag '<<authorized_imports>>' should be provided in the prompt.")
351
+ return prompt_template.replace("<<authorized_imports>>", str(authorized_imports))
352
+
353
+
354
+ class Agent:
355
+ def __init__(
356
+ self,
357
+ tools: Union[List[Tool], Toolbox],
358
+ llm_engine: Callable = None,
359
+ system_prompt: Optional[str] = None,
360
+ tool_description_template: Optional[str] = None,
361
+ additional_args: Dict = {},
362
+ max_iterations: int = 6,
363
+ tool_parser: Optional[Callable] = None,
364
+ add_base_tools: bool = False,
365
+ verbose: int = 0,
366
+ grammar: Optional[Dict[str, str]] = None,
367
+ managed_agents: Optional[List] = None,
368
+ step_callbacks: Optional[List[Callable]] = None,
369
+ monitor_metrics: bool = True,
370
+ ):
371
+ if system_prompt is None:
372
+ system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
373
+ if tool_parser is None:
374
+ tool_parser = parse_json_tool_call
375
+ self.agent_name = self.__class__.__name__
376
+ self.llm_engine = llm_engine
377
+ self.system_prompt_template = system_prompt
378
+ self.tool_description_template = (
379
+ tool_description_template if tool_description_template else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
380
+ )
381
+ self.additional_args = additional_args
382
+ self.max_iterations = max_iterations
383
+ self.logger = logger
384
+ self.tool_parser = tool_parser
385
+ self.grammar = grammar
386
+
387
+ self.managed_agents = None
388
+ if managed_agents is not None:
389
+ self.managed_agents = {agent.name: agent for agent in managed_agents}
390
+
391
+ if isinstance(tools, Toolbox):
392
+ self._toolbox = tools
393
+ if add_base_tools:
394
+ if not is_torch_available():
395
+ raise ImportError("Using the base tools requires torch to be installed.")
396
+
397
+ self._toolbox.add_base_tools(add_python_interpreter=(self.__class__ == ReactJsonAgent))
398
+ else:
399
+ self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
400
+ self._toolbox.add_tool(FinalAnswerTool())
401
+
402
+ self.system_prompt = format_prompt_with_tools(
403
+ self._toolbox, self.system_prompt_template, self.tool_description_template
404
+ )
405
+ self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
406
+ self.prompt = None
407
+ self.logs = []
408
+ self.task = None
409
+
410
+ if verbose == 0:
411
+ logger.setLevel(logging.WARNING)
412
+ elif verbose == 1:
413
+ logger.setLevel(logging.INFO)
414
+ elif verbose == 2:
415
+ logger.setLevel(logging.DEBUG)
416
+
417
+ # Initialize step callbacks
418
+ self.step_callbacks = step_callbacks if step_callbacks is not None else []
419
+
420
+ # Initialize Monitor if monitor_metrics is True
421
+ self.monitor = None
422
+ if monitor_metrics:
423
+ self.monitor = Monitor(self.llm_engine)
424
+ self.step_callbacks.append(self.monitor.update_metrics)
425
+
426
+ @property
427
+ def toolbox(self) -> Toolbox:
428
+ """Get the toolbox currently available to the agent"""
429
+ return self._toolbox
430
+
431
+ def initialize_for_run(self):
432
+ self.token_count = 0
433
+ self.system_prompt = format_prompt_with_tools(
434
+ self._toolbox,
435
+ self.system_prompt_template,
436
+ self.tool_description_template,
437
+ )
438
+ self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
439
+ if hasattr(self, "authorized_imports"):
440
+ self.system_prompt = format_prompt_with_imports(
441
+ self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
442
+ )
443
+ self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
444
+ self.logger.log(33, "======== New task ========")
445
+ self.logger.log(34, self.task)
446
+ self.logger.debug("System prompt is as follows:")
447
+ self.logger.debug(self.system_prompt)
448
+
449
+ def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
450
+ """
451
+ Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
452
+ that can be used as input to the LLM.
453
+ """
454
+ prompt_message = {"role": MessageRole.SYSTEM, "content": self.logs[0]["system_prompt"]}
455
+ task_message = {
456
+ "role": MessageRole.USER,
457
+ "content": "Task: " + self.logs[0]["task"],
458
+ }
459
+ if summary_mode:
460
+ memory = [task_message]
461
+ else:
462
+ memory = [prompt_message, task_message]
463
+ for i, step_log in enumerate(self.logs[1:]):
464
+ if "llm_output" in step_log and not summary_mode:
465
+ thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"].strip()}
466
+ memory.append(thought_message)
467
+ if "facts" in step_log:
468
+ thought_message = {
469
+ "role": MessageRole.ASSISTANT,
470
+ "content": "[FACTS LIST]:\n" + step_log["facts"].strip(),
471
+ }
472
+ memory.append(thought_message)
473
+
474
+ if "plan" in step_log and not summary_mode:
475
+ thought_message = {"role": MessageRole.ASSISTANT, "content": "[PLAN]:\n" + step_log["plan"].strip()}
476
+ memory.append(thought_message)
477
+
478
+ if "tool_call" in step_log and summary_mode:
479
+ tool_call_message = {
480
+ "role": MessageRole.ASSISTANT,
481
+ "content": f"[STEP {i} TOOL CALL]: " + str(step_log["tool_call"]).strip(),
482
+ }
483
+ memory.append(tool_call_message)
484
+
485
+ if "task" in step_log:
486
+ tool_call_message = {
487
+ "role": MessageRole.USER,
488
+ "content": "New task:\n" + step_log["task"],
489
+ }
490
+ memory.append(tool_call_message)
491
+
492
+ if "error" in step_log or "observation" in step_log:
493
+ if "error" in step_log:
494
+ message_content = (
495
+ f"[OUTPUT OF STEP {i}] -> Error:\n"
496
+ + str(step_log["error"])
497
+ + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
498
+ )
499
+ elif "observation" in step_log:
500
+ message_content = f"[OUTPUT OF STEP {i}] -> Observation:\n{step_log['observation']}"
501
+ tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
502
+ memory.append(tool_response_message)
503
+
504
+ return memory
505
+
506
+ def get_succinct_logs(self):
507
+ return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs]
508
+
509
+ def extract_action(self, llm_output: str, split_token: str) -> str:
510
+ """
511
+ Parse action from the LLM output
512
+
513
+ Args:
514
+ llm_output (`str`): Output of the LLM
515
+ split_token (`str`): Separator for the action. Should match the example in the system prompt.
516
+ """
517
+ try:
518
+ split = llm_output.split(split_token)
519
+ rationale, action = (
520
+ split[-2],
521
+ split[-1],
522
+ ) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
523
+ except Exception as e:
524
+ self.logger.error(e, exc_info=1)
525
+ raise AgentParsingError(
526
+ f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
527
+ )
528
+ return rationale.strip(), action.strip()
529
+
530
+ def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
531
+ """
532
+ Execute tool with the provided input and returns the result.
533
+ This method replaces arguments with the actual values from the state if they refer to state variables.
534
+
535
+ Args:
536
+ tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
537
+ arguments (Dict[str, str]): Arguments passed to the Tool.
538
+ """
539
+ available_tools = self.toolbox.tools
540
+ if self.managed_agents is not None:
541
+ available_tools = {**available_tools, **self.managed_agents}
542
+ if tool_name not in available_tools:
543
+ error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
544
+ self.logger.error(error_msg, exc_info=1)
545
+ raise AgentExecutionError(error_msg)
546
+
547
+ try:
548
+ if isinstance(arguments, str):
549
+ observation = available_tools[tool_name](arguments)
550
+ elif isinstance(arguments, dict):
551
+ for key, value in arguments.items():
552
+ # if the value is the name of a state variable like "image.png", replace it with the actual value
553
+ if isinstance(value, str) and value in self.state:
554
+ arguments[key] = self.state[value]
555
+ observation = available_tools[tool_name](**arguments)
556
+ else:
557
+ raise AgentExecutionError(
558
+ f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
559
+ )
560
+ return observation
561
+ except Exception as e:
562
+ if tool_name in self.toolbox.tools:
563
+ raise AgentExecutionError(
564
+ f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
565
+ f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(available_tools[tool_name])}"
566
+ )
567
+ elif tool_name in self.managed_agents:
568
+ raise AgentExecutionError(
569
+ f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
570
+ f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
571
+ )
572
+
573
+ def log_rationale_code_action(self, rationale: str, code_action: str) -> None:
574
+ self.logger.warning("=== Agent thoughts:")
575
+ self.logger.log(31, rationale)
576
+ self.logger.warning(">>> Agent is executing the code below:")
577
+ if is_pygments_available():
578
+ self.logger.log(
579
+ 31, highlight(code_action, PythonLexer(ensurenl=False), Terminal256Formatter(style="nord"))
580
+ )
581
+ else:
582
+ self.logger.log(31, code_action)
583
+ self.logger.warning("====")
584
+
585
+ def run(self, **kwargs):
586
+ """To be implemented in the child class"""
587
+ raise NotImplementedError
588
+
589
+
590
+ class CodeAgent(Agent):
591
+ """
592
+ A class for an agent that solves the given task using a single block of code. It plans all its actions, then executes all in one shot.
593
+ """
594
+
595
+ def __init__(
596
+ self,
597
+ tools: List[Tool],
598
+ llm_engine: Optional[Callable] = None,
599
+ system_prompt: Optional[str] = None,
600
+ tool_description_template: Optional[str] = None,
601
+ grammar: Optional[Dict[str, str]] = None,
602
+ additional_authorized_imports: Optional[List[str]] = None,
603
+ **kwargs,
604
+ ):
605
+ if llm_engine is None:
606
+ llm_engine = HfApiEngine()
607
+ if system_prompt is None:
608
+ system_prompt = DEFAULT_CODE_SYSTEM_PROMPT
609
+ if tool_description_template is None:
610
+ tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
611
+ super().__init__(
612
+ tools=tools,
613
+ llm_engine=llm_engine,
614
+ system_prompt=system_prompt,
615
+ tool_description_template=tool_description_template,
616
+ grammar=grammar,
617
+ **kwargs,
618
+ )
619
+
620
+ if not is_pygments_available():
621
+ transformers_logging.warning_once(
622
+ logger,
623
+ "pygments isn't installed. Installing pygments will enable color syntax highlighting in the "
624
+ "CodeAgent.",
625
+ )
626
+
627
+ self.python_evaluator = evaluate_python_code
628
+ self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
629
+ self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
630
+ self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
631
+
632
+ def parse_code_blob(self, result: str) -> str:
633
+ """
634
+ Override this method if you want to change the way the code is
635
+ cleaned in the `run` method.
636
+ """
637
+ return parse_code_blob(result)
638
+
639
+ def run(self, task: str, return_generated_code: bool = False, **kwargs):
640
+ """
641
+ Runs the agent for the given task.
642
+
643
+ Args:
644
+ task (`str`): The task to perform
645
+ return_generated_code (`bool`, *optional*, defaults to `False`): Whether to return the generated code instead of running it
646
+ kwargs (additional keyword arguments, *optional*):
647
+ Any keyword argument to send to the agent when evaluating the code.
648
+
649
+ Example:
650
+
651
+ ```py
652
+ from transformers.agents import CodeAgent
653
+
654
+ agent = CodeAgent(tools=[])
655
+ agent.run("What is the result of 2 power 3.7384?")
656
+ ```
657
+ """
658
+ self.task = task
659
+ if len(kwargs) > 0:
660
+ self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
661
+ self.state = kwargs.copy()
662
+ self.initialize_for_run()
663
+
664
+ # Run LLM
665
+ prompt_message = {"role": MessageRole.SYSTEM, "content": self.system_prompt}
666
+ task_message = {
667
+ "role": MessageRole.USER,
668
+ "content": "Task: " + self.task,
669
+ }
670
+
671
+ self.prompt = [prompt_message, task_message]
672
+ self.logger.info("====Executing with this prompt====")
673
+ self.logger.info(self.prompt)
674
+
675
+ additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
676
+ llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"], **additional_args)
677
+
678
+ if return_generated_code:
679
+ return llm_output
680
+
681
+ # Parse
682
+ try:
683
+ rationale, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
684
+ except Exception as e:
685
+ self.logger.debug(
686
+ f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}"
687
+ )
688
+ rationale, code_action = "", llm_output
689
+
690
+ try:
691
+ code_action = self.parse_code_blob(code_action)
692
+ except Exception as e:
693
+ error_msg = f"Error in code parsing: {e}. Be sure to provide correct code"
694
+ self.logger.error(error_msg, exc_info=1)
695
+ return error_msg
696
+
697
+ # Execute
698
+ self.log_rationale_code_action(rationale, code_action)
699
+ try:
700
+ available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
701
+ output = self.python_evaluator(
702
+ code_action,
703
+ static_tools=available_tools,
704
+ custom_tools={},
705
+ state=self.state,
706
+ authorized_imports=self.authorized_imports,
707
+ )
708
+ self.logger.info(self.state["print_outputs"])
709
+ return output
710
+ except Exception as e:
711
+ error_msg = f"Error in execution: {e}. Be sure to provide correct code."
712
+ self.logger.error(error_msg, exc_info=1)
713
+ return error_msg
714
+
715
+
716
+ class ReactAgent(Agent):
717
+ """
718
+ This agent that solves the given task step by step, using the ReAct framework:
719
+ While the objective is not reached, the agent will perform a cycle of thinking and acting.
720
+ The action will be parsed from the LLM output: it consists in calls to tools from the toolbox, with arguments chosen by the LLM engine.
721
+ """
722
+
723
+ def __init__(
724
+ self,
725
+ tools: List[Tool],
726
+ llm_engine: Optional[Callable] = None,
727
+ system_prompt: Optional[str] = None,
728
+ tool_description_template: Optional[str] = None,
729
+ grammar: Optional[Dict[str, str]] = None,
730
+ plan_type: Optional[str] = None,
731
+ planning_interval: Optional[int] = None,
732
+ **kwargs,
733
+ ):
734
+ if llm_engine is None:
735
+ llm_engine = HfApiEngine()
736
+ if system_prompt is None:
737
+ system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
738
+ if tool_description_template is None:
739
+ tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
740
+ if plan_type is None:
741
+ plan_type = SUPPORTED_PLAN_TYPES[0]
742
+ else:
743
+ assert plan_type in SUPPORTED_PLAN_TYPES, f"plan type {plan_type} is not supported"
744
+ super().__init__(
745
+ tools=tools,
746
+ llm_engine=llm_engine,
747
+ system_prompt=system_prompt,
748
+ tool_description_template=tool_description_template,
749
+ grammar=grammar,
750
+ **kwargs,
751
+ )
752
+ self.planning_interval = planning_interval
753
+ self.plan_type = plan_type
754
+
755
+ def provide_final_answer(self, task) -> str:
756
+ """
757
+ This method provides a final answer to the task, based on the logs of the agent's interactions.
758
+ """
759
+ self.prompt = [
760
+ {
761
+ "role": MessageRole.SYSTEM,
762
+ "content": "An agent tried to answer an user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
763
+ }
764
+ ]
765
+ self.prompt += self.write_inner_memory_from_logs()[1:]
766
+ self.prompt += [
767
+ {
768
+ "role": MessageRole.USER,
769
+ "content": f"Based on the above, please provide an answer to the following user request:\n{task}",
770
+ }
771
+ ]
772
+ try:
773
+ return self.llm_engine(self.prompt)
774
+ except Exception as e:
775
+ return f"Error in generating final llm output: {e}."
776
+
777
+ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
778
+ """
779
+ Runs the agent for the given task.
780
+
781
+ Args:
782
+ task (`str`): The task to perform
783
+
784
+ Example:
785
+ ```py
786
+ from transformers.agents import ReactCodeAgent
787
+ agent = ReactCodeAgent(tools=[])
788
+ agent.run("What is the result of 2 power 3.7384?")
789
+ ```
790
+ """
791
+ self.task = task
792
+ if len(kwargs) > 0:
793
+ self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
794
+ self.state = kwargs.copy()
795
+ if reset:
796
+ self.initialize_for_run()
797
+ else:
798
+ self.logs.append({"task": task})
799
+ if stream:
800
+ return self.stream_run(task)
801
+ else:
802
+ return self.direct_run(task)
803
+
804
+ def stream_run(self, task: str):
805
+ """
806
+ Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
807
+ """
808
+ final_answer = None
809
+ iteration = 0
810
+ while final_answer is None and iteration < self.max_iterations:
811
+ step_start_time = time.time()
812
+ step_log_entry = {"iteration": iteration, "start_time": step_start_time}
813
+ try:
814
+ self.step(step_log_entry)
815
+ if "final_answer" in step_log_entry:
816
+ final_answer = step_log_entry["final_answer"]
817
+ except AgentError as e:
818
+ self.logger.error(e, exc_info=1)
819
+ step_log_entry["error"] = e
820
+ finally:
821
+ step_end_time = time.time()
822
+ step_log_entry["step_end_time"] = step_end_time
823
+ step_log_entry["step_duration"] = step_end_time - step_start_time
824
+ self.logs.append(step_log_entry)
825
+ for callback in self.step_callbacks:
826
+ callback(step_log_entry)
827
+ iteration += 1
828
+ yield step_log_entry
829
+
830
+ if final_answer is None and iteration == self.max_iterations:
831
+ error_message = "Reached max iterations."
832
+ final_step_log = {"error": AgentMaxIterationsError(error_message)}
833
+ self.logs.append(final_step_log)
834
+ self.logger.error(error_message, exc_info=1)
835
+ final_answer = self.provide_final_answer(task)
836
+ final_step_log["final_answer"] = final_answer
837
+ final_step_log["step_duration"] = 0
838
+ for callback in self.step_callbacks:
839
+ callback(final_step_log)
840
+ yield final_step_log
841
+
842
+ yield final_answer
843
+
844
+ def direct_run(self, task: str):
845
+ """
846
+ Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
847
+ """
848
+ final_answer = None
849
+ iteration = 0
850
+ while final_answer is None and iteration < self.max_iterations:
851
+ step_start_time = time.time()
852
+ step_log_entry = {"iteration": iteration, "start_time": step_start_time}
853
+ try:
854
+ if self.planning_interval is not None and iteration % self.planning_interval == 0:
855
+ self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
856
+ self.step(step_log_entry)
857
+ if "final_answer" in step_log_entry:
858
+ final_answer = step_log_entry["final_answer"]
859
+ except AgentError as e:
860
+ self.logger.error(e, exc_info=1)
861
+ step_log_entry["error"] = e
862
+ finally:
863
+ step_end_time = time.time()
864
+ step_log_entry["step_end_time"] = step_end_time
865
+ step_log_entry["step_duration"] = step_end_time - step_start_time
866
+ self.logs.append(step_log_entry)
867
+ for callback in self.step_callbacks:
868
+ callback(step_log_entry)
869
+ iteration += 1
870
+
871
+ if final_answer is None and iteration == self.max_iterations:
872
+ error_message = "Reached max iterations."
873
+ final_step_log = {"error": AgentMaxIterationsError(error_message)}
874
+ self.logs.append(final_step_log)
875
+ self.logger.error(error_message, exc_info=1)
876
+ final_answer = self.provide_final_answer(task)
877
+ final_step_log["final_answer"] = final_answer
878
+ final_step_log["step_duration"] = 0
879
+ for callback in self.step_callbacks:
880
+ callback(final_step_log)
881
+
882
+ return final_answer
883
+
884
+ def planning_step(self, task, is_first_step: bool = False, iteration: int = None):
885
+ """
886
+ Used periodically by the agent to plan the next steps to reach the objective.
887
+
888
+ Args:
889
+ task (`str`): The task to perform
890
+ is_first_step (`bool`): If this step is not the first one, the plan should be an update over a previous plan.
891
+ iteration (`int`): The number of the current step, used as an indication for the LLM.
892
+ """
893
+ if is_first_step:
894
+ message_prompt_facts = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_FACTS}
895
+ message_prompt_task = {
896
+ "role": MessageRole.USER,
897
+ "content": f"""Here is the task:
898
+ ```
899
+ {task}
900
+ ```
901
+ Now begin!""",
902
+ }
903
+
904
+ answer_facts = self.llm_engine([message_prompt_facts, message_prompt_task])
905
+
906
+ message_system_prompt_plan = {
907
+ "role": MessageRole.SYSTEM,
908
+ "content": PROMPTS_FOR_INITIAL_PLAN[self.plan_type]["system"],
909
+ }
910
+ message_user_prompt_plan = {
911
+ "role": MessageRole.USER,
912
+ "content": PROMPTS_FOR_INITIAL_PLAN[self.plan_type]["user"].format(
913
+ task=task,
914
+ tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
915
+ managed_agents_descriptions=(
916
+ show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else ""
917
+ ),
918
+ answer_facts=answer_facts,
919
+ ),
920
+ }
921
+ answer_plan = self.llm_engine(
922
+ [message_system_prompt_plan, message_user_prompt_plan], stop_sequences=["<end_plan>"]
923
+ )
924
+
925
+ final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
926
+ ```
927
+ {answer_plan}
928
+ ```"""
929
+ final_facts_redaction = f"""Here are the facts that I know so far:
930
+ ```
931
+ {answer_facts}
932
+ ```""".strip()
933
+ self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
934
+ self.logger.log(36, "===== Initial plan =====")
935
+ self.logger.log(35, final_plan_redaction)
936
+ else: # update plan
937
+ agent_memory = self.write_inner_memory_from_logs(
938
+ summary_mode=False
939
+ ) # This will not log the plan but will log facts
940
+
941
+ # Redact updated facts
942
+ facts_update_system_prompt = {
943
+ "role": MessageRole.SYSTEM,
944
+ "content": SYSTEM_PROMPT_FACTS_UPDATE,
945
+ }
946
+ facts_update_message = {
947
+ "role": MessageRole.USER,
948
+ "content": USER_PROMPT_FACTS_UPDATE,
949
+ }
950
+ facts_update = self.llm_engine([facts_update_system_prompt] + agent_memory + [facts_update_message])
951
+
952
+ # Redact updated plan
953
+ plan_update_message = {
954
+ "role": MessageRole.SYSTEM,
955
+ "content": PROMPTS_FOR_PLAN_UPDATE[self.plan_type]["system"].format(task=task),
956
+ }
957
+ plan_update_message_user = {
958
+ "role": MessageRole.USER,
959
+ "content": PROMPTS_FOR_PLAN_UPDATE[self.plan_type]["user"].format(
960
+ task=task,
961
+ tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
962
+ managed_agents_descriptions=(
963
+ show_agents_descriptions(self.managed_agents) if self.managed_agents is not None else ""
964
+ ),
965
+ facts_update=facts_update,
966
+ remaining_steps=(self.max_iterations - iteration),
967
+ ),
968
+ }
969
+ plan_update = self.llm_engine(
970
+ [plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=["<end_plan>"]
971
+ )
972
+
973
+ # Log final facts and plan
974
+ final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
975
+ final_facts_redaction = f"""Here is the updated list of the facts that I know:
976
+ ```
977
+ {facts_update}
978
+ ```"""
979
+ self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
980
+ self.logger.log(36, "===== Updated plan =====")
981
+ self.logger.log(35, final_plan_redaction)
982
+
983
+
984
+ class ReactJsonAgent(ReactAgent):
985
+ """
986
+ This agent that solves the given task step by step, using the ReAct framework:
987
+ While the objective is not reached, the agent will perform a cycle of thinking and acting.
988
+ The tool calls will be formulated by the LLM in JSON format, then parsed and executed.
989
+ """
990
+
991
+ def __init__(
992
+ self,
993
+ tools: List[Tool],
994
+ llm_engine: Optional[Callable] = None,
995
+ system_prompt: Optional[str] = None,
996
+ tool_description_template: Optional[str] = None,
997
+ grammar: Optional[Dict[str, str]] = None,
998
+ planning_interval: Optional[int] = None,
999
+ **kwargs,
1000
+ ):
1001
+ if llm_engine is None:
1002
+ llm_engine = HfApiEngine()
1003
+ if system_prompt is None:
1004
+ system_prompt = DEFAULT_REACT_JSON_SYSTEM_PROMPT
1005
+ if tool_description_template is None:
1006
+ tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
1007
+ super().__init__(
1008
+ tools=tools,
1009
+ llm_engine=llm_engine,
1010
+ system_prompt=system_prompt,
1011
+ tool_description_template=tool_description_template,
1012
+ grammar=grammar,
1013
+ planning_interval=planning_interval,
1014
+ **kwargs,
1015
+ )
1016
+
1017
+ def step(self, log_entry: Dict[str, Any]):
1018
+ """
1019
+ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
1020
+ The errors are raised here, they are caught and logged in the run() method.
1021
+ """
1022
+ agent_memory = self.write_inner_memory_from_logs()
1023
+
1024
+ self.prompt = agent_memory
1025
+ self.logger.debug("===== New step =====")
1026
+
1027
+ # Add new step in logs
1028
+ log_entry["agent_memory"] = agent_memory.copy()
1029
+
1030
+ self.logger.info("===== Calling LLM with this last message: =====")
1031
+ self.logger.info(self.prompt[-1])
1032
+
1033
+ try:
1034
+ additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
1035
+ llm_output = self.llm_engine(
1036
+ self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
1037
+ )
1038
+ except Exception as e:
1039
+ raise AgentGenerationError(f"Error in generating llm output: {e}.")
1040
+ self.logger.debug("===== Output message of the LLM: =====")
1041
+ self.logger.debug(llm_output)
1042
+ log_entry["llm_output"] = llm_output
1043
+
1044
+ # Parse
1045
+ self.logger.debug("===== Extracting action =====")
1046
+ rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:")
1047
+
1048
+ try:
1049
+ tool_name, arguments = self.tool_parser(action)
1050
+ except Exception as e:
1051
+ raise AgentParsingError(f"Could not parse the given action: {e}.")
1052
+
1053
+ log_entry["rationale"] = rationale
1054
+ log_entry["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
1055
+
1056
+ # Execute
1057
+ self.logger.warning("=== Agent thoughts:")
1058
+ self.logger.log(31, rationale)
1059
+ self.logger.warning(f">>> Calling tool: '{tool_name}' with arguments: {arguments}")
1060
+ if tool_name == "final_answer":
1061
+ if isinstance(arguments, dict):
1062
+ if "answer" in arguments:
1063
+ answer = arguments["answer"]
1064
+ if (
1065
+ isinstance(answer, str) and answer in self.state.keys()
1066
+ ): # if the answer is a state variable, return the value
1067
+ answer = self.state[answer]
1068
+ else:
1069
+ answer = arguments
1070
+ else:
1071
+ answer = arguments
1072
+ log_entry["final_answer"] = answer
1073
+ return answer
1074
+ else:
1075
+ if arguments is None:
1076
+ arguments = {}
1077
+ observation = self.execute_tool_call(tool_name, arguments)
1078
+ observation_type = type(observation)
1079
+ if observation_type in [AgentImage, AgentAudio]:
1080
+ if observation_type == AgentImage:
1081
+ observation_name = "image.png"
1082
+ elif observation_type == AgentAudio:
1083
+ observation_name = "audio.mp3"
1084
+ # TODO: observation naming could allow for different names of same type
1085
+
1086
+ self.state[observation_name] = observation
1087
+ updated_information = f"Stored '{observation_name}' in memory."
1088
+ else:
1089
+ updated_information = str(observation).strip()
1090
+ self.logger.info(updated_information)
1091
+ log_entry["observation"] = updated_information
1092
+ return log_entry
1093
+
1094
+
1095
+ class ReactCodeAgent(ReactAgent):
1096
+ """
1097
+ This agent that solves the given task step by step, using the ReAct framework:
1098
+ While the objective is not reached, the agent will perform a cycle of thinking and acting.
1099
+ The tool calls will be formulated by the LLM in code format, then parsed and executed.
1100
+ """
1101
+
1102
+ def __init__(
1103
+ self,
1104
+ tools: List[Tool],
1105
+ llm_engine: Optional[Callable] = None,
1106
+ system_prompt: Optional[str] = None,
1107
+ tool_description_template: Optional[str] = None,
1108
+ grammar: Optional[Dict[str, str]] = None,
1109
+ additional_authorized_imports: Optional[List[str]] = None,
1110
+ planning_interval: Optional[int] = None,
1111
+ **kwargs,
1112
+ ):
1113
+ if llm_engine is None:
1114
+ llm_engine = HfApiEngine()
1115
+ if system_prompt is None:
1116
+ system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
1117
+ if tool_description_template is None:
1118
+ tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
1119
+ super().__init__(
1120
+ tools=tools,
1121
+ llm_engine=llm_engine,
1122
+ system_prompt=system_prompt,
1123
+ tool_description_template=tool_description_template,
1124
+ grammar=grammar,
1125
+ planning_interval=planning_interval,
1126
+ **kwargs,
1127
+ )
1128
+
1129
+ if not is_pygments_available():
1130
+ transformers_logging.warning_once(
1131
+ logger,
1132
+ "pygments isn't installed. Installing pygments will enable color syntax highlighting in the "
1133
+ "ReactCodeAgent.",
1134
+ )
1135
+
1136
+ self.python_evaluator = evaluate_python_code
1137
+ self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
1138
+ self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
1139
+ self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
1140
+ self.custom_tools = {}
1141
+
1142
+ def step(self, log_entry: Dict[str, Any]):
1143
+ """
1144
+ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
1145
+ The errors are raised here, they are caught and logged in the run() method.
1146
+ """
1147
+ agent_memory = self.write_inner_memory_from_logs()
1148
+
1149
+ self.prompt = agent_memory.copy()
1150
+ self.logger.debug("===== New step =====")
1151
+
1152
+ # Add new step in logs
1153
+ log_entry["agent_memory"] = agent_memory.copy()
1154
+
1155
+ self.logger.info("===== Calling LLM with these last messages: =====")
1156
+ self.logger.info(self.prompt[-2:])
1157
+
1158
+ try:
1159
+ additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
1160
+ llm_output = self.llm_engine(
1161
+ self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
1162
+ )
1163
+ except Exception as e:
1164
+ raise AgentGenerationError(f"Error in generating llm output: {e}.")
1165
+
1166
+ self.logger.debug("=== Output message of the LLM:")
1167
+ self.logger.debug(llm_output)
1168
+ log_entry["llm_output"] = llm_output
1169
+
1170
+ # Parse
1171
+ self.logger.debug("=== Extracting action ===")
1172
+ try:
1173
+ rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
1174
+ except Exception as e:
1175
+ self.logger.debug(f"Error in extracting action, trying to parse the whole output. Error trace: {e}")
1176
+ rationale, raw_code_action = llm_output, llm_output
1177
+
1178
+ try:
1179
+ code_action = parse_code_blob(raw_code_action)
1180
+ except Exception as e:
1181
+ error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
1182
+ raise AgentParsingError(error_msg)
1183
+
1184
+ log_entry["rationale"] = rationale
1185
+ log_entry["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action}
1186
+
1187
+ # Execute
1188
+ self.log_rationale_code_action(rationale, code_action)
1189
+ try:
1190
+ static_tools = {
1191
+ **BASE_PYTHON_TOOLS.copy(),
1192
+ **self.toolbox.tools,
1193
+ }
1194
+ if self.managed_agents is not None:
1195
+ static_tools = {**static_tools, **self.managed_agents}
1196
+ result = self.python_evaluator(
1197
+ code_action,
1198
+ static_tools=static_tools,
1199
+ custom_tools=self.custom_tools,
1200
+ state=self.state,
1201
+ authorized_imports=self.authorized_imports,
1202
+ )
1203
+ self.logger.warning("Print outputs:")
1204
+ self.logger.log(32, self.state["print_outputs"])
1205
+ observation = "Print outputs:\n" + self.state["print_outputs"]
1206
+ if result is not None:
1207
+ self.logger.warning("Last output from code snippet:")
1208
+ self.logger.log(32, str(result))
1209
+ observation += "Last output from code snippet:\n" + str(result)[:100000]
1210
+ log_entry["observation"] = observation
1211
+ except Exception as e:
1212
+ error_msg = f"Code execution failed due to the following error:\n{str(e)}"
1213
+ if "'dict' object has no attribute 'read'" in str(e):
1214
+ error_msg += "\nYou get this error because you passed a dict as input for one of the arguments instead of a string."
1215
+ raise AgentExecutionError(error_msg)
1216
+ for line in code_action.split("\n"):
1217
+ if line[: len("final_answer")] == "final_answer":
1218
+ self.logger.log(33, "Final answer:")
1219
+ self.logger.log(32, result)
1220
+ log_entry["final_answer"] = result
1221
+ return result
1222
+
1223
+
1224
+ LENGTH_TRUNCATE_REPORTS = 1000
1225
+
1226
+
1227
+ class ManagedAgent:
1228
+ def __init__(self, agent, name, description, additional_prompting=None, provide_run_summary=False):
1229
+ self.agent = agent
1230
+ self.name = name
1231
+ self.description = description
1232
+ self.additional_prompting = additional_prompting
1233
+ self.provide_run_summary = provide_run_summary
1234
+
1235
+ def write_full_task(self, task):
1236
+ full_task = f"""You're a helpful agent named '{self.name}'.
1237
+ You have been submitted this task by your manager.
1238
+ ---
1239
+ Task:
1240
+ {task}
1241
+ ---
1242
+ You're helping your manager solve a wider task: so make sure to not provide a one-line answer, but give as much information as possible so that they have a clear understanding of the answer.
1243
+
1244
+ Your final_answer WILL HAVE to contain these parts:
1245
+ ### 1. Task outcome (short version):
1246
+ ### 2. Task outcome (extremely detailed version):
1247
+ ### 3. Additional context (if relevant):
1248
+
1249
+ Put all these in your final_answer tool, everything that you do not pass as an argument to final_answer will be lost.
1250
+ And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback.
1251
+ <<additional_prompting>>"""
1252
+ if self.additional_prompting:
1253
+ full_task = full_task.replace("\n<<additional_prompting>>", self.additional_prompting).strip()
1254
+ else:
1255
+ full_task = full_task.replace("\n<<additional_prompting>>", "").strip()
1256
+ return full_task
1257
+
1258
+ def __call__(self, request, **kwargs):
1259
+ full_task = self.write_full_task(request)
1260
+ output = self.agent.run(full_task, **kwargs)
1261
+ if self.provide_run_summary:
1262
+ answer = f"Here is the final answer from your managed agent '{self.name}':\n"
1263
+ answer += str(output)
1264
+ answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
1265
+ for message in self.agent.write_inner_memory_from_logs(summary_mode=True):
1266
+ content = message["content"]
1267
+ if len(str(content)) < LENGTH_TRUNCATE_REPORTS or "[FACTS LIST]" in str(content):
1268
+ answer += "\n" + str(content) + "\n---"
1269
+ else:
1270
+ answer += (
1271
+ "\n"
1272
+ + str(content)[:LENGTH_TRUNCATE_REPORTS]
1273
+ + "\n(...Step was truncated because too long)...\n---"
1274
+ )
1275
+ answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'."
1276
+ return answer
1277
+ else:
1278
+ return output
.venv/Lib/site-packages/transformers/agents/default_tools.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import importlib.util
18
+ import json
19
+ import math
20
+ from dataclasses import dataclass
21
+ from math import sqrt
22
+ from typing import Dict
23
+
24
+ from huggingface_hub import hf_hub_download, list_spaces
25
+
26
+ from ..utils import is_offline_mode
27
+ from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
28
+ from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool
29
+
30
+
31
+ def custom_print(*args):
32
+ return None
33
+
34
+
35
+ BASE_PYTHON_TOOLS = {
36
+ "print": custom_print,
37
+ "isinstance": isinstance,
38
+ "range": range,
39
+ "float": float,
40
+ "int": int,
41
+ "bool": bool,
42
+ "str": str,
43
+ "set": set,
44
+ "list": list,
45
+ "dict": dict,
46
+ "tuple": tuple,
47
+ "round": round,
48
+ "ceil": math.ceil,
49
+ "floor": math.floor,
50
+ "log": math.log,
51
+ "exp": math.exp,
52
+ "sin": math.sin,
53
+ "cos": math.cos,
54
+ "tan": math.tan,
55
+ "asin": math.asin,
56
+ "acos": math.acos,
57
+ "atan": math.atan,
58
+ "atan2": math.atan2,
59
+ "degrees": math.degrees,
60
+ "radians": math.radians,
61
+ "pow": math.pow,
62
+ "sqrt": sqrt,
63
+ "len": len,
64
+ "sum": sum,
65
+ "max": max,
66
+ "min": min,
67
+ "abs": abs,
68
+ "enumerate": enumerate,
69
+ "zip": zip,
70
+ "reversed": reversed,
71
+ "sorted": sorted,
72
+ "all": all,
73
+ "any": any,
74
+ "map": map,
75
+ "filter": filter,
76
+ "ord": ord,
77
+ "chr": chr,
78
+ "next": next,
79
+ "iter": iter,
80
+ "divmod": divmod,
81
+ "callable": callable,
82
+ "getattr": getattr,
83
+ "hasattr": hasattr,
84
+ "setattr": setattr,
85
+ "issubclass": issubclass,
86
+ "type": type,
87
+ }
88
+
89
+
90
+ @dataclass
91
+ class PreTool:
92
+ name: str
93
+ inputs: Dict[str, str]
94
+ output_type: type
95
+ task: str
96
+ description: str
97
+ repo_id: str
98
+
99
+
100
+ HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
101
+ "image-transformation",
102
+ "text-to-image",
103
+ ]
104
+
105
+
106
+ def get_remote_tools(logger, organization="huggingface-tools"):
107
+ if is_offline_mode():
108
+ logger.info("You are in offline mode, so remote tools are not available.")
109
+ return {}
110
+
111
+ spaces = list_spaces(author=organization)
112
+ tools = {}
113
+ for space_info in spaces:
114
+ repo_id = space_info.id
115
+ resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
116
+ with open(resolved_config_file, encoding="utf-8") as reader:
117
+ config = json.load(reader)
118
+ task = repo_id.split("/")[-1]
119
+ tools[config["name"]] = PreTool(
120
+ task=task,
121
+ description=config["description"],
122
+ repo_id=repo_id,
123
+ name=task,
124
+ inputs=config["inputs"],
125
+ output_type=config["output_type"],
126
+ )
127
+
128
+ return tools
129
+
130
+
131
+ def setup_default_tools(logger):
132
+ default_tools = {}
133
+ main_module = importlib.import_module("transformers")
134
+ tools_module = main_module.agents
135
+
136
+ for task_name, tool_class_name in TOOL_MAPPING.items():
137
+ tool_class = getattr(tools_module, tool_class_name)
138
+ tool_instance = tool_class()
139
+ default_tools[tool_class.name] = PreTool(
140
+ name=tool_instance.name,
141
+ inputs=tool_instance.inputs,
142
+ output_type=tool_instance.output_type,
143
+ task=task_name,
144
+ description=tool_instance.description,
145
+ repo_id=None,
146
+ )
147
+
148
+ return default_tools
149
+
150
+
151
+ class PythonInterpreterTool(Tool):
152
+ name = "python_interpreter"
153
+ description = "This is a tool that evaluates python code. It can be used to perform calculations."
154
+
155
+ output_type = "string"
156
+
157
+ def __init__(self, *args, authorized_imports=None, **kwargs):
158
+ if authorized_imports is None:
159
+ self.authorized_imports = list(set(LIST_SAFE_MODULES))
160
+ else:
161
+ self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
162
+ self.inputs = {
163
+ "code": {
164
+ "type": "string",
165
+ "description": (
166
+ "The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
167
+ f"else you will get an error. This code can only import the following python libraries: {authorized_imports}."
168
+ ),
169
+ }
170
+ }
171
+ super().__init__(*args, **kwargs)
172
+
173
+ def forward(self, code):
174
+ output = str(
175
+ evaluate_python_code(code, static_tools=BASE_PYTHON_TOOLS, authorized_imports=self.authorized_imports)
176
+ )
177
+ return output
178
+
179
+
180
+ class FinalAnswerTool(Tool):
181
+ name = "final_answer"
182
+ description = "Provides a final answer to the given problem."
183
+ inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
184
+ output_type = "any"
185
+
186
+ def forward(self, answer):
187
+ return answer
.venv/Lib/site-packages/transformers/agents/document_question_answering.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import re
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from ..models.auto import AutoProcessor
23
+ from ..models.vision_encoder_decoder import VisionEncoderDecoderModel
24
+ from ..utils import is_vision_available
25
+ from .tools import PipelineTool
26
+
27
+
28
+ if is_vision_available():
29
+ from PIL import Image
30
+
31
+
32
+ class DocumentQuestionAnsweringTool(PipelineTool):
33
+ default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
34
+ description = "This is a tool that answers a question about an document (pdf). It returns a string that contains the answer to the question."
35
+ name = "document_qa"
36
+ pre_processor_class = AutoProcessor
37
+ model_class = VisionEncoderDecoderModel
38
+
39
+ inputs = {
40
+ "document": {
41
+ "type": "image",
42
+ "description": "The image containing the information. Can be a PIL Image or a string path to the image.",
43
+ },
44
+ "question": {"type": "string", "description": "The question in English"},
45
+ }
46
+ output_type = "string"
47
+
48
+ def __init__(self, *args, **kwargs):
49
+ if not is_vision_available():
50
+ raise ValueError("Pillow must be installed to use the DocumentQuestionAnsweringTool.")
51
+
52
+ super().__init__(*args, **kwargs)
53
+
54
+ def encode(self, document: "Image", question: str):
55
+ task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
56
+ prompt = task_prompt.replace("{user_input}", question)
57
+ decoder_input_ids = self.pre_processor.tokenizer(
58
+ prompt, add_special_tokens=False, return_tensors="pt"
59
+ ).input_ids
60
+ if isinstance(document, str):
61
+ img = Image.open(document).convert("RGB")
62
+ img_array = np.array(img).transpose(2, 0, 1)
63
+ document = torch.from_numpy(img_array)
64
+ pixel_values = self.pre_processor(document, return_tensors="pt").pixel_values
65
+
66
+ return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values}
67
+
68
+ def forward(self, inputs):
69
+ return self.model.generate(
70
+ inputs["pixel_values"].to(self.device),
71
+ decoder_input_ids=inputs["decoder_input_ids"].to(self.device),
72
+ max_length=self.model.decoder.config.max_position_embeddings,
73
+ early_stopping=True,
74
+ pad_token_id=self.pre_processor.tokenizer.pad_token_id,
75
+ eos_token_id=self.pre_processor.tokenizer.eos_token_id,
76
+ use_cache=True,
77
+ num_beams=1,
78
+ bad_words_ids=[[self.pre_processor.tokenizer.unk_token_id]],
79
+ return_dict_in_generate=True,
80
+ ).sequences
81
+
82
+ def decode(self, outputs):
83
+ sequence = self.pre_processor.batch_decode(outputs)[0]
84
+ sequence = sequence.replace(self.pre_processor.tokenizer.eos_token, "")
85
+ sequence = sequence.replace(self.pre_processor.tokenizer.pad_token, "")
86
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
87
+ sequence = self.pre_processor.token2json(sequence)
88
+
89
+ return sequence["answer"]
.venv/Lib/site-packages/transformers/agents/evaluate_agent.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ from .agents import BASE_PYTHON_TOOLS
18
+ from .python_interpreter import InterpreterError, evaluate
19
+
20
+
21
+ ### Fake tools for test
22
+ def classifier(text, labels):
23
+ return f"This is the classification of {text} along {labels}."
24
+
25
+
26
+ def translator(text, src_lang, tgt_lang):
27
+ return f"This is the translation of {text} from {src_lang} to {tgt_lang}."
28
+
29
+
30
+ def speaker(text):
31
+ return f"This is actually a sound reading {text}."
32
+
33
+
34
+ def transcriber(audio):
35
+ if "sound" not in audio:
36
+ raise ValueError(f"`audio` ({audio}) is not a sound.")
37
+ return f"This is the transcribed text from {audio}."
38
+
39
+
40
+ def image_generator(prompt):
41
+ return f"This is actually an image representing {prompt}."
42
+
43
+
44
+ def image_captioner(image):
45
+ if "image" not in image:
46
+ raise ValueError(f"`image` ({image}) is not an image.")
47
+ return f"This is a description of {image}."
48
+
49
+
50
+ def image_transformer(image, prompt):
51
+ if "image" not in image:
52
+ raise ValueError(f"`image` ({image}) is not an image.")
53
+ return f"This is a transformation of {image} according to {prompt}."
54
+
55
+
56
+ def question_answerer(text, question):
57
+ return f"This is the answer to {question} from {text}."
58
+
59
+
60
+ def image_qa(image, question):
61
+ if "image" not in image:
62
+ raise ValueError(f"`image` ({image}) is not an image.")
63
+ return f"This is the answer to {question} from {image}."
64
+
65
+
66
+ def text_downloader(url):
67
+ return f"This is the content of {url}."
68
+
69
+
70
+ def summarizer(text):
71
+ return f"This is a summary of {text}."
72
+
73
+
74
+ def video_generator(prompt, seconds=2):
75
+ return f"A video of {prompt}"
76
+
77
+
78
+ def document_qa(image, question):
79
+ return f"This is the answer to {question} from the document {image}."
80
+
81
+
82
+ def image_segmenter(image, prompt):
83
+ return f"This is the mask of {prompt} in {image}"
84
+
85
+
86
+ TEST_TOOLS = {
87
+ "text_classifier": classifier,
88
+ "translator": translator,
89
+ "text_reader": speaker,
90
+ "summarizer": summarizer,
91
+ "transcriber": transcriber,
92
+ "image_generator": image_generator,
93
+ "image_captioner": image_captioner,
94
+ "image_transformer": image_transformer,
95
+ "text_qa": question_answerer,
96
+ "text_downloader": text_downloader,
97
+ "image_qa": image_qa,
98
+ "video_generator": video_generator,
99
+ "document_qa": document_qa,
100
+ "image_segmenter": image_segmenter,
101
+ }
102
+
103
+
104
+ class Problem:
105
+ """
106
+ A class regrouping all the information to solve a problem on which we will evaluate agents.
107
+
108
+ Args:
109
+ task (`str` ou `list[str]`):
110
+ One or several descriptions of the task to perform. If a list, it should contain variations on the
111
+ phrasing, but for the same task.
112
+ inputs (`list[str]` or `dict[str, str]`):
113
+ The inputs that will be fed to the tools. For this testing environment, only strings are accepted as
114
+ values. Pass along a dictionary when you want to specify the values of each inputs, or just the list of
115
+ inputs expected (the value used will be `<<input_name>>` in this case).
116
+ answer (`str` or `list[str]`):
117
+ The theoretical answer (or list of possible valid answers) to the problem, as code.
118
+ """
119
+
120
+ def __init__(self, task, inputs, answer):
121
+ self.task = task
122
+ self.inputs = inputs
123
+ self.answer = answer
124
+
125
+
126
+ ### The list of problems the agent will be evaluated on.
127
+ EVALUATION_TASKS = [
128
+ Problem(
129
+ task=[
130
+ "Is the following `text` (in Spanish) positive or negative?",
131
+ "Is the text in the variable `text` (in Spanish) positive or negative?",
132
+ "Translate the following `text` from Spanish to English then tell me if its positive or negative.",
133
+ ],
134
+ inputs=["text"],
135
+ answer="""text_classifier(translator(text, src_lang="Spanish", tgt_lang="English"), labels=["positive", "negative"])""",
136
+ ),
137
+ Problem(
138
+ task=[
139
+ "Tell me out loud what the `image` contains.",
140
+ "Describe the following `image` out loud.",
141
+ "Find what is in the picture stored in `image` then read it out loud.",
142
+ ],
143
+ inputs=["image"],
144
+ answer=[
145
+ "text_reader(image_captioner(image))",
146
+ "text_reader(image_qa(image, question='What is in the image?'))",
147
+ ],
148
+ ),
149
+ Problem(
150
+ task=[
151
+ "Generate an image from the text given in `text_input`. Then transform it according to the text in `prompt`.",
152
+ "Use the following `text_input` to generate an image, then transform it by using the text in `prompt`.",
153
+ ],
154
+ inputs=["text_input", "prompt"],
155
+ answer="image_transformer(image_generator(text_input), prompt)",
156
+ ),
157
+ Problem(
158
+ task=[
159
+ "Download the content of `url`, summarize it then generate an image from its content.",
160
+ "Use a summary of the web page at `url` to generate an image.",
161
+ "Summarize the content of the web page at `url`, and use the result to generate an image.",
162
+ ],
163
+ inputs=["url"],
164
+ answer="image_generator(summarizer(text_downloader(url)))",
165
+ ),
166
+ Problem(
167
+ task=[
168
+ "Transform the following `image` using the prompt in `text`. The prompt is in Spanish.",
169
+ "Use the text prompt in `text` (in Spanish) to transform the following `image`.",
170
+ "Translate the `text` from Spanish to English then use it to transform the picture in `image`.",
171
+ ],
172
+ inputs=["text", "image"],
173
+ answer="image_transformer(image, translator(text, src_lang='Spanish', tgt_lang='English'))",
174
+ ),
175
+ Problem(
176
+ task=[
177
+ "Download the content of `url`, summarize it then read it out loud to me.",
178
+ "Read me a summary of the web page at `url`.",
179
+ ],
180
+ inputs=["url"],
181
+ answer="text_reader(summarizer(text_downloader(url)))",
182
+ ),
183
+ Problem(
184
+ task=[
185
+ "Generate an image from the text given in `text_input`.",
186
+ ],
187
+ inputs=["text_input"],
188
+ answer="image_generator(text_input)",
189
+ ),
190
+ Problem(
191
+ task=[
192
+ "Replace the beaver in the `image` by the `prompt`.",
193
+ "Transform the `image` so that it contains the `prompt`.",
194
+ "Use `prompt` to transform this `image`.",
195
+ ],
196
+ inputs=["image", "prompt"],
197
+ answer="image_transformer(image, prompt)",
198
+ ),
199
+ Problem(
200
+ task=[
201
+ "Provide me the summary of the `text`, then read it to me before transcribing it and translating it in French.",
202
+ "Summarize `text`, read it out loud then transcribe the audio and translate it in French.",
203
+ "Read me a summary of the `text` out loud. Transcribe this and translate it in French.",
204
+ ],
205
+ inputs=["text"],
206
+ answer="translator(transcriber(text_reader(summarizer(text))), src_lang='English', tgt_lang='French')",
207
+ ),
208
+ Problem(
209
+ task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
210
+ inputs={"prompt": "A lobster swimming"},
211
+ answer="video_generator('A lobster swimming')",
212
+ ),
213
+ Problem(
214
+ task=[
215
+ "Download the following file `url`, summarize it in a few words and generate a video from it."
216
+ "Fetch the file at this `url`, summarize it, and create an animation out of it."
217
+ ],
218
+ inputs=["url"],
219
+ answer="video_generator(summarizer(text_downloader(url)))",
220
+ ),
221
+ ]
222
+
223
+
224
+ def get_theoretical_tools(agent_answer, theoretical_answer, code_answer):
225
+ if not isinstance(theoretical_answer, list):
226
+ return {name for name in TEST_TOOLS if name in code_answer}
227
+
228
+ if isinstance(agent_answer, dict):
229
+ for one_answer, one_code in zip(theoretical_answer, code_answer):
230
+ if one_answer in agent_answer.values():
231
+ return {name for name in TEST_TOOLS if name in one_code}
232
+
233
+ for one_answer, one_code in zip(theoretical_answer, code_answer):
234
+ if agent_answer == one_answer:
235
+ return {name for name in TEST_TOOLS if name in one_code}
236
+
237
+ return {name for name in TEST_TOOLS if name in code_answer[0]}
238
+
239
+
240
+ def evaluate_code(code, inputs=None, state=None, verbose=False, return_interpretor_error=False):
241
+ tools = BASE_PYTHON_TOOLS.copy()
242
+ for name, tool in TEST_TOOLS.items():
243
+ if name not in code:
244
+ continue
245
+ tools[name] = tool
246
+
247
+ if isinstance(inputs, dict):
248
+ inputs = inputs.copy()
249
+ elif inputs is not None:
250
+ inputs = {inp: f"<<{inp}>>" for inp in inputs}
251
+
252
+ if state is not None:
253
+ state.update(inputs)
254
+ else:
255
+ state = inputs
256
+
257
+ try:
258
+ return evaluate(code, tools, state)
259
+ except InterpreterError as e:
260
+ return str(e)
261
+ except Exception as e:
262
+ if verbose:
263
+ print(e)
264
+ return None
265
+
266
+
267
+ def score_code(agent_answer, theoretical_answer, verbose: bool = False):
268
+ if verbose:
269
+ print(agent_answer, theoretical_answer)
270
+ theoretical_answer = theoretical_answer if isinstance(theoretical_answer, list) else [theoretical_answer]
271
+
272
+ if agent_answer in theoretical_answer:
273
+ if verbose:
274
+ print("Perfect!")
275
+ return 1
276
+ elif isinstance(agent_answer, dict) and any(v in theoretical_answer for v in agent_answer.values()):
277
+ if verbose:
278
+ print("Almsot perfect, result in state!")
279
+ return 0.75
280
+ else:
281
+ if verbose:
282
+ print("Result is not the right one but code executed.")
283
+ return 0.3
284
+
285
+
286
+ def evaluate_one_result(code, agent_answer, theoretical_answer, answer, verbose=False):
287
+ tools_in_code = {name for name in TEST_TOOLS if f"`{name}`" in code}
288
+ theoretical_tools = get_theoretical_tools(agent_answer, theoretical_answer, answer)
289
+ if tools_in_code == theoretical_tools:
290
+ tool_selection_score = 1.0
291
+ tool_selection_errors = None
292
+ else:
293
+ missing_tools = len(theoretical_tools - tools_in_code)
294
+ unexpected_tools = len(tools_in_code - theoretical_tools)
295
+ tool_selection_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
296
+
297
+ tool_selection_errors = {
298
+ "selected_tools": tools_in_code,
299
+ "theoretical_tools": theoretical_tools,
300
+ }
301
+
302
+ tools_in_code = {name for name in TEST_TOOLS if name in code}
303
+ if tools_in_code == theoretical_tools:
304
+ tool_used_score = 1.0
305
+ tool_used_errors = None
306
+ else:
307
+ missing_tools = len(theoretical_tools - tools_in_code)
308
+ unexpected_tools = len(tools_in_code - theoretical_tools)
309
+ tool_used_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
310
+
311
+ tool_used_errors = {
312
+ "selected_tools": tools_in_code,
313
+ "theoretical_tools": theoretical_tools,
314
+ }
315
+
316
+ score = score_code(agent_answer, theoretical_answer, verbose=verbose)
317
+ if score < 1.0:
318
+ code_errors = {
319
+ "code_produced": code,
320
+ "evaluation": agent_answer,
321
+ "theoretical_answer": theoretical_answer,
322
+ }
323
+ else:
324
+ code_errors = None
325
+
326
+ return (tool_selection_score, tool_used_score, score), (tool_selection_errors, tool_used_errors, code_errors)
327
+
328
+
329
+ def evaluate_agent(agent, batch_size=8, verbose=False, return_errors=False):
330
+ """
331
+ Evaluates a new agent on all `EVALUATION_TASKS`.
332
+
333
+ Example:
334
+
335
+ ```py
336
+ agent = NewOpenAiAgent(model="text-davinci-003", api_key=your_api_key)
337
+ bads = new_evaluate_agent(agent)
338
+ for bad in bads:
339
+ print(bad)
340
+ ```
341
+ """
342
+ # Sanity check
343
+ agent_tools = set(agent.toolbox.keys())
344
+ if agent_tools != set(TEST_TOOLS):
345
+ missing_tools = set(TEST_TOOLS) - agent_tools
346
+ unexpected_tools = set(agent_tools) - TEST_TOOLS
347
+ raise ValueError(
348
+ f"Fix the test tools in the evaluate_agent module. Tools mising: {missing_tools}. Extra tools: {unexpected_tools}."
349
+ )
350
+
351
+ eval_tasks = []
352
+ eval_idx = []
353
+ for idx, pb in enumerate(EVALUATION_TASKS):
354
+ if isinstance(pb.task, list):
355
+ eval_tasks.extend(pb.task)
356
+ eval_idx.extend([idx] * len(pb.task))
357
+ else:
358
+ eval_tasks.append(pb.task)
359
+ eval_idx.append(idx)
360
+
361
+ tool_selection_score = 0
362
+ tool_used_score = 0
363
+ code_score = 0
364
+
365
+ if return_errors:
366
+ tool_selection_errors = {}
367
+ tool_used_errors = {}
368
+ code_errors = {}
369
+
370
+ for start_idx in range(0, len(eval_tasks), batch_size):
371
+ end_idx = min(start_idx + batch_size, len(eval_tasks))
372
+ batch_tasks = eval_tasks[start_idx:end_idx]
373
+
374
+ results = [agent.run(task, return_generated_code=True) for task in batch_tasks]
375
+
376
+ for idx, result in enumerate(results):
377
+ problem = EVALUATION_TASKS[eval_idx[start_idx + idx]]
378
+ if verbose:
379
+ print(f"====Task {start_idx + idx}====\n{batch_tasks[idx]}\n")
380
+ code = agent.extract_action(result, split_token="Answer:")
381
+
382
+ # Evaluate agent answer and code answer
383
+ agent_answer = evaluate_code(code, problem.inputs, verbose=verbose)
384
+ if isinstance(problem.answer, list):
385
+ theoretical_answer = [evaluate_code(answer, problem.inputs) for answer in problem.answer]
386
+ else:
387
+ theoretical_answer = evaluate_code(problem.answer, problem.inputs)
388
+
389
+ scores, errors = evaluate_one_result(
390
+ code, agent_answer, theoretical_answer, problem.answer, verbose=verbose
391
+ )
392
+
393
+ tool_selection_score += scores[0]
394
+ tool_used_score += scores[1]
395
+ code_score += scores[2]
396
+
397
+ if return_errors:
398
+ if errors[0] is not None:
399
+ tool_selection_errors[batch_tasks[idx]] = errors[0]
400
+ if errors[1] is not None:
401
+ tool_used_errors[batch_tasks[idx]] = errors[1]
402
+ if errors[2] is not None:
403
+ code_errors[batch_tasks[idx]] = errors[2]
404
+
405
+ scores = {
406
+ "tool selection score": 100 * (tool_selection_score / len(eval_tasks)),
407
+ "tool used score": 100 * (tool_used_score / len(eval_tasks)),
408
+ "code score": 100 * (code_score / len(eval_tasks)),
409
+ }
410
+
411
+ if return_errors:
412
+ return scores, tool_selection_errors, tool_used_errors, code_errors
413
+ else:
414
+ return scores
.venv/Lib/site-packages/transformers/agents/image_question_answering.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import torch
19
+ from PIL import Image
20
+
21
+ from ..models.auto import AutoModelForVisualQuestionAnswering, AutoProcessor
22
+ from ..utils import requires_backends
23
+ from .tools import PipelineTool
24
+
25
+
26
+ class ImageQuestionAnsweringTool(PipelineTool):
27
+ default_checkpoint = "dandelin/vilt-b32-finetuned-vqa"
28
+ description = (
29
+ "This is a tool that answers a question about an image. It "
30
+ "returns a text that is the answer to the question."
31
+ )
32
+ name = "image_qa"
33
+ pre_processor_class = AutoProcessor
34
+ model_class = AutoModelForVisualQuestionAnswering
35
+
36
+ inputs = {
37
+ "image": {
38
+ "type": "image",
39
+ "description": "The image containing the information. Can be a PIL Image or a string path to the image.",
40
+ },
41
+ "question": {"type": "string", "description": "The question in English"},
42
+ }
43
+ output_type = "string"
44
+
45
+ def __init__(self, *args, **kwargs):
46
+ requires_backends(self, ["vision"])
47
+ super().__init__(*args, **kwargs)
48
+
49
+ def encode(self, image: "Image", question: str):
50
+ return self.pre_processor(image, question, return_tensors="pt")
51
+
52
+ def forward(self, inputs):
53
+ with torch.no_grad():
54
+ return self.model(**inputs).logits
55
+
56
+ def decode(self, outputs):
57
+ idx = outputs.argmax(-1).item()
58
+ return self.model.config.id2label[idx]
.venv/Lib/site-packages/transformers/agents/llm_engine.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ from copy import deepcopy
18
+ from enum import Enum
19
+ from typing import Dict, List, Optional
20
+
21
+ from huggingface_hub import InferenceClient
22
+
23
+ from .. import AutoTokenizer
24
+ from ..pipelines.base import Pipeline
25
+ from ..utils import logging
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class MessageRole(str, Enum):
32
+ USER = "user"
33
+ ASSISTANT = "assistant"
34
+ SYSTEM = "system"
35
+ TOOL_CALL = "tool-call"
36
+ TOOL_RESPONSE = "tool-response"
37
+
38
+ @classmethod
39
+ def roles(cls):
40
+ return [r.value for r in cls]
41
+
42
+
43
+ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}):
44
+ """
45
+ Subsequent messages with the same role will be concatenated to a single message.
46
+
47
+ Args:
48
+ message_list (`List[Dict[str, str]]`): List of chat messages.
49
+ """
50
+ final_message_list = []
51
+ message_list = deepcopy(message_list) # Avoid modifying the original list
52
+ for message in message_list:
53
+ if not set(message.keys()) == {"role", "content"}:
54
+ raise ValueError("Message should contain only 'role' and 'content' keys!")
55
+
56
+ role = message["role"]
57
+ if role not in MessageRole.roles():
58
+ raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
59
+
60
+ if role in role_conversions:
61
+ message["role"] = role_conversions[role]
62
+
63
+ if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
64
+ final_message_list[-1]["content"] += "\n=======\n" + message["content"]
65
+ else:
66
+ final_message_list.append(message)
67
+ return final_message_list
68
+
69
+
70
+ llama_role_conversions = {
71
+ MessageRole.TOOL_RESPONSE: MessageRole.USER,
72
+ }
73
+
74
+
75
+ class HfEngine:
76
+ def __init__(self, model_id: Optional[str] = None):
77
+ self.last_input_token_count = None
78
+ self.last_output_token_count = None
79
+ if model_id is None:
80
+ model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
81
+ logger.warning(f"Using default model for token counting: '{model_id}'")
82
+ try:
83
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
84
+ except Exception as e:
85
+ logger.warning(f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead.")
86
+ self.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct")
87
+
88
+ def get_token_counts(self):
89
+ return {
90
+ "input_token_count": self.last_input_token_count,
91
+ "output_token_count": self.last_output_token_count,
92
+ }
93
+
94
+ def generate(
95
+ self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
96
+ ):
97
+ raise NotImplementedError
98
+
99
+ def __call__(
100
+ self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
101
+ ) -> str:
102
+ """Process the input messages and return the model's response.
103
+
104
+ This method sends a list of messages to the Hugging Face Inference API, optionally with stop sequences and grammar customization.
105
+
106
+ Parameters:
107
+ messages (`List[Dict[str, str]]`):
108
+ A list of message dictionaries to be processed. Each dictionary should have the structure `{"role": "user/system", "content": "message content"}`.
109
+ stop_sequences (`List[str]`, *optional*):
110
+ A list of strings that will stop the generation if encountered in the model's output.
111
+ grammar (`str`, *optional*):
112
+ The grammar or formatting structure to use in the model's response.
113
+
114
+ Returns:
115
+ `str`: The text content of the model's response.
116
+
117
+ Example:
118
+ ```python
119
+ >>> engine = HfApiEngine(
120
+ ... model="meta-llama/Meta-Llama-3.1-8B-Instruct",
121
+ ... token="your_hf_token_here",
122
+ ... max_tokens=2000
123
+ ... )
124
+ >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
125
+ >>> response = engine(messages, stop_sequences=["END"])
126
+ >>> print(response)
127
+ "Quantum mechanics is the branch of physics that studies..."
128
+ ```
129
+ """
130
+ if not isinstance(messages, List):
131
+ raise ValueError("Messages should be a list of dictionaries with 'role' and 'content' keys.")
132
+ if stop_sequences is None:
133
+ stop_sequences = []
134
+ response = self.generate(messages, stop_sequences, grammar)
135
+ self.last_input_token_count = len(self.tokenizer.apply_chat_template(messages, tokenize=True))
136
+ self.last_output_token_count = len(self.tokenizer.encode(response))
137
+
138
+ # Remove stop sequences from LLM output
139
+ for stop_seq in stop_sequences:
140
+ if response[-len(stop_seq) :] == stop_seq:
141
+ response = response[: -len(stop_seq)]
142
+ return response
143
+
144
+
145
+ class HfApiEngine(HfEngine):
146
+ """A class to interact with Hugging Face's Inference API for language model interaction.
147
+
148
+ This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
149
+
150
+ Parameters:
151
+ model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3.1-8B-Instruct"`):
152
+ The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
153
+ token (`str`, *optional*):
154
+ Token used by the Hugging Face API for authentication.
155
+ If not provided, the class will use the token stored in the Hugging Face CLI configuration.
156
+ max_tokens (`int`, *optional*, defaults to 1500):
157
+ The maximum number of tokens allowed in the output.
158
+ timeout (`int`, *optional*, defaults to 120):
159
+ Timeout for the API request, in seconds.
160
+
161
+ Raises:
162
+ ValueError:
163
+ If the model name is not provided.
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
169
+ token: Optional[str] = None,
170
+ max_tokens: Optional[int] = 1500,
171
+ timeout: Optional[int] = 120,
172
+ ):
173
+ super().__init__(model_id=model)
174
+ self.model = model
175
+ self.client = InferenceClient(self.model, token=token, timeout=timeout)
176
+ self.max_tokens = max_tokens
177
+
178
+ def generate(
179
+ self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
180
+ ) -> str:
181
+ # Get clean message list
182
+ messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
183
+
184
+ # Send messages to the Hugging Face Inference API
185
+ if grammar is not None:
186
+ response = self.client.chat_completion(
187
+ messages, stop=stop_sequences, max_tokens=self.max_tokens, response_format=grammar
188
+ )
189
+ else:
190
+ response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=self.max_tokens)
191
+
192
+ response = response.choices[0].message.content
193
+ return response
194
+
195
+
196
+ class TransformersEngine(HfEngine):
197
+ """This engine uses a pre-initialized local text-generation pipeline."""
198
+
199
+ def __init__(self, pipeline: Pipeline, model_id: Optional[str] = None):
200
+ super().__init__(model_id)
201
+ self.pipeline = pipeline
202
+
203
+ def generate(
204
+ self,
205
+ messages: List[Dict[str, str]],
206
+ stop_sequences: Optional[List[str]] = None,
207
+ grammar: Optional[str] = None,
208
+ max_length: int = 1500,
209
+ ) -> str:
210
+ # Get clean message list
211
+ messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
212
+
213
+ # Get LLM output
214
+ if stop_sequences is not None and len(stop_sequences) > 0:
215
+ stop_strings = stop_sequences
216
+ else:
217
+ stop_strings = None
218
+
219
+ output = self.pipeline(
220
+ messages,
221
+ stop_strings=stop_strings,
222
+ max_length=max_length,
223
+ tokenizer=self.pipeline.tokenizer,
224
+ )
225
+
226
+ response = output[0]["generated_text"][-1]["content"]
227
+ return response
228
+
229
+
230
+ DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
231
+ "type": "regex",
232
+ "value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>',
233
+ }
234
+
235
+ DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
236
+ "type": "regex",
237
+ "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
238
+ }
.venv/Lib/site-packages/transformers/agents/monitoring.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ from ..utils import logging
18
+ from .agent_types import AgentAudio, AgentImage, AgentText
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ def pull_message(step_log: dict, test_mode: bool = True):
25
+ try:
26
+ from gradio import ChatMessage
27
+ except ImportError:
28
+ if test_mode:
29
+
30
+ class ChatMessage:
31
+ def __init__(self, role, content, metadata=None):
32
+ self.role = role
33
+ self.content = content
34
+ self.metadata = metadata
35
+ else:
36
+ raise ImportError("Gradio should be installed in order to launch a gradio demo.")
37
+
38
+ if step_log.get("rationale"):
39
+ yield ChatMessage(role="assistant", content=step_log["rationale"])
40
+ if step_log.get("tool_call"):
41
+ used_code = step_log["tool_call"]["tool_name"] == "code interpreter"
42
+ content = step_log["tool_call"]["tool_arguments"]
43
+ if used_code:
44
+ content = f"```py\n{content}\n```"
45
+ yield ChatMessage(
46
+ role="assistant",
47
+ metadata={"title": f"🛠️ Used tool {step_log['tool_call']['tool_name']}"},
48
+ content=str(content),
49
+ )
50
+ if step_log.get("observation"):
51
+ yield ChatMessage(role="assistant", content=f"```\n{step_log['observation']}\n```")
52
+ if step_log.get("error"):
53
+ yield ChatMessage(
54
+ role="assistant",
55
+ content=str(step_log["error"]),
56
+ metadata={"title": "💥 Error"},
57
+ )
58
+
59
+
60
+ def stream_to_gradio(agent, task: str, test_mode: bool = False, **kwargs):
61
+ """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
62
+
63
+ try:
64
+ from gradio import ChatMessage
65
+ except ImportError:
66
+ if test_mode:
67
+
68
+ class ChatMessage:
69
+ def __init__(self, role, content, metadata=None):
70
+ self.role = role
71
+ self.content = content
72
+ self.metadata = metadata
73
+ else:
74
+ raise ImportError("Gradio should be installed in order to launch a gradio demo.")
75
+
76
+ for step_log in agent.run(task, stream=True, **kwargs):
77
+ if isinstance(step_log, dict):
78
+ for message in pull_message(step_log, test_mode=test_mode):
79
+ yield message
80
+
81
+ final_answer = step_log # Last log is the run's final_answer
82
+
83
+ if isinstance(final_answer, AgentText):
84
+ yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```")
85
+ elif isinstance(final_answer, AgentImage):
86
+ yield ChatMessage(
87
+ role="assistant",
88
+ content={"path": final_answer.to_string(), "mime_type": "image/png"},
89
+ )
90
+ elif isinstance(final_answer, AgentAudio):
91
+ yield ChatMessage(
92
+ role="assistant",
93
+ content={"path": final_answer.to_string(), "mime_type": "audio/wav"},
94
+ )
95
+ else:
96
+ yield ChatMessage(role="assistant", content=str(final_answer))
97
+
98
+
99
+ class Monitor:
100
+ def __init__(self, tracked_llm_engine):
101
+ self.step_durations = []
102
+ self.tracked_llm_engine = tracked_llm_engine
103
+ if getattr(self.tracked_llm_engine, "last_input_token_count", "Not found") != "Not found":
104
+ self.total_input_token_count = 0
105
+ self.total_output_token_count = 0
106
+
107
+ def update_metrics(self, step_log):
108
+ step_duration = step_log["step_duration"]
109
+ self.step_durations.append(step_duration)
110
+ logger.info(f"Step {len(self.step_durations)}:")
111
+ logger.info(f"- Time taken: {step_duration:.2f} seconds (valid only if step succeeded)")
112
+
113
+ if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
114
+ self.total_input_token_count += self.tracked_llm_engine.last_input_token_count
115
+ self.total_output_token_count += self.tracked_llm_engine.last_output_token_count
116
+ logger.info(f"- Input tokens: {self.total_input_token_count}")
117
+ logger.info(f"- Output tokens: {self.total_output_token_count}")
.venv/Lib/site-packages/transformers/agents/prompts.py ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import re
18
+
19
+ from ..utils import cached_file
20
+
21
+
22
+ # docstyle-ignore
23
+ CHAT_MESSAGE_PROMPT = """
24
+ Human: <<task>>
25
+
26
+ Assistant: """
27
+
28
+
29
+ DEFAULT_PROMPTS_REPO = "huggingface-tools/default-prompts"
30
+ PROMPT_FILES = {"chat": "chat_prompt_template.txt", "run": "run_prompt_template.txt"}
31
+
32
+
33
+ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
34
+ """
35
+ Downloads and caches the prompt from a repo and returns it contents (if necessary).
36
+ """
37
+ if prompt_or_repo_id is None:
38
+ prompt_or_repo_id = DEFAULT_PROMPTS_REPO
39
+
40
+ # prompt is considered a repo ID when it does not contain any kind of space
41
+ if re.search("\\s", prompt_or_repo_id) is not None:
42
+ return prompt_or_repo_id
43
+
44
+ prompt_file = cached_file(
45
+ prompt_or_repo_id, PROMPT_FILES[mode], repo_type="dataset", user_agent={"agent": agent_name}
46
+ )
47
+ with open(prompt_file, "r", encoding="utf-8") as f:
48
+ return f.read()
49
+
50
+
51
+ DEFAULT_CODE_SYSTEM_PROMPT = """You will be given a task to solve, your job is to come up with a series of simple commands in Python that will perform the task.
52
+ To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
53
+ You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
54
+ Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
55
+ In the end, use tool 'final_answer' to return your answer, its argument will be what gets returned.
56
+ You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
57
+ Be sure to provide a 'Code:' token, else the run will fail.
58
+
59
+ Tools:
60
+ <<tool_descriptions>>
61
+
62
+ Examples:
63
+ ---
64
+ Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
65
+
66
+ Thought: I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
67
+ Code:
68
+ ```py
69
+ translated_question = translator(question=question, src_lang="French", tgt_lang="English")
70
+ print(f"The translated question is {translated_question}.")
71
+ answer = image_qa(image=image, question=translated_question)
72
+ final_answer(f"The answer is {answer}")
73
+ ```<end_action>
74
+
75
+ ---
76
+ Task: "Identify the oldest person in the `document` and create an image showcasing the result."
77
+
78
+ Thought: I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
79
+ Code:
80
+ ```py
81
+ answer = document_qa(document, question="What is the oldest person?")
82
+ print(f"The answer is {answer}.")
83
+ image = image_generator(answer)
84
+ final_answer(image)
85
+ ```<end_action>
86
+
87
+ ---
88
+ Task: "Generate an image using the text given in the variable `caption`."
89
+
90
+ Thought: I will use the following tool: `image_generator` to generate an image.
91
+ Code:
92
+ ```py
93
+ image = image_generator(prompt=caption)
94
+ final_answer(image)
95
+ ```<end_action>
96
+
97
+ ---
98
+ Task: "Summarize the text given in the variable `text` and read it out loud."
99
+
100
+ Thought: I will use the following tools: `summarizer` to create a summary of the input text, then `text_reader` to read it out loud.
101
+ Code:
102
+ ```py
103
+ summarized_text = summarizer(text)
104
+ print(f"Summary: {summarized_text}")
105
+ audio_summary = text_reader(summarized_text)
106
+ final_answer(audio_summary)
107
+ ```<end_action>
108
+
109
+ ---
110
+ Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image."
111
+
112
+ Thought: I will use the following tools: `text_qa` to create the answer, then `image_generator` to generate an image according to the answer.
113
+ Code:
114
+ ```py
115
+ answer = text_qa(text=text, question=question)
116
+ print(f"The answer is {answer}.")
117
+ image = image_generator(answer)
118
+ final_answer(image)
119
+ ```<end_action>
120
+
121
+ ---
122
+ Task: "Caption the following `image`."
123
+
124
+ Thought: I will use the following tool: `image_captioner` to generate a caption for the image.
125
+ Code:
126
+ ```py
127
+ caption = image_captioner(image)
128
+ final_answer(caption)
129
+ ```<end_action>
130
+
131
+ ---
132
+ Above example were using tools that might not exist for you. You only have acces to those Tools:
133
+ <<tool_names>>
134
+
135
+ Remember to make sure that variables you use are all defined.
136
+ Be sure to provide a 'Code:\n```' sequence before the code and '```<end_action>' after, else you will get an error.
137
+ DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'.
138
+
139
+ Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
140
+ """
141
+
142
+
143
+ DEFAULT_REACT_JSON_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using JSON tool calls. You will be given a task to solve as best you can.
144
+ To do so, you have been given access to the following tools: <<tool_names>>
145
+ The way you use the tools is by specifying a json blob, ending with '<end_action>'.
146
+ Specifically, this json should have an `action` key (name of the tool to use) and an `action_input` key (input to the tool).
147
+
148
+ The $ACTION_JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. It should be formatted in json. Do not try to escape special characters. Here is the template of a valid $ACTION_JSON_BLOB:
149
+ {
150
+ "action": $TOOL_NAME,
151
+ "action_input": $INPUT
152
+ }<end_action>
153
+
154
+ Make sure to have the $INPUT as a dictionary in the right format for the tool you are using, and do not put variable names as input if you can find the right values.
155
+
156
+ You should ALWAYS use the following format:
157
+
158
+ Thought: you should always think about one action to take. Then use the action as follows:
159
+ Action:
160
+ $ACTION_JSON_BLOB
161
+ Observation: the result of the action
162
+ ... (this Thought/Action/Observation can repeat N times, you should take several steps when needed. The $ACTION_JSON_BLOB must only use a SINGLE action at a time.)
163
+
164
+ You can use the result of the previous action as input for the next action.
165
+ The observation will always be a string: it can represent a file, like "image_1.jpg".
166
+ Then you can use it as input for the next action. You can do it for instance as follows:
167
+
168
+ Observation: "image_1.jpg"
169
+
170
+ Thought: I need to transform the image that I received in the previous observation to make it green.
171
+ Action:
172
+ {
173
+ "action": "image_transformer",
174
+ "action_input": {"image": "image_1.jpg"}
175
+ }<end_action>
176
+
177
+ To provide the final answer to the task, use an action blob with "action": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this:
178
+ Action:
179
+ {
180
+ "action": "final_answer",
181
+ "action_input": {"answer": "insert your final answer here"}
182
+ }<end_action>
183
+
184
+
185
+ Here are a few examples using notional tools:
186
+ ---
187
+ Task: "Generate an image of the oldest person in this document."
188
+
189
+ Thought: I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
190
+ Action:
191
+ {
192
+ "action": "document_qa",
193
+ "action_input": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"}
194
+ }<end_action>
195
+ Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
196
+
197
+
198
+ Thought: I will now generate an image showcasing the oldest person.
199
+ Action:
200
+ {
201
+ "action": "image_generator",
202
+ "action_input": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}
203
+ }<end_action>
204
+ Observation: "image.png"
205
+
206
+ Thought: I will now return the generated image.
207
+ Action:
208
+ {
209
+ "action": "final_answer",
210
+ "action_input": "image.png"
211
+ }<end_action>
212
+
213
+ ---
214
+ Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
215
+
216
+ Thought: I will use python code evaluator to compute the result of the operation and then return the final answer using the `final_answer` tool
217
+ Action:
218
+ {
219
+ "action": "python_interpreter",
220
+ "action_input": {"code": "5 + 3 + 1294.678"}
221
+ }<end_action>
222
+ Observation: 1302.678
223
+
224
+ Thought: Now that I know the result, I will now return it.
225
+ Action:
226
+ {
227
+ "action": "final_answer",
228
+ "action_input": "1302.678"
229
+ }<end_action>
230
+
231
+ ---
232
+ Task: "Which city has the highest population , Guangzhou or Shanghai?"
233
+
234
+ Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
235
+ Action:
236
+ {
237
+ "action": "search",
238
+ "action_input": "Population Guangzhou"
239
+ }<end_action>
240
+ Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
241
+
242
+
243
+ Thought: Now let's get the population of Shanghai using the tool 'search'.
244
+ Action:
245
+ {
246
+ "action": "search",
247
+ "action_input": "Population Shanghai"
248
+ }
249
+ Observation: '26 million (2019)'
250
+
251
+ Thought: Now I know that Shanghai has a larger population. Let's return the result.
252
+ Action:
253
+ {
254
+ "action": "final_answer",
255
+ "action_input": "Shanghai"
256
+ }<end_action>
257
+
258
+
259
+ Above example were using notional tools that might not exist for you. You only have acces to those tools:
260
+ <<tool_descriptions>>
261
+
262
+ Here are the rules you should always follow to solve your task:
263
+ 1. ALWAYS provide a 'Thought:' sequence, and an 'Action:' sequence that ends with <end_action>, else you will fail.
264
+ 2. Always use the right arguments for the tools. Never use variable names in the 'action_input' field, use the value instead.
265
+ 3. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself.
266
+ 4. Never re-do a tool call that you previously did with the exact same parameters.
267
+
268
+ Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
269
+ """
270
+
271
+
272
+ DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can.
273
+ To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code.
274
+ To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
275
+
276
+ At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use.
277
+ Then in the 'Code:' sequence, you should write the code in simple Python. The code sequence must end with '<end_action>' sequence.
278
+ During each intermediate step, you can use 'print()' to save whatever important information you will then need.
279
+ These print outputs will then appear in the 'Observation:' field, which will be available as input for the next step.
280
+ In the end you have to return a final answer using the `final_answer` tool.
281
+
282
+ Here are a few examples using notional tools:
283
+ ---
284
+ Task: "Generate an image of the oldest person in this document."
285
+
286
+ Thought: I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
287
+ Code:
288
+ ```py
289
+ answer = document_qa(document=document, question="Who is the oldest person mentioned?")
290
+ print(answer)
291
+ ```<end_action>
292
+ Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
293
+
294
+ Thought: I will now generate an image showcasing the oldest person.
295
+ Code:
296
+ ```py
297
+ image = image_generator("A portrait of John Doe, a 55-year-old man living in Canada.")
298
+ final_answer(image)
299
+ ```<end_action>
300
+
301
+ ---
302
+ Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
303
+
304
+ Thought: I will use python code to compute the result of the operation and then return the final answer using the `final_answer` tool
305
+ Code:
306
+ ```py
307
+ result = 5 + 3 + 1294.678
308
+ final_answer(result)
309
+ ```<end_action>
310
+
311
+ ---
312
+ Task: "Which city has the highest population: Guangzhou or Shanghai?"
313
+
314
+ Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
315
+ Code:
316
+ ```py
317
+ population_guangzhou = search("Guangzhou population")
318
+ print("Population Guangzhou:", population_guangzhou)
319
+ population_shanghai = search("Shanghai population")
320
+ print("Population Shanghai:", population_shanghai)
321
+ ```<end_action>
322
+ Observation:
323
+ Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
324
+ Population Shanghai: '26 million (2019)'
325
+
326
+ Thought: Now I know that Shanghai has the highest population.
327
+ Code:
328
+ ```py
329
+ final_answer("Shanghai")
330
+ ```<end_action>
331
+
332
+ ---
333
+ Task: "What is the current age of the pope, raised to the power 0.36?"
334
+
335
+ Thought: I will use the tool `wiki` to get the age of the pope, then raise it to the power 0.36.
336
+ Code:
337
+ ```py
338
+ pope_age = wiki(query="current pope age")
339
+ print("Pope age:", pope_age)
340
+ ```<end_action>
341
+ Observation:
342
+ Pope age: "The pope Francis is currently 85 years old."
343
+
344
+ Thought: I know that the pope is 85 years old. Let's compute the result using python code.
345
+ Code:
346
+ ```py
347
+ pope_current_age = 85 ** 0.36
348
+ final_answer(pope_current_age)
349
+ ```<end_action>
350
+
351
+ Above example were using notional tools that might not exist for you. On top of performing computations in the Python code snippets that you create, you have acces to those tools (and no other tool):
352
+
353
+ <<tool_descriptions>>
354
+
355
+ <<managed_agents_descriptions>>
356
+
357
+ Here are the rules you should always follow to solve your task:
358
+ 1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```<end_action>' sequence, else you will fail.
359
+ 2. Use only variables that you have defined!
360
+ 3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = wiki({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = wiki(query="What is the place where James Bond lives?")'.
361
+ 4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
362
+ 5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
363
+ 6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
364
+ 7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
365
+ 8. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
366
+ 9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
367
+ 10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
368
+
369
+ Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
370
+ """
371
+
372
+ SYSTEM_PROMPT_FACTS = """Below I will present you a task.
373
+
374
+ You will now build a comprehensive preparatory survey of which facts we have at our disposal and which ones we still need.
375
+ To do so, you will have to read the task and identify things that must be discovered in order to successfully complete it.
376
+ Don't make any assumptions. For each item, provide a thorough reasoning. Here is how you will structure this survey:
377
+
378
+ ---
379
+ ### 1. Facts given in the task
380
+ List here the specific facts given in the task that could help you (there might be nothing here).
381
+
382
+ ### 2. Facts to look up
383
+ List here any facts that we may need to look up.
384
+ Also list where to find each of these, for instance a website, a file... - maybe the task contains some sources that you should re-use here.
385
+
386
+ ### 3. Facts to derive
387
+ List here anything that we want to derive from the above by logical reasoning, for instance computation or simulation.
388
+
389
+ Keep in mind that "facts" will typically be specific names, dates, values, etc. Your answer should use the below headings:
390
+ ### 1. Facts given in the task
391
+ ### 2. Facts to look up
392
+ ### 3. Facts to derive
393
+ Do not add anything else."""
394
+
395
+ SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
396
+
397
+ Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
398
+ This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
399
+ Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
400
+ After writing the final step of the plan, write the '\n<end_plan>' tag and stop there."""
401
+
402
+ USER_PROMPT_PLAN = """
403
+ Here is your task:
404
+
405
+ Task:
406
+ ```
407
+ {task}
408
+ ```
409
+
410
+ Your plan can leverage any of these tools:
411
+ {tool_descriptions}
412
+
413
+ {managed_agents_descriptions}
414
+
415
+ List of facts that you know:
416
+ ```
417
+ {answer_facts}
418
+ ```
419
+
420
+ Now begin! Write your plan below."""
421
+
422
+ SYSTEM_PROMPT_FACTS_UPDATE = """
423
+ You are a world expert at gathering known and unknown facts based on a conversation.
424
+ Below you will find a task, and ahistory of attempts made to solve the task. You will have to produce a list of these:
425
+ ### 1. Facts given in the task
426
+ ### 2. Facts that we have learned
427
+ ### 3. Facts still to look up
428
+ ### 4. Facts still to derive
429
+ Find the task and history below."""
430
+
431
+ USER_PROMPT_FACTS_UPDATE = """Earlier we've built a list of facts.
432
+ But since in your previous steps you may have learned useful new facts or invalidated some false ones.
433
+ Please update your list of facts based on the previous history, and provide these headings:
434
+ ### 1. Facts given in the task
435
+ ### 2. Facts that we have learned
436
+ ### 3. Facts still to look up
437
+ ### 4. Facts still to derive
438
+
439
+ Now write your new list of facts below."""
440
+
441
+ SYSTEM_PROMPT_PLAN_UPDATE = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
442
+
443
+ You have been given a task:
444
+ ```
445
+ {task}
446
+ ```
447
+
448
+ Find below the record of what has been tried so far to solve it. Then you will be asked to make an updated plan to solve the task.
449
+ If the previous tries so far have met some success, you can make an updated plan based on these actions.
450
+ If you are stalled, you can make a completely new plan starting from scratch.
451
+ """
452
+
453
+ USER_PROMPT_PLAN_UPDATE = """You're still working towards solving this task:
454
+ ```
455
+ {task}
456
+ ```
457
+
458
+ You have access to these tools and only these:
459
+ {tool_descriptions}
460
+
461
+ {managed_agents_descriptions}
462
+
463
+ Here is the up to date list of facts that you know:
464
+ ```
465
+ {facts_update}
466
+ ```
467
+
468
+ Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
469
+ This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
470
+ Beware that you have {remaining_steps} steps remaining.
471
+ Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
472
+ After writing the final step of the plan, write the '\n<end_plan>' tag and stop there.
473
+
474
+ Now write your new plan below."""
475
+
476
+ SYSTEM_PROMPT_PLAN_STRUCTURED = """Output a step-by-step plan to solve the task using the given tools.
477
+ This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer. Each step should be structured as follows:
478
+ Step #n: {
479
+ "description": <description of what the step does and its output>
480
+ "tool": <tool to use>,
481
+ "params": {
482
+ <parameters to pass to the tool as a valid dict>
483
+ }
484
+ "output_var": <output variable name>
485
+ }
486
+ Each step must be necessary to reach the final answer. Steps should reuse outputs produced by earlier steps. The last step must be the final answer.
487
+
488
+ Below are some examples:
489
+
490
+ Example 1:
491
+ ------
492
+ Inputs:
493
+ ---
494
+ Task:
495
+ How many encoder blocks were in the first attention-only ML architecture published?
496
+
497
+ [FACTS LIST]:
498
+ ### 1. Facts given in the task
499
+ - The paper first introduced an attention-only ML architecture.
500
+ - The specific information required is the page number where the number of encoder blocks is stated.
501
+ - No local files are provided for access.
502
+
503
+ ### 2. Facts to look up
504
+ - The title and authors of the paper that first introduced an attention-only ML architecture.
505
+ - Source: Online search (e.g., Google Scholar, arXiv, or other academic databases)
506
+ - The full text of the identified paper.
507
+ - Source: Online academic repositories (e.g., arXiv, journal websites)
508
+ - The specific page number in the paper where the number of encoder blocks is mentioned.
509
+ - Source: The content of the identified paper
510
+
511
+ ### 3. Facts to derive
512
+ - By identifying the correct paper and locating the specific page, we will derive the page number where the number of encoder blocks is stated.
513
+ - Logical steps: Identify the correct paper, access its content, search for the term "encoder blocks," and note the page number where this information is found.
514
+ ```
515
+
516
+ [STEP 1 TOOL CALL]: {'tool_name': 'code interpreter', 'tool_arguments': '# Step 1: Identify the title and authors of the paper that first introduced an attention-only ML architecture.\nanswer = ask_search_agent(query="Can you find the title and authors of the paper that first introduced an attention-only machine learning architecture? Please provide the full citation.")\nprint(answer)'}
517
+ [OUTPUT OF STEP 1] Observation: **Title**: Attention Is All You Need
518
+ **Authors**: Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin
519
+ [STEP 2 TOOL CALL]: {'tool_name': 'code interpreter', 'tool_arguments': '# Step 1: Find the full text of the identified paper on arXiv\\npaper_url = "https://arxiv.org/pdf/1706.03762.pdf"\\nprint(paper_url)'}
520
+ [OUTPUT OF STEP 2] Observation: https://arxiv.org/pdf/1706.03762.pdf
521
+ ---
522
+
523
+ Output plan:
524
+ ---
525
+ Step #1: {
526
+ "description": "Open the PDF of the paper from the provided URL and search within the text of the paper for the mention of "encoder blocks"",
527
+ "tool": "inspect_file_as_text",
528
+ "params": {
529
+ "file_path": "https://arxiv.org/pdf/1706.03762.pdf",
530
+ "question": "On which page is the number of encoder blocks mentioned?"
531
+ },
532
+ "output_var": "page_number"
533
+ }
534
+
535
+ Step #2: {
536
+ "description": "Provide the final answer",
537
+ "tool": "final_answer",
538
+ "params": {
539
+ "answer": "{page_number}"
540
+ },
541
+ "output_var": ""
542
+ }
543
+ ------
544
+
545
+ Example 2:
546
+ ------
547
+ Inputs:
548
+ ---
549
+ Task:
550
+ How many golf balls fits into a Boeing-747?
551
+
552
+ [FACTS LIST]:
553
+ ### 1. Facts given in the task
554
+ - The task requires calculating the number of golf balls that fir into a Boeing-747
555
+ ### 2. Facts to look up
556
+ - The volume of a golf ball
557
+ - The volume of a Boeing-747
558
+ ### 3. Facts to derive
559
+ - Once the volumes are known the final answer can be calculated
560
+ ---
561
+ Output plan:
562
+ ---
563
+ Step #1: {
564
+ "description": "Find the volume of a Boeing-747",
565
+ "tool": "web_search",
566
+ "params": {
567
+ "query": "What is the internal volume of a Boeing-747 in cubic meters?"
568
+ },
569
+ "output_var": "boeing_volume"
570
+ }
571
+
572
+ Step #2: {
573
+ "description": "Find the volume of a standard golf ball",
574
+ "tool": "ask_search_agent",
575
+ "params": {
576
+ "query": "What is the volume of a standard golf ball in cubic centimeters?"
577
+ },
578
+ "output_var": "golf_ball_volume"
579
+ }
580
+
581
+ Step #3: {
582
+ "description": "Convert the volume of a golf ball from cubic centimeters to cubic meters. Calculate the number of golf balls that fit into the Boeing-747 by dividing the internal volume of the Boeing-747 by the volume of a golf ball.",
583
+ "tool": "python_code",
584
+ "params": {
585
+ "code": "golf_ball_volume_m3 = golf_ball_volume / 1e6\nnumber_of_golf_balls = boeing_volume / golf_ball_volume_m3"
586
+ },
587
+ "output_var": "number_of_golf_balls"
588
+ }
589
+
590
+ Step #4: {
591
+ "description": "Provide the final answer",
592
+ "tool": "final_answer",
593
+ "params": {
594
+ "answer": "{number_of_golf_balls}"
595
+ },
596
+ "output_var": ""
597
+ }
598
+ ------
599
+ Above example were using tools that might not exist for you.
600
+ Your goal is to create a plan to solve the task."""
601
+
602
+ USER_PROMPT_PLAN_STRUCTURED = """
603
+ Here are your inputs:
604
+
605
+ Task:
606
+ ```
607
+ {task}
608
+ ```
609
+
610
+ Your plan can leverage any of these tools:
611
+ {tool_descriptions}
612
+ These tools are Python functions which you can call with code. You also have access to a Python interpreter so you can run Python code.
613
+
614
+ List of facts that you know:
615
+ ```
616
+ {answer_facts}
617
+ ```
618
+
619
+ Now for the given task, create a plan taking into account the list of facts.
620
+ After writing the final step of the plan, write the '\n<end_plan>' tag and stop there. Output the plan only and nothing else."""
621
+
622
+ SYSTEM_PROMPT_PLAN_UPDATE_STRUCTURED = """Output a step-by-step plan to solve the task using the given tools.
623
+ This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer. Each step should be structured as follows:
624
+ Step #n: {{
625
+ "description": <description of what the step does and its output>
626
+ "tool": <tool to use>,
627
+ "params": {{
628
+ <parameters to pass to the tool as a valid dict>
629
+ }}
630
+ "output_var": <output variable name>
631
+ }}
632
+ Each step must be necessary to reach the final answer. Steps should reuse outputs produced by earlier steps. The last step must be the final answer.
633
+
634
+ Below are some examples:
635
+
636
+ Example 1:
637
+ ------
638
+ Inputs:
639
+ ---
640
+ Task:
641
+ How many encoder blocks were in the first attention-only ML architecture published?
642
+
643
+ [FACTS LIST]:
644
+ ### 1. Facts given in the task
645
+ - The paper first introduced an attention-only ML architecture.
646
+ - The specific information required is the page number where the number of encoder blocks is stated.
647
+ - No local files are provided for access.
648
+
649
+ ### 2. Facts to look up
650
+ - The title and authors of the paper that first introduced an attention-only ML architecture.
651
+ - Source: Online search (e.g., Google Scholar, arXiv, or other academic databases)
652
+ - The full text of the identified paper.
653
+ - Source: Online academic repositories (e.g., arXiv, journal websites)
654
+ - The specific page number in the paper where the number of encoder blocks is mentioned.
655
+ - Source: The content of the identified paper
656
+
657
+ ### 3. Facts to derive
658
+ - By identifying the correct paper and locating the specific page, we will derive the page number where the number of encoder blocks is stated.
659
+ - Logical steps: Identify the correct paper, access its content, search for the term "encoder blocks," and note the page number where this information is found.
660
+ ```
661
+
662
+ [STEP 1 TOOL CALL]: {{'tool_name': 'code interpreter', 'tool_arguments': '# Step 1: Identify the title and authors of the paper that first introduced an attention-only ML architecture.\nanswer = ask_search_agent(query="Can you find the title and authors of the paper that first introduced an attention-only machine learning architecture? Please provide the full citation.")\nprint(answer)'}}
663
+ [OUTPUT OF STEP 1] Observation: **Title**: Attention Is All You Need
664
+ **Authors**: Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin
665
+ [STEP 2 TOOL CALL]: {{'tool_name': 'code interpreter', 'tool_arguments': '# Step 1: Find the full text of the identified paper on arXiv\\npaper_url = "https://arxiv.org/pdf/1706.03762.pdf"\\nprint(paper_url)'}}
666
+ [OUTPUT OF STEP 2] Observation: https://arxiv.org/pdf/1706.03762.pdf
667
+ ---
668
+
669
+ Output plan:
670
+ ---
671
+ Step #1: {{
672
+ "description": "Open the PDF of the paper from the provided URL and search within the text of the paper for the mention of "encoder blocks"",
673
+ "tool": "inspect_file_as_text",
674
+ "params": {{
675
+ "file_path": "https://arxiv.org/pdf/1706.03762.pdf",
676
+ "question": "On which page is the number of encoder blocks mentioned?"
677
+ }},
678
+ "output_var": "page_number"
679
+ }}
680
+
681
+ Step #2: {{
682
+ "description": "Provide the final answer",
683
+ "tool": "final_answer",
684
+ "params": {{
685
+ "answer": "{{page_number}}"
686
+ }},
687
+ "output_var": ""
688
+ }}
689
+ ------
690
+
691
+ Example 2:
692
+ ------
693
+ Inputs:
694
+ ---
695
+ Task:
696
+ How many golf balls fits into a Boeing-747?
697
+
698
+ [FACTS LIST]:
699
+ ### 1. Facts given in the task
700
+ - The task requires calculating the number of golf balls that fir into a Boeing-747
701
+ ### 2. Facts to look up
702
+ - The volume of a golf ball
703
+ - The volume of a Boeing-747
704
+ ### 3. Facts to derive
705
+ - Once the volumes are known the final answer can be calculated
706
+ ---
707
+ Output plan:
708
+ ---
709
+ Step #1: {{
710
+ "description": "Find the volume of a Boeing-747",
711
+ "tool": "web_search",
712
+ "params": {{
713
+ "query": "What is the internal volume of a Boeing-747 in cubic meters?"
714
+ }},
715
+ "output_var": "boeing_volume"
716
+ }}
717
+
718
+ Step #2: {{
719
+ "description": "Find the volume of a standard golf ball",
720
+ "tool": "ask_search_agent",
721
+ "params": {{
722
+ "query": "What is the volume of a standard golf ball in cubic centimeters?"
723
+ }},
724
+ "output_var": "golf_ball_volume"
725
+ }}
726
+
727
+ Step #3: {{
728
+ "description": "Convert the volume of a golf ball from cubic centimeters to cubic meters. Calculate the number of golf balls that fit into the Boeing-747 by dividing the internal volume of the Boeing-747 by the volume of a golf ball.",
729
+ "tool": "python_code",
730
+ "params": {{
731
+ "code": "golf_ball_volume_m3 = golf_ball_volume / 1e6\nnumber_of_golf_balls = boeing_volume / golf_ball_volume_m3"
732
+ }},
733
+ "output_var": "number_of_golf_balls"
734
+ }}
735
+
736
+ Step #4: {{
737
+ "description": "Provide the final answer",
738
+ "tool": "final_answer",
739
+ "params": {{
740
+ "answer": "{{number_of_golf_balls}}"
741
+ }},
742
+ "output_var": ""
743
+ }}
744
+ ------
745
+ Above example were using tools that might not exist for you.
746
+ Find below the record of what has been tried so far to solve it. Your goal is to create an updated plan to solve the task."""
747
+
748
+ USER_PROMPT_PLAN_UPDATE_STRUCTURED = """
749
+ Here are your inputs:
750
+
751
+ Task:
752
+ ```
753
+ {task}
754
+ ```
755
+
756
+ Your plan can leverage any of these tools:
757
+ {tool_descriptions}
758
+ These tools are Python functions which you can call with code. You also have access to a Python interpreter so you can run Python code.
759
+
760
+ List of facts that you know:
761
+ ```
762
+ {facts_update}
763
+ ```
764
+
765
+ Now for the given task, create a plan taking into account the above inputs and list of facts.
766
+ Beware that you have {remaining_steps} steps remaining.
767
+ After writing the final step of the plan, write the '\n<end_plan>' tag and stop there. Output the plan only and nothing else."""
768
+
769
+ PLAN_UPDATE_FINAL_PLAN_REDACTION = """I still need to solve the task I was given:
770
+ ```
771
+ {task}
772
+ ```
773
+
774
+ Here is my new/updated plan of action to solve the task:
775
+ ```
776
+ {plan_update}
777
+ ```"""
778
+
779
+ SUPPORTED_PLAN_TYPES = ["default", "structured"]
780
+
781
+ PROMPTS_FOR_INITIAL_PLAN = {
782
+ "default": {"system": SYSTEM_PROMPT_PLAN, "user": USER_PROMPT_PLAN},
783
+ "structured": {"system": SYSTEM_PROMPT_PLAN_STRUCTURED, "user": USER_PROMPT_PLAN_STRUCTURED},
784
+ }
785
+
786
+ PROMPTS_FOR_PLAN_UPDATE = {
787
+ "default": {"system": SYSTEM_PROMPT_PLAN_UPDATE, "user": USER_PROMPT_PLAN_UPDATE},
788
+ "structured": {"system": SYSTEM_PROMPT_PLAN_UPDATE_STRUCTURED, "user": USER_PROMPT_PLAN_UPDATE_STRUCTURED},
789
+ }
.venv/Lib/site-packages/transformers/agents/python_interpreter.py ADDED
@@ -0,0 +1,908 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import ast
18
+ import builtins
19
+ import difflib
20
+ from collections.abc import Mapping
21
+ from importlib import import_module
22
+ from typing import Any, Callable, Dict, List, Optional
23
+
24
+ import numpy as np
25
+
26
+ from ..utils import is_pandas_available
27
+
28
+
29
+ if is_pandas_available():
30
+ import pandas as pd
31
+
32
+
33
+ class InterpreterError(ValueError):
34
+ """
35
+ An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
36
+ operations.
37
+ """
38
+
39
+ pass
40
+
41
+
42
+ ERRORS = {
43
+ name: getattr(builtins, name)
44
+ for name in dir(builtins)
45
+ if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)
46
+ }
47
+
48
+
49
+ LIST_SAFE_MODULES = [
50
+ "random",
51
+ "collections",
52
+ "math",
53
+ "time",
54
+ "queue",
55
+ "itertools",
56
+ "re",
57
+ "stat",
58
+ "statistics",
59
+ "unicodedata",
60
+ ]
61
+
62
+ PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
63
+ OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
64
+
65
+
66
+ class BreakException(Exception):
67
+ pass
68
+
69
+
70
+ class ContinueException(Exception):
71
+ pass
72
+
73
+
74
+ class ReturnException(Exception):
75
+ def __init__(self, value):
76
+ self.value = value
77
+
78
+
79
+ def get_iterable(obj):
80
+ if isinstance(obj, list):
81
+ return obj
82
+ elif hasattr(obj, "__iter__"):
83
+ return list(obj)
84
+ else:
85
+ raise InterpreterError("Object is not iterable")
86
+
87
+
88
+ def evaluate_unaryop(expression, state, static_tools, custom_tools):
89
+ operand = evaluate_ast(expression.operand, state, static_tools, custom_tools)
90
+ if isinstance(expression.op, ast.USub):
91
+ return -operand
92
+ elif isinstance(expression.op, ast.UAdd):
93
+ return operand
94
+ elif isinstance(expression.op, ast.Not):
95
+ return not operand
96
+ elif isinstance(expression.op, ast.Invert):
97
+ return ~operand
98
+ else:
99
+ raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
100
+
101
+
102
+ def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
103
+ args = [arg.arg for arg in lambda_expression.args.args]
104
+
105
+ def lambda_func(*values):
106
+ new_state = state.copy()
107
+ for arg, value in zip(args, values):
108
+ new_state[arg] = value
109
+ return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools)
110
+
111
+ return lambda_func
112
+
113
+
114
+ def evaluate_while(while_loop, state, static_tools, custom_tools):
115
+ max_iterations = 1000
116
+ iterations = 0
117
+ while evaluate_ast(while_loop.test, state, static_tools, custom_tools):
118
+ for node in while_loop.body:
119
+ try:
120
+ evaluate_ast(node, state, static_tools, custom_tools)
121
+ except BreakException:
122
+ return None
123
+ except ContinueException:
124
+ break
125
+ iterations += 1
126
+ if iterations > max_iterations:
127
+ raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
128
+ return None
129
+
130
+
131
+ def create_function(func_def, state, static_tools, custom_tools):
132
+ def new_func(*args, **kwargs):
133
+ func_state = state.copy()
134
+ arg_names = [arg.arg for arg in func_def.args.args]
135
+ default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults]
136
+
137
+ # Apply default values
138
+ defaults = dict(zip(arg_names[-len(default_values) :], default_values))
139
+
140
+ # Set positional arguments
141
+ for name, value in zip(arg_names, args):
142
+ func_state[name] = value
143
+
144
+ # # Set keyword arguments
145
+ for name, value in kwargs.items():
146
+ func_state[name] = value
147
+
148
+ # Handle variable arguments
149
+ if func_def.args.vararg:
150
+ vararg_name = func_def.args.vararg.arg
151
+ func_state[vararg_name] = args
152
+
153
+ if func_def.args.kwarg:
154
+ kwarg_name = func_def.args.kwarg.arg
155
+ func_state[kwarg_name] = kwargs
156
+
157
+ # Set default values for arguments that were not provided
158
+ for name, value in defaults.items():
159
+ if name not in func_state:
160
+ func_state[name] = value
161
+
162
+ # Update function state with self and __class__
163
+ if func_def.args.args and func_def.args.args[0].arg == "self":
164
+ if args:
165
+ func_state["self"] = args[0]
166
+ func_state["__class__"] = args[0].__class__
167
+
168
+ result = None
169
+ try:
170
+ for stmt in func_def.body:
171
+ result = evaluate_ast(stmt, func_state, static_tools, custom_tools)
172
+ except ReturnException as e:
173
+ result = e.value
174
+ return result
175
+
176
+ return new_func
177
+
178
+
179
+ def create_class(class_name, class_bases, class_body):
180
+ class_dict = {}
181
+ for key, value in class_body.items():
182
+ class_dict[key] = value
183
+ return type(class_name, tuple(class_bases), class_dict)
184
+
185
+
186
+ def evaluate_function_def(func_def, state, static_tools, custom_tools):
187
+ custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools)
188
+ return custom_tools[func_def.name]
189
+
190
+
191
+ def evaluate_class_def(class_def, state, static_tools, custom_tools):
192
+ class_name = class_def.name
193
+ bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases]
194
+ class_dict = {}
195
+
196
+ for stmt in class_def.body:
197
+ if isinstance(stmt, ast.FunctionDef):
198
+ class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools)
199
+ elif isinstance(stmt, ast.Assign):
200
+ for target in stmt.targets:
201
+ if isinstance(target, ast.Name):
202
+ class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
203
+ elif isinstance(target, ast.Attribute):
204
+ class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
205
+ else:
206
+ raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
207
+
208
+ new_class = type(class_name, tuple(bases), class_dict)
209
+ state[class_name] = new_class
210
+ return new_class
211
+
212
+
213
+ def evaluate_augassign(expression, state, static_tools, custom_tools):
214
+ # Helper function to get current value and set new value based on the target type
215
+ def get_current_value(target):
216
+ if isinstance(target, ast.Name):
217
+ return state.get(target.id, 0)
218
+ elif isinstance(target, ast.Subscript):
219
+ obj = evaluate_ast(target.value, state, static_tools, custom_tools)
220
+ key = evaluate_ast(target.slice, state, static_tools, custom_tools)
221
+ return obj[key]
222
+ elif isinstance(target, ast.Attribute):
223
+ obj = evaluate_ast(target.value, state, static_tools, custom_tools)
224
+ return getattr(obj, target.attr)
225
+ elif isinstance(target, ast.Tuple):
226
+ return tuple(get_current_value(elt) for elt in target.elts)
227
+ elif isinstance(target, ast.List):
228
+ return [get_current_value(elt) for elt in target.elts]
229
+ else:
230
+ raise InterpreterError("AugAssign not supported for {type(target)} targets.")
231
+
232
+ current_value = get_current_value(expression.target)
233
+ value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
234
+
235
+ # Determine the operation and apply it
236
+ if isinstance(expression.op, ast.Add):
237
+ if isinstance(current_value, list):
238
+ if not isinstance(value_to_add, list):
239
+ raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
240
+ updated_value = current_value + value_to_add
241
+ else:
242
+ updated_value = current_value + value_to_add
243
+ elif isinstance(expression.op, ast.Sub):
244
+ updated_value = current_value - value_to_add
245
+ elif isinstance(expression.op, ast.Mult):
246
+ updated_value = current_value * value_to_add
247
+ elif isinstance(expression.op, ast.Div):
248
+ updated_value = current_value / value_to_add
249
+ elif isinstance(expression.op, ast.Mod):
250
+ updated_value = current_value % value_to_add
251
+ elif isinstance(expression.op, ast.Pow):
252
+ updated_value = current_value**value_to_add
253
+ elif isinstance(expression.op, ast.FloorDiv):
254
+ updated_value = current_value // value_to_add
255
+ elif isinstance(expression.op, ast.BitAnd):
256
+ updated_value = current_value & value_to_add
257
+ elif isinstance(expression.op, ast.BitOr):
258
+ updated_value = current_value | value_to_add
259
+ elif isinstance(expression.op, ast.BitXor):
260
+ updated_value = current_value ^ value_to_add
261
+ elif isinstance(expression.op, ast.LShift):
262
+ updated_value = current_value << value_to_add
263
+ elif isinstance(expression.op, ast.RShift):
264
+ updated_value = current_value >> value_to_add
265
+ else:
266
+ raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
267
+
268
+ # Update the state
269
+ set_value(expression.target, updated_value, state, static_tools, custom_tools)
270
+
271
+ return updated_value
272
+
273
+
274
+ def evaluate_boolop(node, state, static_tools, custom_tools):
275
+ if isinstance(node.op, ast.And):
276
+ for value in node.values:
277
+ if not evaluate_ast(value, state, static_tools, custom_tools):
278
+ return False
279
+ return True
280
+ elif isinstance(node.op, ast.Or):
281
+ for value in node.values:
282
+ if evaluate_ast(value, state, static_tools, custom_tools):
283
+ return True
284
+ return False
285
+
286
+
287
+ def evaluate_binop(binop, state, static_tools, custom_tools):
288
+ # Recursively evaluate the left and right operands
289
+ left_val = evaluate_ast(binop.left, state, static_tools, custom_tools)
290
+ right_val = evaluate_ast(binop.right, state, static_tools, custom_tools)
291
+
292
+ # Determine the operation based on the type of the operator in the BinOp
293
+ if isinstance(binop.op, ast.Add):
294
+ return left_val + right_val
295
+ elif isinstance(binop.op, ast.Sub):
296
+ return left_val - right_val
297
+ elif isinstance(binop.op, ast.Mult):
298
+ return left_val * right_val
299
+ elif isinstance(binop.op, ast.Div):
300
+ return left_val / right_val
301
+ elif isinstance(binop.op, ast.Mod):
302
+ return left_val % right_val
303
+ elif isinstance(binop.op, ast.Pow):
304
+ return left_val**right_val
305
+ elif isinstance(binop.op, ast.FloorDiv):
306
+ return left_val // right_val
307
+ elif isinstance(binop.op, ast.BitAnd):
308
+ return left_val & right_val
309
+ elif isinstance(binop.op, ast.BitOr):
310
+ return left_val | right_val
311
+ elif isinstance(binop.op, ast.BitXor):
312
+ return left_val ^ right_val
313
+ elif isinstance(binop.op, ast.LShift):
314
+ return left_val << right_val
315
+ elif isinstance(binop.op, ast.RShift):
316
+ return left_val >> right_val
317
+ else:
318
+ raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
319
+
320
+
321
+ def evaluate_assign(assign, state, static_tools, custom_tools):
322
+ result = evaluate_ast(assign.value, state, static_tools, custom_tools)
323
+ if len(assign.targets) == 1:
324
+ target = assign.targets[0]
325
+ set_value(target, result, state, static_tools, custom_tools)
326
+ else:
327
+ if len(assign.targets) != len(result):
328
+ raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
329
+ expanded_values = []
330
+ for tgt in assign.targets:
331
+ if isinstance(tgt, ast.Starred):
332
+ expanded_values.extend(result)
333
+ else:
334
+ expanded_values.append(result)
335
+ for tgt, val in zip(assign.targets, expanded_values):
336
+ set_value(tgt, val, state, static_tools, custom_tools)
337
+ return result
338
+
339
+
340
+ def set_value(target, value, state, static_tools, custom_tools):
341
+ if isinstance(target, ast.Name):
342
+ if target.id in static_tools:
343
+ raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
344
+ state[target.id] = value
345
+ elif isinstance(target, ast.Tuple):
346
+ if not isinstance(value, tuple):
347
+ if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
348
+ value = tuple(value)
349
+ else:
350
+ raise InterpreterError("Cannot unpack non-tuple value")
351
+ if len(target.elts) != len(value):
352
+ raise InterpreterError("Cannot unpack tuple of wrong size")
353
+ for i, elem in enumerate(target.elts):
354
+ set_value(elem, value[i], state, static_tools, custom_tools)
355
+ elif isinstance(target, ast.Subscript):
356
+ obj = evaluate_ast(target.value, state, static_tools, custom_tools)
357
+ key = evaluate_ast(target.slice, state, static_tools, custom_tools)
358
+ obj[key] = value
359
+ elif isinstance(target, ast.Attribute):
360
+ obj = evaluate_ast(target.value, state, static_tools, custom_tools)
361
+ setattr(obj, target.attr, value)
362
+
363
+
364
+ def evaluate_call(call, state, static_tools, custom_tools):
365
+ if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
366
+ raise InterpreterError(f"This is not a correct function: {call.func}).")
367
+ if isinstance(call.func, ast.Attribute):
368
+ obj = evaluate_ast(call.func.value, state, static_tools, custom_tools)
369
+ func_name = call.func.attr
370
+ if not hasattr(obj, func_name):
371
+ raise InterpreterError(f"Object {obj} has no attribute {func_name}")
372
+ func = getattr(obj, func_name)
373
+
374
+ elif isinstance(call.func, ast.Name):
375
+ func_name = call.func.id
376
+ if func_name in state:
377
+ func = state[func_name]
378
+ elif func_name in static_tools:
379
+ func = static_tools[func_name]
380
+ elif func_name in custom_tools:
381
+ func = custom_tools[func_name]
382
+ elif func_name in ERRORS:
383
+ func = ERRORS[func_name]
384
+ else:
385
+ raise InterpreterError(
386
+ f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
387
+ )
388
+
389
+ args = []
390
+ for arg in call.args:
391
+ if isinstance(arg, ast.Starred):
392
+ args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools))
393
+ else:
394
+ args.append(evaluate_ast(arg, state, static_tools, custom_tools))
395
+
396
+ args = []
397
+ for arg in call.args:
398
+ if isinstance(arg, ast.Starred):
399
+ unpacked = evaluate_ast(arg.value, state, static_tools, custom_tools)
400
+ if not hasattr(unpacked, "__iter__") or isinstance(unpacked, (str, bytes)):
401
+ raise InterpreterError(f"Cannot unpack non-iterable value {unpacked}")
402
+ args.extend(unpacked)
403
+ else:
404
+ args.append(evaluate_ast(arg, state, static_tools, custom_tools))
405
+
406
+ kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords}
407
+
408
+ if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
409
+ # Instantiate the class using its constructor
410
+ obj = func.__new__(func) # Create a new instance of the class
411
+ if hasattr(obj, "__init__"): # Check if the class has an __init__ method
412
+ obj.__init__(*args, **kwargs) # Call the __init__ method correctly
413
+ return obj
414
+ else:
415
+ if func_name == "super":
416
+ if not args:
417
+ if "__class__" in state and "self" in state:
418
+ return super(state["__class__"], state["self"])
419
+ else:
420
+ raise InterpreterError("super() needs at least one argument")
421
+ cls = args[0]
422
+ if not isinstance(cls, type):
423
+ raise InterpreterError("super() argument 1 must be type")
424
+ if len(args) == 1:
425
+ return super(cls)
426
+ elif len(args) == 2:
427
+ instance = args[1]
428
+ return super(cls, instance)
429
+ else:
430
+ raise InterpreterError("super() takes at most 2 arguments")
431
+ else:
432
+ if func_name == "print":
433
+ output = " ".join(map(str, args))
434
+ global PRINT_OUTPUTS
435
+ PRINT_OUTPUTS += output + "\n"
436
+ # cap the number of lines
437
+ return None
438
+ else: # Assume it's a callable object
439
+ output = func(*args, **kwargs)
440
+ return output
441
+
442
+
443
+ def evaluate_subscript(subscript, state, static_tools, custom_tools):
444
+ index = evaluate_ast(subscript.slice, state, static_tools, custom_tools)
445
+ value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
446
+
447
+ if isinstance(value, str) and isinstance(index, str):
448
+ raise InterpreterError("You're trying to subscript a string with a string index, which is impossible")
449
+ if isinstance(value, pd.core.indexing._LocIndexer):
450
+ parent_object = value.obj
451
+ return parent_object.loc[index]
452
+ if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)):
453
+ return value[index]
454
+ elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy):
455
+ return value[index]
456
+ elif isinstance(index, slice):
457
+ return value[index]
458
+ elif isinstance(value, (list, tuple)):
459
+ if not (-len(value) <= index < len(value)):
460
+ raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
461
+ return value[int(index)]
462
+ elif isinstance(value, str):
463
+ if not (-len(value) <= index < len(value)):
464
+ raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
465
+ return value[index]
466
+ elif index in value:
467
+ return value[index]
468
+ elif isinstance(index, str) and isinstance(value, Mapping):
469
+ close_matches = difflib.get_close_matches(index, list(value.keys()))
470
+ if len(close_matches) > 0:
471
+ return value[close_matches[0]]
472
+ raise InterpreterError(f"Could not index {value} with '{index}'.")
473
+
474
+
475
+ def evaluate_name(name, state, static_tools, custom_tools):
476
+ if name.id in state:
477
+ return state[name.id]
478
+ elif name.id in static_tools:
479
+ return static_tools[name.id]
480
+ elif name.id in ERRORS:
481
+ return ERRORS[name.id]
482
+ close_matches = difflib.get_close_matches(name.id, list(state.keys()))
483
+ if len(close_matches) > 0:
484
+ return state[close_matches[0]]
485
+ raise InterpreterError(f"The variable `{name.id}` is not defined.")
486
+
487
+
488
+ def evaluate_condition(condition, state, static_tools, custom_tools):
489
+ left = evaluate_ast(condition.left, state, static_tools, custom_tools)
490
+ comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators]
491
+ ops = [type(op) for op in condition.ops]
492
+
493
+ result = True
494
+ current_left = left
495
+
496
+ for op, comparator in zip(ops, comparators):
497
+ if op == ast.Eq:
498
+ current_result = current_left == comparator
499
+ elif op == ast.NotEq:
500
+ current_result = current_left != comparator
501
+ elif op == ast.Lt:
502
+ current_result = current_left < comparator
503
+ elif op == ast.LtE:
504
+ current_result = current_left <= comparator
505
+ elif op == ast.Gt:
506
+ current_result = current_left > comparator
507
+ elif op == ast.GtE:
508
+ current_result = current_left >= comparator
509
+ elif op == ast.Is:
510
+ current_result = current_left is comparator
511
+ elif op == ast.IsNot:
512
+ current_result = current_left is not comparator
513
+ elif op == ast.In:
514
+ current_result = current_left in comparator
515
+ elif op == ast.NotIn:
516
+ current_result = current_left not in comparator
517
+ else:
518
+ raise InterpreterError(f"Operator not supported: {op}")
519
+
520
+ result = result & current_result
521
+ current_left = comparator
522
+
523
+ if isinstance(result, bool) and not result:
524
+ break
525
+
526
+ return result if isinstance(result, (bool, pd.Series)) else result.all()
527
+
528
+
529
+ def evaluate_if(if_statement, state, static_tools, custom_tools):
530
+ result = None
531
+ test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools)
532
+ if test_result:
533
+ for line in if_statement.body:
534
+ line_result = evaluate_ast(line, state, static_tools, custom_tools)
535
+ if line_result is not None:
536
+ result = line_result
537
+ else:
538
+ for line in if_statement.orelse:
539
+ line_result = evaluate_ast(line, state, static_tools, custom_tools)
540
+ if line_result is not None:
541
+ result = line_result
542
+ return result
543
+
544
+
545
+ def evaluate_for(for_loop, state, static_tools, custom_tools):
546
+ result = None
547
+ iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools)
548
+ for counter in iterator:
549
+ set_value(for_loop.target, counter, state, static_tools, custom_tools)
550
+ for node in for_loop.body:
551
+ try:
552
+ line_result = evaluate_ast(node, state, static_tools, custom_tools)
553
+ if line_result is not None:
554
+ result = line_result
555
+ except BreakException:
556
+ break
557
+ except ContinueException:
558
+ continue
559
+ else:
560
+ continue
561
+ break
562
+ return result
563
+
564
+
565
+ def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
566
+ def inner_evaluate(generators, index, current_state):
567
+ if index >= len(generators):
568
+ return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)]
569
+ generator = generators[index]
570
+ iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools)
571
+ result = []
572
+ for value in iter_value:
573
+ new_state = current_state.copy()
574
+ if isinstance(generator.target, ast.Tuple):
575
+ for idx, elem in enumerate(generator.target.elts):
576
+ new_state[elem.id] = value[idx]
577
+ else:
578
+ new_state[generator.target.id] = value
579
+ if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs):
580
+ result.extend(inner_evaluate(generators, index + 1, new_state))
581
+ return result
582
+
583
+ return inner_evaluate(listcomp.generators, 0, state)
584
+
585
+
586
+ def evaluate_try(try_node, state, static_tools, custom_tools):
587
+ try:
588
+ for stmt in try_node.body:
589
+ evaluate_ast(stmt, state, static_tools, custom_tools)
590
+ except Exception as e:
591
+ matched = False
592
+ for handler in try_node.handlers:
593
+ if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)):
594
+ matched = True
595
+ if handler.name:
596
+ state[handler.name] = e
597
+ for stmt in handler.body:
598
+ evaluate_ast(stmt, state, static_tools, custom_tools)
599
+ break
600
+ if not matched:
601
+ raise e
602
+ else:
603
+ if try_node.orelse:
604
+ for stmt in try_node.orelse:
605
+ evaluate_ast(stmt, state, static_tools, custom_tools)
606
+ finally:
607
+ if try_node.finalbody:
608
+ for stmt in try_node.finalbody:
609
+ evaluate_ast(stmt, state, static_tools, custom_tools)
610
+
611
+
612
+ def evaluate_raise(raise_node, state, static_tools, custom_tools):
613
+ if raise_node.exc is not None:
614
+ exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools)
615
+ else:
616
+ exc = None
617
+ if raise_node.cause is not None:
618
+ cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools)
619
+ else:
620
+ cause = None
621
+ if exc is not None:
622
+ if cause is not None:
623
+ raise exc from cause
624
+ else:
625
+ raise exc
626
+ else:
627
+ raise InterpreterError("Re-raise is not supported without an active exception")
628
+
629
+
630
+ def evaluate_assert(assert_node, state, static_tools, custom_tools):
631
+ test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools)
632
+ if not test_result:
633
+ if assert_node.msg:
634
+ msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools)
635
+ raise AssertionError(msg)
636
+ else:
637
+ # Include the failing condition in the assertion message
638
+ test_code = ast.unparse(assert_node.test)
639
+ raise AssertionError(f"Assertion failed: {test_code}")
640
+
641
+
642
+ def evaluate_with(with_node, state, static_tools, custom_tools):
643
+ contexts = []
644
+ for item in with_node.items:
645
+ context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools)
646
+ if item.optional_vars:
647
+ state[item.optional_vars.id] = context_expr.__enter__()
648
+ contexts.append(state[item.optional_vars.id])
649
+ else:
650
+ context_var = context_expr.__enter__()
651
+ contexts.append(context_var)
652
+
653
+ try:
654
+ for stmt in with_node.body:
655
+ evaluate_ast(stmt, state, static_tools, custom_tools)
656
+ except Exception as e:
657
+ for context in reversed(contexts):
658
+ context.__exit__(type(e), e, e.__traceback__)
659
+ raise
660
+ else:
661
+ for context in reversed(contexts):
662
+ context.__exit__(None, None, None)
663
+
664
+
665
+ def import_modules(expression, state, authorized_imports):
666
+ def check_module_authorized(module_name):
667
+ module_path = module_name.split(".")
668
+ module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
669
+ return any(subpath in authorized_imports for subpath in module_subpaths)
670
+
671
+ if isinstance(expression, ast.Import):
672
+ for alias in expression.names:
673
+ if check_module_authorized(alias.name):
674
+ module = import_module(alias.name)
675
+ state[alias.asname or alias.name] = module
676
+ else:
677
+ raise InterpreterError(
678
+ f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
679
+ )
680
+ return None
681
+ elif isinstance(expression, ast.ImportFrom):
682
+ if check_module_authorized(expression.module):
683
+ module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
684
+ for alias in expression.names:
685
+ state[alias.asname or alias.name] = getattr(module, alias.name)
686
+ else:
687
+ raise InterpreterError(f"Import from {expression.module} is not allowed.")
688
+ return None
689
+
690
+
691
+ def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools):
692
+ result = {}
693
+ for gen in dictcomp.generators:
694
+ iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools)
695
+ for value in iter_value:
696
+ new_state = state.copy()
697
+ set_value(gen.target, value, new_state, static_tools, custom_tools)
698
+ if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs):
699
+ key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools)
700
+ val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools)
701
+ result[key] = val
702
+ return result
703
+
704
+
705
+ def evaluate_ast(
706
+ expression: ast.AST,
707
+ state: Dict[str, Any],
708
+ static_tools: Dict[str, Callable],
709
+ custom_tools: Dict[str, Callable],
710
+ authorized_imports: List[str] = LIST_SAFE_MODULES,
711
+ ):
712
+ """
713
+ Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
714
+ set of functions.
715
+
716
+ This function will recurse trough the nodes of the tree provided.
717
+
718
+ Args:
719
+ expression (`ast.AST`):
720
+ The code to evaluate, as an abstract syntax tree.
721
+ state (`Dict[str, Any]`):
722
+ A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
723
+ encounters assignements.
724
+ static_tools (`Dict[str, Callable]`):
725
+ Functions that may be called during the evaluation. Trying to change one of these static_tools will raise an error.
726
+ custom_tools (`Dict[str, Callable]`):
727
+ Functions that may be called during the evaluation. These static_tools can be overwritten.
728
+ authorized_imports (`List[str]`):
729
+ The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
730
+ Add more at your own risk!
731
+ """
732
+ global OPERATIONS_COUNT
733
+ if OPERATIONS_COUNT >= MAX_OPERATIONS:
734
+ raise InterpreterError(
735
+ f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations."
736
+ )
737
+ OPERATIONS_COUNT += 1
738
+ if isinstance(expression, ast.Assign):
739
+ # Assignement -> we evaluate the assignment which should update the state
740
+ # We return the variable assigned as it may be used to determine the final result.
741
+ return evaluate_assign(expression, state, static_tools, custom_tools)
742
+ elif isinstance(expression, ast.AugAssign):
743
+ return evaluate_augassign(expression, state, static_tools, custom_tools)
744
+ elif isinstance(expression, ast.Call):
745
+ # Function call -> we return the value of the function call
746
+ return evaluate_call(expression, state, static_tools, custom_tools)
747
+ elif isinstance(expression, ast.Constant):
748
+ # Constant -> just return the value
749
+ return expression.value
750
+ elif isinstance(expression, ast.Tuple):
751
+ return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts)
752
+ elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
753
+ return evaluate_listcomp(expression, state, static_tools, custom_tools)
754
+ elif isinstance(expression, ast.UnaryOp):
755
+ return evaluate_unaryop(expression, state, static_tools, custom_tools)
756
+ elif isinstance(expression, ast.Starred):
757
+ return evaluate_ast(expression.value, state, static_tools, custom_tools)
758
+ elif isinstance(expression, ast.BoolOp):
759
+ # Boolean operation -> evaluate the operation
760
+ return evaluate_boolop(expression, state, static_tools, custom_tools)
761
+ elif isinstance(expression, ast.Break):
762
+ raise BreakException()
763
+ elif isinstance(expression, ast.Continue):
764
+ raise ContinueException()
765
+ elif isinstance(expression, ast.BinOp):
766
+ # Binary operation -> execute operation
767
+ return evaluate_binop(expression, state, static_tools, custom_tools)
768
+ elif isinstance(expression, ast.Compare):
769
+ # Comparison -> evaluate the comparison
770
+ return evaluate_condition(expression, state, static_tools, custom_tools)
771
+ elif isinstance(expression, ast.Lambda):
772
+ return evaluate_lambda(expression, state, static_tools, custom_tools)
773
+ elif isinstance(expression, ast.FunctionDef):
774
+ return evaluate_function_def(expression, state, static_tools, custom_tools)
775
+ elif isinstance(expression, ast.Dict):
776
+ # Dict -> evaluate all keys and values
777
+ keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys]
778
+ values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values]
779
+ return dict(zip(keys, values))
780
+ elif isinstance(expression, ast.Expr):
781
+ # Expression -> evaluate the content
782
+ return evaluate_ast(expression.value, state, static_tools, custom_tools)
783
+ elif isinstance(expression, ast.For):
784
+ # For loop -> execute the loop
785
+ return evaluate_for(expression, state, static_tools, custom_tools)
786
+ elif isinstance(expression, ast.FormattedValue):
787
+ # Formatted value (part of f-string) -> evaluate the content and return
788
+ return evaluate_ast(expression.value, state, static_tools, custom_tools)
789
+ elif isinstance(expression, ast.If):
790
+ # If -> execute the right branch
791
+ return evaluate_if(expression, state, static_tools, custom_tools)
792
+ elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
793
+ return evaluate_ast(expression.value, state, static_tools, custom_tools)
794
+ elif isinstance(expression, ast.JoinedStr):
795
+ return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values])
796
+ elif isinstance(expression, ast.List):
797
+ # List -> evaluate all elements
798
+ return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts]
799
+ elif isinstance(expression, ast.Name):
800
+ # Name -> pick up the value in the state
801
+ return evaluate_name(expression, state, static_tools, custom_tools)
802
+ elif isinstance(expression, ast.Subscript):
803
+ # Subscript -> return the value of the indexing
804
+ return evaluate_subscript(expression, state, static_tools, custom_tools)
805
+ elif isinstance(expression, ast.IfExp):
806
+ test_val = evaluate_ast(expression.test, state, static_tools, custom_tools)
807
+ if test_val:
808
+ return evaluate_ast(expression.body, state, static_tools, custom_tools)
809
+ else:
810
+ return evaluate_ast(expression.orelse, state, static_tools, custom_tools)
811
+ elif isinstance(expression, ast.Attribute):
812
+ value = evaluate_ast(expression.value, state, static_tools, custom_tools)
813
+ return getattr(value, expression.attr)
814
+ elif isinstance(expression, ast.Slice):
815
+ return slice(
816
+ evaluate_ast(expression.lower, state, static_tools, custom_tools)
817
+ if expression.lower is not None
818
+ else None,
819
+ evaluate_ast(expression.upper, state, static_tools, custom_tools)
820
+ if expression.upper is not None
821
+ else None,
822
+ evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None,
823
+ )
824
+ elif isinstance(expression, ast.DictComp):
825
+ return evaluate_dictcomp(expression, state, static_tools, custom_tools)
826
+ elif isinstance(expression, ast.While):
827
+ return evaluate_while(expression, state, static_tools, custom_tools)
828
+ elif isinstance(expression, (ast.Import, ast.ImportFrom)):
829
+ return import_modules(expression, state, authorized_imports)
830
+ elif isinstance(expression, ast.ClassDef):
831
+ return evaluate_class_def(expression, state, static_tools, custom_tools)
832
+ elif isinstance(expression, ast.Try):
833
+ return evaluate_try(expression, state, static_tools, custom_tools)
834
+ elif isinstance(expression, ast.Raise):
835
+ return evaluate_raise(expression, state, static_tools, custom_tools)
836
+ elif isinstance(expression, ast.Assert):
837
+ return evaluate_assert(expression, state, static_tools, custom_tools)
838
+ elif isinstance(expression, ast.With):
839
+ return evaluate_with(expression, state, static_tools, custom_tools)
840
+ elif isinstance(expression, ast.Set):
841
+ return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts}
842
+ elif isinstance(expression, ast.Return):
843
+ raise ReturnException(
844
+ evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None
845
+ )
846
+ else:
847
+ # For now we refuse anything else. Let's add things as we need them.
848
+ raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
849
+
850
+
851
+ def truncate_print_outputs(print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT) -> str:
852
+ if len(print_outputs) < max_len_outputs:
853
+ return print_outputs
854
+ else:
855
+ return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n"
856
+
857
+
858
+ def evaluate_python_code(
859
+ code: str,
860
+ static_tools: Optional[Dict[str, Callable]] = None,
861
+ custom_tools: Optional[Dict[str, Callable]] = None,
862
+ state: Optional[Dict[str, Any]] = None,
863
+ authorized_imports: List[str] = LIST_SAFE_MODULES,
864
+ ):
865
+ """
866
+ Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
867
+ of functions.
868
+
869
+ This function will recurse through the nodes of the tree provided.
870
+
871
+ Args:
872
+ code (`str`):
873
+ The code to evaluate.
874
+ static_tools (`Dict[str, Callable]`):
875
+ The functions that may be called during the evaluation.
876
+ These tools cannot be overwritten in the code: any assignment to their name will raise an error.
877
+ custom_tools (`Dict[str, Callable]`):
878
+ The functions that may be called during the evaluation.
879
+ These tools can be overwritten in the code: any assignment to their name will overwrite them.
880
+ state (`Dict[str, Any]`):
881
+ A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
882
+ updated by this function to contain all variables as they are evaluated.
883
+ The print outputs will be stored in the state under the key 'print_outputs'.
884
+ """
885
+ try:
886
+ expression = ast.parse(code)
887
+ except SyntaxError as e:
888
+ raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
889
+ if state is None:
890
+ state = {}
891
+ if static_tools is None:
892
+ static_tools = {}
893
+ if custom_tools is None:
894
+ custom_tools = {}
895
+ result = None
896
+ global PRINT_OUTPUTS
897
+ PRINT_OUTPUTS = ""
898
+ global OPERATIONS_COUNT
899
+ OPERATIONS_COUNT = 0
900
+ try:
901
+ for node in expression.body:
902
+ result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
903
+ state["print_outputs"] = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT)
904
+ return result
905
+ except InterpreterError as e:
906
+ msg = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT)
907
+ msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
908
+ raise InterpreterError(msg)
.venv/Lib/site-packages/transformers/agents/search.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import re
18
+
19
+ import requests
20
+ from requests.exceptions import RequestException
21
+
22
+ from .tools import Tool
23
+
24
+
25
+ class DuckDuckGoSearchTool(Tool):
26
+ name = "web_search"
27
+ description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
28
+ Each result has keys 'title', 'href' and 'body'."""
29
+ inputs = {"query": {"type": "string", "description": "The search query to perform."}}
30
+ output_type = "any"
31
+
32
+ def forward(self, query: str) -> str:
33
+ try:
34
+ from duckduckgo_search import DDGS
35
+ except ImportError:
36
+ raise ImportError(
37
+ "You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`."
38
+ )
39
+ results = DDGS().text(query, max_results=7)
40
+ return results
41
+
42
+
43
+ class VisitWebpageTool(Tool):
44
+ name = "visit_webpage"
45
+ description = "Visits a webpage at the given url and returns its content as a markdown string."
46
+ inputs = {
47
+ "url": {
48
+ "type": "string",
49
+ "description": "The url of the webpage to visit.",
50
+ }
51
+ }
52
+ output_type = "string"
53
+
54
+ def forward(self, url: str) -> str:
55
+ try:
56
+ from markdownify import markdownify
57
+ except ImportError:
58
+ raise ImportError(
59
+ "You must install package `markdownify` to run this tool: for instance run `pip install markdownify`."
60
+ )
61
+ try:
62
+ # Send a GET request to the URL
63
+ response = requests.get(url)
64
+ response.raise_for_status() # Raise an exception for bad status codes
65
+
66
+ # Convert the HTML content to Markdown
67
+ markdown_content = markdownify(response.text).strip()
68
+
69
+ # Remove multiple line breaks
70
+ markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
71
+
72
+ return markdown_content
73
+
74
+ except RequestException as e:
75
+ return f"Error fetching the webpage: {str(e)}"
76
+ except Exception as e:
77
+ return f"An unexpected error occurred: {str(e)}"
.venv/Lib/site-packages/transformers/agents/speech_to_text.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor
19
+ from .tools import PipelineTool
20
+
21
+
22
+ class SpeechToTextTool(PipelineTool):
23
+ default_checkpoint = "distil-whisper/distil-large-v3"
24
+ description = "This is a tool that transcribes an audio into text. It returns the transcribed text."
25
+ name = "transcriber"
26
+ pre_processor_class = WhisperProcessor
27
+ model_class = WhisperForConditionalGeneration
28
+
29
+ inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}}
30
+ output_type = "string"
31
+
32
+ def encode(self, audio):
33
+ return self.pre_processor(audio, return_tensors="pt")
34
+
35
+ def forward(self, inputs):
36
+ return self.model.generate(inputs["input_features"])
37
+
38
+ def decode(self, outputs):
39
+ return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
.venv/Lib/site-packages/transformers/agents/text_to_speech.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import torch
19
+
20
+ from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
21
+ from ..utils import is_datasets_available
22
+ from .tools import PipelineTool
23
+
24
+
25
+ if is_datasets_available():
26
+ from datasets import load_dataset
27
+
28
+
29
+ class TextToSpeechTool(PipelineTool):
30
+ default_checkpoint = "microsoft/speecht5_tts"
31
+ description = (
32
+ "This is a tool that reads an English text out loud. It returns a waveform object containing the sound."
33
+ )
34
+ name = "text_to_speech"
35
+ pre_processor_class = SpeechT5Processor
36
+ model_class = SpeechT5ForTextToSpeech
37
+ post_processor_class = SpeechT5HifiGan
38
+
39
+ inputs = {"text": {"type": "string", "description": "The text to read out loud (in English)"}}
40
+ output_type = "audio"
41
+
42
+ def setup(self):
43
+ if self.post_processor is None:
44
+ self.post_processor = "microsoft/speecht5_hifigan"
45
+ super().setup()
46
+
47
+ def encode(self, text, speaker_embeddings=None):
48
+ inputs = self.pre_processor(text=text, return_tensors="pt", truncation=True)
49
+
50
+ if speaker_embeddings is None:
51
+ if not is_datasets_available():
52
+ raise ImportError("Datasets needs to be installed if not passing speaker embeddings.")
53
+
54
+ embeddings_dataset = load_dataset(
55
+ "Matthijs/cmu-arctic-xvectors", split="validation", trust_remote_code=True
56
+ )
57
+ speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0)
58
+
59
+ return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings}
60
+
61
+ def forward(self, inputs):
62
+ with torch.no_grad():
63
+ return self.model.generate_speech(**inputs)
64
+
65
+ def decode(self, outputs):
66
+ with torch.no_grad():
67
+ return self.post_processor(outputs).cpu().detach()
.venv/Lib/site-packages/transformers/agents/tools.py ADDED
@@ -0,0 +1,1003 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import ast
18
+ import base64
19
+ import importlib
20
+ import inspect
21
+ import io
22
+ import json
23
+ import os
24
+ import tempfile
25
+ from functools import lru_cache, wraps
26
+ from pathlib import Path
27
+ from typing import Any, Callable, Dict, List, Optional, Union
28
+
29
+ from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder
30
+ from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
31
+ from packaging import version
32
+
33
+ from ..dynamic_module_utils import (
34
+ custom_object_save,
35
+ get_class_from_dynamic_module,
36
+ get_imports,
37
+ )
38
+ from ..models.auto import AutoProcessor
39
+ from ..utils import (
40
+ CONFIG_NAME,
41
+ TypeHintParsingException,
42
+ cached_file,
43
+ get_json_schema,
44
+ is_accelerate_available,
45
+ is_torch_available,
46
+ is_vision_available,
47
+ logging,
48
+ )
49
+ from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+
55
+ if is_torch_available():
56
+ import torch
57
+
58
+ if is_accelerate_available():
59
+ from accelerate import PartialState
60
+ from accelerate.utils import send_to_device
61
+
62
+
63
+ TOOL_CONFIG_FILE = "tool_config.json"
64
+
65
+
66
+ def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
67
+ if repo_type is not None:
68
+ return repo_type
69
+ try:
70
+ hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space", **hub_kwargs)
71
+ return "space"
72
+ except RepositoryNotFoundError:
73
+ try:
74
+ hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
75
+ return "model"
76
+ except RepositoryNotFoundError:
77
+ raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
78
+ except Exception:
79
+ return "model"
80
+ except Exception:
81
+ return "space"
82
+
83
+
84
+ # docstyle-ignore
85
+ APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
86
+ from {module_name} import {class_name}
87
+
88
+ launch_gradio_demo({class_name})
89
+ """
90
+
91
+
92
+ def validate_after_init(cls, do_validate_forward: bool = True):
93
+ original_init = cls.__init__
94
+
95
+ @wraps(original_init)
96
+ def new_init(self, *args, **kwargs):
97
+ original_init(self, *args, **kwargs)
98
+ if not isinstance(self, PipelineTool):
99
+ self.validate_arguments(do_validate_forward=do_validate_forward)
100
+
101
+ cls.__init__ = new_init
102
+ return cls
103
+
104
+
105
+ CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
106
+
107
+
108
+ class Tool:
109
+ """
110
+ A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
111
+ following class attributes:
112
+
113
+ - **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it
114
+ will return. For instance 'This is a tool that downloads a file from a `url`. It takes the `url` as input, and
115
+ returns the text contained in the file'.
116
+ - **name** (`str`) -- A performative name that will be used for your tool in the prompt to the agent. For instance
117
+ `"text-classifier"` or `"image_generator"`.
118
+ - **inputs** (`Dict[str, Dict[str, Union[str, type]]]`) -- The dict of modalities expected for the inputs.
119
+ It has one `type`key and a `description`key.
120
+ This is used by `launch_gradio_demo` or to make a nice space from your tool, and also can be used in the generated
121
+ description for your tool.
122
+ - **output_type** (`type`) -- The type of the tool output. This is used by `launch_gradio_demo`
123
+ or to make a nice space from your tool, and also can be used in the generated description for your tool.
124
+
125
+ You can also override the method [`~Tool.setup`] if your tool as an expensive operation to perform before being
126
+ usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at
127
+ instantiation.
128
+ """
129
+
130
+ name: str
131
+ description: str
132
+ inputs: Dict[str, Dict[str, Union[str, type]]]
133
+ output_type: type
134
+
135
+ def __init__(self, *args, **kwargs):
136
+ self.is_initialized = False
137
+
138
+ def __init_subclass__(cls, **kwargs):
139
+ super().__init_subclass__(**kwargs)
140
+ validate_after_init(cls, do_validate_forward=False)
141
+
142
+ def validate_arguments(self, do_validate_forward: bool = True):
143
+ required_attributes = {
144
+ "description": str,
145
+ "name": str,
146
+ "inputs": dict,
147
+ "output_type": str,
148
+ }
149
+ authorized_types = ["string", "integer", "number", "image", "audio", "any", "boolean"]
150
+
151
+ for attr, expected_type in required_attributes.items():
152
+ attr_value = getattr(self, attr, None)
153
+ if attr_value is None:
154
+ raise TypeError(f"You must set an attribute {attr}.")
155
+ if not isinstance(attr_value, expected_type):
156
+ raise TypeError(
157
+ f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
158
+ )
159
+ for input_name, input_content in self.inputs.items():
160
+ assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
161
+ assert (
162
+ "type" in input_content and "description" in input_content
163
+ ), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
164
+ if input_content["type"] not in authorized_types:
165
+ raise Exception(
166
+ f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}."
167
+ )
168
+
169
+ assert getattr(self, "output_type", None) in authorized_types
170
+ if do_validate_forward:
171
+ if not isinstance(self, PipelineTool):
172
+ signature = inspect.signature(self.forward)
173
+ if not set(signature.parameters.keys()) == set(self.inputs.keys()):
174
+ raise Exception(
175
+ "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
176
+ )
177
+
178
+ def forward(self, *args, **kwargs):
179
+ return NotImplemented("Write this method in your subclass of `Tool`.")
180
+
181
+ def __call__(self, *args, **kwargs):
182
+ args, kwargs = handle_agent_inputs(*args, **kwargs)
183
+ outputs = self.forward(*args, **kwargs)
184
+ return handle_agent_outputs(outputs, self.output_type)
185
+
186
+ def setup(self):
187
+ """
188
+ Overwrite this method here for any operation that is expensive and needs to be executed before you start using
189
+ your tool. Such as loading a big model.
190
+ """
191
+ self.is_initialized = True
192
+
193
+ def save(self, output_dir):
194
+ """
195
+ Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
196
+ tool in `output_dir` as well as autogenerate:
197
+
198
+ - a config file named `tool_config.json`
199
+ - an `app.py` file so that your tool can be converted to a space
200
+ - a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its
201
+ code)
202
+
203
+ You should only use this method to save tools that are defined in a separate module (not `__main__`).
204
+
205
+ Args:
206
+ output_dir (`str`): The folder in which you want to save your tool.
207
+ """
208
+ os.makedirs(output_dir, exist_ok=True)
209
+ # Save module file
210
+ if self.__module__ == "__main__":
211
+ raise ValueError(
212
+ f"We can't save the code defining {self} in {output_dir} as it's been defined in __main__. You "
213
+ "have to put this code in a separate module so we can include it in the saved folder."
214
+ )
215
+ module_files = custom_object_save(self, output_dir)
216
+
217
+ module_name = self.__class__.__module__
218
+ last_module = module_name.split(".")[-1]
219
+ full_name = f"{last_module}.{self.__class__.__name__}"
220
+
221
+ # Save config file
222
+ config_file = os.path.join(output_dir, "tool_config.json")
223
+ if os.path.isfile(config_file):
224
+ with open(config_file, "r", encoding="utf-8") as f:
225
+ tool_config = json.load(f)
226
+ else:
227
+ tool_config = {}
228
+
229
+ tool_config = {
230
+ "tool_class": full_name,
231
+ "description": self.description,
232
+ "name": self.name,
233
+ "inputs": self.inputs,
234
+ "output_type": str(self.output_type),
235
+ }
236
+ with open(config_file, "w", encoding="utf-8") as f:
237
+ f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")
238
+
239
+ # Save app file
240
+ app_file = os.path.join(output_dir, "app.py")
241
+ with open(app_file, "w", encoding="utf-8") as f:
242
+ f.write(APP_FILE_TEMPLATE.format(module_name=last_module, class_name=self.__class__.__name__))
243
+
244
+ # Save requirements file
245
+ requirements_file = os.path.join(output_dir, "requirements.txt")
246
+ imports = []
247
+ for module in module_files:
248
+ imports.extend(get_imports(module))
249
+ imports = list(set(imports))
250
+ with open(requirements_file, "w", encoding="utf-8") as f:
251
+ f.write("\n".join(imports) + "\n")
252
+
253
+ @classmethod
254
+ def from_hub(
255
+ cls,
256
+ repo_id: str,
257
+ token: Optional[str] = None,
258
+ **kwargs,
259
+ ):
260
+ """
261
+ Loads a tool defined on the Hub.
262
+
263
+ <Tip warning={true}>
264
+
265
+ Loading a tool from the Hub means that you'll download the tool and execute it locally.
266
+ ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
267
+ installing a package using pip/npm/apt.
268
+
269
+ </Tip>
270
+
271
+ Args:
272
+ repo_id (`str`):
273
+ The name of the repo on the Hub where your tool is defined.
274
+ token (`str`, *optional*):
275
+ The token to identify you on hf.co. If unset, will use the token generated when running
276
+ `huggingface-cli login` (stored in `~/.huggingface`).
277
+ kwargs (additional keyword arguments, *optional*):
278
+ Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
279
+ `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
280
+ others will be passed along to its init.
281
+ """
282
+ hub_kwargs_names = [
283
+ "cache_dir",
284
+ "force_download",
285
+ "resume_download",
286
+ "proxies",
287
+ "revision",
288
+ "repo_type",
289
+ "subfolder",
290
+ "local_files_only",
291
+ ]
292
+ hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
293
+
294
+ # Try to get the tool config first.
295
+ hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
296
+ resolved_config_file = cached_file(
297
+ repo_id,
298
+ TOOL_CONFIG_FILE,
299
+ token=token,
300
+ **hub_kwargs,
301
+ _raise_exceptions_for_gated_repo=False,
302
+ _raise_exceptions_for_missing_entries=False,
303
+ _raise_exceptions_for_connection_errors=False,
304
+ )
305
+ is_tool_config = resolved_config_file is not None
306
+ if resolved_config_file is None:
307
+ resolved_config_file = cached_file(
308
+ repo_id,
309
+ CONFIG_NAME,
310
+ token=token,
311
+ **hub_kwargs,
312
+ _raise_exceptions_for_gated_repo=False,
313
+ _raise_exceptions_for_missing_entries=False,
314
+ _raise_exceptions_for_connection_errors=False,
315
+ )
316
+ if resolved_config_file is None:
317
+ raise EnvironmentError(
318
+ f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
319
+ )
320
+
321
+ with open(resolved_config_file, encoding="utf-8") as reader:
322
+ config = json.load(reader)
323
+
324
+ if not is_tool_config:
325
+ if "custom_tool" not in config:
326
+ raise EnvironmentError(
327
+ f"{repo_id} does not provide a mapping to custom tools in its configuration `config.json`."
328
+ )
329
+ custom_tool = config["custom_tool"]
330
+ else:
331
+ custom_tool = config
332
+
333
+ tool_class = custom_tool["tool_class"]
334
+ tool_class = get_class_from_dynamic_module(tool_class, repo_id, token=token, **hub_kwargs)
335
+
336
+ if len(tool_class.name) == 0:
337
+ tool_class.name = custom_tool["name"]
338
+ if tool_class.name != custom_tool["name"]:
339
+ logger.warning(
340
+ f"{tool_class.__name__} implements a different name in its configuration and class. Using the tool "
341
+ "configuration name."
342
+ )
343
+ tool_class.name = custom_tool["name"]
344
+
345
+ if len(tool_class.description) == 0:
346
+ tool_class.description = custom_tool["description"]
347
+ if tool_class.description != custom_tool["description"]:
348
+ logger.warning(
349
+ f"{tool_class.__name__} implements a different description in its configuration and class. Using the "
350
+ "tool configuration description."
351
+ )
352
+ tool_class.description = custom_tool["description"]
353
+
354
+ if tool_class.inputs != custom_tool["inputs"]:
355
+ tool_class.inputs = custom_tool["inputs"]
356
+ if tool_class.output_type != custom_tool["output_type"]:
357
+ tool_class.output_type = custom_tool["output_type"]
358
+
359
+ if not isinstance(tool_class.inputs, dict):
360
+ tool_class.inputs = ast.literal_eval(tool_class.inputs)
361
+
362
+ return tool_class(**kwargs)
363
+
364
+ def push_to_hub(
365
+ self,
366
+ repo_id: str,
367
+ commit_message: str = "Upload tool",
368
+ private: Optional[bool] = None,
369
+ token: Optional[Union[bool, str]] = None,
370
+ create_pr: bool = False,
371
+ ) -> str:
372
+ """
373
+ Upload the tool to the Hub.
374
+
375
+ For this method to work properly, your tool must have been defined in a separate module (not `__main__`).
376
+ For instance:
377
+ ```
378
+ from my_tool_module import MyTool
379
+ my_tool = MyTool()
380
+ my_tool.push_to_hub("my-username/my-space")
381
+ ```
382
+
383
+ Parameters:
384
+ repo_id (`str`):
385
+ The name of the repository you want to push your tool to. It should contain your organization name when
386
+ pushing to a given organization.
387
+ commit_message (`str`, *optional*, defaults to `"Upload tool"`):
388
+ Message to commit while pushing.
389
+ private (`bool`, *optional*):
390
+ Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
391
+ token (`bool` or `str`, *optional*):
392
+ The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
393
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
394
+ create_pr (`bool`, *optional*, defaults to `False`):
395
+ Whether or not to create a PR with the uploaded files or directly commit.
396
+ """
397
+ repo_url = create_repo(
398
+ repo_id=repo_id,
399
+ token=token,
400
+ private=private,
401
+ exist_ok=True,
402
+ repo_type="space",
403
+ space_sdk="gradio",
404
+ )
405
+ repo_id = repo_url.repo_id
406
+ metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space")
407
+
408
+ with tempfile.TemporaryDirectory() as work_dir:
409
+ # Save all files.
410
+ self.save(work_dir)
411
+ logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
412
+ return upload_folder(
413
+ repo_id=repo_id,
414
+ commit_message=commit_message,
415
+ folder_path=work_dir,
416
+ token=token,
417
+ create_pr=create_pr,
418
+ repo_type="space",
419
+ )
420
+
421
+ @staticmethod
422
+ def from_space(
423
+ space_id: str, name: str, description: str, api_name: Optional[str] = None, token: Optional[str] = None
424
+ ):
425
+ """
426
+ Creates a [`Tool`] from a Space given its id on the Hub.
427
+
428
+ Args:
429
+ space_id (`str`):
430
+ The id of the Space on the Hub.
431
+ name (`str`):
432
+ The name of the tool.
433
+ description (`str`):
434
+ The description of the tool.
435
+ api_name (`str`, *optional*):
436
+ The specific api_name to use, if the space has several tabs. If not precised, will default to the first available api.
437
+ token (`str`, *optional*):
438
+ Add your token to access private spaces or increase your GPU quotas.
439
+ Returns:
440
+ [`Tool`]:
441
+ The Space, as a tool.
442
+
443
+ Examples:
444
+ ```
445
+ image_generator = Tool.from_space(
446
+ space_id="black-forest-labs/FLUX.1-schnell",
447
+ name="image-generator",
448
+ description="Generate an image from a prompt"
449
+ )
450
+ image = image_generator("Generate an image of a cool surfer in Tahiti")
451
+ ```
452
+ ```
453
+ face_swapper = Tool.from_space(
454
+ "tuan2308/face-swap",
455
+ "face_swapper",
456
+ "Tool that puts the face shown on the first image on the second image. You can give it paths to images.",
457
+ )
458
+ image = face_swapper('./aymeric.jpeg', './ruth.jpg')
459
+ ```
460
+ """
461
+ from gradio_client import Client, handle_file
462
+ from gradio_client.utils import is_http_url_like
463
+
464
+ class SpaceToolWrapper(Tool):
465
+ def __init__(
466
+ self,
467
+ space_id: str,
468
+ name: str,
469
+ description: str,
470
+ api_name: Optional[str] = None,
471
+ token: Optional[str] = None,
472
+ ):
473
+ self.client = Client(space_id, hf_token=token)
474
+ self.name = name
475
+ self.description = description
476
+ space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"]
477
+
478
+ # If api_name is not defined, take the first of the available APIs for this space
479
+ if api_name is None:
480
+ api_name = list(space_description.keys())[0]
481
+ logger.warning(
482
+ f"Since `api_name` was not defined, it was automatically set to the first avilable API: `{api_name}`."
483
+ )
484
+ self.api_name = api_name
485
+
486
+ try:
487
+ space_description_api = space_description[api_name]
488
+ except KeyError:
489
+ raise KeyError(f"Could not find specified {api_name=} among available api names.")
490
+
491
+ self.inputs = {}
492
+ for parameter in space_description_api["parameters"]:
493
+ if not parameter["parameter_has_default"]:
494
+ parameter_type = parameter["type"]["type"]
495
+ if parameter_type == "object":
496
+ parameter_type = "any"
497
+ self.inputs[parameter["parameter_name"]] = {
498
+ "type": parameter_type,
499
+ "description": parameter["python_type"]["description"],
500
+ }
501
+ output_component = space_description_api["returns"][0]["component"]
502
+ if output_component == "Image":
503
+ self.output_type = "image"
504
+ elif output_component == "Audio":
505
+ self.output_type = "audio"
506
+ else:
507
+ self.output_type = "any"
508
+
509
+ def sanitize_argument_for_prediction(self, arg):
510
+ if isinstance(arg, ImageType):
511
+ temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
512
+ arg.save(temp_file.name)
513
+ arg = temp_file.name
514
+ if (isinstance(arg, (str, Path)) and Path(arg).exists() and Path(arg).is_file()) or is_http_url_like(
515
+ arg
516
+ ):
517
+ arg = handle_file(arg)
518
+ return arg
519
+
520
+ def forward(self, *args, **kwargs):
521
+ # Preprocess args and kwargs:
522
+ args = list(args)
523
+ for i, arg in enumerate(args):
524
+ args[i] = self.sanitize_argument_for_prediction(arg)
525
+ for arg_name, arg in kwargs.items():
526
+ kwargs[arg_name] = self.sanitize_argument_for_prediction(arg)
527
+
528
+ output = self.client.predict(*args, api_name=self.api_name, **kwargs)
529
+ if isinstance(output, tuple) or isinstance(output, list):
530
+ return output[
531
+ 0
532
+ ] # Sometime the space also returns the generation seed, in which case the result is at index 0
533
+ return output
534
+
535
+ return SpaceToolWrapper(space_id, name, description, api_name=api_name, token=token)
536
+
537
+ @staticmethod
538
+ def from_gradio(gradio_tool):
539
+ """
540
+ Creates a [`Tool`] from a gradio tool.
541
+ """
542
+ import inspect
543
+
544
+ class GradioToolWrapper(Tool):
545
+ def __init__(self, _gradio_tool):
546
+ self.name = _gradio_tool.name
547
+ self.description = _gradio_tool.description
548
+ self.output_type = "string"
549
+ self._gradio_tool = _gradio_tool
550
+ func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
551
+ self.inputs = {
552
+ key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args
553
+ }
554
+ self.forward = self._gradio_tool.run
555
+
556
+ return GradioToolWrapper(gradio_tool)
557
+
558
+ @staticmethod
559
+ def from_langchain(langchain_tool):
560
+ """
561
+ Creates a [`Tool`] from a langchain tool.
562
+ """
563
+
564
+ class LangChainToolWrapper(Tool):
565
+ def __init__(self, _langchain_tool):
566
+ self.name = _langchain_tool.name.lower()
567
+ self.description = _langchain_tool.description
568
+ self.inputs = _langchain_tool.args.copy()
569
+ for input_content in self.inputs.values():
570
+ if "title" in input_content:
571
+ input_content.pop("title")
572
+ input_content["description"] = ""
573
+ self.output_type = "string"
574
+ self.langchain_tool = _langchain_tool
575
+
576
+ def forward(self, *args, **kwargs):
577
+ tool_input = kwargs.copy()
578
+ for index, argument in enumerate(args):
579
+ if index < len(self.inputs):
580
+ input_key = next(iter(self.inputs))
581
+ tool_input[input_key] = argument
582
+ return self.langchain_tool.run(tool_input)
583
+
584
+ return LangChainToolWrapper(langchain_tool)
585
+
586
+
587
+ DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
588
+ - {{ tool.name }}: {{ tool.description }}
589
+ Takes inputs: {{tool.inputs}}
590
+ Returns an output of type: {{tool.output_type}}
591
+ """
592
+
593
+
594
+ def get_tool_description_with_args(tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE) -> str:
595
+ compiled_template = compile_jinja_template(description_template)
596
+ rendered = compiled_template.render(
597
+ tool=tool,
598
+ )
599
+ return rendered
600
+
601
+
602
+ @lru_cache
603
+ def compile_jinja_template(template):
604
+ try:
605
+ import jinja2
606
+ from jinja2.exceptions import TemplateError
607
+ from jinja2.sandbox import ImmutableSandboxedEnvironment
608
+ except ImportError:
609
+ raise ImportError("template requires jinja2 to be installed.")
610
+
611
+ if version.parse(jinja2.__version__) < version.parse("3.1.0"):
612
+ raise ImportError("template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}.")
613
+
614
+ def raise_exception(message):
615
+ raise TemplateError(message)
616
+
617
+ jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
618
+ jinja_env.globals["raise_exception"] = raise_exception
619
+ return jinja_env.from_string(template)
620
+
621
+
622
+ class PipelineTool(Tool):
623
+ """
624
+ A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
625
+ need to specify:
626
+
627
+ - **model_class** (`type`) -- The class to use to load the model in this tool.
628
+ - **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
629
+ - **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
630
+ pre-processor
631
+ - **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
632
+ post-processor (when different from the pre-processor).
633
+
634
+ Args:
635
+ model (`str` or [`PreTrainedModel`], *optional*):
636
+ The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
637
+ value of the class attribute `default_checkpoint`.
638
+ pre_processor (`str` or `Any`, *optional*):
639
+ The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a
640
+ tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if
641
+ unset.
642
+ post_processor (`str` or `Any`, *optional*):
643
+ The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a
644
+ tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if
645
+ unset.
646
+ device (`int`, `str` or `torch.device`, *optional*):
647
+ The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the
648
+ CPU otherwise.
649
+ device_map (`str` or `dict`, *optional*):
650
+ If passed along, will be used to instantiate the model.
651
+ model_kwargs (`dict`, *optional*):
652
+ Any keyword argument to send to the model instantiation.
653
+ token (`str`, *optional*):
654
+ The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
655
+ running `huggingface-cli login` (stored in `~/.huggingface`).
656
+ hub_kwargs (additional keyword arguments, *optional*):
657
+ Any additional keyword argument to send to the methods that will load the data from the Hub.
658
+ """
659
+
660
+ pre_processor_class = AutoProcessor
661
+ model_class = None
662
+ post_processor_class = AutoProcessor
663
+ default_checkpoint = None
664
+ description = "This is a pipeline tool"
665
+ name = "pipeline"
666
+ inputs = {"prompt": str}
667
+ output_type = str
668
+
669
+ def __init__(
670
+ self,
671
+ model=None,
672
+ pre_processor=None,
673
+ post_processor=None,
674
+ device=None,
675
+ device_map=None,
676
+ model_kwargs=None,
677
+ token=None,
678
+ **hub_kwargs,
679
+ ):
680
+ if not is_torch_available():
681
+ raise ImportError("Please install torch in order to use this tool.")
682
+
683
+ if not is_accelerate_available():
684
+ raise ImportError("Please install accelerate in order to use this tool.")
685
+
686
+ if model is None:
687
+ if self.default_checkpoint is None:
688
+ raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
689
+ model = self.default_checkpoint
690
+ if pre_processor is None:
691
+ pre_processor = model
692
+
693
+ self.model = model
694
+ self.pre_processor = pre_processor
695
+ self.post_processor = post_processor
696
+ self.device = device
697
+ self.device_map = device_map
698
+ self.model_kwargs = {} if model_kwargs is None else model_kwargs
699
+ if device_map is not None:
700
+ self.model_kwargs["device_map"] = device_map
701
+ self.hub_kwargs = hub_kwargs
702
+ self.hub_kwargs["token"] = token
703
+
704
+ super().__init__()
705
+
706
+ def setup(self):
707
+ """
708
+ Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
709
+ """
710
+ if isinstance(self.pre_processor, str):
711
+ self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
712
+
713
+ if isinstance(self.model, str):
714
+ self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
715
+
716
+ if self.post_processor is None:
717
+ self.post_processor = self.pre_processor
718
+ elif isinstance(self.post_processor, str):
719
+ self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
720
+
721
+ if self.device is None:
722
+ if self.device_map is not None:
723
+ self.device = list(self.model.hf_device_map.values())[0]
724
+ else:
725
+ self.device = PartialState().default_device
726
+
727
+ if self.device_map is None:
728
+ self.model.to(self.device)
729
+
730
+ super().setup()
731
+
732
+ def encode(self, raw_inputs):
733
+ """
734
+ Uses the `pre_processor` to prepare the inputs for the `model`.
735
+ """
736
+ return self.pre_processor(raw_inputs)
737
+
738
+ def forward(self, inputs):
739
+ """
740
+ Sends the inputs through the `model`.
741
+ """
742
+ with torch.no_grad():
743
+ return self.model(**inputs)
744
+
745
+ def decode(self, outputs):
746
+ """
747
+ Uses the `post_processor` to decode the model output.
748
+ """
749
+ return self.post_processor(outputs)
750
+
751
+ def __call__(self, *args, **kwargs):
752
+ args, kwargs = handle_agent_inputs(*args, **kwargs)
753
+
754
+ if not self.is_initialized:
755
+ self.setup()
756
+
757
+ encoded_inputs = self.encode(*args, **kwargs)
758
+
759
+ tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)}
760
+ non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)}
761
+
762
+ encoded_inputs = send_to_device(tensor_inputs, self.device)
763
+ outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
764
+ outputs = send_to_device(outputs, "cpu")
765
+ decoded_outputs = self.decode(outputs)
766
+
767
+ return handle_agent_outputs(decoded_outputs, self.output_type)
768
+
769
+
770
+ def launch_gradio_demo(tool_class: Tool):
771
+ """
772
+ Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
773
+ `inputs` and `output_type`.
774
+
775
+ Args:
776
+ tool_class (`type`): The class of the tool for which to launch the demo.
777
+ """
778
+ try:
779
+ import gradio as gr
780
+ except ImportError:
781
+ raise ImportError("Gradio should be installed in order to launch a gradio demo.")
782
+
783
+ tool = tool_class()
784
+
785
+ def fn(*args, **kwargs):
786
+ return tool(*args, **kwargs)
787
+
788
+ TYPE_TO_COMPONENT_CLASS_MAPPING = {
789
+ "image": gr.Image,
790
+ "audio": gr.Audio,
791
+ "string": gr.Textbox,
792
+ "integer": gr.Textbox,
793
+ "number": gr.Textbox,
794
+ }
795
+
796
+ gradio_inputs = []
797
+ for input_name, input_details in tool_class.inputs.items():
798
+ input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]]
799
+ new_component = input_gradio_component_class(label=input_name)
800
+ gradio_inputs.append(new_component)
801
+
802
+ output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[tool_class.output_type]
803
+ gradio_output = output_gradio_componentclass(label=input_name)
804
+
805
+ gr.Interface(
806
+ fn=fn,
807
+ inputs=gradio_inputs,
808
+ outputs=gradio_output,
809
+ title=tool_class.__name__,
810
+ article=tool.description,
811
+ ).launch()
812
+
813
+
814
+ TOOL_MAPPING = {
815
+ "document_question_answering": "DocumentQuestionAnsweringTool",
816
+ "image_question_answering": "ImageQuestionAnsweringTool",
817
+ "speech_to_text": "SpeechToTextTool",
818
+ "text_to_speech": "TextToSpeechTool",
819
+ "translation": "TranslationTool",
820
+ "python_interpreter": "PythonInterpreterTool",
821
+ "web_search": "DuckDuckGoSearchTool",
822
+ }
823
+
824
+
825
+ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
826
+ """
827
+ Main function to quickly load a tool, be it on the Hub or in the Transformers library.
828
+
829
+ <Tip warning={true}>
830
+
831
+ Loading a tool means that you'll download the tool and execute it locally.
832
+ ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
833
+ installing a package using pip/npm/apt.
834
+
835
+ </Tip>
836
+
837
+ Args:
838
+ task_or_repo_id (`str`):
839
+ The task for which to load the tool or a repo ID of a tool on the Hub. Tasks implemented in Transformers
840
+ are:
841
+
842
+ - `"document_question_answering"`
843
+ - `"image_question_answering"`
844
+ - `"speech_to_text"`
845
+ - `"text_to_speech"`
846
+ - `"translation"`
847
+
848
+ model_repo_id (`str`, *optional*):
849
+ Use this argument to use a different model than the default one for the tool you selected.
850
+ token (`str`, *optional*):
851
+ The token to identify you on hf.co. If unset, will use the token generated when running `huggingface-cli
852
+ login` (stored in `~/.huggingface`).
853
+ kwargs (additional keyword arguments, *optional*):
854
+ Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
855
+ `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
856
+ will be passed along to its init.
857
+ """
858
+ if task_or_repo_id in TOOL_MAPPING:
859
+ tool_class_name = TOOL_MAPPING[task_or_repo_id]
860
+ main_module = importlib.import_module("transformers")
861
+ tools_module = main_module.agents
862
+ tool_class = getattr(tools_module, tool_class_name)
863
+ return tool_class(model_repo_id, token=token, **kwargs)
864
+ else:
865
+ logger.warning_once(
866
+ f"You're loading a tool from the Hub from {model_repo_id}. Please make sure this is a source that you "
867
+ f"trust as the code within that tool will be executed on your machine. Always verify the code of "
868
+ f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the "
869
+ f"code that you have checked."
870
+ )
871
+ return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs)
872
+
873
+
874
+ def add_description(description):
875
+ """
876
+ A decorator that adds a description to a function.
877
+ """
878
+
879
+ def inner(func):
880
+ func.description = description
881
+ func.name = func.__name__
882
+ return func
883
+
884
+ return inner
885
+
886
+
887
+ ## Will move to the Hub
888
+ class EndpointClient:
889
+ def __init__(self, endpoint_url: str, token: Optional[str] = None):
890
+ self.headers = {
891
+ **build_hf_headers(token=token),
892
+ "Content-Type": "application/json",
893
+ }
894
+ self.endpoint_url = endpoint_url
895
+
896
+ @staticmethod
897
+ def encode_image(image):
898
+ _bytes = io.BytesIO()
899
+ image.save(_bytes, format="PNG")
900
+ b64 = base64.b64encode(_bytes.getvalue())
901
+ return b64.decode("utf-8")
902
+
903
+ @staticmethod
904
+ def decode_image(raw_image):
905
+ if not is_vision_available():
906
+ raise ImportError(
907
+ "This tool returned an image but Pillow is not installed. Please install it (`pip install Pillow`)."
908
+ )
909
+
910
+ from PIL import Image
911
+
912
+ b64 = base64.b64decode(raw_image)
913
+ _bytes = io.BytesIO(b64)
914
+ return Image.open(_bytes)
915
+
916
+ def __call__(
917
+ self,
918
+ inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
919
+ params: Optional[Dict] = None,
920
+ data: Optional[bytes] = None,
921
+ output_image: bool = False,
922
+ ) -> Any:
923
+ # Build payload
924
+ payload = {}
925
+ if inputs:
926
+ payload["inputs"] = inputs
927
+ if params:
928
+ payload["parameters"] = params
929
+
930
+ # Make API call
931
+ response = get_session().post(self.endpoint_url, headers=self.headers, json=payload, data=data)
932
+
933
+ # By default, parse the response for the user.
934
+ if output_image:
935
+ return self.decode_image(response.content)
936
+ else:
937
+ return response.json()
938
+
939
+
940
+ class ToolCollection:
941
+ """
942
+ Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
943
+
944
+ > [!NOTE]
945
+ > Only Spaces will be fetched, so you can feel free to add models and datasets to your collection if you'd
946
+ > like for this collection to showcase them.
947
+
948
+ Args:
949
+ collection_slug (str):
950
+ The collection slug referencing the collection.
951
+ token (str, *optional*):
952
+ The authentication token if the collection is private.
953
+
954
+ Example:
955
+
956
+ ```py
957
+ >>> from transformers import ToolCollection, ReactCodeAgent
958
+
959
+ >>> image_tool_collection = ToolCollection(collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f")
960
+ >>> agent = ReactCodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True)
961
+
962
+ >>> agent.run("Please draw me a picture of rivers and lakes.")
963
+ ```
964
+ """
965
+
966
+ def __init__(self, collection_slug: str, token: Optional[str] = None):
967
+ self._collection = get_collection(collection_slug, token=token)
968
+ self._hub_repo_ids = {item.item_id for item in self._collection.items if item.item_type == "space"}
969
+ self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids}
970
+
971
+
972
+ def tool(tool_function: Callable) -> Tool:
973
+ """
974
+ Converts a function into an instance of a Tool subclass.
975
+
976
+ Args:
977
+ tool_function: Your function. Should have type hints for each input and a type hint for the output.
978
+ Should also have a docstring description including an 'Args:' part where each argument is described.
979
+ """
980
+ parameters = get_json_schema(tool_function)["function"]
981
+ if "return" not in parameters:
982
+ raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
983
+ class_name = f"{parameters['name'].capitalize()}Tool"
984
+
985
+ class SpecificTool(Tool):
986
+ name = parameters["name"]
987
+ description = parameters["description"]
988
+ inputs = parameters["parameters"]["properties"]
989
+ output_type = parameters["return"]["type"]
990
+
991
+ @wraps(tool_function)
992
+ def forward(self, *args, **kwargs):
993
+ return tool_function(*args, **kwargs)
994
+
995
+ original_signature = inspect.signature(tool_function)
996
+ new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + list(
997
+ original_signature.parameters.values()
998
+ )
999
+ new_signature = original_signature.replace(parameters=new_parameters)
1000
+ SpecificTool.forward.__signature__ = new_signature
1001
+
1002
+ SpecificTool.__name__ = class_name
1003
+ return SpecificTool()
.venv/Lib/site-packages/transformers/agents/translation.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
18
+ from .tools import PipelineTool
19
+
20
+
21
+ LANGUAGE_CODES = {
22
+ "Acehnese Arabic": "ace_Arab",
23
+ "Acehnese Latin": "ace_Latn",
24
+ "Mesopotamian Arabic": "acm_Arab",
25
+ "Ta'izzi-Adeni Arabic": "acq_Arab",
26
+ "Tunisian Arabic": "aeb_Arab",
27
+ "Afrikaans": "afr_Latn",
28
+ "South Levantine Arabic": "ajp_Arab",
29
+ "Akan": "aka_Latn",
30
+ "Amharic": "amh_Ethi",
31
+ "North Levantine Arabic": "apc_Arab",
32
+ "Modern Standard Arabic": "arb_Arab",
33
+ "Modern Standard Arabic Romanized": "arb_Latn",
34
+ "Najdi Arabic": "ars_Arab",
35
+ "Moroccan Arabic": "ary_Arab",
36
+ "Egyptian Arabic": "arz_Arab",
37
+ "Assamese": "asm_Beng",
38
+ "Asturian": "ast_Latn",
39
+ "Awadhi": "awa_Deva",
40
+ "Central Aymara": "ayr_Latn",
41
+ "South Azerbaijani": "azb_Arab",
42
+ "North Azerbaijani": "azj_Latn",
43
+ "Bashkir": "bak_Cyrl",
44
+ "Bambara": "bam_Latn",
45
+ "Balinese": "ban_Latn",
46
+ "Belarusian": "bel_Cyrl",
47
+ "Bemba": "bem_Latn",
48
+ "Bengali": "ben_Beng",
49
+ "Bhojpuri": "bho_Deva",
50
+ "Banjar Arabic": "bjn_Arab",
51
+ "Banjar Latin": "bjn_Latn",
52
+ "Standard Tibetan": "bod_Tibt",
53
+ "Bosnian": "bos_Latn",
54
+ "Buginese": "bug_Latn",
55
+ "Bulgarian": "bul_Cyrl",
56
+ "Catalan": "cat_Latn",
57
+ "Cebuano": "ceb_Latn",
58
+ "Czech": "ces_Latn",
59
+ "Chokwe": "cjk_Latn",
60
+ "Central Kurdish": "ckb_Arab",
61
+ "Crimean Tatar": "crh_Latn",
62
+ "Welsh": "cym_Latn",
63
+ "Danish": "dan_Latn",
64
+ "German": "deu_Latn",
65
+ "Southwestern Dinka": "dik_Latn",
66
+ "Dyula": "dyu_Latn",
67
+ "Dzongkha": "dzo_Tibt",
68
+ "Greek": "ell_Grek",
69
+ "English": "eng_Latn",
70
+ "Esperanto": "epo_Latn",
71
+ "Estonian": "est_Latn",
72
+ "Basque": "eus_Latn",
73
+ "Ewe": "ewe_Latn",
74
+ "Faroese": "fao_Latn",
75
+ "Fijian": "fij_Latn",
76
+ "Finnish": "fin_Latn",
77
+ "Fon": "fon_Latn",
78
+ "French": "fra_Latn",
79
+ "Friulian": "fur_Latn",
80
+ "Nigerian Fulfulde": "fuv_Latn",
81
+ "Scottish Gaelic": "gla_Latn",
82
+ "Irish": "gle_Latn",
83
+ "Galician": "glg_Latn",
84
+ "Guarani": "grn_Latn",
85
+ "Gujarati": "guj_Gujr",
86
+ "Haitian Creole": "hat_Latn",
87
+ "Hausa": "hau_Latn",
88
+ "Hebrew": "heb_Hebr",
89
+ "Hindi": "hin_Deva",
90
+ "Chhattisgarhi": "hne_Deva",
91
+ "Croatian": "hrv_Latn",
92
+ "Hungarian": "hun_Latn",
93
+ "Armenian": "hye_Armn",
94
+ "Igbo": "ibo_Latn",
95
+ "Ilocano": "ilo_Latn",
96
+ "Indonesian": "ind_Latn",
97
+ "Icelandic": "isl_Latn",
98
+ "Italian": "ita_Latn",
99
+ "Javanese": "jav_Latn",
100
+ "Japanese": "jpn_Jpan",
101
+ "Kabyle": "kab_Latn",
102
+ "Jingpho": "kac_Latn",
103
+ "Kamba": "kam_Latn",
104
+ "Kannada": "kan_Knda",
105
+ "Kashmiri Arabic": "kas_Arab",
106
+ "Kashmiri Devanagari": "kas_Deva",
107
+ "Georgian": "kat_Geor",
108
+ "Central Kanuri Arabic": "knc_Arab",
109
+ "Central Kanuri Latin": "knc_Latn",
110
+ "Kazakh": "kaz_Cyrl",
111
+ "Kabiyè": "kbp_Latn",
112
+ "Kabuverdianu": "kea_Latn",
113
+ "Khmer": "khm_Khmr",
114
+ "Kikuyu": "kik_Latn",
115
+ "Kinyarwanda": "kin_Latn",
116
+ "Kyrgyz": "kir_Cyrl",
117
+ "Kimbundu": "kmb_Latn",
118
+ "Northern Kurdish": "kmr_Latn",
119
+ "Kikongo": "kon_Latn",
120
+ "Korean": "kor_Hang",
121
+ "Lao": "lao_Laoo",
122
+ "Ligurian": "lij_Latn",
123
+ "Limburgish": "lim_Latn",
124
+ "Lingala": "lin_Latn",
125
+ "Lithuanian": "lit_Latn",
126
+ "Lombard": "lmo_Latn",
127
+ "Latgalian": "ltg_Latn",
128
+ "Luxembourgish": "ltz_Latn",
129
+ "Luba-Kasai": "lua_Latn",
130
+ "Ganda": "lug_Latn",
131
+ "Luo": "luo_Latn",
132
+ "Mizo": "lus_Latn",
133
+ "Standard Latvian": "lvs_Latn",
134
+ "Magahi": "mag_Deva",
135
+ "Maithili": "mai_Deva",
136
+ "Malayalam": "mal_Mlym",
137
+ "Marathi": "mar_Deva",
138
+ "Minangkabau Arabic ": "min_Arab",
139
+ "Minangkabau Latin": "min_Latn",
140
+ "Macedonian": "mkd_Cyrl",
141
+ "Plateau Malagasy": "plt_Latn",
142
+ "Maltese": "mlt_Latn",
143
+ "Meitei Bengali": "mni_Beng",
144
+ "Halh Mongolian": "khk_Cyrl",
145
+ "Mossi": "mos_Latn",
146
+ "Maori": "mri_Latn",
147
+ "Burmese": "mya_Mymr",
148
+ "Dutch": "nld_Latn",
149
+ "Norwegian Nynorsk": "nno_Latn",
150
+ "Norwegian Bokmål": "nob_Latn",
151
+ "Nepali": "npi_Deva",
152
+ "Northern Sotho": "nso_Latn",
153
+ "Nuer": "nus_Latn",
154
+ "Nyanja": "nya_Latn",
155
+ "Occitan": "oci_Latn",
156
+ "West Central Oromo": "gaz_Latn",
157
+ "Odia": "ory_Orya",
158
+ "Pangasinan": "pag_Latn",
159
+ "Eastern Panjabi": "pan_Guru",
160
+ "Papiamento": "pap_Latn",
161
+ "Western Persian": "pes_Arab",
162
+ "Polish": "pol_Latn",
163
+ "Portuguese": "por_Latn",
164
+ "Dari": "prs_Arab",
165
+ "Southern Pashto": "pbt_Arab",
166
+ "Ayacucho Quechua": "quy_Latn",
167
+ "Romanian": "ron_Latn",
168
+ "Rundi": "run_Latn",
169
+ "Russian": "rus_Cyrl",
170
+ "Sango": "sag_Latn",
171
+ "Sanskrit": "san_Deva",
172
+ "Santali": "sat_Olck",
173
+ "Sicilian": "scn_Latn",
174
+ "Shan": "shn_Mymr",
175
+ "Sinhala": "sin_Sinh",
176
+ "Slovak": "slk_Latn",
177
+ "Slovenian": "slv_Latn",
178
+ "Samoan": "smo_Latn",
179
+ "Shona": "sna_Latn",
180
+ "Sindhi": "snd_Arab",
181
+ "Somali": "som_Latn",
182
+ "Southern Sotho": "sot_Latn",
183
+ "Spanish": "spa_Latn",
184
+ "Tosk Albanian": "als_Latn",
185
+ "Sardinian": "srd_Latn",
186
+ "Serbian": "srp_Cyrl",
187
+ "Swati": "ssw_Latn",
188
+ "Sundanese": "sun_Latn",
189
+ "Swedish": "swe_Latn",
190
+ "Swahili": "swh_Latn",
191
+ "Silesian": "szl_Latn",
192
+ "Tamil": "tam_Taml",
193
+ "Tatar": "tat_Cyrl",
194
+ "Telugu": "tel_Telu",
195
+ "Tajik": "tgk_Cyrl",
196
+ "Tagalog": "tgl_Latn",
197
+ "Thai": "tha_Thai",
198
+ "Tigrinya": "tir_Ethi",
199
+ "Tamasheq Latin": "taq_Latn",
200
+ "Tamasheq Tifinagh": "taq_Tfng",
201
+ "Tok Pisin": "tpi_Latn",
202
+ "Tswana": "tsn_Latn",
203
+ "Tsonga": "tso_Latn",
204
+ "Turkmen": "tuk_Latn",
205
+ "Tumbuka": "tum_Latn",
206
+ "Turkish": "tur_Latn",
207
+ "Twi": "twi_Latn",
208
+ "Central Atlas Tamazight": "tzm_Tfng",
209
+ "Uyghur": "uig_Arab",
210
+ "Ukrainian": "ukr_Cyrl",
211
+ "Umbundu": "umb_Latn",
212
+ "Urdu": "urd_Arab",
213
+ "Northern Uzbek": "uzn_Latn",
214
+ "Venetian": "vec_Latn",
215
+ "Vietnamese": "vie_Latn",
216
+ "Waray": "war_Latn",
217
+ "Wolof": "wol_Latn",
218
+ "Xhosa": "xho_Latn",
219
+ "Eastern Yiddish": "ydd_Hebr",
220
+ "Yoruba": "yor_Latn",
221
+ "Yue Chinese": "yue_Hant",
222
+ "Chinese Simplified": "zho_Hans",
223
+ "Chinese Traditional": "zho_Hant",
224
+ "Standard Malay": "zsm_Latn",
225
+ "Zulu": "zul_Latn",
226
+ }
227
+
228
+
229
+ class TranslationTool(PipelineTool):
230
+ """
231
+ Example:
232
+
233
+ ```py
234
+ from transformers.agents import TranslationTool
235
+
236
+ translator = TranslationTool()
237
+ translator("This is a super nice API!", src_lang="English", tgt_lang="French")
238
+ ```
239
+ """
240
+
241
+ lang_to_code = LANGUAGE_CODES
242
+ default_checkpoint = "facebook/nllb-200-distilled-600M"
243
+ description = (
244
+ "This is a tool that translates text from a language to another."
245
+ f"Both `src_lang`and `tgt_lang` should belong to this list of languages: {list(lang_to_code.keys())}."
246
+ )
247
+ name = "translator"
248
+ pre_processor_class = AutoTokenizer
249
+ model_class = AutoModelForSeq2SeqLM
250
+
251
+ inputs = {
252
+ "text": {"type": "string", "description": "The text to translate"},
253
+ "src_lang": {
254
+ "type": "string",
255
+ "description": "The language of the text to translate. Written in plain English, such as 'Romanian', or 'Albanian'",
256
+ },
257
+ "tgt_lang": {
258
+ "type": "string",
259
+ "description": "The language for the desired ouput language. Written in plain English, such as 'Romanian', or 'Albanian'",
260
+ },
261
+ }
262
+ output_type = "string"
263
+
264
+ def encode(self, text, src_lang, tgt_lang):
265
+ if src_lang not in self.lang_to_code:
266
+ raise ValueError(f"{src_lang} is not a supported language.")
267
+ if tgt_lang not in self.lang_to_code:
268
+ raise ValueError(f"{tgt_lang} is not a supported language.")
269
+ src_lang = self.lang_to_code[src_lang]
270
+ tgt_lang = self.lang_to_code[tgt_lang]
271
+ return self.pre_processor._build_translation_inputs(
272
+ text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
273
+ )
274
+
275
+ def forward(self, inputs):
276
+ return self.model.generate(**inputs)
277
+
278
+ def decode(self, outputs):
279
+ return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True)
.venv/Lib/site-packages/transformers/benchmark/benchmark.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Benchmarking the library on inference and training in PyTorch.
18
+ """
19
+
20
+ import timeit
21
+ from typing import Callable, Optional
22
+
23
+ from ..configuration_utils import PretrainedConfig
24
+ from ..models.auto.modeling_auto import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING
25
+ from ..utils import is_py3nvml_available, is_torch_available, logging
26
+ from .benchmark_utils import (
27
+ Benchmark,
28
+ Memory,
29
+ MemorySummary,
30
+ measure_peak_memory_cpu,
31
+ start_memory_tracing,
32
+ stop_memory_tracing,
33
+ )
34
+
35
+
36
+ if is_torch_available():
37
+ import torch
38
+
39
+ from .benchmark_args import PyTorchBenchmarkArguments
40
+
41
+
42
+ if is_py3nvml_available():
43
+ import py3nvml.py3nvml as nvml
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+
49
+ class PyTorchBenchmark(Benchmark):
50
+ args: PyTorchBenchmarkArguments
51
+ configs: PretrainedConfig
52
+ framework: str = "PyTorch"
53
+
54
+ @property
55
+ def framework_version(self):
56
+ return torch.__version__
57
+
58
+ def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
59
+ _inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
60
+ return self._measure_speed(_inference)
61
+
62
+ def _inference_memory(
63
+ self, model_name: str, batch_size: int, sequence_length: int
64
+ ) -> [Memory, Optional[MemorySummary]]:
65
+ _inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
66
+ return self._measure_memory(_inference)
67
+
68
+ def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
69
+ _train = self._prepare_train_func(model_name, batch_size, sequence_length)
70
+ return self._measure_speed(_train)
71
+
72
+ def _train_memory(
73
+ self, model_name: str, batch_size: int, sequence_length: int
74
+ ) -> [Memory, Optional[MemorySummary]]:
75
+ _train = self._prepare_train_func(model_name, batch_size, sequence_length)
76
+ return self._measure_memory(_train)
77
+
78
+ def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
79
+ config = self.config_dict[model_name]
80
+
81
+ if self.args.torchscript:
82
+ config.torchscript = True
83
+
84
+ has_model_class_in_config = (
85
+ hasattr(config, "architectures")
86
+ and isinstance(config.architectures, list)
87
+ and len(config.architectures) > 0
88
+ )
89
+ if not self.args.only_pretrain_model and has_model_class_in_config:
90
+ try:
91
+ model_class = config.architectures[0]
92
+ transformers_module = __import__("transformers", fromlist=[model_class])
93
+ model_cls = getattr(transformers_module, model_class)
94
+ model = model_cls(config)
95
+ except ImportError:
96
+ raise ImportError(
97
+ f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
98
+ " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
99
+ )
100
+ else:
101
+ model = MODEL_MAPPING[config.__class__](config)
102
+
103
+ model.eval()
104
+ model.to(self.args.device)
105
+
106
+ # encoder-decoder has vocab size saved differently
107
+ vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
108
+ input_ids = torch.randint(vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device)
109
+
110
+ if self.args.fp16:
111
+ logger.info("Running training in Mixed Precision...")
112
+ if not self.args.is_gpu:
113
+ raise ValueError("Mixed precision is possible only for GPU.")
114
+ # amp seems to have memory leaks so that memory usage
115
+ # is measured using .half() for now https://github.com/NVIDIA/apex/issues/439
116
+ model.half()
117
+
118
+ if self.args.torchscript:
119
+ with torch.no_grad():
120
+ inference_model = torch.jit.trace(model, input_ids)
121
+ else:
122
+ inference_model = model
123
+
124
+ def encoder_decoder_forward():
125
+ with torch.no_grad():
126
+ outputs = inference_model(input_ids, decoder_input_ids=input_ids)
127
+ return outputs
128
+
129
+ def encoder_forward():
130
+ with torch.no_grad():
131
+ outputs = inference_model(input_ids)
132
+ return outputs
133
+
134
+ _forward = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
135
+ return _forward
136
+
137
+ def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
138
+ config = self.config_dict[model_name]
139
+
140
+ has_model_class_in_config = (
141
+ hasattr(config, "architectures")
142
+ and isinstance(config.architectures, list)
143
+ and len(config.architectures) > 0
144
+ )
145
+ if not self.args.only_pretrain_model and has_model_class_in_config:
146
+ try:
147
+ model_class = config.architectures[0]
148
+ transformers_module = __import__("transformers", fromlist=[model_class])
149
+ model_cls = getattr(transformers_module, model_class)
150
+ model = model_cls(config)
151
+ except ImportError:
152
+ raise ImportError(
153
+ f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
154
+ " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
155
+ )
156
+ else:
157
+ model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
158
+
159
+ if self.args.torchscript:
160
+ raise NotImplementedError("Training for torchscript is currently not implemented")
161
+ else:
162
+ train_model = model
163
+
164
+ model.train()
165
+ model.to(self.args.device)
166
+
167
+ # encoder-decoder has vocab size saved differently
168
+ vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
169
+ input_ids = torch.randint(vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device)
170
+
171
+ if self.args.fp16:
172
+ logger.info("Running training in Mixed Precision...")
173
+ if not self.args.is_gpu:
174
+ raise ValueError("Mixed precision is possible only for GPU.")
175
+
176
+ # amp seems to have memory leaks so that memory usage
177
+ # is measured using .half() for now https://github.com/NVIDIA/apex/issues/439
178
+ model.half()
179
+
180
+ def compute_loss_and_backprob_encoder():
181
+ loss = train_model(input_ids, labels=input_ids)[0]
182
+ loss.backward()
183
+ return loss
184
+
185
+ def compute_loss_and_backprob_encoder_decoder():
186
+ loss = train_model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0]
187
+ loss.backward()
188
+ return loss
189
+
190
+ _train = (
191
+ compute_loss_and_backprob_encoder_decoder
192
+ if config.is_encoder_decoder
193
+ else compute_loss_and_backprob_encoder
194
+ )
195
+ return _train
196
+
197
+ def _measure_speed(self, func) -> float:
198
+ try:
199
+ if self.args.is_tpu or self.args.torchscript:
200
+ # run additional 10 times to stabilize compilation for tpu and torchscript
201
+ logger.info("Do inference on TPU or torchscript. Running model 5 times to stabilize compilation")
202
+ timeit.repeat(
203
+ func,
204
+ repeat=1,
205
+ number=5,
206
+ )
207
+
208
+ # as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
209
+ runtimes = timeit.repeat(
210
+ func,
211
+ repeat=self.args.repeat,
212
+ number=10,
213
+ )
214
+
215
+ if self.args.is_tpu and self.args.torch_xla_tpu_print_metrics:
216
+ import torch_xla.debug.metrics as met
217
+
218
+ self.print_fn(met.metrics_report())
219
+
220
+ return min(runtimes) / 10.0
221
+ except RuntimeError as e:
222
+ self.print_fn(f"Doesn't fit on GPU. {e}")
223
+ return "N/A"
224
+
225
+ def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
226
+ try:
227
+ if self.args.trace_memory_line_by_line:
228
+ trace = start_memory_tracing("transformers")
229
+
230
+ if self.args.is_tpu:
231
+ # tpu
232
+ raise NotImplementedError(
233
+ "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with"
234
+ " `--no-memory` or `args.memory=False`"
235
+ )
236
+ elif self.args.is_gpu:
237
+ if not is_py3nvml_available():
238
+ logger.warning(
239
+ "py3nvml not installed, we won't log GPU memory usage. "
240
+ "Install py3nvml (pip install py3nvml) to log information about GPU."
241
+ )
242
+ memory = "N/A"
243
+ else:
244
+ logger.info(
245
+ "Measuring total GPU usage on GPU device. Make sure to not have additional processes running"
246
+ " on the same GPU."
247
+ )
248
+ # init nvml
249
+ nvml.nvmlInit()
250
+ func()
251
+ handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx)
252
+ meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
253
+ max_bytes_in_use = meminfo.used
254
+ memory = Memory(max_bytes_in_use)
255
+ # shutdown nvml
256
+ nvml.nvmlShutdown()
257
+ else:
258
+ # cpu
259
+ memory_bytes = measure_peak_memory_cpu(func)
260
+ memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
261
+
262
+ if self.args.trace_memory_line_by_line:
263
+ summary = stop_memory_tracing(trace)
264
+ else:
265
+ summary = None
266
+
267
+ return memory, summary
268
+ except RuntimeError as e:
269
+ self.print_fn(f"Doesn't fit on GPU. {e}")
270
+ return "N/A", None
.venv/Lib/site-packages/transformers/benchmark/benchmark_args.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from dataclasses import dataclass, field
18
+ from typing import Tuple
19
+
20
+ from ..utils import (
21
+ cached_property,
22
+ is_torch_available,
23
+ is_torch_xla_available,
24
+ is_torch_xpu_available,
25
+ logging,
26
+ requires_backends,
27
+ )
28
+ from .benchmark_args_utils import BenchmarkArguments
29
+
30
+
31
+ if is_torch_available():
32
+ import torch
33
+
34
+ if is_torch_xla_available():
35
+ import torch_xla.core.xla_model as xm
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+
41
+ @dataclass
42
+ class PyTorchBenchmarkArguments(BenchmarkArguments):
43
+ deprecated_args = [
44
+ "no_inference",
45
+ "no_cuda",
46
+ "no_tpu",
47
+ "no_speed",
48
+ "no_memory",
49
+ "no_env_print",
50
+ "no_multi_process",
51
+ ]
52
+
53
+ def __init__(self, **kwargs):
54
+ """
55
+ This __init__ is there for legacy code. When removing deprecated args completely, the class can simply be
56
+ deleted
57
+ """
58
+ for deprecated_arg in self.deprecated_args:
59
+ if deprecated_arg in kwargs:
60
+ positive_arg = deprecated_arg[3:]
61
+ setattr(self, positive_arg, not kwargs.pop(deprecated_arg))
62
+ logger.warning(
63
+ f"{deprecated_arg} is depreciated. Please use --no_{positive_arg} or"
64
+ f" {positive_arg}={kwargs[positive_arg]}"
65
+ )
66
+
67
+ self.torchscript = kwargs.pop("torchscript", self.torchscript)
68
+ self.torch_xla_tpu_print_metrics = kwargs.pop("torch_xla_tpu_print_metrics", self.torch_xla_tpu_print_metrics)
69
+ self.fp16_opt_level = kwargs.pop("fp16_opt_level", self.fp16_opt_level)
70
+ super().__init__(**kwargs)
71
+
72
+ torchscript: bool = field(default=False, metadata={"help": "Trace the models using torchscript"})
73
+ torch_xla_tpu_print_metrics: bool = field(default=False, metadata={"help": "Print Xla/PyTorch tpu metrics"})
74
+ fp16_opt_level: str = field(
75
+ default="O1",
76
+ metadata={
77
+ "help": (
78
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
79
+ "See details at https://nvidia.github.io/apex/amp.html"
80
+ )
81
+ },
82
+ )
83
+
84
+ @cached_property
85
+ def _setup_devices(self) -> Tuple["torch.device", int]:
86
+ requires_backends(self, ["torch"])
87
+ logger.info("PyTorch: setting up devices")
88
+ if not self.cuda:
89
+ device = torch.device("cpu")
90
+ n_gpu = 0
91
+ elif is_torch_xla_available():
92
+ device = xm.xla_device()
93
+ n_gpu = 0
94
+ elif is_torch_xpu_available():
95
+ device = torch.device("xpu")
96
+ n_gpu = torch.xpu.device_count()
97
+ else:
98
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
99
+ n_gpu = torch.cuda.device_count()
100
+ return device, n_gpu
101
+
102
+ @property
103
+ def is_tpu(self):
104
+ return is_torch_xla_available() and self.tpu
105
+
106
+ @property
107
+ def device_idx(self) -> int:
108
+ requires_backends(self, ["torch"])
109
+ # TODO(PVP): currently only single GPU is supported
110
+ return torch.cuda.current_device()
111
+
112
+ @property
113
+ def device(self) -> "torch.device":
114
+ requires_backends(self, ["torch"])
115
+ return self._setup_devices[0]
116
+
117
+ @property
118
+ def n_gpu(self):
119
+ requires_backends(self, ["torch"])
120
+ return self._setup_devices[1]
121
+
122
+ @property
123
+ def is_gpu(self):
124
+ return self.n_gpu > 0
.venv/Lib/site-packages/transformers/benchmark/benchmark_args_tf.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from dataclasses import dataclass, field
18
+ from typing import Tuple
19
+
20
+ from ..utils import cached_property, is_tf_available, logging, requires_backends
21
+ from .benchmark_args_utils import BenchmarkArguments
22
+
23
+
24
+ if is_tf_available():
25
+ import tensorflow as tf
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ @dataclass
32
+ class TensorFlowBenchmarkArguments(BenchmarkArguments):
33
+ deprecated_args = [
34
+ "no_inference",
35
+ "no_cuda",
36
+ "no_tpu",
37
+ "no_speed",
38
+ "no_memory",
39
+ "no_env_print",
40
+ "no_multi_process",
41
+ ]
42
+
43
+ def __init__(self, **kwargs):
44
+ """
45
+ This __init__ is there for legacy code. When removing deprecated args completely, the class can simply be
46
+ deleted
47
+ """
48
+ for deprecated_arg in self.deprecated_args:
49
+ if deprecated_arg in kwargs:
50
+ positive_arg = deprecated_arg[3:]
51
+ kwargs[positive_arg] = not kwargs.pop(deprecated_arg)
52
+ logger.warning(
53
+ f"{deprecated_arg} is depreciated. Please use --no-{positive_arg} or"
54
+ f" {positive_arg}={kwargs[positive_arg]}"
55
+ )
56
+ self.tpu_name = kwargs.pop("tpu_name", self.tpu_name)
57
+ self.device_idx = kwargs.pop("device_idx", self.device_idx)
58
+ self.eager_mode = kwargs.pop("eager_mode", self.eager_mode)
59
+ self.use_xla = kwargs.pop("use_xla", self.use_xla)
60
+ super().__init__(**kwargs)
61
+
62
+ tpu_name: str = field(
63
+ default=None,
64
+ metadata={"help": "Name of TPU"},
65
+ )
66
+ device_idx: int = field(
67
+ default=0,
68
+ metadata={"help": "CPU / GPU device index. Defaults to 0."},
69
+ )
70
+ eager_mode: bool = field(default=False, metadata={"help": "Benchmark models in eager model."})
71
+ use_xla: bool = field(
72
+ default=False,
73
+ metadata={
74
+ "help": "Benchmark models using XLA JIT compilation. Note that `eager_model` has to be set to `False`."
75
+ },
76
+ )
77
+
78
+ @cached_property
79
+ def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]:
80
+ requires_backends(self, ["tf"])
81
+ tpu = None
82
+ if self.tpu:
83
+ try:
84
+ if self.tpu_name:
85
+ tpu = tf.distribute.cluster_resolver.TPUClusterResolver(self.tpu_name)
86
+ else:
87
+ tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
88
+ except ValueError:
89
+ tpu = None
90
+ return tpu
91
+
92
+ @cached_property
93
+ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]:
94
+ requires_backends(self, ["tf"])
95
+ if self.is_tpu:
96
+ tf.config.experimental_connect_to_cluster(self._setup_tpu)
97
+ tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)
98
+
99
+ strategy = tf.distribute.TPUStrategy(self._setup_tpu)
100
+ else:
101
+ # currently no multi gpu is allowed
102
+ if self.is_gpu:
103
+ # TODO: Currently only single GPU is supported
104
+ tf.config.set_visible_devices(self.gpu_list[self.device_idx], "GPU")
105
+ strategy = tf.distribute.OneDeviceStrategy(device=f"/gpu:{self.device_idx}")
106
+ else:
107
+ tf.config.set_visible_devices([], "GPU") # disable GPU
108
+ strategy = tf.distribute.OneDeviceStrategy(device=f"/cpu:{self.device_idx}")
109
+
110
+ return strategy
111
+
112
+ @property
113
+ def is_tpu(self) -> bool:
114
+ requires_backends(self, ["tf"])
115
+ return self._setup_tpu is not None
116
+
117
+ @property
118
+ def strategy(self) -> "tf.distribute.Strategy":
119
+ requires_backends(self, ["tf"])
120
+ return self._setup_strategy
121
+
122
+ @property
123
+ def gpu_list(self):
124
+ requires_backends(self, ["tf"])
125
+ return tf.config.list_physical_devices("GPU")
126
+
127
+ @property
128
+ def n_gpu(self) -> int:
129
+ requires_backends(self, ["tf"])
130
+ if self.cuda:
131
+ return len(self.gpu_list)
132
+ return 0
133
+
134
+ @property
135
+ def is_gpu(self) -> bool:
136
+ return self.n_gpu > 0
.venv/Lib/site-packages/transformers/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseTransformersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
.venv/Lib/site-packages/transformers/commands/run.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from argparse import ArgumentParser
16
+
17
+ from ..pipelines import Pipeline, PipelineDataFormat, get_supported_tasks, pipeline
18
+ from ..utils import logging
19
+ from . import BaseTransformersCLICommand
20
+
21
+
22
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
23
+
24
+
25
+ def try_infer_format_from_ext(path: str):
26
+ if not path:
27
+ return "pipe"
28
+
29
+ for ext in PipelineDataFormat.SUPPORTED_FORMATS:
30
+ if path.endswith(ext):
31
+ return ext
32
+
33
+ raise Exception(
34
+ f"Unable to determine file format from file extension {path}. "
35
+ f"Please provide the format through --format {PipelineDataFormat.SUPPORTED_FORMATS}"
36
+ )
37
+
38
+
39
+ def run_command_factory(args):
40
+ nlp = pipeline(
41
+ task=args.task,
42
+ model=args.model if args.model else None,
43
+ config=args.config,
44
+ tokenizer=args.tokenizer,
45
+ device=args.device,
46
+ )
47
+ format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format
48
+ reader = PipelineDataFormat.from_str(
49
+ format=format,
50
+ output_path=args.output,
51
+ input_path=args.input,
52
+ column=args.column if args.column else nlp.default_input_names,
53
+ overwrite=args.overwrite,
54
+ )
55
+ return RunCommand(nlp, reader)
56
+
57
+
58
+ class RunCommand(BaseTransformersCLICommand):
59
+ def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
60
+ self._nlp = nlp
61
+ self._reader = reader
62
+
63
+ @staticmethod
64
+ def register_subcommand(parser: ArgumentParser):
65
+ run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
66
+ run_parser.add_argument("--task", choices=get_supported_tasks(), help="Task to run")
67
+ run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
68
+ run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
69
+ run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
70
+ run_parser.add_argument("--config", type=str, help="Name or path to the model's config to instantiate.")
71
+ run_parser.add_argument(
72
+ "--tokenizer", type=str, help="Name of the tokenizer to use. (default: same as the model name)"
73
+ )
74
+ run_parser.add_argument(
75
+ "--column",
76
+ type=str,
77
+ help="Name of the column to use as input. (For multi columns input as QA use column1,columns2)",
78
+ )
79
+ run_parser.add_argument(
80
+ "--format",
81
+ type=str,
82
+ default="infer",
83
+ choices=PipelineDataFormat.SUPPORTED_FORMATS,
84
+ help="Input format to read from",
85
+ )
86
+ run_parser.add_argument(
87
+ "--device",
88
+ type=int,
89
+ default=-1,
90
+ help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
91
+ )
92
+ run_parser.add_argument("--overwrite", action="store_true", help="Allow overwriting the output file.")
93
+ run_parser.set_defaults(func=run_command_factory)
94
+
95
+ def run(self):
96
+ nlp, outputs = self._nlp, []
97
+
98
+ for entry in self._reader:
99
+ output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry)
100
+ if isinstance(output, dict):
101
+ outputs.append(output)
102
+ else:
103
+ outputs += output
104
+
105
+ # Saving data
106
+ if self._nlp.binary_output:
107
+ binary_path = self._reader.save_binary(outputs)
108
+ logger.warning(f"Current pipeline requires output to be in binary format, saving at {binary_path}")
109
+ else:
110
+ self._reader.save(outputs)
.venv/Lib/site-packages/transformers/commands/serving.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from argparse import ArgumentParser, Namespace
16
+ from typing import Any, List, Optional
17
+
18
+ from ..pipelines import Pipeline, get_supported_tasks, pipeline
19
+ from ..utils import logging
20
+ from . import BaseTransformersCLICommand
21
+
22
+
23
+ try:
24
+ from fastapi import Body, FastAPI, HTTPException
25
+ from fastapi.routing import APIRoute
26
+ from pydantic import BaseModel
27
+ from starlette.responses import JSONResponse
28
+ from uvicorn import run
29
+
30
+ _serve_dependencies_installed = True
31
+ except (ImportError, AttributeError):
32
+ BaseModel = object
33
+
34
+ def Body(*x, **y):
35
+ pass
36
+
37
+ _serve_dependencies_installed = False
38
+
39
+
40
+ logger = logging.get_logger("transformers-cli/serving")
41
+
42
+
43
+ def serve_command_factory(args: Namespace):
44
+ """
45
+ Factory function used to instantiate serving server from provided command line arguments.
46
+
47
+ Returns: ServeCommand
48
+ """
49
+ nlp = pipeline(
50
+ task=args.task,
51
+ model=args.model if args.model else None,
52
+ config=args.config,
53
+ tokenizer=args.tokenizer,
54
+ device=args.device,
55
+ )
56
+ return ServeCommand(nlp, args.host, args.port, args.workers)
57
+
58
+
59
+ class ServeModelInfoResult(BaseModel):
60
+ """
61
+ Expose model information
62
+ """
63
+
64
+ infos: dict
65
+
66
+
67
+ class ServeTokenizeResult(BaseModel):
68
+ """
69
+ Tokenize result model
70
+ """
71
+
72
+ tokens: List[str]
73
+ tokens_ids: Optional[List[int]]
74
+
75
+
76
+ class ServeDeTokenizeResult(BaseModel):
77
+ """
78
+ DeTokenize result model
79
+ """
80
+
81
+ text: str
82
+
83
+
84
+ class ServeForwardResult(BaseModel):
85
+ """
86
+ Forward result model
87
+ """
88
+
89
+ output: Any
90
+
91
+
92
+ class ServeCommand(BaseTransformersCLICommand):
93
+ @staticmethod
94
+ def register_subcommand(parser: ArgumentParser):
95
+ """
96
+ Register this command to argparse so it's available for the transformer-cli
97
+
98
+ Args:
99
+ parser: Root parser to register command-specific arguments
100
+ """
101
+ serve_parser = parser.add_parser(
102
+ "serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
103
+ )
104
+ serve_parser.add_argument(
105
+ "--task",
106
+ type=str,
107
+ choices=get_supported_tasks(),
108
+ help="The task to run the pipeline on",
109
+ )
110
+ serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
111
+ serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
112
+ serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
113
+ serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
114
+ serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
115
+ serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
116
+ serve_parser.add_argument(
117
+ "--device",
118
+ type=int,
119
+ default=-1,
120
+ help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
121
+ )
122
+ serve_parser.set_defaults(func=serve_command_factory)
123
+
124
+ def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
125
+ self._pipeline = pipeline
126
+
127
+ self.host = host
128
+ self.port = port
129
+ self.workers = workers
130
+
131
+ if not _serve_dependencies_installed:
132
+ raise RuntimeError(
133
+ "Using serve command requires FastAPI and uvicorn. "
134
+ 'Please install transformers with [serving]: pip install "transformers[serving]". '
135
+ "Or install FastAPI and uvicorn separately."
136
+ )
137
+ else:
138
+ logger.info(f"Serving model over {host}:{port}")
139
+ self._app = FastAPI(
140
+ routes=[
141
+ APIRoute(
142
+ "/",
143
+ self.model_info,
144
+ response_model=ServeModelInfoResult,
145
+ response_class=JSONResponse,
146
+ methods=["GET"],
147
+ ),
148
+ APIRoute(
149
+ "/tokenize",
150
+ self.tokenize,
151
+ response_model=ServeTokenizeResult,
152
+ response_class=JSONResponse,
153
+ methods=["POST"],
154
+ ),
155
+ APIRoute(
156
+ "/detokenize",
157
+ self.detokenize,
158
+ response_model=ServeDeTokenizeResult,
159
+ response_class=JSONResponse,
160
+ methods=["POST"],
161
+ ),
162
+ APIRoute(
163
+ "/forward",
164
+ self.forward,
165
+ response_model=ServeForwardResult,
166
+ response_class=JSONResponse,
167
+ methods=["POST"],
168
+ ),
169
+ ],
170
+ timeout=600,
171
+ )
172
+
173
+ def run(self):
174
+ run(self._app, host=self.host, port=self.port, workers=self.workers)
175
+
176
+ def model_info(self):
177
+ return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
178
+
179
+ def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
180
+ """
181
+ Tokenize the provided input and eventually returns corresponding tokens id: - **text_input**: String to
182
+ tokenize - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer
183
+ mapping.
184
+ """
185
+ try:
186
+ tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
187
+
188
+ if return_ids:
189
+ tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
190
+ return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
191
+ else:
192
+ return ServeTokenizeResult(tokens=tokens_txt)
193
+
194
+ except Exception as e:
195
+ raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
196
+
197
+ def detokenize(
198
+ self,
199
+ tokens_ids: List[int] = Body(None, embed=True),
200
+ skip_special_tokens: bool = Body(False, embed=True),
201
+ cleanup_tokenization_spaces: bool = Body(True, embed=True),
202
+ ):
203
+ """
204
+ Detokenize the provided tokens ids to readable text: - **tokens_ids**: List of tokens ids -
205
+ **skip_special_tokens**: Flag indicating to not try to decode special tokens - **cleanup_tokenization_spaces**:
206
+ Flag indicating to remove all leading/trailing spaces and intermediate ones.
207
+ """
208
+ try:
209
+ decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
210
+ return ServeDeTokenizeResult(model="", text=decoded_str)
211
+ except Exception as e:
212
+ raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
213
+
214
+ async def forward(self, inputs=Body(None, embed=True)):
215
+ """
216
+ **inputs**: **attention_mask**: **tokens_type_ids**:
217
+ """
218
+
219
+ # Check we don't have empty string
220
+ if len(inputs) == 0:
221
+ return ServeForwardResult(output=[], attention=[])
222
+
223
+ try:
224
+ # Forward through the model
225
+ output = self._pipeline(inputs)
226
+ return ServeForwardResult(output=output)
227
+ except Exception as e:
228
+ raise HTTPException(500, {"error": str(e)})
.venv/Lib/site-packages/transformers/commands/train.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from argparse import ArgumentParser, Namespace
17
+
18
+ from ..data import SingleSentenceClassificationProcessor as Processor
19
+ from ..pipelines import TextClassificationPipeline
20
+ from ..utils import is_tf_available, is_torch_available, logging
21
+ from . import BaseTransformersCLICommand
22
+
23
+
24
+ if not is_tf_available() and not is_torch_available():
25
+ raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
26
+
27
+ # TF training parameters
28
+ USE_XLA = False
29
+ USE_AMP = False
30
+
31
+
32
+ def train_command_factory(args: Namespace):
33
+ """
34
+ Factory function used to instantiate training command from provided command line arguments.
35
+
36
+ Returns: TrainCommand
37
+ """
38
+ return TrainCommand(args)
39
+
40
+
41
+ class TrainCommand(BaseTransformersCLICommand):
42
+ @staticmethod
43
+ def register_subcommand(parser: ArgumentParser):
44
+ """
45
+ Register this command to argparse so it's available for the transformer-cli
46
+
47
+ Args:
48
+ parser: Root parser to register command-specific arguments
49
+ """
50
+ train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.")
51
+
52
+ train_parser.add_argument(
53
+ "--train_data",
54
+ type=str,
55
+ required=True,
56
+ help="path to train (and optionally evaluation) dataset as a csv with tab separated labels and sentences.",
57
+ )
58
+ train_parser.add_argument(
59
+ "--column_label", type=int, default=0, help="Column of the dataset csv file with example labels."
60
+ )
61
+ train_parser.add_argument(
62
+ "--column_text", type=int, default=1, help="Column of the dataset csv file with example texts."
63
+ )
64
+ train_parser.add_argument(
65
+ "--column_id", type=int, default=2, help="Column of the dataset csv file with example ids."
66
+ )
67
+ train_parser.add_argument(
68
+ "--skip_first_row", action="store_true", help="Skip the first row of the csv file (headers)."
69
+ )
70
+
71
+ train_parser.add_argument("--validation_data", type=str, default="", help="path to validation dataset.")
72
+ train_parser.add_argument(
73
+ "--validation_split",
74
+ type=float,
75
+ default=0.1,
76
+ help="if validation dataset is not provided, fraction of train dataset to use as validation dataset.",
77
+ )
78
+
79
+ train_parser.add_argument("--output", type=str, default="./", help="path to saved the trained model.")
80
+
81
+ train_parser.add_argument(
82
+ "--task", type=str, default="text_classification", help="Task to train the model on."
83
+ )
84
+ train_parser.add_argument(
85
+ "--model", type=str, default="google-bert/bert-base-uncased", help="Model's name or path to stored model."
86
+ )
87
+ train_parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training.")
88
+ train_parser.add_argument("--valid_batch_size", type=int, default=64, help="Batch size for validation.")
89
+ train_parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.")
90
+ train_parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon for Adam optimizer.")
91
+ train_parser.set_defaults(func=train_command_factory)
92
+
93
+ def __init__(self, args: Namespace):
94
+ self.logger = logging.get_logger("transformers-cli/training")
95
+
96
+ self.framework = "tf" if is_tf_available() else "torch"
97
+
98
+ os.makedirs(args.output, exist_ok=True)
99
+ self.output = args.output
100
+
101
+ self.column_label = args.column_label
102
+ self.column_text = args.column_text
103
+ self.column_id = args.column_id
104
+
105
+ self.logger.info(f"Loading {args.task} pipeline for {args.model}")
106
+ if args.task == "text_classification":
107
+ self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
108
+ elif args.task == "token_classification":
109
+ raise NotImplementedError
110
+ elif args.task == "question_answering":
111
+ raise NotImplementedError
112
+
113
+ self.logger.info(f"Loading dataset from {args.train_data}")
114
+ self.train_dataset = Processor.create_from_csv(
115
+ args.train_data,
116
+ column_label=args.column_label,
117
+ column_text=args.column_text,
118
+ column_id=args.column_id,
119
+ skip_first_row=args.skip_first_row,
120
+ )
121
+ self.valid_dataset = None
122
+ if args.validation_data:
123
+ self.logger.info(f"Loading validation dataset from {args.validation_data}")
124
+ self.valid_dataset = Processor.create_from_csv(
125
+ args.validation_data,
126
+ column_label=args.column_label,
127
+ column_text=args.column_text,
128
+ column_id=args.column_id,
129
+ skip_first_row=args.skip_first_row,
130
+ )
131
+
132
+ self.validation_split = args.validation_split
133
+ self.train_batch_size = args.train_batch_size
134
+ self.valid_batch_size = args.valid_batch_size
135
+ self.learning_rate = args.learning_rate
136
+ self.adam_epsilon = args.adam_epsilon
137
+
138
+ def run(self):
139
+ if self.framework == "tf":
140
+ return self.run_tf()
141
+ return self.run_torch()
142
+
143
+ def run_torch(self):
144
+ raise NotImplementedError
145
+
146
+ def run_tf(self):
147
+ self.pipeline.fit(
148
+ self.train_dataset,
149
+ validation_data=self.valid_dataset,
150
+ validation_split=self.validation_split,
151
+ learning_rate=self.learning_rate,
152
+ adam_epsilon=self.adam_epsilon,
153
+ train_batch_size=self.train_batch_size,
154
+ valid_batch_size=self.valid_batch_size,
155
+ )
156
+
157
+ # Save trained pipeline
158
+ self.pipeline.save_pretrained(self.output)
.venv/Lib/site-packages/transformers/commands/transformers_cli.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .add_new_model_like import AddNewModelLikeCommand
19
+ from .convert import ConvertCommand
20
+ from .download import DownloadCommand
21
+ from .env import EnvironmentCommand
22
+ from .lfs import LfsCommands
23
+ from .pt_to_tf import PTtoTFCommand
24
+ from .run import RunCommand
25
+ from .serving import ServeCommand
26
+ from .user import UserCommands
27
+
28
+
29
+ def main():
30
+ parser = ArgumentParser("Transformers CLI tool", usage="transformers-cli <command> [<args>]")
31
+ commands_parser = parser.add_subparsers(help="transformers-cli command helpers")
32
+
33
+ # Register commands
34
+ ConvertCommand.register_subcommand(commands_parser)
35
+ DownloadCommand.register_subcommand(commands_parser)
36
+ EnvironmentCommand.register_subcommand(commands_parser)
37
+ RunCommand.register_subcommand(commands_parser)
38
+ ServeCommand.register_subcommand(commands_parser)
39
+ UserCommands.register_subcommand(commands_parser)
40
+ AddNewModelLikeCommand.register_subcommand(commands_parser)
41
+ LfsCommands.register_subcommand(commands_parser)
42
+ PTtoTFCommand.register_subcommand(commands_parser)
43
+
44
+ # Let's go
45
+ args = parser.parse_args()
46
+
47
+ if not hasattr(args, "func"):
48
+ parser.print_help()
49
+ exit(1)
50
+
51
+ # Run
52
+ service = args.func(args)
53
+ service.run()
54
+
55
+
56
+ if __name__ == "__main__":
57
+ main()
.venv/Lib/site-packages/transformers/commands/user.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import subprocess
16
+ from argparse import ArgumentParser
17
+ from typing import List, Union
18
+
19
+ from huggingface_hub.hf_api import HfFolder, create_repo, whoami
20
+ from requests.exceptions import HTTPError
21
+
22
+ from . import BaseTransformersCLICommand
23
+
24
+
25
+ class UserCommands(BaseTransformersCLICommand):
26
+ @staticmethod
27
+ def register_subcommand(parser: ArgumentParser):
28
+ login_parser = parser.add_parser("login", help="Log in using the same credentials as on huggingface.co")
29
+ login_parser.set_defaults(func=lambda args: LoginCommand(args))
30
+ whoami_parser = parser.add_parser("whoami", help="Find out which huggingface.co account you are logged in as.")
31
+ whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
32
+ logout_parser = parser.add_parser("logout", help="Log out")
33
+ logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
34
+
35
+ # new system: git-based repo system
36
+ repo_parser = parser.add_parser(
37
+ "repo",
38
+ help="Deprecated: use `huggingface-cli` instead. Commands to interact with your huggingface.co repos.",
39
+ )
40
+ repo_subparsers = repo_parser.add_subparsers(
41
+ help="Deprecated: use `huggingface-cli` instead. huggingface.co repos related commands"
42
+ )
43
+ repo_create_parser = repo_subparsers.add_parser(
44
+ "create", help="Deprecated: use `huggingface-cli` instead. Create a new repo on huggingface.co"
45
+ )
46
+ repo_create_parser.add_argument(
47
+ "name",
48
+ type=str,
49
+ help="Name for your model's repo. Will be namespaced under your username to build the model id.",
50
+ )
51
+ repo_create_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
52
+ repo_create_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt")
53
+ repo_create_parser.set_defaults(func=lambda args: RepoCreateCommand(args))
54
+
55
+
56
+ class ANSI:
57
+ """
58
+ Helper for en.wikipedia.org/wiki/ANSI_escape_code
59
+ """
60
+
61
+ _bold = "\u001b[1m"
62
+ _red = "\u001b[31m"
63
+ _gray = "\u001b[90m"
64
+ _reset = "\u001b[0m"
65
+
66
+ @classmethod
67
+ def bold(cls, s):
68
+ return f"{cls._bold}{s}{cls._reset}"
69
+
70
+ @classmethod
71
+ def red(cls, s):
72
+ return f"{cls._bold}{cls._red}{s}{cls._reset}"
73
+
74
+ @classmethod
75
+ def gray(cls, s):
76
+ return f"{cls._gray}{s}{cls._reset}"
77
+
78
+
79
+ def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
80
+ """
81
+ Inspired by:
82
+
83
+ - stackoverflow.com/a/8356620/593036
84
+ - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
85
+ """
86
+ col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
87
+ row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
88
+ lines = []
89
+ lines.append(row_format.format(*headers))
90
+ lines.append(row_format.format(*["-" * w for w in col_widths]))
91
+ for row in rows:
92
+ lines.append(row_format.format(*row))
93
+ return "\n".join(lines)
94
+
95
+
96
+ class BaseUserCommand:
97
+ def __init__(self, args):
98
+ self.args = args
99
+
100
+
101
+ class LoginCommand(BaseUserCommand):
102
+ def run(self):
103
+ print(
104
+ ANSI.red(
105
+ "ERROR! `huggingface-cli login` uses an outdated login mechanism "
106
+ "that is not compatible with the Hugging Face Hub backend anymore. "
107
+ "Please use `huggingface-cli login instead."
108
+ )
109
+ )
110
+
111
+
112
+ class WhoamiCommand(BaseUserCommand):
113
+ def run(self):
114
+ print(
115
+ ANSI.red(
116
+ "WARNING! `transformers-cli whoami` is deprecated and will be removed in v5. Please use "
117
+ "`huggingface-cli whoami` instead."
118
+ )
119
+ )
120
+ token = HfFolder.get_token()
121
+ if token is None:
122
+ print("Not logged in")
123
+ exit()
124
+ try:
125
+ user, orgs = whoami(token)
126
+ print(user)
127
+ if orgs:
128
+ print(ANSI.bold("orgs: "), ",".join(orgs))
129
+ except HTTPError as e:
130
+ print(e)
131
+ print(ANSI.red(e.response.text))
132
+ exit(1)
133
+
134
+
135
+ class LogoutCommand(BaseUserCommand):
136
+ def run(self):
137
+ print(
138
+ ANSI.red(
139
+ "ERROR! `transformers-cli logout` uses an outdated logout mechanism "
140
+ "that is not compatible with the Hugging Face Hub backend anymore. "
141
+ "Please use `huggingface-cli logout instead."
142
+ )
143
+ )
144
+
145
+
146
+ class RepoCreateCommand(BaseUserCommand):
147
+ def run(self):
148
+ print(
149
+ ANSI.red(
150
+ "WARNING! Managing repositories through transformers-cli is deprecated. "
151
+ "Please use `huggingface-cli` instead."
152
+ )
153
+ )
154
+ token = HfFolder.get_token()
155
+ if token is None:
156
+ print("Not logged in")
157
+ exit(1)
158
+ try:
159
+ stdout = subprocess.check_output(["git", "--version"]).decode("utf-8")
160
+ print(ANSI.gray(stdout.strip()))
161
+ except FileNotFoundError:
162
+ print("Looks like you do not have git installed, please install.")
163
+
164
+ try:
165
+ stdout = subprocess.check_output(["git-lfs", "--version"]).decode("utf-8")
166
+ print(ANSI.gray(stdout.strip()))
167
+ except FileNotFoundError:
168
+ print(
169
+ ANSI.red(
170
+ "Looks like you do not have git-lfs installed, please install."
171
+ " You can install from https://git-lfs.github.com/."
172
+ " Then run `git lfs install` (you only have to do this once)."
173
+ )
174
+ )
175
+ print("")
176
+
177
+ user, _ = whoami(token)
178
+ namespace = self.args.organization if self.args.organization is not None else user
179
+ full_name = f"{namespace}/{self.args.name}"
180
+ print(f"You are about to create {ANSI.bold(full_name)}")
181
+
182
+ if not self.args.yes:
183
+ choice = input("Proceed? [Y/n] ").lower()
184
+ if not (choice == "" or choice == "y" or choice == "yes"):
185
+ print("Abort")
186
+ exit()
187
+ try:
188
+ url = create_repo(repo_id=full_name, token=token)
189
+ except HTTPError as e:
190
+ print(e)
191
+ print(ANSI.red(e.response.text))
192
+ exit(1)
193
+ print("\nYour repo now lives at:")
194
+ print(f" {ANSI.bold(url)}")
195
+ print("\nYou can clone it locally with the command below, and commit/push as usual.")
196
+ print(f"\n git clone {url}")
197
+ print("")
.venv/Lib/site-packages/transformers/data/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .data_collator import (
16
+ DataCollatorForLanguageModeling,
17
+ DataCollatorForPermutationLanguageModeling,
18
+ DataCollatorForSeq2Seq,
19
+ DataCollatorForSOP,
20
+ DataCollatorForTokenClassification,
21
+ DataCollatorForWholeWordMask,
22
+ DataCollatorWithFlattening,
23
+ DataCollatorWithPadding,
24
+ DefaultDataCollator,
25
+ default_data_collator,
26
+ )
27
+ from .metrics import glue_compute_metrics, xnli_compute_metrics
28
+ from .processors import (
29
+ DataProcessor,
30
+ InputExample,
31
+ InputFeatures,
32
+ SingleSentenceClassificationProcessor,
33
+ SquadExample,
34
+ SquadFeatures,
35
+ SquadV1Processor,
36
+ SquadV2Processor,
37
+ glue_convert_examples_to_features,
38
+ glue_output_modes,
39
+ glue_processors,
40
+ glue_tasks_num_labels,
41
+ squad_convert_examples_to_features,
42
+ xnli_output_modes,
43
+ xnli_processors,
44
+ xnli_tasks_num_labels,
45
+ )
.venv/Lib/site-packages/transformers/data/data_collator.py ADDED
@@ -0,0 +1,1653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+ import warnings
17
+ from collections.abc import Mapping
18
+ from dataclasses import dataclass
19
+ from random import randint
20
+ from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+
24
+ from ..models.bert import BertTokenizer, BertTokenizerFast
25
+ from ..tokenization_utils_base import PreTrainedTokenizerBase
26
+ from ..utils import PaddingStrategy
27
+
28
+
29
+ InputDataClass = NewType("InputDataClass", Any)
30
+
31
+ """
32
+ A DataCollator is a function that takes a list of samples from a Dataset and collate them into a batch, as a dictionary
33
+ of PyTorch/TensorFlow tensors or NumPy arrays.
34
+ """
35
+ DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, Any]])
36
+
37
+
38
+ class DataCollatorMixin:
39
+ def __call__(self, features, return_tensors=None):
40
+ if return_tensors is None:
41
+ return_tensors = self.return_tensors
42
+ if return_tensors == "tf":
43
+ return self.tf_call(features)
44
+ elif return_tensors == "pt":
45
+ return self.torch_call(features)
46
+ elif return_tensors == "np":
47
+ return self.numpy_call(features)
48
+ else:
49
+ raise ValueError(f"Framework '{return_tensors}' not recognized!")
50
+
51
+
52
+ def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
53
+ """
54
+ Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
55
+ """
56
+
57
+ # To avoid errors when using Feature extractors
58
+ if not hasattr(tokenizer, "deprecation_warnings"):
59
+ return tokenizer.pad(*pad_args, **pad_kwargs)
60
+
61
+ # Save the state of the warning, then disable it
62
+ warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
63
+ tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
64
+
65
+ try:
66
+ padded = tokenizer.pad(*pad_args, **pad_kwargs)
67
+ finally:
68
+ # Restore the state of the warning.
69
+ tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state
70
+
71
+ return padded
72
+
73
+
74
+ def default_data_collator(features: List[InputDataClass], return_tensors="pt") -> Dict[str, Any]:
75
+ """
76
+ Very simple data collator that simply collates batches of dict-like objects and performs special handling for
77
+ potential keys named:
78
+
79
+ - `label`: handles a single value (int or float) per object
80
+ - `label_ids`: handles a list of values per object
81
+
82
+ Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
83
+ to the model. See glue and ner for example of how it's useful.
84
+ """
85
+
86
+ # In this function we'll make the assumption that all `features` in the batch
87
+ # have the same attributes.
88
+ # So we will look at the first element as a proxy for what attributes exist
89
+ # on the whole batch.
90
+
91
+ if return_tensors == "pt":
92
+ return torch_default_data_collator(features)
93
+ elif return_tensors == "tf":
94
+ return tf_default_data_collator(features)
95
+ elif return_tensors == "np":
96
+ return numpy_default_data_collator(features)
97
+
98
+
99
+ @dataclass
100
+ class DefaultDataCollator(DataCollatorMixin):
101
+ """
102
+ Very simple data collator that simply collates batches of dict-like objects and performs special handling for
103
+ potential keys named:
104
+
105
+ - `label`: handles a single value (int or float) per object
106
+ - `label_ids`: handles a list of values per object
107
+
108
+ Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
109
+ to the model. See glue and ner for example of how it's useful.
110
+
111
+ This is an object (like other data collators) rather than a pure function like default_data_collator. This can be
112
+ helpful if you need to set a return_tensors value at initialization.
113
+
114
+ Args:
115
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
116
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
117
+ """
118
+
119
+ return_tensors: str = "pt"
120
+
121
+ def __call__(self, features: List[Dict[str, Any]], return_tensors=None) -> Dict[str, Any]:
122
+ if return_tensors is None:
123
+ return_tensors = self.return_tensors
124
+ return default_data_collator(features, return_tensors)
125
+
126
+
127
+ def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
128
+ import torch
129
+
130
+ if not isinstance(features[0], Mapping):
131
+ features = [vars(f) for f in features]
132
+ first = features[0]
133
+ batch = {}
134
+
135
+ # Special handling for labels.
136
+ # Ensure that tensor is created with the correct type
137
+ # (it should be automatically the case, but let's make sure of it.)
138
+ if "label" in first and first["label"] is not None:
139
+ label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
140
+ dtype = torch.long if isinstance(label, int) else torch.float
141
+ batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
142
+ elif "label_ids" in first and first["label_ids"] is not None:
143
+ if isinstance(first["label_ids"], torch.Tensor):
144
+ batch["labels"] = torch.stack([f["label_ids"] for f in features])
145
+ else:
146
+ dtype = torch.long if isinstance(first["label_ids"][0], int) else torch.float
147
+ batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
148
+
149
+ # Handling of all other possible keys.
150
+ # Again, we will use the first element to figure out which key/values are not None for this model.
151
+ for k, v in first.items():
152
+ if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
153
+ if isinstance(v, torch.Tensor):
154
+ batch[k] = torch.stack([f[k] for f in features])
155
+ elif isinstance(v, np.ndarray):
156
+ batch[k] = torch.from_numpy(np.stack([f[k] for f in features]))
157
+ else:
158
+ batch[k] = torch.tensor([f[k] for f in features])
159
+
160
+ return batch
161
+
162
+
163
+ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
164
+ import tensorflow as tf
165
+
166
+ if not isinstance(features[0], Mapping):
167
+ features = [vars(f) for f in features]
168
+ first = features[0]
169
+ batch = {}
170
+
171
+ # Special handling for labels.
172
+ # Ensure that tensor is created with the correct type
173
+ # (it should be automatically the case, but let's make sure of it.)
174
+ if "label" in first and first["label"] is not None:
175
+ label_col_name = "label"
176
+ elif "label_ids" in first and first["label_ids"] is not None:
177
+ label_col_name = "label_ids"
178
+ elif "labels" in first and first["labels"] is not None:
179
+ label_col_name = "labels"
180
+ else:
181
+ label_col_name = None
182
+ if label_col_name is not None:
183
+ if isinstance(first[label_col_name], tf.Tensor):
184
+ dtype = tf.int64 if first[label_col_name].dtype.is_integer else tf.float32
185
+ elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic):
186
+ dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32
187
+ elif isinstance(first[label_col_name], (tuple, list)):
188
+ dtype = tf.int64 if isinstance(first[label_col_name][0], int) else tf.float32
189
+ else:
190
+ dtype = tf.int64 if isinstance(first[label_col_name], int) else tf.float32
191
+ batch["labels"] = tf.convert_to_tensor([f[label_col_name] for f in features], dtype=dtype)
192
+ # Handling of all other possible keys.
193
+ # Again, we will use the first element to figure out which key/values are not None for this model.
194
+ for k, v in first.items():
195
+ if k not in ("label", "label_ids", "labels") and v is not None and not isinstance(v, str):
196
+ if isinstance(v, (tf.Tensor, np.ndarray)):
197
+ batch[k] = tf.stack([f[k] for f in features])
198
+ else:
199
+ batch[k] = tf.convert_to_tensor([f[k] for f in features])
200
+
201
+ return batch
202
+
203
+
204
+ def numpy_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
205
+ if not isinstance(features[0], Mapping):
206
+ features = [vars(f) for f in features]
207
+ first = features[0]
208
+ batch = {}
209
+
210
+ # Special handling for labels.
211
+ # Ensure that tensor is created with the correct type
212
+ # (it should be automatically the case, but let's make sure of it.)
213
+ if "label" in first and first["label"] is not None:
214
+ label = first["label"].item() if isinstance(first["label"], np.ndarray) else first["label"]
215
+ dtype = np.int64 if isinstance(label, int) else np.float32
216
+ batch["labels"] = np.array([f["label"] for f in features], dtype=dtype)
217
+ elif "label_ids" in first and first["label_ids"] is not None:
218
+ if isinstance(first["label_ids"], np.ndarray):
219
+ batch["labels"] = np.stack([f["label_ids"] for f in features])
220
+ else:
221
+ dtype = np.int64 if isinstance(first["label_ids"][0], int) else np.float32
222
+ batch["labels"] = np.array([f["label_ids"] for f in features], dtype=dtype)
223
+
224
+ # Handling of all other possible keys.
225
+ # Again, we will use the first element to figure out which key/values are not None for this model.
226
+ for k, v in first.items():
227
+ if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
228
+ if isinstance(v, np.ndarray):
229
+ batch[k] = np.stack([f[k] for f in features])
230
+ else:
231
+ batch[k] = np.array([f[k] for f in features])
232
+
233
+ return batch
234
+
235
+
236
+ @dataclass
237
+ class DataCollatorWithPadding:
238
+ """
239
+ Data collator that will dynamically pad the inputs received.
240
+
241
+ Args:
242
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
243
+ The tokenizer used for encoding the data.
244
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
245
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
246
+ among:
247
+
248
+ - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
249
+ sequence is provided).
250
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
251
+ acceptable input length for the model if that argument is not provided.
252
+ - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
253
+ max_length (`int`, *optional*):
254
+ Maximum length of the returned list and optionally padding length (see above).
255
+ pad_to_multiple_of (`int`, *optional*):
256
+ If set will pad the sequence to a multiple of the provided value.
257
+
258
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
259
+ 7.5 (Volta).
260
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
261
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
262
+ """
263
+
264
+ tokenizer: PreTrainedTokenizerBase
265
+ padding: Union[bool, str, PaddingStrategy] = True
266
+ max_length: Optional[int] = None
267
+ pad_to_multiple_of: Optional[int] = None
268
+ return_tensors: str = "pt"
269
+
270
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
271
+ batch = pad_without_fast_tokenizer_warning(
272
+ self.tokenizer,
273
+ features,
274
+ padding=self.padding,
275
+ max_length=self.max_length,
276
+ pad_to_multiple_of=self.pad_to_multiple_of,
277
+ return_tensors=self.return_tensors,
278
+ )
279
+ if "label" in batch:
280
+ batch["labels"] = batch["label"]
281
+ del batch["label"]
282
+ if "label_ids" in batch:
283
+ batch["labels"] = batch["label_ids"]
284
+ del batch["label_ids"]
285
+ return batch
286
+
287
+
288
+ @dataclass
289
+ class DataCollatorForTokenClassification(DataCollatorMixin):
290
+ """
291
+ Data collator that will dynamically pad the inputs received, as well as the labels.
292
+
293
+ Args:
294
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
295
+ The tokenizer used for encoding the data.
296
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
297
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
298
+ among:
299
+
300
+ - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
301
+ sequence is provided).
302
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
303
+ acceptable input length for the model if that argument is not provided.
304
+ - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
305
+ max_length (`int`, *optional*):
306
+ Maximum length of the returned list and optionally padding length (see above).
307
+ pad_to_multiple_of (`int`, *optional*):
308
+ If set will pad the sequence to a multiple of the provided value.
309
+
310
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
311
+ 7.5 (Volta).
312
+ label_pad_token_id (`int`, *optional*, defaults to -100):
313
+ The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
314
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
315
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
316
+ """
317
+
318
+ tokenizer: PreTrainedTokenizerBase
319
+ padding: Union[bool, str, PaddingStrategy] = True
320
+ max_length: Optional[int] = None
321
+ pad_to_multiple_of: Optional[int] = None
322
+ label_pad_token_id: int = -100
323
+ return_tensors: str = "pt"
324
+
325
+ def torch_call(self, features):
326
+ import torch
327
+
328
+ label_name = "label" if "label" in features[0].keys() else "labels"
329
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
330
+
331
+ no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
332
+
333
+ batch = pad_without_fast_tokenizer_warning(
334
+ self.tokenizer,
335
+ no_labels_features,
336
+ padding=self.padding,
337
+ max_length=self.max_length,
338
+ pad_to_multiple_of=self.pad_to_multiple_of,
339
+ return_tensors="pt",
340
+ )
341
+
342
+ if labels is None:
343
+ return batch
344
+
345
+ sequence_length = batch["input_ids"].shape[1]
346
+ padding_side = self.tokenizer.padding_side
347
+
348
+ def to_list(tensor_or_iterable):
349
+ if isinstance(tensor_or_iterable, torch.Tensor):
350
+ return tensor_or_iterable.tolist()
351
+ return list(tensor_or_iterable)
352
+
353
+ if padding_side == "right":
354
+ batch[label_name] = [
355
+ to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
356
+ ]
357
+ else:
358
+ batch[label_name] = [
359
+ [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
360
+ ]
361
+
362
+ batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
363
+ return batch
364
+
365
+ def tf_call(self, features):
366
+ import tensorflow as tf
367
+
368
+ label_name = "label" if "label" in features[0].keys() else "labels"
369
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
370
+ batch = pad_without_fast_tokenizer_warning(
371
+ self.tokenizer,
372
+ features,
373
+ padding=self.padding,
374
+ max_length=self.max_length,
375
+ pad_to_multiple_of=self.pad_to_multiple_of,
376
+ # Conversion to tensors will fail if we have labels as they are not of the same length yet.
377
+ return_tensors="tf" if labels is None else None,
378
+ )
379
+
380
+ if labels is None:
381
+ return batch
382
+
383
+ sequence_length = tf.convert_to_tensor(batch["input_ids"]).shape[1]
384
+ padding_side = self.tokenizer.padding_side
385
+ if padding_side == "right":
386
+ batch["labels"] = [
387
+ list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
388
+ ]
389
+ else:
390
+ batch["labels"] = [
391
+ [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
392
+ ]
393
+
394
+ batch = {k: tf.convert_to_tensor(v, dtype=tf.int64) for k, v in batch.items()}
395
+ return batch
396
+
397
+ def numpy_call(self, features):
398
+ label_name = "label" if "label" in features[0].keys() else "labels"
399
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
400
+ batch = pad_without_fast_tokenizer_warning(
401
+ self.tokenizer,
402
+ features,
403
+ padding=self.padding,
404
+ max_length=self.max_length,
405
+ pad_to_multiple_of=self.pad_to_multiple_of,
406
+ # Conversion to tensors will fail if we have labels as they are not of the same length yet.
407
+ return_tensors="np" if labels is None else None,
408
+ )
409
+
410
+ if labels is None:
411
+ return batch
412
+
413
+ sequence_length = np.array(batch["input_ids"]).shape[1]
414
+ padding_side = self.tokenizer.padding_side
415
+ if padding_side == "right":
416
+ batch["labels"] = [
417
+ list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
418
+ ]
419
+ else:
420
+ batch["labels"] = [
421
+ [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
422
+ ]
423
+
424
+ batch = {k: np.array(v, dtype=np.int64) for k, v in batch.items()}
425
+ return batch
426
+
427
+
428
+ def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
429
+ """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
430
+ import torch
431
+
432
+ # Tensorize if necessary.
433
+ if isinstance(examples[0], (list, tuple, np.ndarray)):
434
+ examples = [torch.tensor(e, dtype=torch.long) for e in examples]
435
+
436
+ length_of_first = examples[0].size(0)
437
+
438
+ # Check if padding is necessary.
439
+
440
+ are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
441
+ if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
442
+ if not isinstance(examples, torch.Tensor):
443
+ return torch.stack(examples, dim=0)
444
+
445
+ # If yes, check if we have a `pad_token`.
446
+ if tokenizer.pad_token is None:
447
+ raise ValueError(
448
+ "You are attempting to pad samples but the tokenizer you are using"
449
+ f" ({tokenizer.__class__.__name__}) does not have a pad token."
450
+ )
451
+
452
+ # Creating the full tensor and filling it with our data.
453
+ max_length = max(x.size(0) for x in examples)
454
+ if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
455
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
456
+ result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
457
+ for i, example in enumerate(examples):
458
+ if tokenizer.padding_side == "right":
459
+ result[i, : example.shape[0]] = example
460
+ else:
461
+ result[i, -example.shape[0] :] = example
462
+ return result
463
+
464
+
465
+ def _tf_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
466
+ import tensorflow as tf
467
+
468
+ """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
469
+ # Tensorize if necessary.
470
+ if isinstance(examples[0], (list, tuple)):
471
+ examples = [tf.convert_to_tensor(e, dtype=tf.int64) for e in examples]
472
+
473
+ # Check if padding is necessary.
474
+ length_of_first = len(examples[0])
475
+ are_tensors_same_length = all(len(x) == length_of_first for x in examples)
476
+ if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
477
+ return tf.stack(examples, axis=0)
478
+
479
+ # If yes, check if we have a `pad_token`.
480
+ if tokenizer.pad_token is None:
481
+ raise ValueError(
482
+ "You are attempting to pad samples but the tokenizer you are using"
483
+ f" ({tokenizer.__class__.__name__}) does not have a pad token."
484
+ )
485
+
486
+ # Creating the full tensor and filling it with our data.
487
+ max_length = max(len(x) for x in examples)
488
+ if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
489
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
490
+ # result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
491
+ result = []
492
+ rank = tf.rank(examples[0])
493
+ paddings = np.zeros((rank, 2), dtype=np.int32)
494
+ for example in examples:
495
+ if tokenizer.padding_side == "right":
496
+ paddings[0, 1] = max_length - len(example)
497
+ else:
498
+ paddings[0, 0] = max_length - len(example)
499
+ result.append(tf.pad(example, paddings, constant_values=tokenizer.pad_token_id))
500
+ return tf.stack(result, axis=0)
501
+
502
+
503
+ def _numpy_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
504
+ """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
505
+ # Tensorize if necessary.
506
+ if isinstance(examples[0], (list, tuple)):
507
+ examples = [np.array(e, dtype=np.int64) for e in examples]
508
+
509
+ # Check if padding is necessary.
510
+ length_of_first = len(examples[0])
511
+ are_tensors_same_length = all(len(x) == length_of_first for x in examples)
512
+ if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
513
+ return np.stack(examples, axis=0)
514
+
515
+ # If yes, check if we have a `pad_token`.
516
+ if tokenizer.pad_token is None:
517
+ raise ValueError(
518
+ "You are attempting to pad samples but the tokenizer you are using"
519
+ f" ({tokenizer.__class__.__name__}) does not have a pad token."
520
+ )
521
+
522
+ # Creating the full tensor and filling it with our data.
523
+ max_length = max(len(x) for x in examples)
524
+ if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
525
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
526
+ result = np.full(shape=(len(examples), max_length), fill_value=tokenizer.pad_token_id, dtype=examples[0].dtype)
527
+ for i, example in enumerate(examples):
528
+ if tokenizer.padding_side == "right":
529
+ result[i, : example.shape[0]] = example
530
+ else:
531
+ result[i, -example.shape[0] :] = example
532
+ return result
533
+
534
+
535
+ def tolist(x):
536
+ if isinstance(x, list):
537
+ return x
538
+ elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import
539
+ x = x.numpy()
540
+ return x.tolist()
541
+
542
+
543
+ @dataclass
544
+ class DataCollatorForSeq2Seq:
545
+ """
546
+ Data collator that will dynamically pad the inputs received, as well as the labels.
547
+
548
+ Args:
549
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
550
+ The tokenizer used for encoding the data.
551
+ model ([`PreTrainedModel`], *optional*):
552
+ The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
553
+ prepare the *decoder_input_ids*
554
+
555
+ This is useful when using *label_smoothing* to avoid calculating loss twice.
556
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
557
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
558
+ among:
559
+
560
+ - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
561
+ sequence is provided).
562
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
563
+ acceptable input length for the model if that argument is not provided.
564
+ - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
565
+ max_length (`int`, *optional*):
566
+ Maximum length of the returned list and optionally padding length (see above).
567
+ pad_to_multiple_of (`int`, *optional*):
568
+ If set will pad the sequence to a multiple of the provided value.
569
+
570
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
571
+ 7.5 (Volta).
572
+ label_pad_token_id (`int`, *optional*, defaults to -100):
573
+ The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
574
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
575
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
576
+ """
577
+
578
+ tokenizer: PreTrainedTokenizerBase
579
+ model: Optional[Any] = None
580
+ padding: Union[bool, str, PaddingStrategy] = True
581
+ max_length: Optional[int] = None
582
+ pad_to_multiple_of: Optional[int] = None
583
+ label_pad_token_id: int = -100
584
+ return_tensors: str = "pt"
585
+
586
+ def __call__(self, features, return_tensors=None):
587
+ if return_tensors is None:
588
+ return_tensors = self.return_tensors
589
+
590
+ label_name = "label" if "label" in features[0].keys() else "labels"
591
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
592
+ # reconvert list[None] to None if necessary
593
+ # this might occur when we pass {..., "labels": None}
594
+ if labels is not None and all(label is None for label in labels):
595
+ labels = None
596
+ non_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
597
+
598
+ # run through tokenizer without labels to ensure no side effects
599
+ batch = pad_without_fast_tokenizer_warning(
600
+ self.tokenizer,
601
+ non_labels_features,
602
+ padding=self.padding,
603
+ max_length=self.max_length,
604
+ pad_to_multiple_of=self.pad_to_multiple_of,
605
+ return_tensors=return_tensors,
606
+ )
607
+
608
+ # we have to pad the labels manually as we cannot rely on `tokenizer.pad` and we need them to be of the same length to return tensors
609
+ no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD
610
+ if labels is not None:
611
+ if no_padding:
612
+ if isinstance(features[0][label_name], list):
613
+ batch["labels"] = list(labels)
614
+ else:
615
+ batch["labels"] = [np.concatenate([label, []]) for label in labels]
616
+ else:
617
+ max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None
618
+ max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length
619
+ if self.pad_to_multiple_of is not None:
620
+ max_label_length = (
621
+ (max_label_length + self.pad_to_multiple_of - 1)
622
+ // self.pad_to_multiple_of
623
+ * self.pad_to_multiple_of
624
+ )
625
+
626
+ padding_side = self.tokenizer.padding_side
627
+ if isinstance(features[0][label_name], list):
628
+ batch["labels"] = [
629
+ label + [self.label_pad_token_id] * (max_label_length - len(label))
630
+ if padding_side == "right"
631
+ else [self.label_pad_token_id] * (max_label_length - len(label)) + label
632
+ for label in labels
633
+ ]
634
+ else:
635
+ batch["labels"] = [
636
+ np.concatenate(
637
+ [
638
+ label,
639
+ np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64),
640
+ ]
641
+ )
642
+ if padding_side == "right"
643
+ else np.concatenate(
644
+ [
645
+ np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64),
646
+ label,
647
+ ]
648
+ )
649
+ for label in labels
650
+ ]
651
+
652
+ # reintroduce side effects via tokenizer that return respective datatypes for the `return_tensors` argument
653
+ if batch.get("labels", None) is not None:
654
+ if return_tensors == "pt":
655
+ import torch
656
+
657
+ batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
658
+ elif return_tensors == "tf":
659
+ import tensorflow as tf
660
+
661
+ batch["labels"] = tf.constant(batch["labels"], dtype=tf.int64)
662
+ else:
663
+ batch["labels"] = np.array(batch["labels"], dtype=np.int64)
664
+ else:
665
+ batch["labels"] = None
666
+
667
+ # prepare decoder_input_ids
668
+ if (
669
+ labels is not None
670
+ and self.model is not None
671
+ and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
672
+ ):
673
+ decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"])
674
+ batch["decoder_input_ids"] = decoder_input_ids
675
+
676
+ return batch
677
+
678
+
679
+ @dataclass
680
+ class DataCollatorForLanguageModeling(DataCollatorMixin):
681
+ """
682
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
683
+ are not all of the same length.
684
+
685
+ Args:
686
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
687
+ The tokenizer used for encoding the data.
688
+ mlm (`bool`, *optional*, defaults to `True`):
689
+ Whether or not to use masked language modeling. If set to `False`, the labels are the same as the inputs
690
+ with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for non-masked
691
+ tokens and the value to predict for the masked token.
692
+ mlm_probability (`float`, *optional*, defaults to 0.15):
693
+ The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
694
+ pad_to_multiple_of (`int`, *optional*):
695
+ If set will pad the sequence to a multiple of the provided value.
696
+ return_tensors (`str`):
697
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
698
+
699
+ <Tip>
700
+
701
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
702
+ BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a
703
+ [`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`.
704
+
705
+ </Tip>"""
706
+
707
+ tokenizer: PreTrainedTokenizerBase
708
+ mlm: bool = True
709
+ mlm_probability: float = 0.15
710
+ pad_to_multiple_of: Optional[int] = None
711
+ tf_experimental_compile: bool = False
712
+ return_tensors: str = "pt"
713
+
714
+ def __post_init__(self):
715
+ if self.mlm and self.tokenizer.mask_token is None:
716
+ raise ValueError(
717
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
718
+ "You should pass `mlm=False` to train on causal language modeling instead."
719
+ )
720
+ if self.tf_experimental_compile:
721
+ import tensorflow as tf
722
+
723
+ self.tf_mask_tokens = tf.function(self.tf_mask_tokens, jit_compile=True)
724
+
725
+ @staticmethod
726
+ def tf_bernoulli(shape, probability):
727
+ import tensorflow as tf
728
+
729
+ prob_matrix = tf.fill(shape, probability)
730
+ return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool)
731
+
732
+ def tf_mask_tokens(
733
+ self, inputs: Any, vocab_size, mask_token_id, special_tokens_mask: Optional[Any] = None
734
+ ) -> Tuple[Any, Any]:
735
+ """
736
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
737
+ """
738
+ import tensorflow as tf
739
+
740
+ mask_token_id = tf.cast(mask_token_id, inputs.dtype)
741
+
742
+ input_shape = tf.shape(inputs)
743
+ # 1 for a special token, 0 for a normal token in the special tokens mask
744
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
745
+ masked_indices = self.tf_bernoulli(input_shape, self.mlm_probability) & ~special_tokens_mask
746
+ # Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
747
+ labels = tf.where(masked_indices, inputs, -100)
748
+
749
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
750
+ indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
751
+
752
+ inputs = tf.where(indices_replaced, mask_token_id, inputs)
753
+
754
+ # 10% of the time, we replace masked input tokens with random word
755
+ indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
756
+ random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)
757
+
758
+ inputs = tf.where(indices_random, random_words, inputs)
759
+
760
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
761
+ return inputs, labels
762
+
763
+ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
764
+ import tensorflow as tf
765
+
766
+ # Handle dict or lists with proper padding and conversion to tensor.
767
+ if isinstance(examples[0], Mapping):
768
+ batch = pad_without_fast_tokenizer_warning(
769
+ self.tokenizer, examples, return_tensors="tf", pad_to_multiple_of=self.pad_to_multiple_of
770
+ )
771
+ else:
772
+ batch = {
773
+ "input_ids": _tf_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
774
+ }
775
+
776
+ # If special token mask has been preprocessed, pop it from the dict.
777
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
778
+ if self.mlm:
779
+ if special_tokens_mask is None:
780
+ special_tokens_mask = [
781
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
782
+ for val in batch["input_ids"].numpy().tolist()
783
+ ]
784
+ # Cannot directly create as bool
785
+ special_tokens_mask = tf.cast(tf.convert_to_tensor(special_tokens_mask, dtype=tf.int64), tf.bool)
786
+ else:
787
+ special_tokens_mask = tf.cast(special_tokens_mask, tf.bool)
788
+ batch["input_ids"], batch["labels"] = self.tf_mask_tokens(
789
+ tf.cast(batch["input_ids"], tf.int64),
790
+ special_tokens_mask=special_tokens_mask,
791
+ mask_token_id=self.tokenizer.mask_token_id,
792
+ vocab_size=len(self.tokenizer),
793
+ )
794
+ else:
795
+ labels = batch["input_ids"]
796
+ if self.tokenizer.pad_token_id is not None:
797
+ # Replace self.tokenizer.pad_token_id with -100
798
+ labels = tf.where(labels == self.tokenizer.pad_token_id, -100, labels)
799
+ else:
800
+ labels = tf.identity(labels) # Makes a copy, just in case
801
+ batch["labels"] = labels
802
+ return batch
803
+
804
+ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
805
+ # Handle dict or lists with proper padding and conversion to tensor.
806
+ if isinstance(examples[0], Mapping):
807
+ batch = pad_without_fast_tokenizer_warning(
808
+ self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of
809
+ )
810
+ else:
811
+ batch = {
812
+ "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
813
+ }
814
+
815
+ # If special token mask has been preprocessed, pop it from the dict.
816
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
817
+ if self.mlm:
818
+ batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
819
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
820
+ )
821
+ else:
822
+ labels = batch["input_ids"].clone()
823
+ if self.tokenizer.pad_token_id is not None:
824
+ labels[labels == self.tokenizer.pad_token_id] = -100
825
+ batch["labels"] = labels
826
+ return batch
827
+
828
+ def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
829
+ """
830
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
831
+ """
832
+ import torch
833
+
834
+ labels = inputs.clone()
835
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
836
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
837
+ if special_tokens_mask is None:
838
+ special_tokens_mask = [
839
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
840
+ ]
841
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
842
+ else:
843
+ special_tokens_mask = special_tokens_mask.bool()
844
+
845
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
846
+ masked_indices = torch.bernoulli(probability_matrix).bool()
847
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
848
+
849
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
850
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
851
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
852
+
853
+ # 10% of the time, we replace masked input tokens with random word
854
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
855
+ random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
856
+ inputs[indices_random] = random_words[indices_random]
857
+
858
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
859
+ return inputs, labels
860
+
861
+ def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
862
+ # Handle dict or lists with proper padding and conversion to tensor.
863
+ if isinstance(examples[0], Mapping):
864
+ batch = pad_without_fast_tokenizer_warning(
865
+ self.tokenizer, examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of
866
+ )
867
+ else:
868
+ batch = {
869
+ "input_ids": _numpy_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
870
+ }
871
+
872
+ # If special token mask has been preprocessed, pop it from the dict.
873
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
874
+ if self.mlm:
875
+ batch["input_ids"], batch["labels"] = self.numpy_mask_tokens(
876
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
877
+ )
878
+ else:
879
+ labels = np.copy(batch["input_ids"])
880
+ if self.tokenizer.pad_token_id is not None:
881
+ labels[labels == self.tokenizer.pad_token_id] = -100
882
+ batch["labels"] = labels
883
+ return batch
884
+
885
+ def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
886
+ """
887
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
888
+ """
889
+ labels = np.copy(inputs)
890
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
891
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
892
+ if special_tokens_mask is None:
893
+ special_tokens_mask = [
894
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
895
+ ]
896
+ special_tokens_mask = np.array(special_tokens_mask, dtype=bool)
897
+ else:
898
+ special_tokens_mask = special_tokens_mask.astype(bool)
899
+
900
+ probability_matrix[special_tokens_mask] = 0
901
+ # Numpy doesn't have bernoulli, so we use a binomial with 1 trial
902
+ masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
903
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
904
+
905
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
906
+ indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
907
+ inputs[indices_replaced] = self.tokenizer.mask_token_id
908
+
909
+ # 10% of the time, we replace masked input tokens with random word
910
+ # indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
911
+ indices_random = (
912
+ np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
913
+ )
914
+ random_words = np.random.randint(
915
+ low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
916
+ )
917
+ inputs[indices_random] = random_words
918
+
919
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
920
+ return inputs, labels
921
+
922
+
923
+ @dataclass
924
+ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
925
+ """
926
+ Data collator used for language modeling that masks entire words.
927
+
928
+ - collates batches of tensors, honoring their tokenizer's pad_token
929
+ - preprocesses batches for masked language modeling
930
+
931
+ <Tip>
932
+
933
+ This collator relies on details of the implementation of subword tokenization by [`BertTokenizer`], specifically
934
+ that subword tokens are prefixed with *##*. For tokenizers that do not adhere to this scheme, this collator will
935
+ produce an output that is roughly equivalent to [`.DataCollatorForLanguageModeling`].
936
+
937
+ </Tip>"""
938
+
939
+ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
940
+ if isinstance(examples[0], Mapping):
941
+ input_ids = [e["input_ids"] for e in examples]
942
+ else:
943
+ input_ids = examples
944
+ examples = [{"input_ids": e} for e in examples]
945
+
946
+ batch_input = _torch_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
947
+
948
+ mask_labels = []
949
+ for e in examples:
950
+ ref_tokens = []
951
+ for id in tolist(e["input_ids"]):
952
+ token = self.tokenizer._convert_id_to_token(id)
953
+ ref_tokens.append(token)
954
+
955
+ # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
956
+ if "chinese_ref" in e:
957
+ ref_pos = tolist(e["chinese_ref"])
958
+ len_seq = len(e["input_ids"])
959
+ for i in range(len_seq):
960
+ if i in ref_pos:
961
+ ref_tokens[i] = "##" + ref_tokens[i]
962
+ mask_labels.append(self._whole_word_mask(ref_tokens))
963
+ batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
964
+ inputs, labels = self.torch_mask_tokens(batch_input, batch_mask)
965
+ return {"input_ids": inputs, "labels": labels}
966
+
967
+ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
968
+ import tensorflow as tf
969
+
970
+ if isinstance(examples[0], Mapping):
971
+ input_ids = [e["input_ids"] for e in examples]
972
+ else:
973
+ input_ids = examples
974
+ examples = [{"input_ids": e} for e in examples]
975
+
976
+ batch_input = _tf_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
977
+
978
+ mask_labels = []
979
+ for e in examples:
980
+ ref_tokens = []
981
+ for id in tolist(e["input_ids"]):
982
+ token = self.tokenizer._convert_id_to_token(id)
983
+ ref_tokens.append(token)
984
+
985
+ # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
986
+ if "chinese_ref" in e:
987
+ ref_pos = tolist(e["chinese_ref"])
988
+ len_seq = len(e["input_ids"])
989
+ for i in range(len_seq):
990
+ if i in ref_pos:
991
+ ref_tokens[i] = "##" + ref_tokens[i]
992
+ mask_labels.append(self._whole_word_mask(ref_tokens))
993
+ batch_mask = _tf_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
994
+ inputs, labels = self.tf_mask_tokens(tf.cast(batch_input, tf.int64), batch_mask)
995
+ return {"input_ids": inputs, "labels": labels}
996
+
997
+ def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
998
+ if isinstance(examples[0], Mapping):
999
+ input_ids = [e["input_ids"] for e in examples]
1000
+ else:
1001
+ input_ids = examples
1002
+ examples = [{"input_ids": e} for e in examples]
1003
+
1004
+ batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
1005
+
1006
+ mask_labels = []
1007
+ for e in examples:
1008
+ ref_tokens = []
1009
+ for id in tolist(e["input_ids"]):
1010
+ token = self.tokenizer._convert_id_to_token(id)
1011
+ ref_tokens.append(token)
1012
+
1013
+ # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
1014
+ if "chinese_ref" in e:
1015
+ ref_pos = tolist(e["chinese_ref"])
1016
+ len_seq = len(e["input_ids"])
1017
+ for i in range(len_seq):
1018
+ if i in ref_pos:
1019
+ ref_tokens[i] = "##" + ref_tokens[i]
1020
+ mask_labels.append(self._whole_word_mask(ref_tokens))
1021
+ batch_mask = _numpy_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
1022
+ inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask)
1023
+ return {"input_ids": inputs, "labels": labels}
1024
+
1025
+ def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
1026
+ """
1027
+ Get 0/1 labels for masked tokens with whole word mask proxy
1028
+ """
1029
+ if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
1030
+ warnings.warn(
1031
+ "DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
1032
+ "Please refer to the documentation for more information."
1033
+ )
1034
+
1035
+ cand_indexes = []
1036
+ for i, token in enumerate(input_tokens):
1037
+ if token == "[CLS]" or token == "[SEP]":
1038
+ continue
1039
+
1040
+ if len(cand_indexes) >= 1 and token.startswith("##"):
1041
+ cand_indexes[-1].append(i)
1042
+ else:
1043
+ cand_indexes.append([i])
1044
+
1045
+ random.shuffle(cand_indexes)
1046
+ num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
1047
+ masked_lms = []
1048
+ covered_indexes = set()
1049
+ for index_set in cand_indexes:
1050
+ if len(masked_lms) >= num_to_predict:
1051
+ break
1052
+ # If adding a whole-word mask would exceed the maximum number of
1053
+ # predictions, then just skip this candidate.
1054
+ if len(masked_lms) + len(index_set) > num_to_predict:
1055
+ continue
1056
+ is_any_index_covered = False
1057
+ for index in index_set:
1058
+ if index in covered_indexes:
1059
+ is_any_index_covered = True
1060
+ break
1061
+ if is_any_index_covered:
1062
+ continue
1063
+ for index in index_set:
1064
+ covered_indexes.add(index)
1065
+ masked_lms.append(index)
1066
+
1067
+ if len(covered_indexes) != len(masked_lms):
1068
+ raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
1069
+ mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
1070
+ return mask_labels
1071
+
1072
+ def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
1073
+ """
1074
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
1075
+ 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
1076
+ """
1077
+ import torch
1078
+
1079
+ if self.tokenizer.mask_token is None:
1080
+ raise ValueError(
1081
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
1082
+ " --mlm flag if you want to use this tokenizer."
1083
+ )
1084
+ labels = inputs.clone()
1085
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
1086
+
1087
+ probability_matrix = mask_labels
1088
+
1089
+ special_tokens_mask = [
1090
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
1091
+ ]
1092
+ probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
1093
+ if self.tokenizer.pad_token is not None:
1094
+ padding_mask = labels.eq(self.tokenizer.pad_token_id)
1095
+ probability_matrix.masked_fill_(padding_mask, value=0.0)
1096
+
1097
+ masked_indices = probability_matrix.bool()
1098
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
1099
+
1100
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1101
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
1102
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
1103
+
1104
+ # 10% of the time, we replace masked input tokens with random word
1105
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
1106
+ random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
1107
+ inputs[indices_random] = random_words[indices_random]
1108
+
1109
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
1110
+ return inputs, labels
1111
+
1112
+ def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
1113
+ """
1114
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
1115
+ 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
1116
+ """
1117
+ import tensorflow as tf
1118
+
1119
+ input_shape = tf.shape(inputs)
1120
+ if self.tokenizer.mask_token is None:
1121
+ raise ValueError(
1122
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
1123
+ " --mlm flag if you want to use this tokenizer."
1124
+ )
1125
+ labels = tf.identity(inputs)
1126
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
1127
+
1128
+ masked_indices = tf.cast(mask_labels, tf.bool)
1129
+
1130
+ special_tokens_mask = [
1131
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels
1132
+ ]
1133
+ masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool)
1134
+ if self.tokenizer.pad_token is not None:
1135
+ padding_mask = inputs == self.tokenizer.pad_token_id
1136
+ masked_indices = masked_indices & ~padding_mask
1137
+
1138
+ # Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
1139
+ labels = tf.where(masked_indices, inputs, -100)
1140
+
1141
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1142
+ indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
1143
+
1144
+ inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs)
1145
+
1146
+ # 10% of the time, we replace masked input tokens with random word
1147
+ indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
1148
+ random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
1149
+ inputs = tf.where(indices_random, random_words, inputs)
1150
+
1151
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
1152
+ return inputs, labels
1153
+
1154
+ def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
1155
+ """
1156
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
1157
+ 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
1158
+ """
1159
+ if self.tokenizer.mask_token is None:
1160
+ raise ValueError(
1161
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
1162
+ " --mlm flag if you want to use this tokenizer."
1163
+ )
1164
+ labels = np.copy(inputs)
1165
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
1166
+
1167
+ masked_indices = mask_labels.astype(bool)
1168
+
1169
+ special_tokens_mask = [
1170
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
1171
+ ]
1172
+ masked_indices[np.array(special_tokens_mask, dtype=bool)] = 0
1173
+ if self.tokenizer.pad_token is not None:
1174
+ padding_mask = labels == self.tokenizer.pad_token_id
1175
+ masked_indices[padding_mask] = 0
1176
+
1177
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
1178
+
1179
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1180
+ indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
1181
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
1182
+
1183
+ # 10% of the time, we replace masked input tokens with random word
1184
+ # indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
1185
+ indices_random = (
1186
+ np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
1187
+ )
1188
+ random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
1189
+ inputs[indices_random] = random_words[indices_random]
1190
+
1191
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
1192
+ return inputs, labels
1193
+
1194
+
1195
+ @dataclass
1196
+ class DataCollatorForSOP(DataCollatorForLanguageModeling):
1197
+ """
1198
+ Data collator used for sentence order prediction task.
1199
+
1200
+ - collates batches of tensors, honoring their tokenizer's pad_token
1201
+ - preprocesses batches for both masked language modeling and sentence order prediction
1202
+ """
1203
+
1204
+ def __init__(self, *args, **kwargs):
1205
+ warnings.warn(
1206
+ "DataCollatorForSOP is deprecated and will be removed in a future version, you can now use "
1207
+ "DataCollatorForLanguageModeling instead.",
1208
+ FutureWarning,
1209
+ )
1210
+
1211
+ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]:
1212
+ import torch
1213
+ from torch.nn.utils.rnn import pad_sequence
1214
+
1215
+ input_ids = [example["input_ids"] for example in examples]
1216
+ input_ids = _torch_collate_batch(input_ids, self.tokenizer)
1217
+ input_ids, labels, attention_mask = self.mask_tokens(input_ids)
1218
+
1219
+ token_type_ids = [example["token_type_ids"] for example in examples]
1220
+ # size of segment_ids varied because randomness, padding zero to the end as the original implementation
1221
+ token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
1222
+
1223
+ sop_label_list = [example["sentence_order_label"] for example in examples]
1224
+ sentence_order_label = torch.stack(sop_label_list)
1225
+
1226
+ return {
1227
+ "input_ids": input_ids,
1228
+ "labels": labels,
1229
+ "attention_mask": attention_mask,
1230
+ "token_type_ids": token_type_ids,
1231
+ "sentence_order_label": sentence_order_label,
1232
+ }
1233
+
1234
+ def mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any]:
1235
+ """
1236
+ Prepare masked tokens inputs/labels/attention_mask for masked language modeling: 80% MASK, 10% random, 10%
1237
+ original. N-gram not applied yet.
1238
+ """
1239
+ import torch
1240
+
1241
+ if self.tokenizer.mask_token is None:
1242
+ raise ValueError(
1243
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
1244
+ " --mlm flag if you want to use this tokenizer."
1245
+ )
1246
+
1247
+ labels = inputs.clone()
1248
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
1249
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
1250
+ special_tokens_mask = [
1251
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
1252
+ ]
1253
+ probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
1254
+ if self.tokenizer.pad_token is not None:
1255
+ padding_mask = labels.eq(self.tokenizer.pad_token_id)
1256
+ probability_matrix.masked_fill_(padding_mask, value=0.0)
1257
+ masked_indices = torch.bernoulli(probability_matrix).bool()
1258
+ # probability be `1` (masked), however in albert model attention mask `0` means masked, revert the value
1259
+ attention_mask = (~masked_indices).float()
1260
+ if self.tokenizer.pad_token is not None:
1261
+ attention_padding_mask = labels.eq(self.tokenizer.pad_token_id)
1262
+ attention_mask.masked_fill_(attention_padding_mask, value=1.0)
1263
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens, -100 is default for CE compute
1264
+
1265
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1266
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
1267
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
1268
+
1269
+ # 10% of the time, we replace masked input tokens with random word
1270
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
1271
+ random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
1272
+ inputs[indices_random] = random_words[indices_random]
1273
+
1274
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
1275
+ return inputs, labels, attention_mask
1276
+
1277
+
1278
+ @dataclass
1279
+ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
1280
+ """
1281
+ Data collator used for permutation language modeling.
1282
+
1283
+ - collates batches of tensors, honoring their tokenizer's pad_token
1284
+ - preprocesses batches for permutation language modeling with procedures specific to XLNet
1285
+ """
1286
+
1287
+ tokenizer: PreTrainedTokenizerBase
1288
+ plm_probability: float = 1 / 6
1289
+ max_span_length: int = 5 # maximum length of a span of masked tokens
1290
+ return_tensors: str = "pt"
1291
+
1292
+ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
1293
+ if isinstance(examples[0], Mapping):
1294
+ examples = [e["input_ids"] for e in examples]
1295
+ batch = _torch_collate_batch(examples, self.tokenizer)
1296
+ inputs, perm_mask, target_mapping, labels = self.torch_mask_tokens(batch)
1297
+ return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
1298
+
1299
+ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
1300
+ if isinstance(examples[0], Mapping):
1301
+ examples = [e["input_ids"] for e in examples]
1302
+ batch = _tf_collate_batch(examples, self.tokenizer)
1303
+ inputs, perm_mask, target_mapping, labels = self.tf_mask_tokens(batch)
1304
+ return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
1305
+
1306
+ def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
1307
+ if isinstance(examples[0], Mapping):
1308
+ examples = [e["input_ids"] for e in examples]
1309
+ batch = _numpy_collate_batch(examples, self.tokenizer)
1310
+ inputs, perm_mask, target_mapping, labels = self.numpy_mask_tokens(batch)
1311
+ return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
1312
+
1313
+ def torch_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
1314
+ """
1315
+ The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
1316
+
1317
+ 0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1318
+ 1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1319
+ 2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
1320
+ masked
1321
+ 3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
1322
+ span_length]` and mask tokens `start_index:start_index + span_length`
1323
+ 4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
1324
+ sequence to be processed), repeat from Step 1.
1325
+ """
1326
+ import torch
1327
+
1328
+ if self.tokenizer.mask_token is None:
1329
+ raise ValueError(
1330
+ "This tokenizer does not have a mask token which is necessary for permutation language modeling."
1331
+ " Please add a mask token if you want to use this tokenizer."
1332
+ )
1333
+
1334
+ if inputs.size(1) % 2 != 0:
1335
+ raise ValueError(
1336
+ "This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
1337
+ " relevant comments in source code for details."
1338
+ )
1339
+
1340
+ labels = inputs.clone()
1341
+ # Creating the mask and target_mapping tensors
1342
+ masked_indices = torch.full(labels.shape, 0, dtype=torch.bool)
1343
+ target_mapping = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
1344
+
1345
+ for i in range(labels.size(0)):
1346
+ # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1347
+ cur_len = 0
1348
+ max_len = labels.size(1)
1349
+
1350
+ while cur_len < max_len:
1351
+ # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1352
+ span_length = torch.randint(1, self.max_span_length + 1, (1,)).item()
1353
+ # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
1354
+ context_length = int(span_length / self.plm_probability)
1355
+ # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
1356
+ start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item()
1357
+ masked_indices[i, start_index : start_index + span_length] = 1
1358
+ # Set `cur_len = cur_len + context_length`
1359
+ cur_len += context_length
1360
+
1361
+ # Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
1362
+ # the i-th predict corresponds to the i-th token.
1363
+ target_mapping[i] = torch.eye(labels.size(1))
1364
+
1365
+ special_tokens_mask = torch.tensor(
1366
+ [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
1367
+ dtype=torch.bool,
1368
+ )
1369
+ masked_indices.masked_fill_(special_tokens_mask, value=0.0)
1370
+ if self.tokenizer.pad_token is not None:
1371
+ padding_mask = labels.eq(self.tokenizer.pad_token_id)
1372
+ masked_indices.masked_fill_(padding_mask, value=0.0)
1373
+
1374
+ # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
1375
+ non_func_mask = ~(padding_mask | special_tokens_mask)
1376
+
1377
+ inputs[masked_indices] = self.tokenizer.mask_token_id
1378
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
1379
+
1380
+ perm_mask = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
1381
+
1382
+ for i in range(labels.size(0)):
1383
+ # Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
1384
+ # determine which tokens a given token can attend to (encoded in `perm_mask`).
1385
+ # Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
1386
+ # (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
1387
+ # we assume that reused length is half of sequence length and permutation length is equal to reused length.
1388
+ # This requires that the sequence length be even.
1389
+
1390
+ # Create a linear factorisation order
1391
+ perm_index = torch.arange(labels.size(1))
1392
+ # Split this into two halves, assuming that half the sequence is reused each time
1393
+ perm_index = perm_index.reshape((-1, labels.size(1) // 2)).transpose(0, 1)
1394
+ # Permute the two halves such that they do not cross over
1395
+ perm_index = perm_index[torch.randperm(labels.size(1) // 2)]
1396
+ # Flatten this out into the desired permuted factorisation order
1397
+ perm_index = torch.flatten(perm_index.transpose(0, 1))
1398
+ # Set the permutation indices of non-masked (non-functional) tokens to the
1399
+ # smallest index (-1) so that:
1400
+ # (1) They can be seen by all other positions
1401
+ # (2) They cannot see masked positions, so there won't be information leak
1402
+ perm_index.masked_fill_(~masked_indices[i] & non_func_mask[i], -1)
1403
+ # The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
1404
+ # 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
1405
+ # 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
1406
+ perm_mask[i] = (
1407
+ perm_index.reshape((labels.size(1), 1)) <= perm_index.reshape((1, labels.size(1)))
1408
+ ) & masked_indices[i]
1409
+
1410
+ return inputs.long(), perm_mask, target_mapping, labels.long()
1411
+
1412
+ def tf_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
1413
+ """
1414
+ The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
1415
+
1416
+ 0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1417
+ 1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1418
+ 2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
1419
+ masked
1420
+ 3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
1421
+ span_length]` and mask tokens `start_index:start_index + span_length`
1422
+ 4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
1423
+ sequence to be processed), repeat from Step 1.
1424
+ """
1425
+ import tensorflow as tf
1426
+
1427
+ if self.tokenizer.mask_token is None:
1428
+ raise ValueError(
1429
+ "This tokenizer does not have a mask token which is necessary for permutation language modeling."
1430
+ " Please add a mask token if you want to use this tokenizer."
1431
+ )
1432
+
1433
+ if tf.shape(inputs)[1] % 2 != 0:
1434
+ raise ValueError(
1435
+ "This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
1436
+ " relevant comments in source code for details."
1437
+ )
1438
+
1439
+ labels = tf.identity(inputs)
1440
+ # Creating the mask and target_mapping tensors
1441
+ masked_indices = np.full(labels.shape.as_list(), 0, dtype=bool)
1442
+ labels_shape = tf.shape(labels)
1443
+ target_mapping = np.zeros((labels_shape[0], labels_shape[1], labels_shape[1]), dtype=np.float32)
1444
+
1445
+ for i in range(len(labels)):
1446
+ # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1447
+ cur_len = 0
1448
+ max_len = tf.shape(labels)[1]
1449
+
1450
+ while cur_len < max_len:
1451
+ # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1452
+ span_length = randint(1, self.max_span_length + 1)
1453
+ # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
1454
+ context_length = int(span_length / self.plm_probability)
1455
+ # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
1456
+ start_index = cur_len + randint(0, context_length - span_length + 1)
1457
+ masked_indices[i, start_index : start_index + span_length] = 1
1458
+ # Set `cur_len = cur_len + context_length`
1459
+ cur_len += context_length
1460
+
1461
+ # Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
1462
+ # the i-th predict corresponds to the i-th token.
1463
+ target_mapping[i] = np.eye(labels_shape[1])
1464
+ masked_indices = tf.cast(tf.convert_to_tensor(masked_indices), dtype=tf.bool)
1465
+ target_mapping = tf.convert_to_tensor(target_mapping)
1466
+ special_tokens_mask = tf.convert_to_tensor(
1467
+ [
1468
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
1469
+ for val in labels.numpy().tolist()
1470
+ ],
1471
+ )
1472
+ special_tokens_mask = tf.cast(special_tokens_mask, dtype=tf.bool)
1473
+ masked_indices = masked_indices & ~special_tokens_mask
1474
+ if self.tokenizer.pad_token is not None:
1475
+ padding_mask = labels == self.tokenizer.pad_token_id
1476
+ masked_indices = masked_indices & ~padding_mask
1477
+
1478
+ # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
1479
+ non_func_mask = ~(padding_mask | special_tokens_mask)
1480
+
1481
+ inputs = tf.where(masked_indices, self.tokenizer.mask_token_id, inputs)
1482
+ labels = tf.where(masked_indices, labels, -100) # We only compute loss on masked tokens
1483
+
1484
+ perm_mask = []
1485
+
1486
+ for i in range(len(labels)):
1487
+ # Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
1488
+ # determine which tokens a given token can attend to (encoded in `perm_mask`).
1489
+ # Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
1490
+ # (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
1491
+ # we assume that reused length is half of sequence length and permutation length is equal to reused length.
1492
+ # This requires that the sequence length be even.
1493
+
1494
+ # Create a linear factorisation order
1495
+ # tf.range is the equivalent of torch.arange
1496
+ perm_index = tf.range(labels_shape[1])
1497
+ # Split this into two halves, assuming that half the sequence is reused each time
1498
+ perm_index = tf.transpose(tf.reshape(perm_index, (-1, labels_shape[1] // 2)))
1499
+ # Permute the two halves such that they do not cross over
1500
+ perm_index = tf.random.shuffle(perm_index) # Shuffles along the first dimension
1501
+ # Flatten this out into the desired permuted factorisation order
1502
+ perm_index = tf.reshape(tf.transpose(perm_index), (-1,))
1503
+ # Set the permutation indices of non-masked (non-functional) tokens to the
1504
+ # smallest index (-1) so that:
1505
+ # (1) They can be seen by all other positions
1506
+ # (2) They cannot see masked positions, so there won't be information leak
1507
+ perm_index = tf.where(~masked_indices[i] & non_func_mask[i], -1, perm_index)
1508
+ # The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
1509
+ # 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
1510
+ # 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
1511
+ perm_mask.append(
1512
+ (tf.reshape(perm_index, (labels_shape[1], 1)) <= tf.reshape(perm_index, (1, labels_shape[1])))
1513
+ & masked_indices[i]
1514
+ )
1515
+ perm_mask = tf.stack(perm_mask, axis=0)
1516
+
1517
+ return tf.cast(inputs, tf.int64), tf.cast(perm_mask, tf.float32), target_mapping, tf.cast(labels, tf.int64)
1518
+
1519
+ def numpy_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
1520
+ """
1521
+ The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
1522
+
1523
+ 0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1524
+ 1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1525
+ 2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
1526
+ masked
1527
+ 3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
1528
+ span_length]` and mask tokens `start_index:start_index + span_length`
1529
+ 4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
1530
+ sequence to be processed), repeat from Step 1.
1531
+ """
1532
+ if self.tokenizer.mask_token is None:
1533
+ raise ValueError(
1534
+ "This tokenizer does not have a mask token which is necessary for permutation language modeling."
1535
+ " Please add a mask token if you want to use this tokenizer."
1536
+ )
1537
+
1538
+ if inputs.shape[1] % 2 != 0:
1539
+ raise ValueError(
1540
+ "This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
1541
+ " relevant comments in source code for details."
1542
+ )
1543
+
1544
+ labels = np.copy(inputs)
1545
+ # Creating the mask and target_mapping tensors
1546
+ masked_indices = np.full(labels.shape, 0, dtype=bool)
1547
+ target_mapping = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)
1548
+
1549
+ for i in range(labels.shape[0]):
1550
+ # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1551
+ cur_len = 0
1552
+ max_len = labels.shape[1]
1553
+
1554
+ while cur_len < max_len:
1555
+ # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1556
+ span_length = randint(1, self.max_span_length + 1)
1557
+ # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
1558
+ context_length = int(span_length / self.plm_probability)
1559
+ # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
1560
+ start_index = cur_len + randint(0, context_length - span_length + 1)
1561
+ masked_indices[i, start_index : start_index + span_length] = 1
1562
+ # Set `cur_len = cur_len + context_length`
1563
+ cur_len += context_length
1564
+
1565
+ # Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
1566
+ # the i-th predict corresponds to the i-th token.
1567
+ target_mapping[i] = np.eye(labels.shape[1])
1568
+
1569
+ special_tokens_mask = np.array(
1570
+ [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
1571
+ dtype=bool,
1572
+ )
1573
+ masked_indices[special_tokens_mask] = 0
1574
+ if self.tokenizer.pad_token is not None:
1575
+ padding_mask = labels == self.tokenizer.pad_token_id
1576
+ masked_indices[padding_mask] = 0.0
1577
+
1578
+ # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
1579
+ non_func_mask = ~(padding_mask | special_tokens_mask)
1580
+
1581
+ inputs[masked_indices] = self.tokenizer.mask_token_id
1582
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
1583
+
1584
+ perm_mask = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)
1585
+
1586
+ for i in range(labels.shape[0]):
1587
+ # Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
1588
+ # determine which tokens a given token can attend to (encoded in `perm_mask`).
1589
+ # Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
1590
+ # (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
1591
+ # we assume that reused length is half of sequence length and permutation length is equal to reused length.
1592
+ # This requires that the sequence length be even.
1593
+
1594
+ # Create a linear factorisation order
1595
+ perm_index = np.arange(labels.shape[1])
1596
+ # Split this into two halves, assuming that half the sequence is reused each time
1597
+ perm_index = perm_index.reshape((-1, labels.shape[1] // 2)).T
1598
+ # Permute the two halves such that they do not cross over
1599
+ np.random.shuffle(perm_index)
1600
+ # Flatten this out into the desired permuted factorisation order
1601
+ perm_index = perm_index.T.flatten()
1602
+ # Set the permutation indices of non-masked (non-functional) tokens to the
1603
+ # smallest index (-1) so that:
1604
+ # (1) They can be seen by all other positions
1605
+ # (2) They cannot see masked positions, so there won't be information leak
1606
+ perm_index[~masked_indices[i] & non_func_mask[i]] = -1
1607
+ # The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
1608
+ # 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
1609
+ # 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
1610
+ perm_mask[i] = (
1611
+ perm_index.reshape((labels.shape[1], 1)) <= perm_index.reshape((1, labels.shape[1]))
1612
+ ) & masked_indices[i]
1613
+
1614
+ return inputs.astype(np.int64), perm_mask, target_mapping, labels.astype(np.int64)
1615
+
1616
+
1617
+ @dataclass
1618
+ class DataCollatorWithFlattening(DefaultDataCollator):
1619
+ """
1620
+ Data collator used for padding free approach. Does the following:
1621
+
1622
+ - concatate the entire mini batch into single long sequence [1, total_tokens]
1623
+ - uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100
1624
+ - no padding will be added, returns `input_ids`, `labels` and `position_ids`
1625
+ """
1626
+
1627
+ def __init__(self, *args, return_position_ids=True, separator_id=-100, **kwargs):
1628
+ super().__init__(*args, **kwargs)
1629
+ self.return_position_ids = return_position_ids
1630
+ self.separator_id = separator_id
1631
+ warnings.warn(
1632
+ "Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence."
1633
+ "Make sure your attention computation is able to handle it!"
1634
+ )
1635
+
1636
+ def __call__(self, features, return_tensors=None, separator_id=None):
1637
+ if return_tensors is None:
1638
+ return_tensors = self.return_tensors
1639
+ if separator_id is None:
1640
+ separator_id = self.separator_id
1641
+ is_labels_provided = "labels" in features[0]
1642
+ ret = {"input_ids": [], "labels": []}
1643
+ if self.return_position_ids:
1644
+ ret.update({"position_ids": []})
1645
+ for idx in range(0, len(features)):
1646
+ ret["input_ids"] += features[idx]["input_ids"]
1647
+ if is_labels_provided:
1648
+ ret["labels"] += [separator_id] + features[idx]["labels"][1:]
1649
+ else:
1650
+ ret["labels"] += [separator_id] + features[idx]["input_ids"][1:]
1651
+ if self.return_position_ids:
1652
+ ret["position_ids"] += list(range(len(features[idx]["input_ids"])))
1653
+ return default_data_collator([ret], return_tensors)
.venv/Lib/site-packages/transformers/data/datasets/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .glue import GlueDataset, GlueDataTrainingArguments
16
+ from .language_modeling import (
17
+ LineByLineTextDataset,
18
+ LineByLineWithRefDataset,
19
+ LineByLineWithSOPTextDataset,
20
+ TextDataset,
21
+ TextDatasetForNextSentencePrediction,
22
+ )
23
+ from .squad import SquadDataset, SquadDataTrainingArguments
.venv/Lib/site-packages/transformers/data/datasets/glue.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import time
17
+ import warnings
18
+ from dataclasses import dataclass, field
19
+ from enum import Enum
20
+ from typing import List, Optional, Union
21
+
22
+ import torch
23
+ from filelock import FileLock
24
+ from torch.utils.data import Dataset
25
+
26
+ from ...tokenization_utils_base import PreTrainedTokenizerBase
27
+ from ...utils import logging
28
+ from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
29
+ from ..processors.utils import InputFeatures
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ @dataclass
36
+ class GlueDataTrainingArguments:
37
+ """
38
+ Arguments pertaining to what data we are going to input our model for training and eval.
39
+
40
+ Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command
41
+ line.
42
+ """
43
+
44
+ task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
45
+ data_dir: str = field(
46
+ metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
47
+ )
48
+ max_seq_length: int = field(
49
+ default=128,
50
+ metadata={
51
+ "help": (
52
+ "The maximum total input sequence length after tokenization. Sequences longer "
53
+ "than this will be truncated, sequences shorter will be padded."
54
+ )
55
+ },
56
+ )
57
+ overwrite_cache: bool = field(
58
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
59
+ )
60
+
61
+ def __post_init__(self):
62
+ self.task_name = self.task_name.lower()
63
+
64
+
65
+ class Split(Enum):
66
+ train = "train"
67
+ dev = "dev"
68
+ test = "test"
69
+
70
+
71
+ class GlueDataset(Dataset):
72
+ """
73
+ This will be superseded by a framework-agnostic approach soon.
74
+ """
75
+
76
+ args: GlueDataTrainingArguments
77
+ output_mode: str
78
+ features: List[InputFeatures]
79
+
80
+ def __init__(
81
+ self,
82
+ args: GlueDataTrainingArguments,
83
+ tokenizer: PreTrainedTokenizerBase,
84
+ limit_length: Optional[int] = None,
85
+ mode: Union[str, Split] = Split.train,
86
+ cache_dir: Optional[str] = None,
87
+ ):
88
+ warnings.warn(
89
+ "This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
90
+ "library. You can have a look at this example script for pointers: "
91
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py",
92
+ FutureWarning,
93
+ )
94
+ self.args = args
95
+ self.processor = glue_processors[args.task_name]()
96
+ self.output_mode = glue_output_modes[args.task_name]
97
+ if isinstance(mode, str):
98
+ try:
99
+ mode = Split[mode]
100
+ except KeyError:
101
+ raise KeyError("mode is not a valid split name")
102
+ # Load data features from cache or dataset file
103
+ cached_features_file = os.path.join(
104
+ cache_dir if cache_dir is not None else args.data_dir,
105
+ f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{args.task_name}",
106
+ )
107
+ label_list = self.processor.get_labels()
108
+ if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__.__name__ in (
109
+ "RobertaTokenizer",
110
+ "RobertaTokenizerFast",
111
+ "XLMRobertaTokenizer",
112
+ "BartTokenizer",
113
+ "BartTokenizerFast",
114
+ ):
115
+ # HACK(label indices are swapped in RoBERTa pretrained model)
116
+ label_list[1], label_list[2] = label_list[2], label_list[1]
117
+ self.label_list = label_list
118
+
119
+ # Make sure only the first process in distributed training processes the dataset,
120
+ # and the others will use the cache.
121
+ lock_path = cached_features_file + ".lock"
122
+ with FileLock(lock_path):
123
+ if os.path.exists(cached_features_file) and not args.overwrite_cache:
124
+ start = time.time()
125
+ self.features = torch.load(cached_features_file)
126
+ logger.info(
127
+ f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
128
+ )
129
+ else:
130
+ logger.info(f"Creating features from dataset file at {args.data_dir}")
131
+
132
+ if mode == Split.dev:
133
+ examples = self.processor.get_dev_examples(args.data_dir)
134
+ elif mode == Split.test:
135
+ examples = self.processor.get_test_examples(args.data_dir)
136
+ else:
137
+ examples = self.processor.get_train_examples(args.data_dir)
138
+ if limit_length is not None:
139
+ examples = examples[:limit_length]
140
+ self.features = glue_convert_examples_to_features(
141
+ examples,
142
+ tokenizer,
143
+ max_length=args.max_seq_length,
144
+ label_list=label_list,
145
+ output_mode=self.output_mode,
146
+ )
147
+ start = time.time()
148
+ torch.save(self.features, cached_features_file)
149
+ # ^ This seems to take a lot of time so I want to investigate why and how we can improve.
150
+ logger.info(
151
+ f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
152
+ )
153
+
154
+ def __len__(self):
155
+ return len(self.features)
156
+
157
+ def __getitem__(self, i) -> InputFeatures:
158
+ return self.features[i]
159
+
160
+ def get_labels(self):
161
+ return self.label_list
.venv/Lib/site-packages/transformers/data/datasets/language_modeling.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ import pickle
18
+ import random
19
+ import time
20
+ import warnings
21
+ from typing import Dict, List, Optional
22
+
23
+ import torch
24
+ from filelock import FileLock
25
+ from torch.utils.data import Dataset
26
+
27
+ from ...tokenization_utils import PreTrainedTokenizer
28
+ from ...utils import logging
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ DEPRECATION_WARNING = (
35
+ "This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
36
+ "library. You can have a look at this example script for pointers: {0}"
37
+ )
38
+
39
+
40
+ class TextDataset(Dataset):
41
+ """
42
+ This will be superseded by a framework-agnostic approach soon.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ tokenizer: PreTrainedTokenizer,
48
+ file_path: str,
49
+ block_size: int,
50
+ overwrite_cache=False,
51
+ cache_dir: Optional[str] = None,
52
+ ):
53
+ warnings.warn(
54
+ DEPRECATION_WARNING.format(
55
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
56
+ ),
57
+ FutureWarning,
58
+ )
59
+ if os.path.isfile(file_path) is False:
60
+ raise ValueError(f"Input file path {file_path} not found")
61
+
62
+ block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
63
+
64
+ directory, filename = os.path.split(file_path)
65
+ cached_features_file = os.path.join(
66
+ cache_dir if cache_dir is not None else directory,
67
+ f"cached_lm_{tokenizer.__class__.__name__}_{block_size}_{filename}",
68
+ )
69
+
70
+ # Make sure only the first process in distributed training processes the dataset,
71
+ # and the others will use the cache.
72
+ lock_path = cached_features_file + ".lock"
73
+ with FileLock(lock_path):
74
+ if os.path.exists(cached_features_file) and not overwrite_cache:
75
+ start = time.time()
76
+ with open(cached_features_file, "rb") as handle:
77
+ self.examples = pickle.load(handle)
78
+ logger.info(
79
+ f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
80
+ )
81
+
82
+ else:
83
+ logger.info(f"Creating features from dataset file at {directory}")
84
+
85
+ self.examples = []
86
+ with open(file_path, encoding="utf-8") as f:
87
+ text = f.read()
88
+
89
+ tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
90
+
91
+ for i in range(0, len(tokenized_text) - block_size + 1, block_size): # Truncate in block of block_size
92
+ self.examples.append(
93
+ tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size])
94
+ )
95
+ # Note that we are losing the last truncated example here for the sake of simplicity (no padding)
96
+ # If your dataset is small, first you should look for a bigger one :-) and second you
97
+ # can change this behavior by adding (model specific) padding.
98
+
99
+ start = time.time()
100
+ with open(cached_features_file, "wb") as handle:
101
+ pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
102
+ logger.info(
103
+ f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
104
+ )
105
+
106
+ def __len__(self):
107
+ return len(self.examples)
108
+
109
+ def __getitem__(self, i) -> torch.Tensor:
110
+ return torch.tensor(self.examples[i], dtype=torch.long)
111
+
112
+
113
+ class LineByLineTextDataset(Dataset):
114
+ """
115
+ This will be superseded by a framework-agnostic approach soon.
116
+ """
117
+
118
+ def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
119
+ warnings.warn(
120
+ DEPRECATION_WARNING.format(
121
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
122
+ ),
123
+ FutureWarning,
124
+ )
125
+ if os.path.isfile(file_path) is False:
126
+ raise ValueError(f"Input file path {file_path} not found")
127
+ # Here, we do not cache the features, operating under the assumption
128
+ # that we will soon use fast multithreaded tokenizers from the
129
+ # `tokenizers` repo everywhere =)
130
+ logger.info(f"Creating features from dataset file at {file_path}")
131
+
132
+ with open(file_path, encoding="utf-8") as f:
133
+ lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
134
+
135
+ batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size)
136
+ self.examples = batch_encoding["input_ids"]
137
+ self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
138
+
139
+ def __len__(self):
140
+ return len(self.examples)
141
+
142
+ def __getitem__(self, i) -> Dict[str, torch.tensor]:
143
+ return self.examples[i]
144
+
145
+
146
+ class LineByLineWithRefDataset(Dataset):
147
+ """
148
+ This will be superseded by a framework-agnostic approach soon.
149
+ """
150
+
151
+ def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, ref_path: str):
152
+ warnings.warn(
153
+ DEPRECATION_WARNING.format(
154
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm_wwm.py"
155
+ ),
156
+ FutureWarning,
157
+ )
158
+ if os.path.isfile(file_path) is False:
159
+ raise ValueError(f"Input file path {file_path} not found")
160
+ if os.path.isfile(ref_path) is False:
161
+ raise ValueError(f"Ref file path {file_path} not found")
162
+ # Here, we do not cache the features, operating under the assumption
163
+ # that we will soon use fast multithreaded tokenizers from the
164
+ # `tokenizers` repo everywhere =)
165
+ logger.info(f"Creating features from dataset file at {file_path}")
166
+ logger.info(f"Use ref segment results at {ref_path}")
167
+ with open(file_path, encoding="utf-8") as f:
168
+ data = f.readlines() # use this method to avoid delimiter '\u2029' to split a line
169
+ data = [line.strip() for line in data if len(line) > 0 and not line.isspace()]
170
+ # Get ref inf from file
171
+ with open(ref_path, encoding="utf-8") as f:
172
+ ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
173
+ if len(data) != len(ref):
174
+ raise ValueError(
175
+ f"Length of Input file should be equal to Ref file. But the length of {file_path} is {len(data)} "
176
+ f"while length of {ref_path} is {len(ref)}"
177
+ )
178
+
179
+ batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
180
+ self.examples = batch_encoding["input_ids"]
181
+ self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
182
+
183
+ n = len(self.examples)
184
+ for i in range(n):
185
+ self.examples[i]["chinese_ref"] = torch.tensor(ref[i], dtype=torch.long)
186
+
187
+ def __len__(self):
188
+ return len(self.examples)
189
+
190
+ def __getitem__(self, i) -> Dict[str, torch.tensor]:
191
+ return self.examples[i]
192
+
193
+
194
+ class LineByLineWithSOPTextDataset(Dataset):
195
+ """
196
+ Dataset for sentence order prediction task, prepare sentence pairs for SOP task
197
+ """
198
+
199
+ def __init__(self, tokenizer: PreTrainedTokenizer, file_dir: str, block_size: int):
200
+ warnings.warn(
201
+ DEPRECATION_WARNING.format(
202
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
203
+ ),
204
+ FutureWarning,
205
+ )
206
+ if os.path.isdir(file_dir) is False:
207
+ raise ValueError(f"{file_dir} is not a directory")
208
+ logger.info(f"Creating features from dataset file folder at {file_dir}")
209
+ self.examples = []
210
+ # TODO: randomness could apply a random seed, ex. rng = random.Random(random_seed)
211
+ # file path looks like ./dataset/wiki_1, ./dataset/wiki_2
212
+ for file_name in os.listdir(file_dir):
213
+ file_path = os.path.join(file_dir, file_name)
214
+ if os.path.isfile(file_path) is False:
215
+ raise ValueError(f"{file_path} is not a file")
216
+ article_open = False
217
+ with open(file_path, encoding="utf-8") as f:
218
+ original_lines = f.readlines()
219
+ article_lines = []
220
+ for line in original_lines:
221
+ if "<doc id=" in line:
222
+ article_open = True
223
+ elif "</doc>" in line:
224
+ article_open = False
225
+ document = [
226
+ tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line))
227
+ for line in article_lines[1:]
228
+ if (len(line) > 0 and not line.isspace())
229
+ ]
230
+
231
+ examples = self.create_examples_from_document(document, block_size, tokenizer)
232
+ self.examples.extend(examples)
233
+ article_lines = []
234
+ else:
235
+ if article_open:
236
+ article_lines.append(line)
237
+
238
+ logger.info("Dataset parse finished.")
239
+
240
+ def create_examples_from_document(self, document, block_size, tokenizer, short_seq_prob=0.1):
241
+ """Creates examples for a single document."""
242
+
243
+ # Account for special tokens
244
+ max_num_tokens = block_size - tokenizer.num_special_tokens_to_add(pair=True)
245
+
246
+ # We *usually* want to fill up the entire sequence since we are padding
247
+ # to `block_size` anyways, so short sequences are generally wasted
248
+ # computation. However, we *sometimes*
249
+ # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
250
+ # sequences to minimize the mismatch between pretraining and fine-tuning.
251
+ # The `target_seq_length` is just a rough target however, whereas
252
+ # `block_size` is a hard limit.
253
+ target_seq_length = max_num_tokens
254
+ if random.random() < short_seq_prob:
255
+ target_seq_length = random.randint(2, max_num_tokens)
256
+
257
+ # We DON'T just concatenate all of the tokens from a document into a long
258
+ # sequence and choose an arbitrary split point because this would make the
259
+ # next sentence prediction task too easy. Instead, we split the input into
260
+ # segments "A" and "B" based on the actual "sentences" provided by the user
261
+ # input.
262
+ examples = []
263
+ current_chunk = [] # a buffer stored current working segments
264
+ current_length = 0
265
+ i = 0
266
+ while i < len(document):
267
+ segment = document[i] # get a segment
268
+ if not segment:
269
+ i += 1
270
+ continue
271
+ current_chunk.append(segment) # add a segment to current chunk
272
+ current_length += len(segment) # overall token length
273
+ # if current length goes to the target length or reaches the end of file, start building token a and b
274
+ if i == len(document) - 1 or current_length >= target_seq_length:
275
+ if current_chunk:
276
+ # `a_end` is how many segments from `current_chunk` go into the `A` (first) sentence.
277
+ a_end = 1
278
+ # if current chunk has more than 2 sentences, pick part of it `A` (first) sentence
279
+ if len(current_chunk) >= 2:
280
+ a_end = random.randint(1, len(current_chunk) - 1)
281
+ # token a
282
+ tokens_a = []
283
+ for j in range(a_end):
284
+ tokens_a.extend(current_chunk[j])
285
+
286
+ # token b
287
+ tokens_b = []
288
+ for j in range(a_end, len(current_chunk)):
289
+ tokens_b.extend(current_chunk[j])
290
+
291
+ if len(tokens_a) == 0 or len(tokens_b) == 0:
292
+ continue
293
+
294
+ # switch tokens_a and tokens_b randomly
295
+ if random.random() < 0.5:
296
+ is_next = False
297
+ tokens_a, tokens_b = tokens_b, tokens_a
298
+ else:
299
+ is_next = True
300
+
301
+ def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
302
+ """Truncates a pair of sequences to a maximum sequence length."""
303
+ while True:
304
+ total_length = len(tokens_a) + len(tokens_b)
305
+ if total_length <= max_num_tokens:
306
+ break
307
+ trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
308
+ if not (len(trunc_tokens) >= 1):
309
+ raise ValueError("Sequence length to be truncated must be no less than one")
310
+ # We want to sometimes truncate from the front and sometimes from the
311
+ # back to add more randomness and avoid biases.
312
+ if random.random() < 0.5:
313
+ del trunc_tokens[0]
314
+ else:
315
+ trunc_tokens.pop()
316
+
317
+ truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)
318
+ if not (len(tokens_a) >= 1):
319
+ raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
320
+ if not (len(tokens_b) >= 1):
321
+ raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
322
+
323
+ # add special tokens
324
+ input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
325
+ # add token type ids, 0 for sentence a, 1 for sentence b
326
+ token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
327
+
328
+ example = {
329
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
330
+ "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
331
+ "sentence_order_label": torch.tensor(0 if is_next else 1, dtype=torch.long),
332
+ }
333
+ examples.append(example)
334
+ current_chunk = [] # clear current chunk
335
+ current_length = 0 # reset current text length
336
+ i += 1 # go to next line
337
+ return examples
338
+
339
+ def __len__(self):
340
+ return len(self.examples)
341
+
342
+ def __getitem__(self, i) -> Dict[str, torch.tensor]:
343
+ return self.examples[i]
344
+
345
+
346
+ class TextDatasetForNextSentencePrediction(Dataset):
347
+ """
348
+ This will be superseded by a framework-agnostic approach soon.
349
+ """
350
+
351
+ def __init__(
352
+ self,
353
+ tokenizer: PreTrainedTokenizer,
354
+ file_path: str,
355
+ block_size: int,
356
+ overwrite_cache=False,
357
+ short_seq_probability=0.1,
358
+ nsp_probability=0.5,
359
+ ):
360
+ warnings.warn(
361
+ DEPRECATION_WARNING.format(
362
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
363
+ ),
364
+ FutureWarning,
365
+ )
366
+ if not os.path.isfile(file_path):
367
+ raise ValueError(f"Input file path {file_path} not found")
368
+
369
+ self.short_seq_probability = short_seq_probability
370
+ self.nsp_probability = nsp_probability
371
+
372
+ directory, filename = os.path.split(file_path)
373
+ cached_features_file = os.path.join(
374
+ directory,
375
+ f"cached_nsp_{tokenizer.__class__.__name__}_{block_size}_{filename}",
376
+ )
377
+
378
+ self.tokenizer = tokenizer
379
+
380
+ # Make sure only the first process in distributed training processes the dataset,
381
+ # and the others will use the cache.
382
+ lock_path = cached_features_file + ".lock"
383
+
384
+ # Input file format:
385
+ # (1) One sentence per line. These should ideally be actual sentences, not
386
+ # entire paragraphs or arbitrary spans of text. (Because we use the
387
+ # sentence boundaries for the "next sentence prediction" task).
388
+ # (2) Blank lines between documents. Document boundaries are needed so
389
+ # that the "next sentence prediction" task doesn't span between documents.
390
+ #
391
+ # Example:
392
+ # I am very happy.
393
+ # Here is the second sentence.
394
+ #
395
+ # A new document.
396
+
397
+ with FileLock(lock_path):
398
+ if os.path.exists(cached_features_file) and not overwrite_cache:
399
+ start = time.time()
400
+ with open(cached_features_file, "rb") as handle:
401
+ self.examples = pickle.load(handle)
402
+ logger.info(
403
+ f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
404
+ )
405
+ else:
406
+ logger.info(f"Creating features from dataset file at {directory}")
407
+
408
+ self.documents = [[]]
409
+ with open(file_path, encoding="utf-8") as f:
410
+ while True:
411
+ line = f.readline()
412
+ if not line:
413
+ break
414
+ line = line.strip()
415
+
416
+ # Empty lines are used as document delimiters
417
+ if not line and len(self.documents[-1]) != 0:
418
+ self.documents.append([])
419
+ tokens = tokenizer.tokenize(line)
420
+ tokens = tokenizer.convert_tokens_to_ids(tokens)
421
+ if tokens:
422
+ self.documents[-1].append(tokens)
423
+
424
+ logger.info(f"Creating examples from {len(self.documents)} documents.")
425
+ self.examples = []
426
+ for doc_index, document in enumerate(self.documents):
427
+ self.create_examples_from_document(document, doc_index, block_size)
428
+
429
+ start = time.time()
430
+ with open(cached_features_file, "wb") as handle:
431
+ pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
432
+ logger.info(
433
+ f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
434
+ )
435
+
436
+ def create_examples_from_document(self, document: List[List[int]], doc_index: int, block_size: int):
437
+ """Creates examples for a single document."""
438
+
439
+ max_num_tokens = block_size - self.tokenizer.num_special_tokens_to_add(pair=True)
440
+
441
+ # We *usually* want to fill up the entire sequence since we are padding
442
+ # to `block_size` anyways, so short sequences are generally wasted
443
+ # computation. However, we *sometimes*
444
+ # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
445
+ # sequences to minimize the mismatch between pretraining and fine-tuning.
446
+ # The `target_seq_length` is just a rough target however, whereas
447
+ # `block_size` is a hard limit.
448
+ target_seq_length = max_num_tokens
449
+ if random.random() < self.short_seq_probability:
450
+ target_seq_length = random.randint(2, max_num_tokens)
451
+
452
+ current_chunk = [] # a buffer stored current working segments
453
+ current_length = 0
454
+ i = 0
455
+
456
+ while i < len(document):
457
+ segment = document[i]
458
+ current_chunk.append(segment)
459
+ current_length += len(segment)
460
+ if i == len(document) - 1 or current_length >= target_seq_length:
461
+ if current_chunk:
462
+ # `a_end` is how many segments from `current_chunk` go into the `A`
463
+ # (first) sentence.
464
+ a_end = 1
465
+ if len(current_chunk) >= 2:
466
+ a_end = random.randint(1, len(current_chunk) - 1)
467
+
468
+ tokens_a = []
469
+ for j in range(a_end):
470
+ tokens_a.extend(current_chunk[j])
471
+
472
+ tokens_b = []
473
+
474
+ if len(current_chunk) == 1 or random.random() < self.nsp_probability:
475
+ is_random_next = True
476
+ target_b_length = target_seq_length - len(tokens_a)
477
+
478
+ # This should rarely go for more than one iteration for large
479
+ # corpora. However, just to be careful, we try to make sure that
480
+ # the random document is not the same as the document
481
+ # we're processing.
482
+ for _ in range(10):
483
+ random_document_index = random.randint(0, len(self.documents) - 1)
484
+ if random_document_index != doc_index:
485
+ break
486
+
487
+ random_document = self.documents[random_document_index]
488
+ random_start = random.randint(0, len(random_document) - 1)
489
+ for j in range(random_start, len(random_document)):
490
+ tokens_b.extend(random_document[j])
491
+ if len(tokens_b) >= target_b_length:
492
+ break
493
+ # We didn't actually use these segments so we "put them back" so
494
+ # they don't go to waste.
495
+ num_unused_segments = len(current_chunk) - a_end
496
+ i -= num_unused_segments
497
+ # Actual next
498
+ else:
499
+ is_random_next = False
500
+ for j in range(a_end, len(current_chunk)):
501
+ tokens_b.extend(current_chunk[j])
502
+
503
+ if not (len(tokens_a) >= 1):
504
+ raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
505
+ if not (len(tokens_b) >= 1):
506
+ raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
507
+
508
+ # add special tokens
509
+ input_ids = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
510
+ # add token type ids, 0 for sentence a, 1 for sentence b
511
+ token_type_ids = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
512
+
513
+ example = {
514
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
515
+ "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
516
+ "next_sentence_label": torch.tensor(1 if is_random_next else 0, dtype=torch.long),
517
+ }
518
+
519
+ self.examples.append(example)
520
+
521
+ current_chunk = []
522
+ current_length = 0
523
+
524
+ i += 1
525
+
526
+ def __len__(self):
527
+ return len(self.examples)
528
+
529
+ def __getitem__(self, i):
530
+ return self.examples[i]
.venv/Lib/site-packages/transformers/data/datasets/squad.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import time
17
+ from dataclasses import dataclass, field
18
+ from enum import Enum
19
+ from typing import Dict, List, Optional, Union
20
+
21
+ import torch
22
+ from filelock import FileLock
23
+ from torch.utils.data import Dataset
24
+
25
+ from ...models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
26
+ from ...tokenization_utils import PreTrainedTokenizer
27
+ from ...utils import logging
28
+ from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
34
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
35
+
36
+
37
+ @dataclass
38
+ class SquadDataTrainingArguments:
39
+ """
40
+ Arguments pertaining to what data we are going to input our model for training and eval.
41
+ """
42
+
43
+ model_type: str = field(
44
+ default=None, metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_TYPES)}
45
+ )
46
+ data_dir: str = field(
47
+ default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
48
+ )
49
+ max_seq_length: int = field(
50
+ default=128,
51
+ metadata={
52
+ "help": (
53
+ "The maximum total input sequence length after tokenization. Sequences longer "
54
+ "than this will be truncated, sequences shorter will be padded."
55
+ )
56
+ },
57
+ )
58
+ doc_stride: int = field(
59
+ default=128,
60
+ metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
61
+ )
62
+ max_query_length: int = field(
63
+ default=64,
64
+ metadata={
65
+ "help": (
66
+ "The maximum number of tokens for the question. Questions longer than this will "
67
+ "be truncated to this length."
68
+ )
69
+ },
70
+ )
71
+ max_answer_length: int = field(
72
+ default=30,
73
+ metadata={
74
+ "help": (
75
+ "The maximum length of an answer that can be generated. This is needed because the start "
76
+ "and end predictions are not conditioned on one another."
77
+ )
78
+ },
79
+ )
80
+ overwrite_cache: bool = field(
81
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
82
+ )
83
+ version_2_with_negative: bool = field(
84
+ default=False, metadata={"help": "If true, the SQuAD examples contain some that do not have an answer."}
85
+ )
86
+ null_score_diff_threshold: float = field(
87
+ default=0.0, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
88
+ )
89
+ n_best_size: int = field(
90
+ default=20, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
91
+ )
92
+ lang_id: int = field(
93
+ default=0,
94
+ metadata={
95
+ "help": (
96
+ "language id of input for language-specific xlm models (see"
97
+ " tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
98
+ )
99
+ },
100
+ )
101
+ threads: int = field(default=1, metadata={"help": "multiple threads for converting example to features"})
102
+
103
+
104
+ class Split(Enum):
105
+ train = "train"
106
+ dev = "dev"
107
+
108
+
109
+ class SquadDataset(Dataset):
110
+ """
111
+ This will be superseded by a framework-agnostic approach soon.
112
+ """
113
+
114
+ args: SquadDataTrainingArguments
115
+ features: List[SquadFeatures]
116
+ mode: Split
117
+ is_language_sensitive: bool
118
+
119
+ def __init__(
120
+ self,
121
+ args: SquadDataTrainingArguments,
122
+ tokenizer: PreTrainedTokenizer,
123
+ limit_length: Optional[int] = None,
124
+ mode: Union[str, Split] = Split.train,
125
+ is_language_sensitive: Optional[bool] = False,
126
+ cache_dir: Optional[str] = None,
127
+ dataset_format: Optional[str] = "pt",
128
+ ):
129
+ self.args = args
130
+ self.is_language_sensitive = is_language_sensitive
131
+ self.processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
132
+ if isinstance(mode, str):
133
+ try:
134
+ mode = Split[mode]
135
+ except KeyError:
136
+ raise KeyError("mode is not a valid split name")
137
+ self.mode = mode
138
+ # Load data features from cache or dataset file
139
+ version_tag = "v2" if args.version_2_with_negative else "v1"
140
+ cached_features_file = os.path.join(
141
+ cache_dir if cache_dir is not None else args.data_dir,
142
+ f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{version_tag}",
143
+ )
144
+
145
+ # Make sure only the first process in distributed training processes the dataset,
146
+ # and the others will use the cache.
147
+ lock_path = cached_features_file + ".lock"
148
+ with FileLock(lock_path):
149
+ if os.path.exists(cached_features_file) and not args.overwrite_cache:
150
+ start = time.time()
151
+ self.old_features = torch.load(cached_features_file)
152
+
153
+ # Legacy cache files have only features, while new cache files
154
+ # will have dataset and examples also.
155
+ self.features = self.old_features["features"]
156
+ self.dataset = self.old_features.get("dataset", None)
157
+ self.examples = self.old_features.get("examples", None)
158
+ logger.info(
159
+ f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
160
+ )
161
+
162
+ if self.dataset is None or self.examples is None:
163
+ logger.warning(
164
+ f"Deleting cached file {cached_features_file} will allow dataset and examples to be cached in"
165
+ " future run"
166
+ )
167
+ else:
168
+ if mode == Split.dev:
169
+ self.examples = self.processor.get_dev_examples(args.data_dir)
170
+ else:
171
+ self.examples = self.processor.get_train_examples(args.data_dir)
172
+
173
+ self.features, self.dataset = squad_convert_examples_to_features(
174
+ examples=self.examples,
175
+ tokenizer=tokenizer,
176
+ max_seq_length=args.max_seq_length,
177
+ doc_stride=args.doc_stride,
178
+ max_query_length=args.max_query_length,
179
+ is_training=mode == Split.train,
180
+ threads=args.threads,
181
+ return_dataset=dataset_format,
182
+ )
183
+
184
+ start = time.time()
185
+ torch.save(
186
+ {"features": self.features, "dataset": self.dataset, "examples": self.examples},
187
+ cached_features_file,
188
+ )
189
+ # ^ This seems to take a lot of time so I want to investigate why and how we can improve.
190
+ logger.info(
191
+ f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
192
+ )
193
+
194
+ def __len__(self):
195
+ return len(self.features)
196
+
197
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
198
+ # Convert to Tensors and build dataset
199
+ feature = self.features[i]
200
+
201
+ input_ids = torch.tensor(feature.input_ids, dtype=torch.long)
202
+ attention_mask = torch.tensor(feature.attention_mask, dtype=torch.long)
203
+ token_type_ids = torch.tensor(feature.token_type_ids, dtype=torch.long)
204
+ cls_index = torch.tensor(feature.cls_index, dtype=torch.long)
205
+ p_mask = torch.tensor(feature.p_mask, dtype=torch.float)
206
+ is_impossible = torch.tensor(feature.is_impossible, dtype=torch.float)
207
+
208
+ inputs = {
209
+ "input_ids": input_ids,
210
+ "attention_mask": attention_mask,
211
+ "token_type_ids": token_type_ids,
212
+ }
213
+
214
+ if self.args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
215
+ del inputs["token_type_ids"]
216
+
217
+ if self.args.model_type in ["xlnet", "xlm"]:
218
+ inputs.update({"cls_index": cls_index, "p_mask": p_mask})
219
+ if self.args.version_2_with_negative:
220
+ inputs.update({"is_impossible": is_impossible})
221
+ if self.is_language_sensitive:
222
+ inputs.update({"langs": (torch.ones(input_ids.shape, dtype=torch.int64) * self.args.lang_id)})
223
+
224
+ if self.mode == Split.train:
225
+ start_positions = torch.tensor(feature.start_position, dtype=torch.long)
226
+ end_positions = torch.tensor(feature.end_position, dtype=torch.long)
227
+ inputs.update({"start_positions": start_positions, "end_positions": end_positions})
228
+
229
+ return inputs
.venv/Lib/site-packages/transformers/data/metrics/__init__.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ import warnings
14
+
15
+ from ...utils import is_sklearn_available, requires_backends
16
+
17
+
18
+ if is_sklearn_available():
19
+ from scipy.stats import pearsonr, spearmanr
20
+ from sklearn.metrics import f1_score, matthews_corrcoef
21
+
22
+
23
+ DEPRECATION_WARNING = (
24
+ "This metric will be removed from the library soon, metrics should be handled with the 🤗 Evaluate "
25
+ "library. You can have a look at this example script for pointers: "
26
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
27
+ )
28
+
29
+
30
+ def simple_accuracy(preds, labels):
31
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
32
+ requires_backends(simple_accuracy, "sklearn")
33
+ return (preds == labels).mean()
34
+
35
+
36
+ def acc_and_f1(preds, labels):
37
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
38
+ requires_backends(acc_and_f1, "sklearn")
39
+ acc = simple_accuracy(preds, labels)
40
+ f1 = f1_score(y_true=labels, y_pred=preds)
41
+ return {
42
+ "acc": acc,
43
+ "f1": f1,
44
+ "acc_and_f1": (acc + f1) / 2,
45
+ }
46
+
47
+
48
+ def pearson_and_spearman(preds, labels):
49
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
50
+ requires_backends(pearson_and_spearman, "sklearn")
51
+ pearson_corr = pearsonr(preds, labels)[0]
52
+ spearman_corr = spearmanr(preds, labels)[0]
53
+ return {
54
+ "pearson": pearson_corr,
55
+ "spearmanr": spearman_corr,
56
+ "corr": (pearson_corr + spearman_corr) / 2,
57
+ }
58
+
59
+
60
+ def glue_compute_metrics(task_name, preds, labels):
61
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
62
+ requires_backends(glue_compute_metrics, "sklearn")
63
+ assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
64
+ if task_name == "cola":
65
+ return {"mcc": matthews_corrcoef(labels, preds)}
66
+ elif task_name == "sst-2":
67
+ return {"acc": simple_accuracy(preds, labels)}
68
+ elif task_name == "mrpc":
69
+ return acc_and_f1(preds, labels)
70
+ elif task_name == "sts-b":
71
+ return pearson_and_spearman(preds, labels)
72
+ elif task_name == "qqp":
73
+ return acc_and_f1(preds, labels)
74
+ elif task_name == "mnli":
75
+ return {"mnli/acc": simple_accuracy(preds, labels)}
76
+ elif task_name == "mnli-mm":
77
+ return {"mnli-mm/acc": simple_accuracy(preds, labels)}
78
+ elif task_name == "qnli":
79
+ return {"acc": simple_accuracy(preds, labels)}
80
+ elif task_name == "rte":
81
+ return {"acc": simple_accuracy(preds, labels)}
82
+ elif task_name == "wnli":
83
+ return {"acc": simple_accuracy(preds, labels)}
84
+ elif task_name == "hans":
85
+ return {"acc": simple_accuracy(preds, labels)}
86
+ else:
87
+ raise KeyError(task_name)
88
+
89
+
90
+ def xnli_compute_metrics(task_name, preds, labels):
91
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
92
+ requires_backends(xnli_compute_metrics, "sklearn")
93
+ if len(preds) != len(labels):
94
+ raise ValueError(f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}")
95
+ if task_name == "xnli":
96
+ return {"acc": simple_accuracy(preds, labels)}
97
+ else:
98
+ raise KeyError(task_name)
.venv/Lib/site-packages/transformers/data/metrics/squad_metrics.py ADDED
@@ -0,0 +1,779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Very heavily inspired by the official evaluation script for SQuAD version 2.0 which was modified by XLNet authors to
16
+ update `find_best_threshold` scripts for SQuAD V2.0
17
+
18
+ In addition to basic functionality, we also compute additional statistics and plot precision-recall curves if an
19
+ additional na_prob.json file is provided. This file is expected to map question ID's to the model's predicted
20
+ probability that a question is unanswerable.
21
+ """
22
+
23
+ import collections
24
+ import json
25
+ import math
26
+ import re
27
+ import string
28
+
29
+ from ...models.bert import BasicTokenizer
30
+ from ...utils import logging
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ def normalize_answer(s):
37
+ """Lower text and remove punctuation, articles and extra whitespace."""
38
+
39
+ def remove_articles(text):
40
+ regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
41
+ return re.sub(regex, " ", text)
42
+
43
+ def white_space_fix(text):
44
+ return " ".join(text.split())
45
+
46
+ def remove_punc(text):
47
+ exclude = set(string.punctuation)
48
+ return "".join(ch for ch in text if ch not in exclude)
49
+
50
+ def lower(text):
51
+ return text.lower()
52
+
53
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
54
+
55
+
56
+ def get_tokens(s):
57
+ if not s:
58
+ return []
59
+ return normalize_answer(s).split()
60
+
61
+
62
+ def compute_exact(a_gold, a_pred):
63
+ return int(normalize_answer(a_gold) == normalize_answer(a_pred))
64
+
65
+
66
+ def compute_f1(a_gold, a_pred):
67
+ gold_toks = get_tokens(a_gold)
68
+ pred_toks = get_tokens(a_pred)
69
+ common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
70
+ num_same = sum(common.values())
71
+ if len(gold_toks) == 0 or len(pred_toks) == 0:
72
+ # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
73
+ return int(gold_toks == pred_toks)
74
+ if num_same == 0:
75
+ return 0
76
+ precision = 1.0 * num_same / len(pred_toks)
77
+ recall = 1.0 * num_same / len(gold_toks)
78
+ f1 = (2 * precision * recall) / (precision + recall)
79
+ return f1
80
+
81
+
82
+ def get_raw_scores(examples, preds):
83
+ """
84
+ Computes the exact and f1 scores from the examples and the model predictions
85
+ """
86
+ exact_scores = {}
87
+ f1_scores = {}
88
+
89
+ for example in examples:
90
+ qas_id = example.qas_id
91
+ gold_answers = [answer["text"] for answer in example.answers if normalize_answer(answer["text"])]
92
+
93
+ if not gold_answers:
94
+ # For unanswerable questions, only correct answer is empty string
95
+ gold_answers = [""]
96
+
97
+ if qas_id not in preds:
98
+ print(f"Missing prediction for {qas_id}")
99
+ continue
100
+
101
+ prediction = preds[qas_id]
102
+ exact_scores[qas_id] = max(compute_exact(a, prediction) for a in gold_answers)
103
+ f1_scores[qas_id] = max(compute_f1(a, prediction) for a in gold_answers)
104
+
105
+ return exact_scores, f1_scores
106
+
107
+
108
+ def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
109
+ new_scores = {}
110
+ for qid, s in scores.items():
111
+ pred_na = na_probs[qid] > na_prob_thresh
112
+ if pred_na:
113
+ new_scores[qid] = float(not qid_to_has_ans[qid])
114
+ else:
115
+ new_scores[qid] = s
116
+ return new_scores
117
+
118
+
119
+ def make_eval_dict(exact_scores, f1_scores, qid_list=None):
120
+ if not qid_list:
121
+ total = len(exact_scores)
122
+ return collections.OrderedDict(
123
+ [
124
+ ("exact", 100.0 * sum(exact_scores.values()) / total),
125
+ ("f1", 100.0 * sum(f1_scores.values()) / total),
126
+ ("total", total),
127
+ ]
128
+ )
129
+ else:
130
+ total = len(qid_list)
131
+ return collections.OrderedDict(
132
+ [
133
+ ("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
134
+ ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
135
+ ("total", total),
136
+ ]
137
+ )
138
+
139
+
140
+ def merge_eval(main_eval, new_eval, prefix):
141
+ for k in new_eval:
142
+ main_eval[f"{prefix}_{k}"] = new_eval[k]
143
+
144
+
145
+ def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
146
+ num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
147
+ cur_score = num_no_ans
148
+ best_score = cur_score
149
+ best_thresh = 0.0
150
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
151
+ for i, qid in enumerate(qid_list):
152
+ if qid not in scores:
153
+ continue
154
+ if qid_to_has_ans[qid]:
155
+ diff = scores[qid]
156
+ else:
157
+ if preds[qid]:
158
+ diff = -1
159
+ else:
160
+ diff = 0
161
+ cur_score += diff
162
+ if cur_score > best_score:
163
+ best_score = cur_score
164
+ best_thresh = na_probs[qid]
165
+
166
+ has_ans_score, has_ans_cnt = 0, 0
167
+ for qid in qid_list:
168
+ if not qid_to_has_ans[qid]:
169
+ continue
170
+ has_ans_cnt += 1
171
+
172
+ if qid not in scores:
173
+ continue
174
+ has_ans_score += scores[qid]
175
+
176
+ return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
177
+
178
+
179
+ def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
180
+ best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
181
+ best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
182
+ main_eval["best_exact"] = best_exact
183
+ main_eval["best_exact_thresh"] = exact_thresh
184
+ main_eval["best_f1"] = best_f1
185
+ main_eval["best_f1_thresh"] = f1_thresh
186
+ main_eval["has_ans_exact"] = has_ans_exact
187
+ main_eval["has_ans_f1"] = has_ans_f1
188
+
189
+
190
+ def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
191
+ num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
192
+ cur_score = num_no_ans
193
+ best_score = cur_score
194
+ best_thresh = 0.0
195
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
196
+ for _, qid in enumerate(qid_list):
197
+ if qid not in scores:
198
+ continue
199
+ if qid_to_has_ans[qid]:
200
+ diff = scores[qid]
201
+ else:
202
+ if preds[qid]:
203
+ diff = -1
204
+ else:
205
+ diff = 0
206
+ cur_score += diff
207
+ if cur_score > best_score:
208
+ best_score = cur_score
209
+ best_thresh = na_probs[qid]
210
+ return 100.0 * best_score / len(scores), best_thresh
211
+
212
+
213
+ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
214
+ best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
215
+ best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
216
+
217
+ main_eval["best_exact"] = best_exact
218
+ main_eval["best_exact_thresh"] = exact_thresh
219
+ main_eval["best_f1"] = best_f1
220
+ main_eval["best_f1_thresh"] = f1_thresh
221
+
222
+
223
+ def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0):
224
+ qas_id_to_has_answer = {example.qas_id: bool(example.answers) for example in examples}
225
+ has_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if has_answer]
226
+ no_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if not has_answer]
227
+
228
+ if no_answer_probs is None:
229
+ no_answer_probs = {k: 0.0 for k in preds}
230
+
231
+ exact, f1 = get_raw_scores(examples, preds)
232
+
233
+ exact_threshold = apply_no_ans_threshold(
234
+ exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold
235
+ )
236
+ f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
237
+
238
+ evaluation = make_eval_dict(exact_threshold, f1_threshold)
239
+
240
+ if has_answer_qids:
241
+ has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids)
242
+ merge_eval(evaluation, has_ans_eval, "HasAns")
243
+
244
+ if no_answer_qids:
245
+ no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids)
246
+ merge_eval(evaluation, no_ans_eval, "NoAns")
247
+
248
+ if no_answer_probs:
249
+ find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer)
250
+
251
+ return evaluation
252
+
253
+
254
+ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
255
+ """Project the tokenized prediction back to the original text."""
256
+
257
+ # When we created the data, we kept track of the alignment between original
258
+ # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
259
+ # now `orig_text` contains the span of our original text corresponding to the
260
+ # span that we predicted.
261
+ #
262
+ # However, `orig_text` may contain extra characters that we don't want in
263
+ # our prediction.
264
+ #
265
+ # For example, let's say:
266
+ # pred_text = steve smith
267
+ # orig_text = Steve Smith's
268
+ #
269
+ # We don't want to return `orig_text` because it contains the extra "'s".
270
+ #
271
+ # We don't want to return `pred_text` because it's already been normalized
272
+ # (the SQuAD eval script also does punctuation stripping/lower casing but
273
+ # our tokenizer does additional normalization like stripping accent
274
+ # characters).
275
+ #
276
+ # What we really want to return is "Steve Smith".
277
+ #
278
+ # Therefore, we have to apply a semi-complicated alignment heuristic between
279
+ # `pred_text` and `orig_text` to get a character-to-character alignment. This
280
+ # can fail in certain cases in which case we just return `orig_text`.
281
+
282
+ def _strip_spaces(text):
283
+ ns_chars = []
284
+ ns_to_s_map = collections.OrderedDict()
285
+ for i, c in enumerate(text):
286
+ if c == " ":
287
+ continue
288
+ ns_to_s_map[len(ns_chars)] = i
289
+ ns_chars.append(c)
290
+ ns_text = "".join(ns_chars)
291
+ return (ns_text, ns_to_s_map)
292
+
293
+ # We first tokenize `orig_text`, strip whitespace from the result
294
+ # and `pred_text`, and check if they are the same length. If they are
295
+ # NOT the same length, the heuristic has failed. If they are the same
296
+ # length, we assume the characters are one-to-one aligned.
297
+ tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
298
+
299
+ tok_text = " ".join(tokenizer.tokenize(orig_text))
300
+
301
+ start_position = tok_text.find(pred_text)
302
+ if start_position == -1:
303
+ if verbose_logging:
304
+ logger.info(f"Unable to find text: '{pred_text}' in '{orig_text}'")
305
+ return orig_text
306
+ end_position = start_position + len(pred_text) - 1
307
+
308
+ (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
309
+ (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
310
+
311
+ if len(orig_ns_text) != len(tok_ns_text):
312
+ if verbose_logging:
313
+ logger.info(f"Length not equal after stripping spaces: '{orig_ns_text}' vs '{tok_ns_text}'")
314
+ return orig_text
315
+
316
+ # We then project the characters in `pred_text` back to `orig_text` using
317
+ # the character-to-character alignment.
318
+ tok_s_to_ns_map = {}
319
+ for i, tok_index in tok_ns_to_s_map.items():
320
+ tok_s_to_ns_map[tok_index] = i
321
+
322
+ orig_start_position = None
323
+ if start_position in tok_s_to_ns_map:
324
+ ns_start_position = tok_s_to_ns_map[start_position]
325
+ if ns_start_position in orig_ns_to_s_map:
326
+ orig_start_position = orig_ns_to_s_map[ns_start_position]
327
+
328
+ if orig_start_position is None:
329
+ if verbose_logging:
330
+ logger.info("Couldn't map start position")
331
+ return orig_text
332
+
333
+ orig_end_position = None
334
+ if end_position in tok_s_to_ns_map:
335
+ ns_end_position = tok_s_to_ns_map[end_position]
336
+ if ns_end_position in orig_ns_to_s_map:
337
+ orig_end_position = orig_ns_to_s_map[ns_end_position]
338
+
339
+ if orig_end_position is None:
340
+ if verbose_logging:
341
+ logger.info("Couldn't map end position")
342
+ return orig_text
343
+
344
+ output_text = orig_text[orig_start_position : (orig_end_position + 1)]
345
+ return output_text
346
+
347
+
348
+ def _get_best_indexes(logits, n_best_size):
349
+ """Get the n-best logits from a list."""
350
+ index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
351
+
352
+ best_indexes = []
353
+ for i in range(len(index_and_score)):
354
+ if i >= n_best_size:
355
+ break
356
+ best_indexes.append(index_and_score[i][0])
357
+ return best_indexes
358
+
359
+
360
+ def _compute_softmax(scores):
361
+ """Compute softmax probability over raw logits."""
362
+ if not scores:
363
+ return []
364
+
365
+ max_score = None
366
+ for score in scores:
367
+ if max_score is None or score > max_score:
368
+ max_score = score
369
+
370
+ exp_scores = []
371
+ total_sum = 0.0
372
+ for score in scores:
373
+ x = math.exp(score - max_score)
374
+ exp_scores.append(x)
375
+ total_sum += x
376
+
377
+ probs = []
378
+ for score in exp_scores:
379
+ probs.append(score / total_sum)
380
+ return probs
381
+
382
+
383
+ def compute_predictions_logits(
384
+ all_examples,
385
+ all_features,
386
+ all_results,
387
+ n_best_size,
388
+ max_answer_length,
389
+ do_lower_case,
390
+ output_prediction_file,
391
+ output_nbest_file,
392
+ output_null_log_odds_file,
393
+ verbose_logging,
394
+ version_2_with_negative,
395
+ null_score_diff_threshold,
396
+ tokenizer,
397
+ ):
398
+ """Write final predictions to the json file and log-odds of null if needed."""
399
+ if output_prediction_file:
400
+ logger.info(f"Writing predictions to: {output_prediction_file}")
401
+ if output_nbest_file:
402
+ logger.info(f"Writing nbest to: {output_nbest_file}")
403
+ if output_null_log_odds_file and version_2_with_negative:
404
+ logger.info(f"Writing null_log_odds to: {output_null_log_odds_file}")
405
+
406
+ example_index_to_features = collections.defaultdict(list)
407
+ for feature in all_features:
408
+ example_index_to_features[feature.example_index].append(feature)
409
+
410
+ unique_id_to_result = {}
411
+ for result in all_results:
412
+ unique_id_to_result[result.unique_id] = result
413
+
414
+ _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
415
+ "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]
416
+ )
417
+
418
+ all_predictions = collections.OrderedDict()
419
+ all_nbest_json = collections.OrderedDict()
420
+ scores_diff_json = collections.OrderedDict()
421
+
422
+ for example_index, example in enumerate(all_examples):
423
+ features = example_index_to_features[example_index]
424
+
425
+ prelim_predictions = []
426
+ # keep track of the minimum score of null start+end of position 0
427
+ score_null = 1000000 # large and positive
428
+ min_null_feature_index = 0 # the paragraph slice with min null score
429
+ null_start_logit = 0 # the start logit at the slice with min null score
430
+ null_end_logit = 0 # the end logit at the slice with min null score
431
+ for feature_index, feature in enumerate(features):
432
+ result = unique_id_to_result[feature.unique_id]
433
+ start_indexes = _get_best_indexes(result.start_logits, n_best_size)
434
+ end_indexes = _get_best_indexes(result.end_logits, n_best_size)
435
+ # if we could have irrelevant answers, get the min score of irrelevant
436
+ if version_2_with_negative:
437
+ feature_null_score = result.start_logits[0] + result.end_logits[0]
438
+ if feature_null_score < score_null:
439
+ score_null = feature_null_score
440
+ min_null_feature_index = feature_index
441
+ null_start_logit = result.start_logits[0]
442
+ null_end_logit = result.end_logits[0]
443
+ for start_index in start_indexes:
444
+ for end_index in end_indexes:
445
+ # We could hypothetically create invalid predictions, e.g., predict
446
+ # that the start of the span is in the question. We throw out all
447
+ # invalid predictions.
448
+ if start_index >= len(feature.tokens):
449
+ continue
450
+ if end_index >= len(feature.tokens):
451
+ continue
452
+ if start_index not in feature.token_to_orig_map:
453
+ continue
454
+ if end_index not in feature.token_to_orig_map:
455
+ continue
456
+ if not feature.token_is_max_context.get(start_index, False):
457
+ continue
458
+ if end_index < start_index:
459
+ continue
460
+ length = end_index - start_index + 1
461
+ if length > max_answer_length:
462
+ continue
463
+ prelim_predictions.append(
464
+ _PrelimPrediction(
465
+ feature_index=feature_index,
466
+ start_index=start_index,
467
+ end_index=end_index,
468
+ start_logit=result.start_logits[start_index],
469
+ end_logit=result.end_logits[end_index],
470
+ )
471
+ )
472
+ if version_2_with_negative:
473
+ prelim_predictions.append(
474
+ _PrelimPrediction(
475
+ feature_index=min_null_feature_index,
476
+ start_index=0,
477
+ end_index=0,
478
+ start_logit=null_start_logit,
479
+ end_logit=null_end_logit,
480
+ )
481
+ )
482
+ prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
483
+
484
+ _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
485
+ "NbestPrediction", ["text", "start_logit", "end_logit"]
486
+ )
487
+
488
+ seen_predictions = {}
489
+ nbest = []
490
+ for pred in prelim_predictions:
491
+ if len(nbest) >= n_best_size:
492
+ break
493
+ feature = features[pred.feature_index]
494
+ if pred.start_index > 0: # this is a non-null prediction
495
+ tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
496
+ orig_doc_start = feature.token_to_orig_map[pred.start_index]
497
+ orig_doc_end = feature.token_to_orig_map[pred.end_index]
498
+ orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
499
+
500
+ tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
501
+
502
+ # tok_text = " ".join(tok_tokens)
503
+ #
504
+ # # De-tokenize WordPieces that have been split off.
505
+ # tok_text = tok_text.replace(" ##", "")
506
+ # tok_text = tok_text.replace("##", "")
507
+
508
+ # Clean whitespace
509
+ tok_text = tok_text.strip()
510
+ tok_text = " ".join(tok_text.split())
511
+ orig_text = " ".join(orig_tokens)
512
+
513
+ final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
514
+ if final_text in seen_predictions:
515
+ continue
516
+
517
+ seen_predictions[final_text] = True
518
+ else:
519
+ final_text = ""
520
+ seen_predictions[final_text] = True
521
+
522
+ nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit))
523
+ # if we didn't include the empty option in the n-best, include it
524
+ if version_2_with_negative:
525
+ if "" not in seen_predictions:
526
+ nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit))
527
+
528
+ # In very rare edge cases we could only have single null prediction.
529
+ # So we just create a nonce prediction in this case to avoid failure.
530
+ if len(nbest) == 1:
531
+ nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
532
+
533
+ # In very rare edge cases we could have no valid predictions. So we
534
+ # just create a nonce prediction in this case to avoid failure.
535
+ if not nbest:
536
+ nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
537
+
538
+ if len(nbest) < 1:
539
+ raise ValueError("No valid predictions")
540
+
541
+ total_scores = []
542
+ best_non_null_entry = None
543
+ for entry in nbest:
544
+ total_scores.append(entry.start_logit + entry.end_logit)
545
+ if not best_non_null_entry:
546
+ if entry.text:
547
+ best_non_null_entry = entry
548
+
549
+ probs = _compute_softmax(total_scores)
550
+
551
+ nbest_json = []
552
+ for i, entry in enumerate(nbest):
553
+ output = collections.OrderedDict()
554
+ output["text"] = entry.text
555
+ output["probability"] = probs[i]
556
+ output["start_logit"] = entry.start_logit
557
+ output["end_logit"] = entry.end_logit
558
+ nbest_json.append(output)
559
+
560
+ if len(nbest_json) < 1:
561
+ raise ValueError("No valid predictions")
562
+
563
+ if not version_2_with_negative:
564
+ all_predictions[example.qas_id] = nbest_json[0]["text"]
565
+ else:
566
+ # predict "" iff the null score - the score of best non-null > threshold
567
+ score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit)
568
+ scores_diff_json[example.qas_id] = score_diff
569
+ if score_diff > null_score_diff_threshold:
570
+ all_predictions[example.qas_id] = ""
571
+ else:
572
+ all_predictions[example.qas_id] = best_non_null_entry.text
573
+ all_nbest_json[example.qas_id] = nbest_json
574
+
575
+ if output_prediction_file:
576
+ with open(output_prediction_file, "w") as writer:
577
+ writer.write(json.dumps(all_predictions, indent=4) + "\n")
578
+
579
+ if output_nbest_file:
580
+ with open(output_nbest_file, "w") as writer:
581
+ writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
582
+
583
+ if output_null_log_odds_file and version_2_with_negative:
584
+ with open(output_null_log_odds_file, "w") as writer:
585
+ writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
586
+
587
+ return all_predictions
588
+
589
+
590
+ def compute_predictions_log_probs(
591
+ all_examples,
592
+ all_features,
593
+ all_results,
594
+ n_best_size,
595
+ max_answer_length,
596
+ output_prediction_file,
597
+ output_nbest_file,
598
+ output_null_log_odds_file,
599
+ start_n_top,
600
+ end_n_top,
601
+ version_2_with_negative,
602
+ tokenizer,
603
+ verbose_logging,
604
+ ):
605
+ """
606
+ XLNet write prediction logic (more complex than Bert's). Write final predictions to the json file and log-odds of
607
+ null if needed.
608
+
609
+ Requires utils_squad_evaluate.py
610
+ """
611
+ _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
612
+ "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"]
613
+ )
614
+
615
+ _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
616
+ "NbestPrediction", ["text", "start_log_prob", "end_log_prob"]
617
+ )
618
+
619
+ logger.info(f"Writing predictions to: {output_prediction_file}")
620
+
621
+ example_index_to_features = collections.defaultdict(list)
622
+ for feature in all_features:
623
+ example_index_to_features[feature.example_index].append(feature)
624
+
625
+ unique_id_to_result = {}
626
+ for result in all_results:
627
+ unique_id_to_result[result.unique_id] = result
628
+
629
+ all_predictions = collections.OrderedDict()
630
+ all_nbest_json = collections.OrderedDict()
631
+ scores_diff_json = collections.OrderedDict()
632
+
633
+ for example_index, example in enumerate(all_examples):
634
+ features = example_index_to_features[example_index]
635
+
636
+ prelim_predictions = []
637
+ # keep track of the minimum score of null start+end of position 0
638
+ score_null = 1000000 # large and positive
639
+
640
+ for feature_index, feature in enumerate(features):
641
+ result = unique_id_to_result[feature.unique_id]
642
+
643
+ cur_null_score = result.cls_logits
644
+
645
+ # if we could have irrelevant answers, get the min score of irrelevant
646
+ score_null = min(score_null, cur_null_score)
647
+
648
+ for i in range(start_n_top):
649
+ for j in range(end_n_top):
650
+ start_log_prob = result.start_logits[i]
651
+ start_index = result.start_top_index[i]
652
+
653
+ j_index = i * end_n_top + j
654
+
655
+ end_log_prob = result.end_logits[j_index]
656
+ end_index = result.end_top_index[j_index]
657
+
658
+ # We could hypothetically create invalid predictions, e.g., predict
659
+ # that the start of the span is in the question. We throw out all
660
+ # invalid predictions.
661
+ if start_index >= feature.paragraph_len - 1:
662
+ continue
663
+ if end_index >= feature.paragraph_len - 1:
664
+ continue
665
+
666
+ if not feature.token_is_max_context.get(start_index, False):
667
+ continue
668
+ if end_index < start_index:
669
+ continue
670
+ length = end_index - start_index + 1
671
+ if length > max_answer_length:
672
+ continue
673
+
674
+ prelim_predictions.append(
675
+ _PrelimPrediction(
676
+ feature_index=feature_index,
677
+ start_index=start_index,
678
+ end_index=end_index,
679
+ start_log_prob=start_log_prob,
680
+ end_log_prob=end_log_prob,
681
+ )
682
+ )
683
+
684
+ prelim_predictions = sorted(
685
+ prelim_predictions, key=lambda x: (x.start_log_prob + x.end_log_prob), reverse=True
686
+ )
687
+
688
+ seen_predictions = {}
689
+ nbest = []
690
+ for pred in prelim_predictions:
691
+ if len(nbest) >= n_best_size:
692
+ break
693
+ feature = features[pred.feature_index]
694
+
695
+ # XLNet un-tokenizer
696
+ # Let's keep it simple for now and see if we need all this later.
697
+ #
698
+ # tok_start_to_orig_index = feature.tok_start_to_orig_index
699
+ # tok_end_to_orig_index = feature.tok_end_to_orig_index
700
+ # start_orig_pos = tok_start_to_orig_index[pred.start_index]
701
+ # end_orig_pos = tok_end_to_orig_index[pred.end_index]
702
+ # paragraph_text = example.paragraph_text
703
+ # final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
704
+
705
+ # Previously used Bert untokenizer
706
+ tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
707
+ orig_doc_start = feature.token_to_orig_map[pred.start_index]
708
+ orig_doc_end = feature.token_to_orig_map[pred.end_index]
709
+ orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
710
+ tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
711
+
712
+ # Clean whitespace
713
+ tok_text = tok_text.strip()
714
+ tok_text = " ".join(tok_text.split())
715
+ orig_text = " ".join(orig_tokens)
716
+
717
+ if hasattr(tokenizer, "do_lower_case"):
718
+ do_lower_case = tokenizer.do_lower_case
719
+ else:
720
+ do_lower_case = tokenizer.do_lowercase_and_remove_accent
721
+
722
+ final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
723
+
724
+ if final_text in seen_predictions:
725
+ continue
726
+
727
+ seen_predictions[final_text] = True
728
+
729
+ nbest.append(
730
+ _NbestPrediction(text=final_text, start_log_prob=pred.start_log_prob, end_log_prob=pred.end_log_prob)
731
+ )
732
+
733
+ # In very rare edge cases we could have no valid predictions. So we
734
+ # just create a nonce prediction in this case to avoid failure.
735
+ if not nbest:
736
+ nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
737
+
738
+ total_scores = []
739
+ best_non_null_entry = None
740
+ for entry in nbest:
741
+ total_scores.append(entry.start_log_prob + entry.end_log_prob)
742
+ if not best_non_null_entry:
743
+ best_non_null_entry = entry
744
+
745
+ probs = _compute_softmax(total_scores)
746
+
747
+ nbest_json = []
748
+ for i, entry in enumerate(nbest):
749
+ output = collections.OrderedDict()
750
+ output["text"] = entry.text
751
+ output["probability"] = probs[i]
752
+ output["start_log_prob"] = entry.start_log_prob
753
+ output["end_log_prob"] = entry.end_log_prob
754
+ nbest_json.append(output)
755
+
756
+ if len(nbest_json) < 1:
757
+ raise ValueError("No valid predictions")
758
+ if best_non_null_entry is None:
759
+ raise ValueError("No valid predictions")
760
+
761
+ score_diff = score_null
762
+ scores_diff_json[example.qas_id] = score_diff
763
+ # note(zhiliny): always predict best_non_null_entry
764
+ # and the evaluation script will search for the best threshold
765
+ all_predictions[example.qas_id] = best_non_null_entry.text
766
+
767
+ all_nbest_json[example.qas_id] = nbest_json
768
+
769
+ with open(output_prediction_file, "w") as writer:
770
+ writer.write(json.dumps(all_predictions, indent=4) + "\n")
771
+
772
+ with open(output_nbest_file, "w") as writer:
773
+ writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
774
+
775
+ if version_2_with_negative:
776
+ with open(output_null_log_odds_file, "w") as writer:
777
+ writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
778
+
779
+ return all_predictions
.venv/Lib/site-packages/transformers/data/processors/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .glue import glue_convert_examples_to_features, glue_output_modes, glue_processors, glue_tasks_num_labels
16
+ from .squad import SquadExample, SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
17
+ from .utils import DataProcessor, InputExample, InputFeatures, SingleSentenceClassificationProcessor
18
+ from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
.venv/Lib/site-packages/transformers/data/processors/glue.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """GLUE processors and helpers"""
17
+
18
+ import os
19
+ import warnings
20
+ from dataclasses import asdict
21
+ from enum import Enum
22
+ from typing import List, Optional, Union
23
+
24
+ from ...tokenization_utils import PreTrainedTokenizer
25
+ from ...utils import is_tf_available, logging
26
+ from .utils import DataProcessor, InputExample, InputFeatures
27
+
28
+
29
+ if is_tf_available():
30
+ import tensorflow as tf
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ DEPRECATION_WARNING = (
35
+ "This {0} will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
36
+ "library. You can have a look at this example script for pointers: "
37
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
38
+ )
39
+
40
+
41
+ def glue_convert_examples_to_features(
42
+ examples: Union[List[InputExample], "tf.data.Dataset"],
43
+ tokenizer: PreTrainedTokenizer,
44
+ max_length: Optional[int] = None,
45
+ task=None,
46
+ label_list=None,
47
+ output_mode=None,
48
+ ):
49
+ """
50
+ Loads a data file into a list of `InputFeatures`
51
+
52
+ Args:
53
+ examples: List of `InputExamples` or `tf.data.Dataset` containing the examples.
54
+ tokenizer: Instance of a tokenizer that will tokenize the examples
55
+ max_length: Maximum example length. Defaults to the tokenizer's max_len
56
+ task: GLUE task
57
+ label_list: List of labels. Can be obtained from the processor using the `processor.get_labels()` method
58
+ output_mode: String indicating the output mode. Either `regression` or `classification`
59
+
60
+ Returns:
61
+ If the `examples` input is a `tf.data.Dataset`, will return a `tf.data.Dataset` containing the task-specific
62
+ features. If the input is a list of `InputExamples`, will return a list of task-specific `InputFeatures` which
63
+ can be fed to the model.
64
+
65
+ """
66
+ warnings.warn(DEPRECATION_WARNING.format("function"), FutureWarning)
67
+ if is_tf_available() and isinstance(examples, tf.data.Dataset):
68
+ if task is None:
69
+ raise ValueError("When calling glue_convert_examples_to_features from TF, the task parameter is required.")
70
+ return _tf_glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
71
+ return _glue_convert_examples_to_features(
72
+ examples, tokenizer, max_length=max_length, task=task, label_list=label_list, output_mode=output_mode
73
+ )
74
+
75
+
76
+ if is_tf_available():
77
+
78
+ def _tf_glue_convert_examples_to_features(
79
+ examples: tf.data.Dataset,
80
+ tokenizer: PreTrainedTokenizer,
81
+ task=str,
82
+ max_length: Optional[int] = None,
83
+ ) -> tf.data.Dataset:
84
+ """
85
+ Returns:
86
+ A `tf.data.Dataset` containing the task-specific features.
87
+
88
+ """
89
+ processor = glue_processors[task]()
90
+ examples = [processor.tfds_map(processor.get_example_from_tensor_dict(example)) for example in examples]
91
+ features = glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
92
+ label_type = tf.float32 if task == "sts-b" else tf.int64
93
+
94
+ def gen():
95
+ for ex in features:
96
+ d = {k: v for k, v in asdict(ex).items() if v is not None}
97
+ label = d.pop("label")
98
+ yield (d, label)
99
+
100
+ input_names = tokenizer.model_input_names
101
+
102
+ return tf.data.Dataset.from_generator(
103
+ gen,
104
+ ({k: tf.int32 for k in input_names}, label_type),
105
+ ({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])),
106
+ )
107
+
108
+
109
+ def _glue_convert_examples_to_features(
110
+ examples: List[InputExample],
111
+ tokenizer: PreTrainedTokenizer,
112
+ max_length: Optional[int] = None,
113
+ task=None,
114
+ label_list=None,
115
+ output_mode=None,
116
+ ):
117
+ if max_length is None:
118
+ max_length = tokenizer.model_max_length
119
+
120
+ if task is not None:
121
+ processor = glue_processors[task]()
122
+ if label_list is None:
123
+ label_list = processor.get_labels()
124
+ logger.info(f"Using label list {label_list} for task {task}")
125
+ if output_mode is None:
126
+ output_mode = glue_output_modes[task]
127
+ logger.info(f"Using output mode {output_mode} for task {task}")
128
+
129
+ label_map = {label: i for i, label in enumerate(label_list)}
130
+
131
+ def label_from_example(example: InputExample) -> Union[int, float, None]:
132
+ if example.label is None:
133
+ return None
134
+ if output_mode == "classification":
135
+ return label_map[example.label]
136
+ elif output_mode == "regression":
137
+ return float(example.label)
138
+ raise KeyError(output_mode)
139
+
140
+ labels = [label_from_example(example) for example in examples]
141
+
142
+ batch_encoding = tokenizer(
143
+ [(example.text_a, example.text_b) for example in examples],
144
+ max_length=max_length,
145
+ padding="max_length",
146
+ truncation=True,
147
+ )
148
+
149
+ features = []
150
+ for i in range(len(examples)):
151
+ inputs = {k: batch_encoding[k][i] for k in batch_encoding}
152
+
153
+ feature = InputFeatures(**inputs, label=labels[i])
154
+ features.append(feature)
155
+
156
+ for i, example in enumerate(examples[:5]):
157
+ logger.info("*** Example ***")
158
+ logger.info(f"guid: {example.guid}")
159
+ logger.info(f"features: {features[i]}")
160
+
161
+ return features
162
+
163
+
164
+ class OutputMode(Enum):
165
+ classification = "classification"
166
+ regression = "regression"
167
+
168
+
169
+ class MrpcProcessor(DataProcessor):
170
+ """Processor for the MRPC data set (GLUE version)."""
171
+
172
+ def __init__(self, *args, **kwargs):
173
+ super().__init__(*args, **kwargs)
174
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
175
+
176
+ def get_example_from_tensor_dict(self, tensor_dict):
177
+ """See base class."""
178
+ return InputExample(
179
+ tensor_dict["idx"].numpy(),
180
+ tensor_dict["sentence1"].numpy().decode("utf-8"),
181
+ tensor_dict["sentence2"].numpy().decode("utf-8"),
182
+ str(tensor_dict["label"].numpy()),
183
+ )
184
+
185
+ def get_train_examples(self, data_dir):
186
+ """See base class."""
187
+ logger.info(f"LOOKING AT {os.path.join(data_dir, 'train.tsv')}")
188
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
189
+
190
+ def get_dev_examples(self, data_dir):
191
+ """See base class."""
192
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
193
+
194
+ def get_test_examples(self, data_dir):
195
+ """See base class."""
196
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
197
+
198
+ def get_labels(self):
199
+ """See base class."""
200
+ return ["0", "1"]
201
+
202
+ def _create_examples(self, lines, set_type):
203
+ """Creates examples for the training, dev and test sets."""
204
+ examples = []
205
+ for i, line in enumerate(lines):
206
+ if i == 0:
207
+ continue
208
+ guid = f"{set_type}-{i}"
209
+ text_a = line[3]
210
+ text_b = line[4]
211
+ label = None if set_type == "test" else line[0]
212
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
213
+ return examples
214
+
215
+
216
+ class MnliProcessor(DataProcessor):
217
+ """Processor for the MultiNLI data set (GLUE version)."""
218
+
219
+ def __init__(self, *args, **kwargs):
220
+ super().__init__(*args, **kwargs)
221
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
222
+
223
+ def get_example_from_tensor_dict(self, tensor_dict):
224
+ """See base class."""
225
+ return InputExample(
226
+ tensor_dict["idx"].numpy(),
227
+ tensor_dict["premise"].numpy().decode("utf-8"),
228
+ tensor_dict["hypothesis"].numpy().decode("utf-8"),
229
+ str(tensor_dict["label"].numpy()),
230
+ )
231
+
232
+ def get_train_examples(self, data_dir):
233
+ """See base class."""
234
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
235
+
236
+ def get_dev_examples(self, data_dir):
237
+ """See base class."""
238
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
239
+
240
+ def get_test_examples(self, data_dir):
241
+ """See base class."""
242
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched")
243
+
244
+ def get_labels(self):
245
+ """See base class."""
246
+ return ["contradiction", "entailment", "neutral"]
247
+
248
+ def _create_examples(self, lines, set_type):
249
+ """Creates examples for the training, dev and test sets."""
250
+ examples = []
251
+ for i, line in enumerate(lines):
252
+ if i == 0:
253
+ continue
254
+ guid = f"{set_type}-{line[0]}"
255
+ text_a = line[8]
256
+ text_b = line[9]
257
+ label = None if set_type.startswith("test") else line[-1]
258
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
259
+ return examples
260
+
261
+
262
+ class MnliMismatchedProcessor(MnliProcessor):
263
+ """Processor for the MultiNLI Mismatched data set (GLUE version)."""
264
+
265
+ def __init__(self, *args, **kwargs):
266
+ super().__init__(*args, **kwargs)
267
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
268
+
269
+ def get_dev_examples(self, data_dir):
270
+ """See base class."""
271
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
272
+
273
+ def get_test_examples(self, data_dir):
274
+ """See base class."""
275
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched")
276
+
277
+
278
+ class ColaProcessor(DataProcessor):
279
+ """Processor for the CoLA data set (GLUE version)."""
280
+
281
+ def __init__(self, *args, **kwargs):
282
+ super().__init__(*args, **kwargs)
283
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
284
+
285
+ def get_example_from_tensor_dict(self, tensor_dict):
286
+ """See base class."""
287
+ return InputExample(
288
+ tensor_dict["idx"].numpy(),
289
+ tensor_dict["sentence"].numpy().decode("utf-8"),
290
+ None,
291
+ str(tensor_dict["label"].numpy()),
292
+ )
293
+
294
+ def get_train_examples(self, data_dir):
295
+ """See base class."""
296
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
297
+
298
+ def get_dev_examples(self, data_dir):
299
+ """See base class."""
300
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
301
+
302
+ def get_test_examples(self, data_dir):
303
+ """See base class."""
304
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
305
+
306
+ def get_labels(self):
307
+ """See base class."""
308
+ return ["0", "1"]
309
+
310
+ def _create_examples(self, lines, set_type):
311
+ """Creates examples for the training, dev and test sets."""
312
+ test_mode = set_type == "test"
313
+ if test_mode:
314
+ lines = lines[1:]
315
+ text_index = 1 if test_mode else 3
316
+ examples = []
317
+ for i, line in enumerate(lines):
318
+ guid = f"{set_type}-{i}"
319
+ text_a = line[text_index]
320
+ label = None if test_mode else line[1]
321
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
322
+ return examples
323
+
324
+
325
+ class Sst2Processor(DataProcessor):
326
+ """Processor for the SST-2 data set (GLUE version)."""
327
+
328
+ def __init__(self, *args, **kwargs):
329
+ super().__init__(*args, **kwargs)
330
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
331
+
332
+ def get_example_from_tensor_dict(self, tensor_dict):
333
+ """See base class."""
334
+ return InputExample(
335
+ tensor_dict["idx"].numpy(),
336
+ tensor_dict["sentence"].numpy().decode("utf-8"),
337
+ None,
338
+ str(tensor_dict["label"].numpy()),
339
+ )
340
+
341
+ def get_train_examples(self, data_dir):
342
+ """See base class."""
343
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
344
+
345
+ def get_dev_examples(self, data_dir):
346
+ """See base class."""
347
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
348
+
349
+ def get_test_examples(self, data_dir):
350
+ """See base class."""
351
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
352
+
353
+ def get_labels(self):
354
+ """See base class."""
355
+ return ["0", "1"]
356
+
357
+ def _create_examples(self, lines, set_type):
358
+ """Creates examples for the training, dev and test sets."""
359
+ examples = []
360
+ text_index = 1 if set_type == "test" else 0
361
+ for i, line in enumerate(lines):
362
+ if i == 0:
363
+ continue
364
+ guid = f"{set_type}-{i}"
365
+ text_a = line[text_index]
366
+ label = None if set_type == "test" else line[1]
367
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
368
+ return examples
369
+
370
+
371
+ class StsbProcessor(DataProcessor):
372
+ """Processor for the STS-B data set (GLUE version)."""
373
+
374
+ def __init__(self, *args, **kwargs):
375
+ super().__init__(*args, **kwargs)
376
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
377
+
378
+ def get_example_from_tensor_dict(self, tensor_dict):
379
+ """See base class."""
380
+ return InputExample(
381
+ tensor_dict["idx"].numpy(),
382
+ tensor_dict["sentence1"].numpy().decode("utf-8"),
383
+ tensor_dict["sentence2"].numpy().decode("utf-8"),
384
+ str(tensor_dict["label"].numpy()),
385
+ )
386
+
387
+ def get_train_examples(self, data_dir):
388
+ """See base class."""
389
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
390
+
391
+ def get_dev_examples(self, data_dir):
392
+ """See base class."""
393
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
394
+
395
+ def get_test_examples(self, data_dir):
396
+ """See base class."""
397
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
398
+
399
+ def get_labels(self):
400
+ """See base class."""
401
+ return [None]
402
+
403
+ def _create_examples(self, lines, set_type):
404
+ """Creates examples for the training, dev and test sets."""
405
+ examples = []
406
+ for i, line in enumerate(lines):
407
+ if i == 0:
408
+ continue
409
+ guid = f"{set_type}-{line[0]}"
410
+ text_a = line[7]
411
+ text_b = line[8]
412
+ label = None if set_type == "test" else line[-1]
413
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
414
+ return examples
415
+
416
+
417
+ class QqpProcessor(DataProcessor):
418
+ """Processor for the QQP data set (GLUE version)."""
419
+
420
+ def __init__(self, *args, **kwargs):
421
+ super().__init__(*args, **kwargs)
422
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
423
+
424
+ def get_example_from_tensor_dict(self, tensor_dict):
425
+ """See base class."""
426
+ return InputExample(
427
+ tensor_dict["idx"].numpy(),
428
+ tensor_dict["question1"].numpy().decode("utf-8"),
429
+ tensor_dict["question2"].numpy().decode("utf-8"),
430
+ str(tensor_dict["label"].numpy()),
431
+ )
432
+
433
+ def get_train_examples(self, data_dir):
434
+ """See base class."""
435
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
436
+
437
+ def get_dev_examples(self, data_dir):
438
+ """See base class."""
439
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
440
+
441
+ def get_test_examples(self, data_dir):
442
+ """See base class."""
443
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
444
+
445
+ def get_labels(self):
446
+ """See base class."""
447
+ return ["0", "1"]
448
+
449
+ def _create_examples(self, lines, set_type):
450
+ """Creates examples for the training, dev and test sets."""
451
+ test_mode = set_type == "test"
452
+ q1_index = 1 if test_mode else 3
453
+ q2_index = 2 if test_mode else 4
454
+ examples = []
455
+ for i, line in enumerate(lines):
456
+ if i == 0:
457
+ continue
458
+ guid = f"{set_type}-{line[0]}"
459
+ try:
460
+ text_a = line[q1_index]
461
+ text_b = line[q2_index]
462
+ label = None if test_mode else line[5]
463
+ except IndexError:
464
+ continue
465
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
466
+ return examples
467
+
468
+
469
+ class QnliProcessor(DataProcessor):
470
+ """Processor for the QNLI data set (GLUE version)."""
471
+
472
+ def __init__(self, *args, **kwargs):
473
+ super().__init__(*args, **kwargs)
474
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
475
+
476
+ def get_example_from_tensor_dict(self, tensor_dict):
477
+ """See base class."""
478
+ return InputExample(
479
+ tensor_dict["idx"].numpy(),
480
+ tensor_dict["question"].numpy().decode("utf-8"),
481
+ tensor_dict["sentence"].numpy().decode("utf-8"),
482
+ str(tensor_dict["label"].numpy()),
483
+ )
484
+
485
+ def get_train_examples(self, data_dir):
486
+ """See base class."""
487
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
488
+
489
+ def get_dev_examples(self, data_dir):
490
+ """See base class."""
491
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
492
+
493
+ def get_test_examples(self, data_dir):
494
+ """See base class."""
495
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
496
+
497
+ def get_labels(self):
498
+ """See base class."""
499
+ return ["entailment", "not_entailment"]
500
+
501
+ def _create_examples(self, lines, set_type):
502
+ """Creates examples for the training, dev and test sets."""
503
+ examples = []
504
+ for i, line in enumerate(lines):
505
+ if i == 0:
506
+ continue
507
+ guid = f"{set_type}-{line[0]}"
508
+ text_a = line[1]
509
+ text_b = line[2]
510
+ label = None if set_type == "test" else line[-1]
511
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
512
+ return examples
513
+
514
+
515
+ class RteProcessor(DataProcessor):
516
+ """Processor for the RTE data set (GLUE version)."""
517
+
518
+ def __init__(self, *args, **kwargs):
519
+ super().__init__(*args, **kwargs)
520
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
521
+
522
+ def get_example_from_tensor_dict(self, tensor_dict):
523
+ """See base class."""
524
+ return InputExample(
525
+ tensor_dict["idx"].numpy(),
526
+ tensor_dict["sentence1"].numpy().decode("utf-8"),
527
+ tensor_dict["sentence2"].numpy().decode("utf-8"),
528
+ str(tensor_dict["label"].numpy()),
529
+ )
530
+
531
+ def get_train_examples(self, data_dir):
532
+ """See base class."""
533
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
534
+
535
+ def get_dev_examples(self, data_dir):
536
+ """See base class."""
537
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
538
+
539
+ def get_test_examples(self, data_dir):
540
+ """See base class."""
541
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
542
+
543
+ def get_labels(self):
544
+ """See base class."""
545
+ return ["entailment", "not_entailment"]
546
+
547
+ def _create_examples(self, lines, set_type):
548
+ """Creates examples for the training, dev and test sets."""
549
+ examples = []
550
+ for i, line in enumerate(lines):
551
+ if i == 0:
552
+ continue
553
+ guid = f"{set_type}-{line[0]}"
554
+ text_a = line[1]
555
+ text_b = line[2]
556
+ label = None if set_type == "test" else line[-1]
557
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
558
+ return examples
559
+
560
+
561
+ class WnliProcessor(DataProcessor):
562
+ """Processor for the WNLI data set (GLUE version)."""
563
+
564
+ def __init__(self, *args, **kwargs):
565
+ super().__init__(*args, **kwargs)
566
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
567
+
568
+ def get_example_from_tensor_dict(self, tensor_dict):
569
+ """See base class."""
570
+ return InputExample(
571
+ tensor_dict["idx"].numpy(),
572
+ tensor_dict["sentence1"].numpy().decode("utf-8"),
573
+ tensor_dict["sentence2"].numpy().decode("utf-8"),
574
+ str(tensor_dict["label"].numpy()),
575
+ )
576
+
577
+ def get_train_examples(self, data_dir):
578
+ """See base class."""
579
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
580
+
581
+ def get_dev_examples(self, data_dir):
582
+ """See base class."""
583
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
584
+
585
+ def get_test_examples(self, data_dir):
586
+ """See base class."""
587
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
588
+
589
+ def get_labels(self):
590
+ """See base class."""
591
+ return ["0", "1"]
592
+
593
+ def _create_examples(self, lines, set_type):
594
+ """Creates examples for the training, dev and test sets."""
595
+ examples = []
596
+ for i, line in enumerate(lines):
597
+ if i == 0:
598
+ continue
599
+ guid = f"{set_type}-{line[0]}"
600
+ text_a = line[1]
601
+ text_b = line[2]
602
+ label = None if set_type == "test" else line[-1]
603
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
604
+ return examples
605
+
606
+
607
+ glue_tasks_num_labels = {
608
+ "cola": 2,
609
+ "mnli": 3,
610
+ "mrpc": 2,
611
+ "sst-2": 2,
612
+ "sts-b": 1,
613
+ "qqp": 2,
614
+ "qnli": 2,
615
+ "rte": 2,
616
+ "wnli": 2,
617
+ }
618
+
619
+ glue_processors = {
620
+ "cola": ColaProcessor,
621
+ "mnli": MnliProcessor,
622
+ "mnli-mm": MnliMismatchedProcessor,
623
+ "mrpc": MrpcProcessor,
624
+ "sst-2": Sst2Processor,
625
+ "sts-b": StsbProcessor,
626
+ "qqp": QqpProcessor,
627
+ "qnli": QnliProcessor,
628
+ "rte": RteProcessor,
629
+ "wnli": WnliProcessor,
630
+ }
631
+
632
+ glue_output_modes = {
633
+ "cola": "classification",
634
+ "mnli": "classification",
635
+ "mnli-mm": "classification",
636
+ "mrpc": "classification",
637
+ "sst-2": "classification",
638
+ "sts-b": "regression",
639
+ "qqp": "classification",
640
+ "qnli": "classification",
641
+ "rte": "classification",
642
+ "wnli": "classification",
643
+ }
.venv/Lib/site-packages/transformers/data/processors/squad.py ADDED
@@ -0,0 +1,845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ from functools import partial
18
+ from multiprocessing import Pool, cpu_count
19
+
20
+ import numpy as np
21
+ from tqdm import tqdm
22
+
23
+ from ...models.bert.tokenization_bert import whitespace_tokenize
24
+ from ...tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TruncationStrategy
25
+ from ...utils import is_tf_available, is_torch_available, logging
26
+ from .utils import DataProcessor
27
+
28
+
29
+ # Store the tokenizers which insert 2 separators tokens
30
+ MULTI_SEP_TOKENS_TOKENIZERS_SET = {"roberta", "camembert", "bart", "mpnet"}
31
+
32
+
33
+ if is_torch_available():
34
+ import torch
35
+ from torch.utils.data import TensorDataset
36
+
37
+ if is_tf_available():
38
+ import tensorflow as tf
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+
43
+ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
44
+ """Returns tokenized answer spans that better match the annotated answer."""
45
+ tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
46
+
47
+ for new_start in range(input_start, input_end + 1):
48
+ for new_end in range(input_end, new_start - 1, -1):
49
+ text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
50
+ if text_span == tok_answer_text:
51
+ return (new_start, new_end)
52
+
53
+ return (input_start, input_end)
54
+
55
+
56
+ def _check_is_max_context(doc_spans, cur_span_index, position):
57
+ """Check if this is the 'max context' doc span for the token."""
58
+ best_score = None
59
+ best_span_index = None
60
+ for span_index, doc_span in enumerate(doc_spans):
61
+ end = doc_span.start + doc_span.length - 1
62
+ if position < doc_span.start:
63
+ continue
64
+ if position > end:
65
+ continue
66
+ num_left_context = position - doc_span.start
67
+ num_right_context = end - position
68
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
69
+ if best_score is None or score > best_score:
70
+ best_score = score
71
+ best_span_index = span_index
72
+
73
+ return cur_span_index == best_span_index
74
+
75
+
76
+ def _new_check_is_max_context(doc_spans, cur_span_index, position):
77
+ """Check if this is the 'max context' doc span for the token."""
78
+ # if len(doc_spans) == 1:
79
+ # return True
80
+ best_score = None
81
+ best_span_index = None
82
+ for span_index, doc_span in enumerate(doc_spans):
83
+ end = doc_span["start"] + doc_span["length"] - 1
84
+ if position < doc_span["start"]:
85
+ continue
86
+ if position > end:
87
+ continue
88
+ num_left_context = position - doc_span["start"]
89
+ num_right_context = end - position
90
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span["length"]
91
+ if best_score is None or score > best_score:
92
+ best_score = score
93
+ best_span_index = span_index
94
+
95
+ return cur_span_index == best_span_index
96
+
97
+
98
+ def _is_whitespace(c):
99
+ if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
100
+ return True
101
+ return False
102
+
103
+
104
+ def squad_convert_example_to_features(
105
+ example, max_seq_length, doc_stride, max_query_length, padding_strategy, is_training
106
+ ):
107
+ features = []
108
+ if is_training and not example.is_impossible:
109
+ # Get start and end position
110
+ start_position = example.start_position
111
+ end_position = example.end_position
112
+
113
+ # If the answer cannot be found in the text, then skip this example.
114
+ actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
115
+ cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
116
+ if actual_text.find(cleaned_answer_text) == -1:
117
+ logger.warning(f"Could not find answer: '{actual_text}' vs. '{cleaned_answer_text}'")
118
+ return []
119
+
120
+ tok_to_orig_index = []
121
+ orig_to_tok_index = []
122
+ all_doc_tokens = []
123
+ for i, token in enumerate(example.doc_tokens):
124
+ orig_to_tok_index.append(len(all_doc_tokens))
125
+ if tokenizer.__class__.__name__ in [
126
+ "RobertaTokenizer",
127
+ "LongformerTokenizer",
128
+ "BartTokenizer",
129
+ "RobertaTokenizerFast",
130
+ "LongformerTokenizerFast",
131
+ "BartTokenizerFast",
132
+ ]:
133
+ sub_tokens = tokenizer.tokenize(token, add_prefix_space=True)
134
+ else:
135
+ sub_tokens = tokenizer.tokenize(token)
136
+ for sub_token in sub_tokens:
137
+ tok_to_orig_index.append(i)
138
+ all_doc_tokens.append(sub_token)
139
+
140
+ if is_training and not example.is_impossible:
141
+ tok_start_position = orig_to_tok_index[example.start_position]
142
+ if example.end_position < len(example.doc_tokens) - 1:
143
+ tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
144
+ else:
145
+ tok_end_position = len(all_doc_tokens) - 1
146
+
147
+ (tok_start_position, tok_end_position) = _improve_answer_span(
148
+ all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text
149
+ )
150
+
151
+ spans = []
152
+
153
+ truncated_query = tokenizer.encode(
154
+ example.question_text, add_special_tokens=False, truncation=True, max_length=max_query_length
155
+ )
156
+
157
+ # Tokenizers who insert 2 SEP tokens in-between <context> & <question> need to have special handling
158
+ # in the way they compute mask of added tokens.
159
+ tokenizer_type = type(tokenizer).__name__.replace("Tokenizer", "").lower()
160
+ sequence_added_tokens = (
161
+ tokenizer.model_max_length - tokenizer.max_len_single_sentence + 1
162
+ if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET
163
+ else tokenizer.model_max_length - tokenizer.max_len_single_sentence
164
+ )
165
+ sequence_pair_added_tokens = tokenizer.model_max_length - tokenizer.max_len_sentences_pair
166
+
167
+ span_doc_tokens = all_doc_tokens
168
+ while len(spans) * doc_stride < len(all_doc_tokens):
169
+ # Define the side we want to truncate / pad and the text/pair sorting
170
+ if tokenizer.padding_side == "right":
171
+ texts = truncated_query
172
+ pairs = span_doc_tokens
173
+ truncation = TruncationStrategy.ONLY_SECOND.value
174
+ else:
175
+ texts = span_doc_tokens
176
+ pairs = truncated_query
177
+ truncation = TruncationStrategy.ONLY_FIRST.value
178
+
179
+ encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic
180
+ texts,
181
+ pairs,
182
+ truncation=truncation,
183
+ padding=padding_strategy,
184
+ max_length=max_seq_length,
185
+ return_overflowing_tokens=True,
186
+ stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
187
+ return_token_type_ids=True,
188
+ )
189
+
190
+ paragraph_len = min(
191
+ len(all_doc_tokens) - len(spans) * doc_stride,
192
+ max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
193
+ )
194
+
195
+ if tokenizer.pad_token_id in encoded_dict["input_ids"]:
196
+ if tokenizer.padding_side == "right":
197
+ non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
198
+ else:
199
+ last_padding_id_position = (
200
+ len(encoded_dict["input_ids"]) - 1 - encoded_dict["input_ids"][::-1].index(tokenizer.pad_token_id)
201
+ )
202
+ non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1 :]
203
+
204
+ else:
205
+ non_padded_ids = encoded_dict["input_ids"]
206
+
207
+ tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
208
+
209
+ token_to_orig_map = {}
210
+ for i in range(paragraph_len):
211
+ index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i
212
+ token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
213
+
214
+ encoded_dict["paragraph_len"] = paragraph_len
215
+ encoded_dict["tokens"] = tokens
216
+ encoded_dict["token_to_orig_map"] = token_to_orig_map
217
+ encoded_dict["truncated_query_with_special_tokens_length"] = len(truncated_query) + sequence_added_tokens
218
+ encoded_dict["token_is_max_context"] = {}
219
+ encoded_dict["start"] = len(spans) * doc_stride
220
+ encoded_dict["length"] = paragraph_len
221
+
222
+ spans.append(encoded_dict)
223
+
224
+ if "overflowing_tokens" not in encoded_dict or (
225
+ "overflowing_tokens" in encoded_dict and len(encoded_dict["overflowing_tokens"]) == 0
226
+ ):
227
+ break
228
+ span_doc_tokens = encoded_dict["overflowing_tokens"]
229
+
230
+ for doc_span_index in range(len(spans)):
231
+ for j in range(spans[doc_span_index]["paragraph_len"]):
232
+ is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
233
+ index = (
234
+ j
235
+ if tokenizer.padding_side == "left"
236
+ else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
237
+ )
238
+ spans[doc_span_index]["token_is_max_context"][index] = is_max_context
239
+
240
+ for span in spans:
241
+ # Identify the position of the CLS token
242
+ cls_index = span["input_ids"].index(tokenizer.cls_token_id)
243
+
244
+ # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
245
+ # Original TF implementation also keep the classification token (set to 0)
246
+ p_mask = np.ones_like(span["token_type_ids"])
247
+ if tokenizer.padding_side == "right":
248
+ p_mask[len(truncated_query) + sequence_added_tokens :] = 0
249
+ else:
250
+ p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0
251
+
252
+ pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id)
253
+ special_token_indices = np.asarray(
254
+ tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True)
255
+ ).nonzero()
256
+
257
+ p_mask[pad_token_indices] = 1
258
+ p_mask[special_token_indices] = 1
259
+
260
+ # Set the cls index to 0: the CLS index can be used for impossible answers
261
+ p_mask[cls_index] = 0
262
+
263
+ span_is_impossible = example.is_impossible
264
+ start_position = 0
265
+ end_position = 0
266
+ if is_training and not span_is_impossible:
267
+ # For training, if our document chunk does not contain an annotation
268
+ # we throw it out, since there is nothing to predict.
269
+ doc_start = span["start"]
270
+ doc_end = span["start"] + span["length"] - 1
271
+ out_of_span = False
272
+
273
+ if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
274
+ out_of_span = True
275
+
276
+ if out_of_span:
277
+ start_position = cls_index
278
+ end_position = cls_index
279
+ span_is_impossible = True
280
+ else:
281
+ if tokenizer.padding_side == "left":
282
+ doc_offset = 0
283
+ else:
284
+ doc_offset = len(truncated_query) + sequence_added_tokens
285
+
286
+ start_position = tok_start_position - doc_start + doc_offset
287
+ end_position = tok_end_position - doc_start + doc_offset
288
+
289
+ features.append(
290
+ SquadFeatures(
291
+ span["input_ids"],
292
+ span["attention_mask"],
293
+ span["token_type_ids"],
294
+ cls_index,
295
+ p_mask.tolist(),
296
+ example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing.
297
+ unique_id=0,
298
+ paragraph_len=span["paragraph_len"],
299
+ token_is_max_context=span["token_is_max_context"],
300
+ tokens=span["tokens"],
301
+ token_to_orig_map=span["token_to_orig_map"],
302
+ start_position=start_position,
303
+ end_position=end_position,
304
+ is_impossible=span_is_impossible,
305
+ qas_id=example.qas_id,
306
+ )
307
+ )
308
+ return features
309
+
310
+
311
+ def squad_convert_example_to_features_init(tokenizer_for_convert: PreTrainedTokenizerBase):
312
+ global tokenizer
313
+ tokenizer = tokenizer_for_convert
314
+
315
+
316
+ def squad_convert_examples_to_features(
317
+ examples,
318
+ tokenizer,
319
+ max_seq_length,
320
+ doc_stride,
321
+ max_query_length,
322
+ is_training,
323
+ padding_strategy="max_length",
324
+ return_dataset=False,
325
+ threads=1,
326
+ tqdm_enabled=True,
327
+ ):
328
+ """
329
+ Converts a list of examples into a list of features that can be directly given as input to a model. It is
330
+ model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
331
+
332
+ Args:
333
+ examples: list of [`~data.processors.squad.SquadExample`]
334
+ tokenizer: an instance of a child of [`PreTrainedTokenizer`]
335
+ max_seq_length: The maximum sequence length of the inputs.
336
+ doc_stride: The stride used when the context is too large and is split across several features.
337
+ max_query_length: The maximum length of the query.
338
+ is_training: whether to create features for model evaluation or model training.
339
+ padding_strategy: Default to "max_length". Which padding strategy to use
340
+ return_dataset: Default False. Either 'pt' or 'tf'.
341
+ if 'pt': returns a torch.data.TensorDataset, if 'tf': returns a tf.data.Dataset
342
+ threads: multiple processing threads.
343
+
344
+
345
+ Returns:
346
+ list of [`~data.processors.squad.SquadFeatures`]
347
+
348
+ Example:
349
+
350
+ ```python
351
+ processor = SquadV2Processor()
352
+ examples = processor.get_dev_examples(data_dir)
353
+
354
+ features = squad_convert_examples_to_features(
355
+ examples=examples,
356
+ tokenizer=tokenizer,
357
+ max_seq_length=args.max_seq_length,
358
+ doc_stride=args.doc_stride,
359
+ max_query_length=args.max_query_length,
360
+ is_training=not evaluate,
361
+ )
362
+ ```"""
363
+ # Defining helper methods
364
+ features = []
365
+
366
+ threads = min(threads, cpu_count())
367
+ with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
368
+ annotate_ = partial(
369
+ squad_convert_example_to_features,
370
+ max_seq_length=max_seq_length,
371
+ doc_stride=doc_stride,
372
+ max_query_length=max_query_length,
373
+ padding_strategy=padding_strategy,
374
+ is_training=is_training,
375
+ )
376
+ features = list(
377
+ tqdm(
378
+ p.imap(annotate_, examples, chunksize=32),
379
+ total=len(examples),
380
+ desc="convert squad examples to features",
381
+ disable=not tqdm_enabled,
382
+ )
383
+ )
384
+
385
+ new_features = []
386
+ unique_id = 1000000000
387
+ example_index = 0
388
+ for example_features in tqdm(
389
+ features, total=len(features), desc="add example index and unique id", disable=not tqdm_enabled
390
+ ):
391
+ if not example_features:
392
+ continue
393
+ for example_feature in example_features:
394
+ example_feature.example_index = example_index
395
+ example_feature.unique_id = unique_id
396
+ new_features.append(example_feature)
397
+ unique_id += 1
398
+ example_index += 1
399
+ features = new_features
400
+ del new_features
401
+ if return_dataset == "pt":
402
+ if not is_torch_available():
403
+ raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.")
404
+
405
+ # Convert to Tensors and build dataset
406
+ all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
407
+ all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
408
+ all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
409
+ all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
410
+ all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
411
+ all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)
412
+
413
+ if not is_training:
414
+ all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
415
+ dataset = TensorDataset(
416
+ all_input_ids, all_attention_masks, all_token_type_ids, all_feature_index, all_cls_index, all_p_mask
417
+ )
418
+ else:
419
+ all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
420
+ all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
421
+ dataset = TensorDataset(
422
+ all_input_ids,
423
+ all_attention_masks,
424
+ all_token_type_ids,
425
+ all_start_positions,
426
+ all_end_positions,
427
+ all_cls_index,
428
+ all_p_mask,
429
+ all_is_impossible,
430
+ )
431
+
432
+ return features, dataset
433
+ elif return_dataset == "tf":
434
+ if not is_tf_available():
435
+ raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.")
436
+
437
+ def gen():
438
+ for i, ex in enumerate(features):
439
+ if ex.token_type_ids is None:
440
+ yield (
441
+ {
442
+ "input_ids": ex.input_ids,
443
+ "attention_mask": ex.attention_mask,
444
+ "feature_index": i,
445
+ "qas_id": ex.qas_id,
446
+ },
447
+ {
448
+ "start_positions": ex.start_position,
449
+ "end_positions": ex.end_position,
450
+ "cls_index": ex.cls_index,
451
+ "p_mask": ex.p_mask,
452
+ "is_impossible": ex.is_impossible,
453
+ },
454
+ )
455
+ else:
456
+ yield (
457
+ {
458
+ "input_ids": ex.input_ids,
459
+ "attention_mask": ex.attention_mask,
460
+ "token_type_ids": ex.token_type_ids,
461
+ "feature_index": i,
462
+ "qas_id": ex.qas_id,
463
+ },
464
+ {
465
+ "start_positions": ex.start_position,
466
+ "end_positions": ex.end_position,
467
+ "cls_index": ex.cls_index,
468
+ "p_mask": ex.p_mask,
469
+ "is_impossible": ex.is_impossible,
470
+ },
471
+ )
472
+
473
+ # Why have we split the batch into a tuple? PyTorch just has a list of tensors.
474
+ if "token_type_ids" in tokenizer.model_input_names:
475
+ train_types = (
476
+ {
477
+ "input_ids": tf.int32,
478
+ "attention_mask": tf.int32,
479
+ "token_type_ids": tf.int32,
480
+ "feature_index": tf.int64,
481
+ "qas_id": tf.string,
482
+ },
483
+ {
484
+ "start_positions": tf.int64,
485
+ "end_positions": tf.int64,
486
+ "cls_index": tf.int64,
487
+ "p_mask": tf.int32,
488
+ "is_impossible": tf.int32,
489
+ },
490
+ )
491
+
492
+ train_shapes = (
493
+ {
494
+ "input_ids": tf.TensorShape([None]),
495
+ "attention_mask": tf.TensorShape([None]),
496
+ "token_type_ids": tf.TensorShape([None]),
497
+ "feature_index": tf.TensorShape([]),
498
+ "qas_id": tf.TensorShape([]),
499
+ },
500
+ {
501
+ "start_positions": tf.TensorShape([]),
502
+ "end_positions": tf.TensorShape([]),
503
+ "cls_index": tf.TensorShape([]),
504
+ "p_mask": tf.TensorShape([None]),
505
+ "is_impossible": tf.TensorShape([]),
506
+ },
507
+ )
508
+ else:
509
+ train_types = (
510
+ {"input_ids": tf.int32, "attention_mask": tf.int32, "feature_index": tf.int64, "qas_id": tf.string},
511
+ {
512
+ "start_positions": tf.int64,
513
+ "end_positions": tf.int64,
514
+ "cls_index": tf.int64,
515
+ "p_mask": tf.int32,
516
+ "is_impossible": tf.int32,
517
+ },
518
+ )
519
+
520
+ train_shapes = (
521
+ {
522
+ "input_ids": tf.TensorShape([None]),
523
+ "attention_mask": tf.TensorShape([None]),
524
+ "feature_index": tf.TensorShape([]),
525
+ "qas_id": tf.TensorShape([]),
526
+ },
527
+ {
528
+ "start_positions": tf.TensorShape([]),
529
+ "end_positions": tf.TensorShape([]),
530
+ "cls_index": tf.TensorShape([]),
531
+ "p_mask": tf.TensorShape([None]),
532
+ "is_impossible": tf.TensorShape([]),
533
+ },
534
+ )
535
+
536
+ return tf.data.Dataset.from_generator(gen, train_types, train_shapes)
537
+ else:
538
+ return features
539
+
540
+
541
+ class SquadProcessor(DataProcessor):
542
+ """
543
+ Processor for the SQuAD data set. overridden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and
544
+ version 2.0 of SQuAD, respectively.
545
+ """
546
+
547
+ train_file = None
548
+ dev_file = None
549
+
550
+ def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False):
551
+ if not evaluate:
552
+ answer = tensor_dict["answers"]["text"][0].numpy().decode("utf-8")
553
+ answer_start = tensor_dict["answers"]["answer_start"][0].numpy()
554
+ answers = []
555
+ else:
556
+ answers = [
557
+ {"answer_start": start.numpy(), "text": text.numpy().decode("utf-8")}
558
+ for start, text in zip(tensor_dict["answers"]["answer_start"], tensor_dict["answers"]["text"])
559
+ ]
560
+
561
+ answer = None
562
+ answer_start = None
563
+
564
+ return SquadExample(
565
+ qas_id=tensor_dict["id"].numpy().decode("utf-8"),
566
+ question_text=tensor_dict["question"].numpy().decode("utf-8"),
567
+ context_text=tensor_dict["context"].numpy().decode("utf-8"),
568
+ answer_text=answer,
569
+ start_position_character=answer_start,
570
+ title=tensor_dict["title"].numpy().decode("utf-8"),
571
+ answers=answers,
572
+ )
573
+
574
+ def get_examples_from_dataset(self, dataset, evaluate=False):
575
+ """
576
+ Creates a list of [`~data.processors.squad.SquadExample`] using a TFDS dataset.
577
+
578
+ Args:
579
+ dataset: The tfds dataset loaded from *tensorflow_datasets.load("squad")*
580
+ evaluate: Boolean specifying if in evaluation mode or in training mode
581
+
582
+ Returns:
583
+ List of SquadExample
584
+
585
+ Examples:
586
+
587
+ ```python
588
+ >>> import tensorflow_datasets as tfds
589
+
590
+ >>> dataset = tfds.load("squad")
591
+
592
+ >>> training_examples = get_examples_from_dataset(dataset, evaluate=False)
593
+ >>> evaluation_examples = get_examples_from_dataset(dataset, evaluate=True)
594
+ ```"""
595
+
596
+ if evaluate:
597
+ dataset = dataset["validation"]
598
+ else:
599
+ dataset = dataset["train"]
600
+
601
+ examples = []
602
+ for tensor_dict in tqdm(dataset):
603
+ examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate))
604
+
605
+ return examples
606
+
607
+ def get_train_examples(self, data_dir, filename=None):
608
+ """
609
+ Returns the training examples from the data directory.
610
+
611
+ Args:
612
+ data_dir: Directory containing the data files used for training and evaluating.
613
+ filename: None by default, specify this if the training file has a different name than the original one
614
+ which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
615
+
616
+ """
617
+ if data_dir is None:
618
+ data_dir = ""
619
+
620
+ if self.train_file is None:
621
+ raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
622
+
623
+ with open(
624
+ os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8"
625
+ ) as reader:
626
+ input_data = json.load(reader)["data"]
627
+ return self._create_examples(input_data, "train")
628
+
629
+ def get_dev_examples(self, data_dir, filename=None):
630
+ """
631
+ Returns the evaluation example from the data directory.
632
+
633
+ Args:
634
+ data_dir: Directory containing the data files used for training and evaluating.
635
+ filename: None by default, specify this if the evaluation file has a different name than the original one
636
+ which is `dev-v1.1.json` and `dev-v2.0.json` for squad versions 1.1 and 2.0 respectively.
637
+ """
638
+ if data_dir is None:
639
+ data_dir = ""
640
+
641
+ if self.dev_file is None:
642
+ raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
643
+
644
+ with open(
645
+ os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8"
646
+ ) as reader:
647
+ input_data = json.load(reader)["data"]
648
+ return self._create_examples(input_data, "dev")
649
+
650
+ def _create_examples(self, input_data, set_type):
651
+ is_training = set_type == "train"
652
+ examples = []
653
+ for entry in tqdm(input_data):
654
+ title = entry["title"]
655
+ for paragraph in entry["paragraphs"]:
656
+ context_text = paragraph["context"]
657
+ for qa in paragraph["qas"]:
658
+ qas_id = qa["id"]
659
+ question_text = qa["question"]
660
+ start_position_character = None
661
+ answer_text = None
662
+ answers = []
663
+
664
+ is_impossible = qa.get("is_impossible", False)
665
+ if not is_impossible:
666
+ if is_training:
667
+ answer = qa["answers"][0]
668
+ answer_text = answer["text"]
669
+ start_position_character = answer["answer_start"]
670
+ else:
671
+ answers = qa["answers"]
672
+
673
+ example = SquadExample(
674
+ qas_id=qas_id,
675
+ question_text=question_text,
676
+ context_text=context_text,
677
+ answer_text=answer_text,
678
+ start_position_character=start_position_character,
679
+ title=title,
680
+ is_impossible=is_impossible,
681
+ answers=answers,
682
+ )
683
+ examples.append(example)
684
+ return examples
685
+
686
+
687
+ class SquadV1Processor(SquadProcessor):
688
+ train_file = "train-v1.1.json"
689
+ dev_file = "dev-v1.1.json"
690
+
691
+
692
+ class SquadV2Processor(SquadProcessor):
693
+ train_file = "train-v2.0.json"
694
+ dev_file = "dev-v2.0.json"
695
+
696
+
697
+ class SquadExample:
698
+ """
699
+ A single training/test example for the Squad dataset, as loaded from disk.
700
+
701
+ Args:
702
+ qas_id: The example's unique identifier
703
+ question_text: The question string
704
+ context_text: The context string
705
+ answer_text: The answer string
706
+ start_position_character: The character position of the start of the answer
707
+ title: The title of the example
708
+ answers: None by default, this is used during evaluation. Holds answers as well as their start positions.
709
+ is_impossible: False by default, set to True if the example has no possible answer.
710
+ """
711
+
712
+ def __init__(
713
+ self,
714
+ qas_id,
715
+ question_text,
716
+ context_text,
717
+ answer_text,
718
+ start_position_character,
719
+ title,
720
+ answers=[],
721
+ is_impossible=False,
722
+ ):
723
+ self.qas_id = qas_id
724
+ self.question_text = question_text
725
+ self.context_text = context_text
726
+ self.answer_text = answer_text
727
+ self.title = title
728
+ self.is_impossible = is_impossible
729
+ self.answers = answers
730
+
731
+ self.start_position, self.end_position = 0, 0
732
+
733
+ doc_tokens = []
734
+ char_to_word_offset = []
735
+ prev_is_whitespace = True
736
+
737
+ # Split on whitespace so that different tokens may be attributed to their original position.
738
+ for c in self.context_text:
739
+ if _is_whitespace(c):
740
+ prev_is_whitespace = True
741
+ else:
742
+ if prev_is_whitespace:
743
+ doc_tokens.append(c)
744
+ else:
745
+ doc_tokens[-1] += c
746
+ prev_is_whitespace = False
747
+ char_to_word_offset.append(len(doc_tokens) - 1)
748
+
749
+ self.doc_tokens = doc_tokens
750
+ self.char_to_word_offset = char_to_word_offset
751
+
752
+ # Start and end positions only has a value during evaluation.
753
+ if start_position_character is not None and not is_impossible:
754
+ self.start_position = char_to_word_offset[start_position_character]
755
+ self.end_position = char_to_word_offset[
756
+ min(start_position_character + len(answer_text) - 1, len(char_to_word_offset) - 1)
757
+ ]
758
+
759
+
760
+ class SquadFeatures:
761
+ """
762
+ Single squad example features to be fed to a model. Those features are model-specific and can be crafted from
763
+ [`~data.processors.squad.SquadExample`] using the
764
+ :method:*~transformers.data.processors.squad.squad_convert_examples_to_features* method.
765
+
766
+ Args:
767
+ input_ids: Indices of input sequence tokens in the vocabulary.
768
+ attention_mask: Mask to avoid performing attention on padding token indices.
769
+ token_type_ids: Segment token indices to indicate first and second portions of the inputs.
770
+ cls_index: the index of the CLS token.
771
+ p_mask: Mask identifying tokens that can be answers vs. tokens that cannot.
772
+ Mask with 1 for tokens than cannot be in the answer and 0 for token that can be in an answer
773
+ example_index: the index of the example
774
+ unique_id: The unique Feature identifier
775
+ paragraph_len: The length of the context
776
+ token_is_max_context:
777
+ List of booleans identifying which tokens have their maximum context in this feature object. If a token
778
+ does not have their maximum context in this feature object, it means that another feature object has more
779
+ information related to that token and should be prioritized over this feature for that token.
780
+ tokens: list of tokens corresponding to the input ids
781
+ token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer.
782
+ start_position: start of the answer token index
783
+ end_position: end of the answer token index
784
+ encoding: optionally store the BatchEncoding with the fast-tokenizer alignment methods.
785
+ """
786
+
787
+ def __init__(
788
+ self,
789
+ input_ids,
790
+ attention_mask,
791
+ token_type_ids,
792
+ cls_index,
793
+ p_mask,
794
+ example_index,
795
+ unique_id,
796
+ paragraph_len,
797
+ token_is_max_context,
798
+ tokens,
799
+ token_to_orig_map,
800
+ start_position,
801
+ end_position,
802
+ is_impossible,
803
+ qas_id: str = None,
804
+ encoding: BatchEncoding = None,
805
+ ):
806
+ self.input_ids = input_ids
807
+ self.attention_mask = attention_mask
808
+ self.token_type_ids = token_type_ids
809
+ self.cls_index = cls_index
810
+ self.p_mask = p_mask
811
+
812
+ self.example_index = example_index
813
+ self.unique_id = unique_id
814
+ self.paragraph_len = paragraph_len
815
+ self.token_is_max_context = token_is_max_context
816
+ self.tokens = tokens
817
+ self.token_to_orig_map = token_to_orig_map
818
+
819
+ self.start_position = start_position
820
+ self.end_position = end_position
821
+ self.is_impossible = is_impossible
822
+ self.qas_id = qas_id
823
+
824
+ self.encoding = encoding
825
+
826
+
827
+ class SquadResult:
828
+ """
829
+ Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.
830
+
831
+ Args:
832
+ unique_id: The unique identifier corresponding to that example.
833
+ start_logits: The logits corresponding to the start of the answer
834
+ end_logits: The logits corresponding to the end of the answer
835
+ """
836
+
837
+ def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None):
838
+ self.start_logits = start_logits
839
+ self.end_logits = end_logits
840
+ self.unique_id = unique_id
841
+
842
+ if start_top_index:
843
+ self.start_top_index = start_top_index
844
+ self.end_top_index = end_top_index
845
+ self.cls_logits = cls_logits
.venv/Lib/site-packages/transformers/data/processors/utils.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import csv
18
+ import dataclasses
19
+ import json
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Union
22
+
23
+ from ...utils import is_tf_available, is_torch_available, logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ @dataclass
30
+ class InputExample:
31
+ """
32
+ A single training/test example for simple sequence classification.
33
+
34
+ Args:
35
+ guid: Unique id for the example.
36
+ text_a: string. The untokenized text of the first sequence. For single
37
+ sequence tasks, only this sequence must be specified.
38
+ text_b: (Optional) string. The untokenized text of the second sequence.
39
+ Only must be specified for sequence pair tasks.
40
+ label: (Optional) string. The label of the example. This should be
41
+ specified for train and dev examples, but not for test examples.
42
+ """
43
+
44
+ guid: str
45
+ text_a: str
46
+ text_b: Optional[str] = None
47
+ label: Optional[str] = None
48
+
49
+ def to_json_string(self):
50
+ """Serializes this instance to a JSON string."""
51
+ return json.dumps(dataclasses.asdict(self), indent=2) + "\n"
52
+
53
+
54
+ @dataclass(frozen=True)
55
+ class InputFeatures:
56
+ """
57
+ A single set of features of data. Property names are the same names as the corresponding inputs to a model.
58
+
59
+ Args:
60
+ input_ids: Indices of input sequence tokens in the vocabulary.
61
+ attention_mask: Mask to avoid performing attention on padding token indices.
62
+ Mask values selected in `[0, 1]`: Usually `1` for tokens that are NOT MASKED, `0` for MASKED (padded)
63
+ tokens.
64
+ token_type_ids: (Optional) Segment token indices to indicate first and second
65
+ portions of the inputs. Only some models use them.
66
+ label: (Optional) Label corresponding to the input. Int for classification problems,
67
+ float for regression problems.
68
+ """
69
+
70
+ input_ids: List[int]
71
+ attention_mask: Optional[List[int]] = None
72
+ token_type_ids: Optional[List[int]] = None
73
+ label: Optional[Union[int, float]] = None
74
+
75
+ def to_json_string(self):
76
+ """Serializes this instance to a JSON string."""
77
+ return json.dumps(dataclasses.asdict(self)) + "\n"
78
+
79
+
80
+ class DataProcessor:
81
+ """Base class for data converters for sequence classification data sets."""
82
+
83
+ def get_example_from_tensor_dict(self, tensor_dict):
84
+ """
85
+ Gets an example from a dict with tensorflow tensors.
86
+
87
+ Args:
88
+ tensor_dict: Keys and values should match the corresponding Glue
89
+ tensorflow_dataset examples.
90
+ """
91
+ raise NotImplementedError()
92
+
93
+ def get_train_examples(self, data_dir):
94
+ """Gets a collection of [`InputExample`] for the train set."""
95
+ raise NotImplementedError()
96
+
97
+ def get_dev_examples(self, data_dir):
98
+ """Gets a collection of [`InputExample`] for the dev set."""
99
+ raise NotImplementedError()
100
+
101
+ def get_test_examples(self, data_dir):
102
+ """Gets a collection of [`InputExample`] for the test set."""
103
+ raise NotImplementedError()
104
+
105
+ def get_labels(self):
106
+ """Gets the list of labels for this data set."""
107
+ raise NotImplementedError()
108
+
109
+ def tfds_map(self, example):
110
+ """
111
+ Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are. This method converts
112
+ examples to the correct format.
113
+ """
114
+ if len(self.get_labels()) > 1:
115
+ example.label = self.get_labels()[int(example.label)]
116
+ return example
117
+
118
+ @classmethod
119
+ def _read_tsv(cls, input_file, quotechar=None):
120
+ """Reads a tab separated value file."""
121
+ with open(input_file, "r", encoding="utf-8-sig") as f:
122
+ return list(csv.reader(f, delimiter="\t", quotechar=quotechar))
123
+
124
+
125
+ class SingleSentenceClassificationProcessor(DataProcessor):
126
+ """Generic processor for a single sentence classification data set."""
127
+
128
+ def __init__(self, labels=None, examples=None, mode="classification", verbose=False):
129
+ self.labels = [] if labels is None else labels
130
+ self.examples = [] if examples is None else examples
131
+ self.mode = mode
132
+ self.verbose = verbose
133
+
134
+ def __len__(self):
135
+ return len(self.examples)
136
+
137
+ def __getitem__(self, idx):
138
+ if isinstance(idx, slice):
139
+ return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx])
140
+ return self.examples[idx]
141
+
142
+ @classmethod
143
+ def create_from_csv(
144
+ cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs
145
+ ):
146
+ processor = cls(**kwargs)
147
+ processor.add_examples_from_csv(
148
+ file_name,
149
+ split_name=split_name,
150
+ column_label=column_label,
151
+ column_text=column_text,
152
+ column_id=column_id,
153
+ skip_first_row=skip_first_row,
154
+ overwrite_labels=True,
155
+ overwrite_examples=True,
156
+ )
157
+ return processor
158
+
159
+ @classmethod
160
+ def create_from_examples(cls, texts_or_text_and_labels, labels=None, **kwargs):
161
+ processor = cls(**kwargs)
162
+ processor.add_examples(texts_or_text_and_labels, labels=labels)
163
+ return processor
164
+
165
+ def add_examples_from_csv(
166
+ self,
167
+ file_name,
168
+ split_name="",
169
+ column_label=0,
170
+ column_text=1,
171
+ column_id=None,
172
+ skip_first_row=False,
173
+ overwrite_labels=False,
174
+ overwrite_examples=False,
175
+ ):
176
+ lines = self._read_tsv(file_name)
177
+ if skip_first_row:
178
+ lines = lines[1:]
179
+ texts = []
180
+ labels = []
181
+ ids = []
182
+ for i, line in enumerate(lines):
183
+ texts.append(line[column_text])
184
+ labels.append(line[column_label])
185
+ if column_id is not None:
186
+ ids.append(line[column_id])
187
+ else:
188
+ guid = f"{split_name}-{i}" if split_name else str(i)
189
+ ids.append(guid)
190
+
191
+ return self.add_examples(
192
+ texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples
193
+ )
194
+
195
+ def add_examples(
196
+ self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
197
+ ):
198
+ if labels is not None and len(texts_or_text_and_labels) != len(labels):
199
+ raise ValueError(
200
+ f"Text and labels have mismatched lengths {len(texts_or_text_and_labels)} and {len(labels)}"
201
+ )
202
+ if ids is not None and len(texts_or_text_and_labels) != len(ids):
203
+ raise ValueError(f"Text and ids have mismatched lengths {len(texts_or_text_and_labels)} and {len(ids)}")
204
+ if ids is None:
205
+ ids = [None] * len(texts_or_text_and_labels)
206
+ if labels is None:
207
+ labels = [None] * len(texts_or_text_and_labels)
208
+ examples = []
209
+ added_labels = set()
210
+ for text_or_text_and_label, label, guid in zip(texts_or_text_and_labels, labels, ids):
211
+ if isinstance(text_or_text_and_label, (tuple, list)) and label is None:
212
+ text, label = text_or_text_and_label
213
+ else:
214
+ text = text_or_text_and_label
215
+ added_labels.add(label)
216
+ examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label))
217
+
218
+ # Update examples
219
+ if overwrite_examples:
220
+ self.examples = examples
221
+ else:
222
+ self.examples.extend(examples)
223
+
224
+ # Update labels
225
+ if overwrite_labels:
226
+ self.labels = list(added_labels)
227
+ else:
228
+ self.labels = list(set(self.labels).union(added_labels))
229
+
230
+ return self.examples
231
+
232
+ def get_features(
233
+ self,
234
+ tokenizer,
235
+ max_length=None,
236
+ pad_on_left=False,
237
+ pad_token=0,
238
+ mask_padding_with_zero=True,
239
+ return_tensors=None,
240
+ ):
241
+ """
242
+ Convert examples in a list of `InputFeatures`
243
+
244
+ Args:
245
+ tokenizer: Instance of a tokenizer that will tokenize the examples
246
+ max_length: Maximum example length
247
+ pad_on_left: If set to `True`, the examples will be padded on the left rather than on the right (default)
248
+ pad_token: Padding token
249
+ mask_padding_with_zero: If set to `True`, the attention mask will be filled by `1` for actual values
250
+ and by `0` for padded values. If set to `False`, inverts it (`1` for padded values, `0` for actual
251
+ values)
252
+
253
+ Returns:
254
+ If the `examples` input is a `tf.data.Dataset`, will return a `tf.data.Dataset` containing the
255
+ task-specific features. If the input is a list of `InputExamples`, will return a list of task-specific
256
+ `InputFeatures` which can be fed to the model.
257
+
258
+ """
259
+ if max_length is None:
260
+ max_length = tokenizer.max_len
261
+
262
+ label_map = {label: i for i, label in enumerate(self.labels)}
263
+
264
+ all_input_ids = []
265
+ for ex_index, example in enumerate(self.examples):
266
+ if ex_index % 10000 == 0:
267
+ logger.info(f"Tokenizing example {ex_index}")
268
+
269
+ input_ids = tokenizer.encode(
270
+ example.text_a,
271
+ add_special_tokens=True,
272
+ max_length=min(max_length, tokenizer.max_len),
273
+ )
274
+ all_input_ids.append(input_ids)
275
+
276
+ batch_length = max(len(input_ids) for input_ids in all_input_ids)
277
+
278
+ features = []
279
+ for ex_index, (input_ids, example) in enumerate(zip(all_input_ids, self.examples)):
280
+ if ex_index % 10000 == 0:
281
+ logger.info(f"Writing example {ex_index}/{len(self.examples)}")
282
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
283
+ # tokens are attended to.
284
+ attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
285
+
286
+ # Zero-pad up to the sequence length.
287
+ padding_length = batch_length - len(input_ids)
288
+ if pad_on_left:
289
+ input_ids = ([pad_token] * padding_length) + input_ids
290
+ attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
291
+ else:
292
+ input_ids = input_ids + ([pad_token] * padding_length)
293
+ attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
294
+
295
+ if len(input_ids) != batch_length:
296
+ raise ValueError(f"Error with input length {len(input_ids)} vs {batch_length}")
297
+ if len(attention_mask) != batch_length:
298
+ raise ValueError(f"Error with input length {len(attention_mask)} vs {batch_length}")
299
+
300
+ if self.mode == "classification":
301
+ label = label_map[example.label]
302
+ elif self.mode == "regression":
303
+ label = float(example.label)
304
+ else:
305
+ raise ValueError(self.mode)
306
+
307
+ if ex_index < 5 and self.verbose:
308
+ logger.info("*** Example ***")
309
+ logger.info(f"guid: {example.guid}")
310
+ logger.info(f"input_ids: {' '.join([str(x) for x in input_ids])}")
311
+ logger.info(f"attention_mask: {' '.join([str(x) for x in attention_mask])}")
312
+ logger.info(f"label: {example.label} (id = {label})")
313
+
314
+ features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, label=label))
315
+
316
+ if return_tensors is None:
317
+ return features
318
+ elif return_tensors == "tf":
319
+ if not is_tf_available():
320
+ raise RuntimeError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
321
+ import tensorflow as tf
322
+
323
+ def gen():
324
+ for ex in features:
325
+ yield ({"input_ids": ex.input_ids, "attention_mask": ex.attention_mask}, ex.label)
326
+
327
+ dataset = tf.data.Dataset.from_generator(
328
+ gen,
329
+ ({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64),
330
+ ({"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])}, tf.TensorShape([])),
331
+ )
332
+ return dataset
333
+ elif return_tensors == "pt":
334
+ if not is_torch_available():
335
+ raise RuntimeError("return_tensors set to 'pt' but PyTorch can't be imported")
336
+ import torch
337
+ from torch.utils.data import TensorDataset
338
+
339
+ all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
340
+ all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
341
+ if self.mode == "classification":
342
+ all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
343
+ elif self.mode == "regression":
344
+ all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
345
+
346
+ dataset = TensorDataset(all_input_ids, all_attention_mask, all_labels)
347
+ return dataset
348
+ else:
349
+ raise ValueError("return_tensors should be one of 'tf' or 'pt'")
.venv/Lib/site-packages/transformers/data/processors/xnli.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """XNLI utils (dataset loading and evaluation)"""
17
+
18
+ import os
19
+
20
+ from ...utils import logging
21
+ from .utils import DataProcessor, InputExample
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class XnliProcessor(DataProcessor):
28
+ """
29
+ Processor for the XNLI dataset. Adapted from
30
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207
31
+ """
32
+
33
+ def __init__(self, language, train_language=None):
34
+ self.language = language
35
+ self.train_language = train_language
36
+
37
+ def get_train_examples(self, data_dir):
38
+ """See base class."""
39
+ lg = self.language if self.train_language is None else self.train_language
40
+ lines = self._read_tsv(os.path.join(data_dir, f"XNLI-MT-1.0/multinli/multinli.train.{lg}.tsv"))
41
+ examples = []
42
+ for i, line in enumerate(lines):
43
+ if i == 0:
44
+ continue
45
+ guid = f"train-{i}"
46
+ text_a = line[0]
47
+ text_b = line[1]
48
+ label = "contradiction" if line[2] == "contradictory" else line[2]
49
+ if not isinstance(text_a, str):
50
+ raise TypeError(f"Training input {text_a} is not a string")
51
+ if not isinstance(text_b, str):
52
+ raise TypeError(f"Training input {text_b} is not a string")
53
+ if not isinstance(label, str):
54
+ raise TypeError(f"Training label {label} is not a string")
55
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
56
+ return examples
57
+
58
+ def get_test_examples(self, data_dir):
59
+ """See base class."""
60
+ lines = self._read_tsv(os.path.join(data_dir, "XNLI-1.0/xnli.test.tsv"))
61
+ examples = []
62
+ for i, line in enumerate(lines):
63
+ if i == 0:
64
+ continue
65
+ language = line[0]
66
+ if language != self.language:
67
+ continue
68
+ guid = f"test-{i}"
69
+ text_a = line[6]
70
+ text_b = line[7]
71
+ label = line[1]
72
+ if not isinstance(text_a, str):
73
+ raise TypeError(f"Training input {text_a} is not a string")
74
+ if not isinstance(text_b, str):
75
+ raise TypeError(f"Training input {text_b} is not a string")
76
+ if not isinstance(label, str):
77
+ raise TypeError(f"Training label {label} is not a string")
78
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
79
+ return examples
80
+
81
+ def get_labels(self):
82
+ """See base class."""
83
+ return ["contradiction", "entailment", "neutral"]
84
+
85
+
86
+ xnli_processors = {
87
+ "xnli": XnliProcessor,
88
+ }
89
+
90
+ xnli_output_modes = {
91
+ "xnli": "classification",
92
+ }
93
+
94
+ xnli_tasks_num_labels = {
95
+ "xnli": 3,
96
+ }
.venv/Lib/site-packages/transformers/generation/__init__.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available
18
+
19
+
20
+ _import_structure = {
21
+ "configuration_utils": [
22
+ "BaseWatermarkingConfig",
23
+ "CompileConfig",
24
+ "GenerationConfig",
25
+ "GenerationMode",
26
+ "SynthIDTextWatermarkingConfig",
27
+ "WatermarkingConfig",
28
+ ],
29
+ "streamers": ["TextIteratorStreamer", "TextStreamer"],
30
+ }
31
+
32
+ try:
33
+ if not is_torch_available():
34
+ raise OptionalDependencyNotAvailable()
35
+ except OptionalDependencyNotAvailable:
36
+ pass
37
+ else:
38
+ _import_structure["beam_constraints"] = [
39
+ "Constraint",
40
+ "ConstraintListState",
41
+ "DisjunctiveConstraint",
42
+ "PhrasalConstraint",
43
+ ]
44
+ _import_structure["beam_search"] = [
45
+ "BeamHypotheses",
46
+ "BeamScorer",
47
+ "BeamSearchScorer",
48
+ "ConstrainedBeamSearchScorer",
49
+ ]
50
+ _import_structure["candidate_generator"] = [
51
+ "AssistedCandidateGenerator",
52
+ "CandidateGenerator",
53
+ "EarlyExitCandidateGenerator",
54
+ "PromptLookupCandidateGenerator",
55
+ ]
56
+ _import_structure["logits_process"] = [
57
+ "AlternatingCodebooksLogitsProcessor",
58
+ "ClassifierFreeGuidanceLogitsProcessor",
59
+ "EncoderNoRepeatNGramLogitsProcessor",
60
+ "EncoderRepetitionPenaltyLogitsProcessor",
61
+ "EpsilonLogitsWarper",
62
+ "EtaLogitsWarper",
63
+ "ExponentialDecayLengthPenalty",
64
+ "ForcedBOSTokenLogitsProcessor",
65
+ "ForcedEOSTokenLogitsProcessor",
66
+ "HammingDiversityLogitsProcessor",
67
+ "InfNanRemoveLogitsProcessor",
68
+ "LogitNormalization",
69
+ "LogitsProcessor",
70
+ "LogitsProcessorList",
71
+ "LogitsWarper",
72
+ "MinLengthLogitsProcessor",
73
+ "MinNewTokensLengthLogitsProcessor",
74
+ "MinPLogitsWarper",
75
+ "NoBadWordsLogitsProcessor",
76
+ "NoRepeatNGramLogitsProcessor",
77
+ "PrefixConstrainedLogitsProcessor",
78
+ "RepetitionPenaltyLogitsProcessor",
79
+ "SequenceBiasLogitsProcessor",
80
+ "SuppressTokensLogitsProcessor",
81
+ "SuppressTokensAtBeginLogitsProcessor",
82
+ "SynthIDTextWatermarkLogitsProcessor",
83
+ "TemperatureLogitsWarper",
84
+ "TopKLogitsWarper",
85
+ "TopPLogitsWarper",
86
+ "TypicalLogitsWarper",
87
+ "UnbatchedClassifierFreeGuidanceLogitsProcessor",
88
+ "WhisperTimeStampLogitsProcessor",
89
+ "WatermarkLogitsProcessor",
90
+ ]
91
+ _import_structure["stopping_criteria"] = [
92
+ "MaxNewTokensCriteria",
93
+ "MaxLengthCriteria",
94
+ "MaxTimeCriteria",
95
+ "ConfidenceCriteria",
96
+ "EosTokenCriteria",
97
+ "StoppingCriteria",
98
+ "StoppingCriteriaList",
99
+ "validate_stopping_criteria",
100
+ "StopStringCriteria",
101
+ ]
102
+ _import_structure["utils"] = [
103
+ "GenerationMixin",
104
+ "GreedySearchEncoderDecoderOutput",
105
+ "GreedySearchDecoderOnlyOutput",
106
+ "SampleEncoderDecoderOutput",
107
+ "SampleDecoderOnlyOutput",
108
+ "BeamSearchEncoderDecoderOutput",
109
+ "BeamSearchDecoderOnlyOutput",
110
+ "BeamSampleEncoderDecoderOutput",
111
+ "BeamSampleDecoderOnlyOutput",
112
+ "ContrastiveSearchEncoderDecoderOutput",
113
+ "ContrastiveSearchDecoderOnlyOutput",
114
+ "GenerateBeamDecoderOnlyOutput",
115
+ "GenerateBeamEncoderDecoderOutput",
116
+ "GenerateDecoderOnlyOutput",
117
+ "GenerateEncoderDecoderOutput",
118
+ ]
119
+ _import_structure["watermarking"] = [
120
+ "WatermarkDetector",
121
+ "WatermarkDetectorOutput",
122
+ "BayesianDetectorModel",
123
+ "BayesianDetectorConfig",
124
+ "SynthIDTextWatermarkDetector",
125
+ ]
126
+
127
+ try:
128
+ if not is_tf_available():
129
+ raise OptionalDependencyNotAvailable()
130
+ except OptionalDependencyNotAvailable:
131
+ pass
132
+ else:
133
+ _import_structure["tf_logits_process"] = [
134
+ "TFForcedBOSTokenLogitsProcessor",
135
+ "TFForcedEOSTokenLogitsProcessor",
136
+ "TFForceTokensLogitsProcessor",
137
+ "TFLogitsProcessor",
138
+ "TFLogitsProcessorList",
139
+ "TFLogitsWarper",
140
+ "TFMinLengthLogitsProcessor",
141
+ "TFNoBadWordsLogitsProcessor",
142
+ "TFNoRepeatNGramLogitsProcessor",
143
+ "TFRepetitionPenaltyLogitsProcessor",
144
+ "TFSuppressTokensAtBeginLogitsProcessor",
145
+ "TFSuppressTokensLogitsProcessor",
146
+ "TFTemperatureLogitsWarper",
147
+ "TFTopKLogitsWarper",
148
+ "TFTopPLogitsWarper",
149
+ ]
150
+ _import_structure["tf_utils"] = [
151
+ "TFGenerationMixin",
152
+ "TFGreedySearchDecoderOnlyOutput",
153
+ "TFGreedySearchEncoderDecoderOutput",
154
+ "TFSampleEncoderDecoderOutput",
155
+ "TFSampleDecoderOnlyOutput",
156
+ "TFBeamSearchEncoderDecoderOutput",
157
+ "TFBeamSearchDecoderOnlyOutput",
158
+ "TFBeamSampleEncoderDecoderOutput",
159
+ "TFBeamSampleDecoderOnlyOutput",
160
+ "TFContrastiveSearchEncoderDecoderOutput",
161
+ "TFContrastiveSearchDecoderOnlyOutput",
162
+ ]
163
+
164
+ try:
165
+ if not is_flax_available():
166
+ raise OptionalDependencyNotAvailable()
167
+ except OptionalDependencyNotAvailable:
168
+ pass
169
+ else:
170
+ _import_structure["flax_logits_process"] = [
171
+ "FlaxForcedBOSTokenLogitsProcessor",
172
+ "FlaxForcedEOSTokenLogitsProcessor",
173
+ "FlaxForceTokensLogitsProcessor",
174
+ "FlaxLogitsProcessor",
175
+ "FlaxLogitsProcessorList",
176
+ "FlaxLogitsWarper",
177
+ "FlaxMinLengthLogitsProcessor",
178
+ "FlaxSuppressTokensAtBeginLogitsProcessor",
179
+ "FlaxSuppressTokensLogitsProcessor",
180
+ "FlaxTemperatureLogitsWarper",
181
+ "FlaxTopKLogitsWarper",
182
+ "FlaxTopPLogitsWarper",
183
+ "FlaxWhisperTimeStampLogitsProcessor",
184
+ "FlaxNoRepeatNGramLogitsProcessor",
185
+ ]
186
+ _import_structure["flax_utils"] = [
187
+ "FlaxGenerationMixin",
188
+ "FlaxGreedySearchOutput",
189
+ "FlaxSampleOutput",
190
+ "FlaxBeamSearchOutput",
191
+ ]
192
+
193
+ if TYPE_CHECKING:
194
+ from .configuration_utils import (
195
+ BaseWatermarkingConfig,
196
+ CompileConfig,
197
+ GenerationConfig,
198
+ GenerationMode,
199
+ SynthIDTextWatermarkingConfig,
200
+ WatermarkingConfig,
201
+ )
202
+ from .streamers import TextIteratorStreamer, TextStreamer
203
+
204
+ try:
205
+ if not is_torch_available():
206
+ raise OptionalDependencyNotAvailable()
207
+ except OptionalDependencyNotAvailable:
208
+ pass
209
+ else:
210
+ from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
211
+ from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
212
+ from .candidate_generator import (
213
+ AssistedCandidateGenerator,
214
+ CandidateGenerator,
215
+ EarlyExitCandidateGenerator,
216
+ PromptLookupCandidateGenerator,
217
+ )
218
+ from .logits_process import (
219
+ AlternatingCodebooksLogitsProcessor,
220
+ ClassifierFreeGuidanceLogitsProcessor,
221
+ EncoderNoRepeatNGramLogitsProcessor,
222
+ EncoderRepetitionPenaltyLogitsProcessor,
223
+ EpsilonLogitsWarper,
224
+ EtaLogitsWarper,
225
+ ExponentialDecayLengthPenalty,
226
+ ForcedBOSTokenLogitsProcessor,
227
+ ForcedEOSTokenLogitsProcessor,
228
+ HammingDiversityLogitsProcessor,
229
+ InfNanRemoveLogitsProcessor,
230
+ LogitNormalization,
231
+ LogitsProcessor,
232
+ LogitsProcessorList,
233
+ LogitsWarper,
234
+ MinLengthLogitsProcessor,
235
+ MinNewTokensLengthLogitsProcessor,
236
+ MinPLogitsWarper,
237
+ NoBadWordsLogitsProcessor,
238
+ NoRepeatNGramLogitsProcessor,
239
+ PrefixConstrainedLogitsProcessor,
240
+ RepetitionPenaltyLogitsProcessor,
241
+ SequenceBiasLogitsProcessor,
242
+ SuppressTokensAtBeginLogitsProcessor,
243
+ SuppressTokensLogitsProcessor,
244
+ SynthIDTextWatermarkLogitsProcessor,
245
+ TemperatureLogitsWarper,
246
+ TopKLogitsWarper,
247
+ TopPLogitsWarper,
248
+ TypicalLogitsWarper,
249
+ UnbatchedClassifierFreeGuidanceLogitsProcessor,
250
+ WatermarkLogitsProcessor,
251
+ WhisperTimeStampLogitsProcessor,
252
+ )
253
+ from .stopping_criteria import (
254
+ ConfidenceCriteria,
255
+ EosTokenCriteria,
256
+ MaxLengthCriteria,
257
+ MaxNewTokensCriteria,
258
+ MaxTimeCriteria,
259
+ StoppingCriteria,
260
+ StoppingCriteriaList,
261
+ StopStringCriteria,
262
+ validate_stopping_criteria,
263
+ )
264
+ from .utils import (
265
+ BeamSampleDecoderOnlyOutput,
266
+ BeamSampleEncoderDecoderOutput,
267
+ BeamSearchDecoderOnlyOutput,
268
+ BeamSearchEncoderDecoderOutput,
269
+ ContrastiveSearchDecoderOnlyOutput,
270
+ ContrastiveSearchEncoderDecoderOutput,
271
+ GenerateBeamDecoderOnlyOutput,
272
+ GenerateBeamEncoderDecoderOutput,
273
+ GenerateDecoderOnlyOutput,
274
+ GenerateEncoderDecoderOutput,
275
+ GenerationMixin,
276
+ GreedySearchDecoderOnlyOutput,
277
+ GreedySearchEncoderDecoderOutput,
278
+ SampleDecoderOnlyOutput,
279
+ SampleEncoderDecoderOutput,
280
+ )
281
+ from .watermarking import (
282
+ BayesianDetectorConfig,
283
+ BayesianDetectorModel,
284
+ SynthIDTextWatermarkDetector,
285
+ WatermarkDetector,
286
+ WatermarkDetectorOutput,
287
+ )
288
+
289
+ try:
290
+ if not is_tf_available():
291
+ raise OptionalDependencyNotAvailable()
292
+ except OptionalDependencyNotAvailable:
293
+ pass
294
+ else:
295
+ from .tf_logits_process import (
296
+ TFForcedBOSTokenLogitsProcessor,
297
+ TFForcedEOSTokenLogitsProcessor,
298
+ TFForceTokensLogitsProcessor,
299
+ TFLogitsProcessor,
300
+ TFLogitsProcessorList,
301
+ TFLogitsWarper,
302
+ TFMinLengthLogitsProcessor,
303
+ TFNoBadWordsLogitsProcessor,
304
+ TFNoRepeatNGramLogitsProcessor,
305
+ TFRepetitionPenaltyLogitsProcessor,
306
+ TFSuppressTokensAtBeginLogitsProcessor,
307
+ TFSuppressTokensLogitsProcessor,
308
+ TFTemperatureLogitsWarper,
309
+ TFTopKLogitsWarper,
310
+ TFTopPLogitsWarper,
311
+ )
312
+ from .tf_utils import (
313
+ TFBeamSampleDecoderOnlyOutput,
314
+ TFBeamSampleEncoderDecoderOutput,
315
+ TFBeamSearchDecoderOnlyOutput,
316
+ TFBeamSearchEncoderDecoderOutput,
317
+ TFContrastiveSearchDecoderOnlyOutput,
318
+ TFContrastiveSearchEncoderDecoderOutput,
319
+ TFGenerationMixin,
320
+ TFGreedySearchDecoderOnlyOutput,
321
+ TFGreedySearchEncoderDecoderOutput,
322
+ TFSampleDecoderOnlyOutput,
323
+ TFSampleEncoderDecoderOutput,
324
+ )
325
+
326
+ try:
327
+ if not is_flax_available():
328
+ raise OptionalDependencyNotAvailable()
329
+ except OptionalDependencyNotAvailable:
330
+ pass
331
+ else:
332
+ from .flax_logits_process import (
333
+ FlaxForcedBOSTokenLogitsProcessor,
334
+ FlaxForcedEOSTokenLogitsProcessor,
335
+ FlaxForceTokensLogitsProcessor,
336
+ FlaxLogitsProcessor,
337
+ FlaxLogitsProcessorList,
338
+ FlaxLogitsWarper,
339
+ FlaxMinLengthLogitsProcessor,
340
+ FlaxNoRepeatNGramLogitsProcessor,
341
+ FlaxSuppressTokensAtBeginLogitsProcessor,
342
+ FlaxSuppressTokensLogitsProcessor,
343
+ FlaxTemperatureLogitsWarper,
344
+ FlaxTopKLogitsWarper,
345
+ FlaxTopPLogitsWarper,
346
+ FlaxWhisperTimeStampLogitsProcessor,
347
+ )
348
+ from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput
349
+ else:
350
+ import sys
351
+
352
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.venv/Lib/site-packages/transformers/generation/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (6.45 kB). View file
 
.venv/Lib/site-packages/transformers/generation/__pycache__/beam_constraints.cpython-39.pyc ADDED
Binary file (16.3 kB). View file
 
.venv/Lib/site-packages/transformers/generation/__pycache__/beam_search.cpython-39.pyc ADDED
Binary file (28.8 kB). View file
 
.venv/Lib/site-packages/transformers/generation/__pycache__/candidate_generator.cpython-39.pyc ADDED
Binary file (25.4 kB). View file
 
.venv/Lib/site-packages/transformers/generation/__pycache__/configuration_utils.cpython-39.pyc ADDED
Binary file (65.9 kB). View file
 
.venv/Lib/site-packages/transformers/generation/__pycache__/logits_process.cpython-39.pyc ADDED
Binary file (122 kB). View file
 
.venv/Lib/site-packages/transformers/generation/__pycache__/stopping_criteria.cpython-39.pyc ADDED
Binary file (23.7 kB). View file
 
.venv/Lib/site-packages/transformers/generation/__pycache__/utils.cpython-39.pyc ADDED
Binary file (130 kB). View file