|
import argparse |
|
from pathlib import Path |
|
|
|
from .dataset import JobDataset, HuggingFaceJobDataset |
|
from .utils import compute_metrics |
|
from .models import Model, DistilBERTBaseModel |
|
|
|
|
|
def train_model_from_cli(args): |
|
model_name = args.model |
|
model_dir = args.model_dir |
|
experiment = args.experiment |
|
|
|
if model_name == "distilbert": |
|
model_title = "DistilBERTBase" |
|
model_path = Path(model_dir, model_title, experiment) |
|
model_path.mkdir(parents=True, exist_ok=True) |
|
subsample = args.subsample |
|
|
|
training_args = { |
|
"learning_rate": args.learning_rate, |
|
"per_device_train_batch_size": args.batch_size, |
|
"per_device_eval_batch_size": args.batch_size, |
|
"num_train_epochs": args.epochs, |
|
"weight_decay": args.weight_decay, |
|
"save_steps": args.save_steps, |
|
} |
|
|
|
dataset = HuggingFaceJobDataset() |
|
|
|
model = DistilBERTBaseModel() |
|
model.set_training_args(**training_args) |
|
model.fit(dataset, subsample=subsample) |
|
print(model.evaluate(subsample=subsample)) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Trains the fake job detector model.') |
|
parser.add_argument("model", type=str, choices=["distilbert"], help="Which model to train.") |
|
parser.add_argument("--model_dir", type=str, default="./models", help="Where to store the models after training.") |
|
parser.add_argument("--experiment", type=str, default="base", help="Name of experiment.") |
|
|
|
distilbert_group = parser.add_argument_group("DistilBERT training arguments") |
|
distilbert_group.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate of model.") |
|
distilbert_group.add_argument("--batch_size", type=int, default=16, help="Batch size when training or evaluating the model.") |
|
distilbert_group.add_argument("--epochs", type=int, default=3, help="Number of epochs to train the model.") |
|
distilbert_group.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay induced.") |
|
distilbert_group.add_argument("--save_steps", type=int, default=5, help="Number of training steps in between checkpoints.") |
|
distilbert_group.add_argument("--subsample", type=bool, default=False, help="Whether or not to use only a subsample.") |
|
|
|
args = parser.parse_args() |
|
train_model_from_cli(args) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |