import logging import sys import argparse import os import inspect from typing import Optional, Any from dataclasses import dataclass, field, make_dataclass from transformers import Trainer, TrainingArguments, AutoTokenizer, HfArgumentParser from datasets import load_from_disk from funnel_vae.src.funnel_vae import FunnelVae from funnel_vae.src.config import FunnelVaeConfig @dataclass class BaseArgs: # hyperparameters sent by the client are passed as command-line arguments to the script. model_name: str epochs: int = 3 per_device_train_batch_size: int = 32 per_device_eval_batch_size: int = 64 warmup_steps: int = 500 learning_rate: str = 5e-5 output_data_dir: str = os.environ["SM_OUTPUT_DATA_DIR"] model_dir: str = os.environ["SM_MODEL_DIR"] n_gpus: str = os.environ["SM_NUM_GPUS"] training_dir: str = os.environ["SM_CHANNEL_TRAIN"] test_dir: str = os.environ["SM_CHANNEL_TEST"] # ModelArguments fields = [ ( 'tokenizer_name', Optional[str], field( default='t5-base', metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} ) ), ] + [ ( name, type(info.default) if info.default is not None else Any, field( default=info.default, metadata={"help": f"Has default {info.default}, see FunnelVaeConfig docstring for more info."} ) ) # get relevent model arguments with defaults for name, info in inspect.signature(FunnelVaeConfig.__init__).parameters.items() if name not in ['self', 'kwargs', 'use_extra_logs', 'cache_dir'] ] # ensure starting with non-default args start_f = list(filter(lambda field: field[2].default is None, fields)) end_f = list(filter(lambda field: field[2].default is not None, fields)) ModelArguments = make_dataclass('ModelArguments', start_f + end_f) @dataclass class DataArguments: dataset_name: Optional[str] = field( default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} ) text_column: Optional[str] = field(default=None, metadata={"help": "Use this dataset column as 'text'."}) train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) validation_file: Optional[str] = field( default=None, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, ) overwrite_cache: bool = field(default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}) preprocessing_num_workers: Optional[int] = field( default=None, metadata={"help": "The number of processes to use for the preprocessing."}, ) mlm_probability: float = field( default=0.0, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} ) validation_name: str = field( default="validation", metadata={"help": "Name of the set to run evaluation on."}, ) def __post_init__(self): if self.dataset_name is None and self.train_file is None and self.validation_file is None: raise ValueError("Need either a dataset name or a training/validation file.") else: if self.train_file is not None: extension = self.train_file.split(".")[-1] assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." if self.validation_file is not None: extension = self.validation_file.split(".")[-1] assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." if __name__ == "__main__": parser = HfArgumentParser((BaseArgs, ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() parser = argparse.ArgumentParser() args, _ = parser.parse_known_args() # Set up logging logger = logging.getLogger(__name__) logging.basicConfig( level=logging.getLevelName("INFO"), handlers=[logging.StreamHandler(sys.stdout)], format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) # load datasets train_dataset = load_from_disk(args.training_dir) test_dataset = load_from_disk(args.test_dir) logger.info(f" loaded train_dataset length is: {len(train_dataset)}") logger.info(f" loaded test_dataset length is: {len(test_dataset)}") # init model config = FunnelVaeConfig.from_pretrained(**model_args.__dict__) tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, use_fast_tokenizer=True) vocab_size = len(tokenizer) config.funnel.vocab_size = vocab_size config.t5.vocab_size = vocab_size config.vocab_size = vocab_size model = FunnelVae(config) model = FunnelVae.from_pretrained() tokenizer = AutoTokenizer.from_pretrained(args.model_name) # define training args training_args = TrainingArguments( output_dir=args.model_dir, num_train_epochs=args.epochs, per_device_train_batch_size=args.train_batch_size, per_device_eval_batch_size=args.eval_batch_size, warmup_steps=args.warmup_steps, evaluation_strategy="epoch", logging_dir=f"{args.output_data_dir}/logs", learning_rate=float(args.learning_rate), ) # create Trainer instance trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=test_dataset, tokenizer=tokenizer, ) # train model trainer.train() # evaluate model eval_result = trainer.evaluate(eval_dataset=test_dataset) # writes eval result to file which can be accessed later in s3 ouput with open(os.path.join(args.output_data_dir, "eval_results.txt"), "w") as writer: print(f"***** Eval results *****") for key, value in sorted(eval_result.items()): writer.write(f"{key} = {value}\n") # Saves the model to s3 trainer.save_model(args.model_dir)