|
import argparse |
|
|
|
|
|
|
|
class ModelEnum: |
|
SVM = "svm" |
|
DISTILBERT = "distilbert" |
|
|
|
@classmethod |
|
def choices(cls): |
|
return [cls.SVM, cls.DISTILBERT] |
|
|
|
|
|
parser = argparse.ArgumentParser(description="CLI for predicting if a job is fake based on the title and description") |
|
subparsers = parser.add_subparsers(title="subcommands", dest="subcommand") |
|
|
|
|
|
predict_parser = subparsers.add_parser("predict", help="Make predictions using a trained model") |
|
predict_parser.add_argument("-m", "--model", choices=ModelEnum.choices(), required=True, help="Model to choose between SVM baseline and fine-tuned DistilBERT") |
|
predict_parser.add_argument("-f", "--file", required=True, help="Path to trained model folder") |
|
predict_parser.add_argument("--title", required=True, help="Job title to classify") |
|
predict_parser.add_argument("--description", required=True, help="Job description to classify") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
if args.subcommand == "predict": |
|
print(f"""===\n\nPredicting with {args.model} using |
|
title '{args.title[:50]}{'...' if len(args.title) > 50 else ''}' and |
|
description '{args.description[:50]}{'...' if len(args.description) > 50 else ''}'""") |
|
|
|
if args.model == ModelEnum.SVM: |
|
print("Loading SVM model...") |
|
from fake_job_detector.models import BaselineSVMModel |
|
model = BaselineSVMModel() |
|
model.load_model(args.file) |
|
|
|
elif args.model == ModelEnum.DISTILBERT: |
|
print("Loading DistilBERT model for CPU inference...") |
|
from fake_job_detector.models import DistilBERTBaseModel |
|
model = DistilBERTBaseModel(pretrained_model=args.file, cpu=True) |
|
|
|
print(f"===\n\nJob is {'fake' if model(args.title, args.description) else 'real'}") |
|
|