juliensimon's picture
juliensimon HF staff
Initial version
a5a9972
raw history blame
No virus
3.63 kB
import random, sys, argparse, os, logging, torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from datasets import load_from_disk
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# hyperparameters sent by the client are passed as command-line arguments to the script.
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--train-batch-size", type=int, default=32)
parser.add_argument("--eval-batch-size", type=int, default=64)
parser.add_argument("--save-strategy", type=str, default='no')
parser.add_argument("--save-steps", type=int, default=500)
parser.add_argument("--model-name", type=str)
parser.add_argument("--learning-rate", type=str, default=5e-5)
# Data, model, and output directories
parser.add_argument("--output-data-dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"])
parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
parser.add_argument("--n-gpus", type=str, default=os.environ["SM_NUM_GPUS"])
parser.add_argument("--train-dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"])
parser.add_argument("--valid-dir", type=str, default=os.environ["SM_CHANNEL_VALID"])
args, _ = parser.parse_known_args()
# load datasets
train_dataset = load_from_disk(args.train_dir)
valid_dataset = load_from_disk(args.valid_dir)
logger = logging.getLogger(__name__)
logger.info(f" loaded train_dataset length is: {len(train_dataset)}")
logger.info(f" loaded valid_dataset length is: {len(valid_dataset)}")
# compute metrics function for binary classification
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
acc = accuracy_score(labels, preds)
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
# download model from model hub
model = AutoModelForSequenceClassification.from_pretrained(args.model_name)
# download the tokenizer too, which will be saved in the model artifact
# and used at prediction time
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# define training args
training_args = TrainingArguments(
output_dir=args.model_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.eval_batch_size,
save_strategy=args.save_strategy,
save_steps=args.save_steps,
evaluation_strategy="epoch",
logging_dir=f"{args.output_data_dir}/logs",
learning_rate=float(args.learning_rate),
)
# create Trainer instance
trainer = Trainer(
model=model,
args=training_args,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
)
# train model
trainer.train()
# evaluate model
eval_result = trainer.evaluate(eval_dataset=valid_dataset)
# writes eval result to file which can be accessed later in s3 output
with open(os.path.join(args.output_data_dir, "eval_results.txt"), "w") as writer:
print(f"***** Eval results *****")
for key, value in sorted(eval_result.items()):
writer.write(f"{key} = {value}\n")
# Saves the model to s3
trainer.save_model(args.model_dir)