|
""" |
|
T5 model trained to generate text from text |
|
--------------------------------------------------------------------- |
|
|
|
""" |
|
import json |
|
import os |
|
|
|
import torch |
|
import transformers |
|
|
|
from textattack.model_args import TEXTATTACK_MODELS |
|
from textattack.models.tokenizers import T5Tokenizer |
|
|
|
|
|
class T5ForTextToText(torch.nn.Module): |
|
"""A T5 model trained to generate text from text. |
|
|
|
For more information, please see the T5 paper, "Exploring the Limits of |
|
Transfer Learning with a Unified Text-to-Text Transformer". |
|
Appendix D contains information about the various tasks supported |
|
by T5. |
|
|
|
For usage information, see HuggingFace Transformers documentation section |
|
on text-to-text with T5: |
|
https://huggingface.co/transformers/usage.html. |
|
|
|
Args: |
|
mode (string): Name of the T5 model to use. |
|
output_max_length (int): The max length of the sequence to be generated. |
|
Between 1 and infinity. |
|
input_max_length (int): Max length of the input sequence. |
|
num_beams (int): Number of beams for beam search. Must be between 1 and |
|
infinity. 1 means no beam search. |
|
early_stopping (bool): if set to `True` beam search is stopped when at |
|
least `num_beams` sentences finished per batch. Defaults to `True`. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
mode="english_to_german", |
|
output_max_length=20, |
|
input_max_length=64, |
|
num_beams=1, |
|
early_stopping=True, |
|
): |
|
super().__init__() |
|
self.model = transformers.T5ForConditionalGeneration.from_pretrained("t5-base") |
|
self.model.eval() |
|
self.tokenizer = T5Tokenizer(mode, max_length=output_max_length) |
|
self.mode = mode |
|
self.output_max_length = output_max_length |
|
self.input_max_length = input_max_length |
|
self.num_beams = num_beams |
|
self.early_stopping = early_stopping |
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
output_ids_list = self.model.generate( |
|
*args, |
|
**kwargs, |
|
max_length=self.output_max_length, |
|
num_beams=self.num_beams, |
|
early_stopping=self.early_stopping, |
|
) |
|
|
|
return [self.tokenizer.decode(ids) for ids in output_ids_list] |
|
|
|
def save_pretrained(self, output_dir): |
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
config = { |
|
"mode": self.mode, |
|
"output_max_length": self.output_max_length, |
|
"input_max_length": self.input_max_length, |
|
"num_beams": self.num_beams, |
|
"early_stoppping": self.early_stopping, |
|
} |
|
|
|
with open(os.path.join(output_dir, "t5-wrapper-config.json"), "w") as f: |
|
json.dump(config, f) |
|
self.model.save_pretrained(output_dir) |
|
|
|
@classmethod |
|
def from_pretrained(cls, name_or_path): |
|
"""Load trained LSTM model by name or from path. |
|
|
|
Args: |
|
name_or_path (str): Name of the model (e.g. "t5-en-de") or model saved via `save_pretrained`. |
|
""" |
|
if name_or_path in TEXTATTACK_MODELS: |
|
t5 = cls(TEXTATTACK_MODELS[name_or_path]) |
|
return t5 |
|
else: |
|
config_path = os.path.join(name_or_path, "t5-wrapper-config.json") |
|
with open(config_path, "r") as f: |
|
config = json.load(f) |
|
t5 = cls.__new__(cls) |
|
for key in config: |
|
setattr(t5, key, config[key]) |
|
t5.model = transformers.T5ForConditionalGeneration.from_pretrained( |
|
name_or_path |
|
) |
|
t5.tokenizer = T5Tokenizer(t5.mode, max_length=t5.output_max_length) |
|
return t5 |
|
|
|
def get_input_embeddings(self): |
|
return self.model.get_input_embeddings() |
|
|