wiki-vae / train.py
Fraser's picture
start sagemaker code
caac576
raw history blame
No virus
6.11 kB
import logging
import sys
import argparse
import os
import inspect
from typing import Optional, Any
from dataclasses import dataclass, field, make_dataclass
from transformers import Trainer, TrainingArguments, AutoTokenizer, HfArgumentParser
from datasets import load_from_disk
from funnel_vae.src.funnel_vae import FunnelVae
from funnel_vae.src.config import FunnelVaeConfig
@dataclass
class BaseArgs:
# hyperparameters sent by the client are passed as command-line arguments to the script.
model_name: str
epochs: int = 3
per_device_train_batch_size: int = 32
per_device_eval_batch_size: int = 64
warmup_steps: int = 500
learning_rate: str = 5e-5
output_data_dir: str = os.environ["SM_OUTPUT_DATA_DIR"]
model_dir: str = os.environ["SM_MODEL_DIR"]
n_gpus: str = os.environ["SM_NUM_GPUS"]
training_dir: str = os.environ["SM_CHANNEL_TRAIN"]
test_dir: str = os.environ["SM_CHANNEL_TEST"]
# ModelArguments
fields = [
(
'tokenizer_name', Optional[str], field(
default='t5-base', metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
),
] + [
(
name, type(info.default) if info.default is not None else Any, field(
default=info.default, metadata={"help": f"Has default {info.default}, see FunnelVaeConfig docstring for more info."}
)
)
# get relevent model arguments with defaults
for name, info in inspect.signature(FunnelVaeConfig.__init__).parameters.items() if name not in ['self', 'kwargs', 'use_extra_logs', 'cache_dir']
]
# ensure starting with non-default args
start_f = list(filter(lambda field: field[2].default is None, fields))
end_f = list(filter(lambda field: field[2].default is not None, fields))
ModelArguments = make_dataclass('ModelArguments', start_f + end_f)
@dataclass
class DataArguments:
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
text_column: Optional[str] = field(default=None, metadata={"help": "Use this dataset column as 'text'."})
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
overwrite_cache: bool = field(default=False, metadata={"help": "Overwrite the cached training and evaluation sets"})
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
mlm_probability: float = field(
default=0.0, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
)
validation_name: str = field(
default="validation",
metadata={"help": "Name of the set to run evaluation on."},
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if __name__ == "__main__":
parser = HfArgumentParser((BaseArgs, ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
parser = argparse.ArgumentParser()
args, _ = parser.parse_known_args()
# Set up logging
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.getLevelName("INFO"),
handlers=[logging.StreamHandler(sys.stdout)],
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
# load datasets
train_dataset = load_from_disk(args.training_dir)
test_dataset = load_from_disk(args.test_dir)
logger.info(f" loaded train_dataset length is: {len(train_dataset)}")
logger.info(f" loaded test_dataset length is: {len(test_dataset)}")
# init model
config = FunnelVaeConfig.from_pretrained(**model_args.__dict__)
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, use_fast_tokenizer=True)
vocab_size = len(tokenizer)
config.funnel.vocab_size = vocab_size
config.t5.vocab_size = vocab_size
config.vocab_size = vocab_size
model = FunnelVae(config)
model = FunnelVae.from_pretrained()
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# define training args
training_args = TrainingArguments(
output_dir=args.model_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.eval_batch_size,
warmup_steps=args.warmup_steps,
evaluation_strategy="epoch",
logging_dir=f"{args.output_data_dir}/logs",
learning_rate=float(args.learning_rate),
)
# create Trainer instance
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
tokenizer=tokenizer,
)
# train model
trainer.train()
# evaluate model
eval_result = trainer.evaluate(eval_dataset=test_dataset)
# writes eval result to file which can be accessed later in s3 ouput
with open(os.path.join(args.output_data_dir, "eval_results.txt"), "w") as writer:
print(f"***** Eval results *****")
for key, value in sorted(eval_result.items()):
writer.write(f"{key} = {value}\n")
# Saves the model to s3
trainer.save_model(args.model_dir)