File size: 1,840 Bytes
1212df0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import argparse


# Define the models we can choose from
class ModelEnum:
    SVM = "svm"
    DISTILBERT = "distilbert"

    @classmethod
    def choices(cls):
        return [cls.SVM, cls.DISTILBERT]

# Define the CLI parser
parser = argparse.ArgumentParser(description="CLI for predicting if a job is fake based on the title and description")
subparsers = parser.add_subparsers(title="subcommands", dest="subcommand")

# Prediction sub-command
predict_parser = subparsers.add_parser("predict", help="Make predictions using a trained model")
predict_parser.add_argument("-m", "--model", choices=ModelEnum.choices(), required=True, help="Model to choose between SVM baseline and fine-tuned DistilBERT")
predict_parser.add_argument("-f", "--file", required=True, help="Path to trained model folder")
predict_parser.add_argument("--title", required=True, help="Job title to classify")
predict_parser.add_argument("--description", required=True, help="Job description to classify")

# Parse the arguments
args = parser.parse_args()

if args.subcommand == "predict":
    print(f"""===\n\nPredicting with {args.model} using
          title '{args.title[:50]}{'...' if len(args.title) > 50 else ''}' and
          description '{args.description[:50]}{'...' if len(args.description) > 50 else ''}'""")

    if args.model == ModelEnum.SVM:
        print("Loading SVM model...")
        from fake_job_detector.models import BaselineSVMModel
        model = BaselineSVMModel()
        model.load_model(args.file)

    elif args.model == ModelEnum.DISTILBERT:
        print("Loading DistilBERT model for CPU inference...")
        from fake_job_detector.models import DistilBERTBaseModel
        model = DistilBERTBaseModel(pretrained_model=args.file, cpu=True)

    print(f"===\n\nJob is {'fake' if model(args.title, args.description) else 'real'}")