bvishnu123's picture
setup
1212df0 verified
raw
history blame
1.84 kB
import argparse
# Define the models we can choose from
class ModelEnum:
SVM = "svm"
DISTILBERT = "distilbert"
@classmethod
def choices(cls):
return [cls.SVM, cls.DISTILBERT]
# Define the CLI parser
parser = argparse.ArgumentParser(description="CLI for predicting if a job is fake based on the title and description")
subparsers = parser.add_subparsers(title="subcommands", dest="subcommand")
# Prediction sub-command
predict_parser = subparsers.add_parser("predict", help="Make predictions using a trained model")
predict_parser.add_argument("-m", "--model", choices=ModelEnum.choices(), required=True, help="Model to choose between SVM baseline and fine-tuned DistilBERT")
predict_parser.add_argument("-f", "--file", required=True, help="Path to trained model folder")
predict_parser.add_argument("--title", required=True, help="Job title to classify")
predict_parser.add_argument("--description", required=True, help="Job description to classify")
# Parse the arguments
args = parser.parse_args()
if args.subcommand == "predict":
print(f"""===\n\nPredicting with {args.model} using
title '{args.title[:50]}{'...' if len(args.title) > 50 else ''}' and
description '{args.description[:50]}{'...' if len(args.description) > 50 else ''}'""")
if args.model == ModelEnum.SVM:
print("Loading SVM model...")
from fake_job_detector.models import BaselineSVMModel
model = BaselineSVMModel()
model.load_model(args.file)
elif args.model == ModelEnum.DISTILBERT:
print("Loading DistilBERT model for CPU inference...")
from fake_job_detector.models import DistilBERTBaseModel
model = DistilBERTBaseModel(pretrained_model=args.file, cpu=True)
print(f"===\n\nJob is {'fake' if model(args.title, args.description) else 'real'}")