import argparse from pathlib import Path from .dataset import JobDataset, HuggingFaceJobDataset from .utils import compute_metrics from .models import Model, DistilBERTBaseModel def train_model_from_cli(args): model_name = args.model model_dir = args.model_dir experiment = args.experiment if model_name == "distilbert": model_title = "DistilBERTBase" model_path = Path(model_dir, model_title, experiment) model_path.mkdir(parents=True, exist_ok=True) subsample = args.subsample training_args = { "learning_rate": args.learning_rate, "per_device_train_batch_size": args.batch_size, "per_device_eval_batch_size": args.batch_size, "num_train_epochs": args.epochs, "weight_decay": args.weight_decay, "save_steps": args.save_steps, } dataset = HuggingFaceJobDataset() model = DistilBERTBaseModel() model.set_training_args(**training_args) model.fit(dataset, subsample=subsample) print(model.evaluate(subsample=subsample)) def main(): parser = argparse.ArgumentParser(description='Trains the fake job detector model.') parser.add_argument("model", type=str, choices=["distilbert"], help="Which model to train.") parser.add_argument("--model_dir", type=str, default="./models", help="Where to store the models after training.") parser.add_argument("--experiment", type=str, default="base", help="Name of experiment.") distilbert_group = parser.add_argument_group("DistilBERT training arguments") distilbert_group.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate of model.") distilbert_group.add_argument("--batch_size", type=int, default=16, help="Batch size when training or evaluating the model.") distilbert_group.add_argument("--epochs", type=int, default=3, help="Number of epochs to train the model.") distilbert_group.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay induced.") distilbert_group.add_argument("--save_steps", type=int, default=5, help="Number of training steps in between checkpoints.") distilbert_group.add_argument("--subsample", type=bool, default=False, help="Whether or not to use only a subsample.") args = parser.parse_args() train_model_from_cli(args) if __name__ == "__main__": main()