ReactSeq / onmt /transforms /insert_mask_before_placeholder.py
Oopstom's picture
Upload 313 files
c668e80 verified
from onmt.constants import DefaultTokens
from onmt.transforms import register_transform
from onmt.utils.logging import logger
from .transform import Transform
@register_transform(name="insert_mask_before_placeholder")
class InsertMaskBeforePlaceholdersTransform(Transform):
"""Add the `DefaultTokens.MASK_BEFORE` placeholder between
the prompt and the response in an LM finetuning exemple.
This is necessary to enable the 'zero-out prompt loss' mechanism.
"""
def __init__(self, opts):
super().__init__(opts)
@classmethod
def add_options(cls, parser):
"""Options for mask_before placeholders insertion"""
group = parser.add_argument_group(
"Transform/InsertMaskBeforePlaceholdersTransform"
)
group.add(
"--response_pattern",
"-response_pattern",
type=str,
help="Response patten to locate the end of the prompt",
default="Response : ⦅newline⦆",
)
def _parse_opts(self):
self.response_pattern = self.opts.response_pattern
def apply(self, example, is_train=False, stats=None, **kwargs):
_src = " ".join(example["src"])
if len(_src.split(self.response_pattern)) != 2:
logger.info("The mask_before could not be inserted")
return example
prompt, response = _src.split(self.response_pattern)
response = DefaultTokens.MASK_BEFORE.join([self.response_pattern, response])
_src = "".join([prompt, response])
example["src"] = _src.split(" ")
example["tgt"] = _src.split(" ")
return example