RvanB commited on
Commit
5381b52
1 Parent(s): a36bc1b

Change package entrypoint to parent CLI script

Browse files
marcai/cli.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from . import train, predict, process, find_matches
3
+
4
+
5
+ def main():
6
+ parser = argparse.ArgumentParser(
7
+ description="Command-line interface for marcai package"
8
+ )
9
+ subparsers = parser.add_subparsers(required=True)
10
+
11
+ train_parser = subparsers.add_parser(
12
+ "train", parents=[train.args_parser()], help="Train a model", add_help=False
13
+ )
14
+ predict_parser = subparsers.add_parser(
15
+ "predict",
16
+ parents=[predict.args_parser()],
17
+ help="Make predictions using a trained model",
18
+ add_help=False,
19
+ )
20
+ process_parser = subparsers.add_parser(
21
+ "process", parents=[process.args_parser()], help="Process data", add_help=False
22
+ )
23
+ find_matches_parser = subparsers.add_parser(
24
+ "find_matches",
25
+ parents=[find_matches.args_parser()],
26
+ help="Find matches in data",
27
+ add_help=False,
28
+ )
29
+
30
+ train_parser.set_defaults(func=train.main)
31
+ predict_parser.set_defaults(func=predict.main)
32
+ process_parser.set_defaults(func=process.main)
33
+ find_matches_parser.set_defaults(func=find_matches.main)
34
+
35
+ args = parser.parse_args()
36
+
37
+ args.func(args)
38
+
39
+
40
+
41
+ if __name__ == "__main__":
42
+ main()
marcai/find_matches.py CHANGED
@@ -10,7 +10,7 @@ from marcai.utils import load_config
10
  from marcai.utils.parsing import load_records, record_dict
11
 
12
 
13
- def main():
14
  parser = argparse.ArgumentParser()
15
  parser.add_argument("-i", "--inputs", nargs="+", help="MARC files", required=True)
16
  parser.add_argument(
@@ -32,7 +32,12 @@ def main():
32
  parser.add_argument("-o", "--output", help="Output file", required=True)
33
  parser.add_argument("-t", "--threshold", help="Threshold for matching", type=float)
34
 
35
- args = parser.parse_args()
 
 
 
 
 
36
 
37
  config_path = f"{args.model_dir}/config.yaml"
38
  model_onnx = f"{args.model_dir}/model.onnx"
 
10
  from marcai.utils.parsing import load_records, record_dict
11
 
12
 
13
+ def args_parser():
14
  parser = argparse.ArgumentParser()
15
  parser.add_argument("-i", "--inputs", nargs="+", help="MARC files", required=True)
16
  parser.add_argument(
 
32
  parser.add_argument("-o", "--output", help="Output file", required=True)
33
  parser.add_argument("-t", "--threshold", help="Threshold for matching", type=float)
34
 
35
+ return parser
36
+
37
+
38
+ def main():
39
+
40
+ args = args_parser().parse_args()
41
 
42
  config_path = f"{args.model_dir}/config.yaml"
43
  model_onnx = f"{args.model_dir}/model.onnx"
marcai/predict.py CHANGED
@@ -23,8 +23,7 @@ def predict_onnx(model_onnx_path, data):
23
 
24
  return ort_outs
25
 
26
-
27
- def main():
28
  parser = argparse.ArgumentParser()
29
  parser.add_argument(
30
  "-i", "--input", help="Path to preprocessed data file", required=True
@@ -42,8 +41,12 @@ def main():
42
  default=1024,
43
  type=int,
44
  )
 
 
 
 
45
 
46
- args = parser.parse_args()
47
 
48
  config_path = f"{args.model_dir}/config.yaml"
49
  model_onnx = f"{args.model_dir}/model.onnx"
 
23
 
24
  return ort_outs
25
 
26
+ def args_parser():
 
27
  parser = argparse.ArgumentParser()
28
  parser.add_argument(
29
  "-i", "--input", help="Path to preprocessed data file", required=True
 
41
  default=1024,
42
  type=int,
43
  )
44
+ return parser
45
+
46
+
47
+ def main():
48
 
49
+ args = args_parser().parse_args()
50
 
51
  config_path = f"{args.model_dir}/config.yaml"
52
  model_onnx = f"{args.model_dir}/model.onnx"
marcai/process.py CHANGED
@@ -190,7 +190,7 @@ def process(df0, df1):
190
  return result_df
191
 
192
 
193
- def parse_args():
194
  parser = argparse.ArgumentParser(
195
  formatter_class=argparse.ArgumentDefaultsHelpFormatter
196
  )
@@ -217,13 +217,13 @@ def parse_args():
217
  default=1,
218
  )
219
 
220
- return parser.parse_args()
221
 
222
 
223
  def main():
224
 
225
  start = time.time()
226
- args = parse_args()
227
 
228
  # Load records
229
  print("Loading records...")
 
190
  return result_df
191
 
192
 
193
+ def args_parser():
194
  parser = argparse.ArgumentParser(
195
  formatter_class=argparse.ArgumentDefaultsHelpFormatter
196
  )
 
217
  default=1,
218
  )
219
 
220
+ return parser
221
 
222
 
223
  def main():
224
 
225
  start = time.time()
226
+ args = args_parser().parse_args()
227
 
228
  # Load records
229
  print("Loading records...")
marcai/train.py CHANGED
@@ -88,12 +88,14 @@ def train(name=None):
88
  archive.add(save_dir, arcname=os.path.basename(save_dir))
89
 
90
 
91
- def main():
92
  parser = argparse.ArgumentParser()
93
- parser.add_argument(
94
- "-n", "--run-name", help="Name for training run"
95
- )
96
- args = parser.parse_args()
 
 
97
 
98
  train(args.run_name)
99
 
 
88
  archive.add(save_dir, arcname=os.path.basename(save_dir))
89
 
90
 
91
+ def args_parser():
92
  parser = argparse.ArgumentParser()
93
+ parser.add_argument("-n", "--run-name", help="Name for training run", required=True)
94
+ return parser
95
+
96
+ def main():
97
+
98
+ args = args_parser().parse_args()
99
 
100
  train(args.run_name)
101
 
setup.cfg CHANGED
@@ -7,8 +7,5 @@ packages = find:
7
 
8
  [options.entry_points]
9
  console_scripts =
10
- process = marcai:process.main
11
- predict = marcai:predict.main
12
- train = marcai:train.main
13
- find_matches = marcai:find_matches.main
14
-
 
7
 
8
  [options.entry_points]
9
  console_scripts =
10
+ marc-ai = marcai:cli.main
11
+