bvishnu123's picture
setup
1212df0 verified
raw
history blame
2.41 kB
import argparse
from pathlib import Path
from .dataset import JobDataset, HuggingFaceJobDataset
from .utils import compute_metrics
from .models import Model, DistilBERTBaseModel
def train_model_from_cli(args):
model_name = args.model
model_dir = args.model_dir
experiment = args.experiment
if model_name == "distilbert":
model_title = "DistilBERTBase"
model_path = Path(model_dir, model_title, experiment)
model_path.mkdir(parents=True, exist_ok=True)
subsample = args.subsample
training_args = {
"learning_rate": args.learning_rate,
"per_device_train_batch_size": args.batch_size,
"per_device_eval_batch_size": args.batch_size,
"num_train_epochs": args.epochs,
"weight_decay": args.weight_decay,
"save_steps": args.save_steps,
}
dataset = HuggingFaceJobDataset()
model = DistilBERTBaseModel()
model.set_training_args(**training_args)
model.fit(dataset, subsample=subsample)
print(model.evaluate(subsample=subsample))
def main():
parser = argparse.ArgumentParser(description='Trains the fake job detector model.')
parser.add_argument("model", type=str, choices=["distilbert"], help="Which model to train.")
parser.add_argument("--model_dir", type=str, default="./models", help="Where to store the models after training.")
parser.add_argument("--experiment", type=str, default="base", help="Name of experiment.")
distilbert_group = parser.add_argument_group("DistilBERT training arguments")
distilbert_group.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate of model.")
distilbert_group.add_argument("--batch_size", type=int, default=16, help="Batch size when training or evaluating the model.")
distilbert_group.add_argument("--epochs", type=int, default=3, help="Number of epochs to train the model.")
distilbert_group.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay induced.")
distilbert_group.add_argument("--save_steps", type=int, default=5, help="Number of training steps in between checkpoints.")
distilbert_group.add_argument("--subsample", type=bool, default=False, help="Whether or not to use only a subsample.")
args = parser.parse_args()
train_model_from_cli(args)
if __name__ == "__main__":
main()