RvanB commited on
Commit
d29e6b9
1 Parent(s): 5381b52

Fix CLI argument passing

Browse files
marcai/cli.py CHANGED
@@ -33,9 +33,7 @@ def main():
33
  find_matches_parser.set_defaults(func=find_matches.main)
34
 
35
  args = parser.parse_args()
36
-
37
  args.func(args)
38
-
39
 
40
 
41
  if __name__ == "__main__":
 
33
  find_matches_parser.set_defaults(func=find_matches.main)
34
 
35
  args = parser.parse_args()
 
36
  args.func(args)
 
37
 
38
 
39
  if __name__ == "__main__":
marcai/find_matches.py CHANGED
@@ -35,10 +35,7 @@ def args_parser():
35
  return parser
36
 
37
 
38
- def main():
39
-
40
- args = args_parser().parse_args()
41
-
42
  config_path = f"{args.model_dir}/config.yaml"
43
  model_onnx = f"{args.model_dir}/model.onnx"
44
 
@@ -59,9 +56,9 @@ def main():
59
  with open(args.pair_indices, "r") as indices_file:
60
  reader = csv.reader(indices_file)
61
  # Process records
62
- for df in tqdm(multiprocess_pairs(
63
- records_df, reader, args.chunksize, args.processes
64
- )):
65
  input_df = df[config["model"]["features"]]
66
  prediction = predict_onnx(model_onnx, input_df)
67
  df.loc[:, "prediction"] = prediction.squeeze()
@@ -74,6 +71,8 @@ def main():
74
  written = True
75
  else:
76
  df.to_csv(args.output, index=False, mode="a", header=False)
77
-
 
78
  if __name__ == "__main__":
79
- main()
 
 
35
  return parser
36
 
37
 
38
+ def main(args):
 
 
 
39
  config_path = f"{args.model_dir}/config.yaml"
40
  model_onnx = f"{args.model_dir}/model.onnx"
41
 
 
56
  with open(args.pair_indices, "r") as indices_file:
57
  reader = csv.reader(indices_file)
58
  # Process records
59
+ for df in tqdm(
60
+ multiprocess_pairs(records_df, reader, args.chunksize, args.processes)
61
+ ):
62
  input_df = df[config["model"]["features"]]
63
  prediction = predict_onnx(model_onnx, input_df)
64
  df.loc[:, "prediction"] = prediction.squeeze()
 
71
  written = True
72
  else:
73
  df.to_csv(args.output, index=False, mode="a", header=False)
74
+
75
+
76
  if __name__ == "__main__":
77
+ args = args_parser().parse_args()
78
+ main(args)
marcai/predict.py CHANGED
@@ -44,10 +44,7 @@ def args_parser():
44
  return parser
45
 
46
 
47
- def main():
48
-
49
- args = args_parser().parse_args()
50
-
51
  config_path = f"{args.model_dir}/config.yaml"
52
  model_onnx = f"{args.model_dir}/model.onnx"
53
 
@@ -75,4 +72,5 @@ def main():
75
 
76
 
77
  if __name__ == "__main__":
78
- main()
 
 
44
  return parser
45
 
46
 
47
+ def main(args):
 
 
 
48
  config_path = f"{args.model_dir}/config.yaml"
49
  model_onnx = f"{args.model_dir}/model.onnx"
50
 
 
72
 
73
 
74
  if __name__ == "__main__":
75
+ args = args_parser().parse_args()
76
+ main(args)
marcai/process.py CHANGED
@@ -47,7 +47,7 @@ def multiprocess_pairs(
47
 
48
  for future in done:
49
  # Get job's output
50
- df = future.result()
51
 
52
  # Yield output
53
  yield df
@@ -58,7 +58,7 @@ def multiprocess_pairs(
58
 
59
  if pairs_chunk is None:
60
  break
61
-
62
  indices = np.array(pairs_chunk).astype(int)
63
 
64
  left_indices = indices[:, 0]
@@ -127,11 +127,7 @@ def process(df0, df1):
127
  result_df["author"] = comps.maximum(authors, null_value=0.5)
128
 
129
  # Weighted title comparison
130
- weights = {
131
- "title_a": 1,
132
- "raw": 0,
133
- "title_p": 1
134
- }
135
 
136
  result_df["title_agg"] = comps.column_aggregate_similarity(
137
  df0[weights.keys()], df1[weights.keys()], weights.values(), null_value=0
@@ -142,8 +138,6 @@ def process(df0, df1):
142
  df0["title"], df1["title"], null_value=0.5
143
  )
144
 
145
-
146
-
147
  # Token set similarity
148
  result_df["title_tokenset"] = comps.token_set_similarity(
149
  df0["title"], df1["title"], null_value=0
@@ -220,10 +214,8 @@ def args_parser():
220
  return parser
221
 
222
 
223
- def main():
224
-
225
  start = time.time()
226
- args = args_parser().parse_args()
227
 
228
  # Load records
229
  print("Loading records...")
@@ -258,4 +250,5 @@ def main():
258
 
259
 
260
  if __name__ == "__main__":
261
- main()
 
 
47
 
48
  for future in done:
49
  # Get job's output
50
+ df = future.result()
51
 
52
  # Yield output
53
  yield df
 
58
 
59
  if pairs_chunk is None:
60
  break
61
+
62
  indices = np.array(pairs_chunk).astype(int)
63
 
64
  left_indices = indices[:, 0]
 
127
  result_df["author"] = comps.maximum(authors, null_value=0.5)
128
 
129
  # Weighted title comparison
130
+ weights = {"title_a": 1, "raw": 0, "title_p": 1}
 
 
 
 
131
 
132
  result_df["title_agg"] = comps.column_aggregate_similarity(
133
  df0[weights.keys()], df1[weights.keys()], weights.values(), null_value=0
 
138
  df0["title"], df1["title"], null_value=0.5
139
  )
140
 
 
 
141
  # Token set similarity
142
  result_df["title_tokenset"] = comps.token_set_similarity(
143
  df0["title"], df1["title"], null_value=0
 
214
  return parser
215
 
216
 
217
+ def main(args):
 
218
  start = time.time()
 
219
 
220
  # Load records
221
  print("Loading records...")
 
250
 
251
 
252
  if __name__ == "__main__":
253
+ args = args_parser().parse_args()
254
+ main(args)
marcai/train.py CHANGED
@@ -93,12 +93,11 @@ def args_parser():
93
  parser.add_argument("-n", "--run-name", help="Name for training run", required=True)
94
  return parser
95
 
96
- def main():
97
-
98
- args = args_parser().parse_args()
99
 
 
100
  train(args.run_name)
101
 
102
 
103
  if __name__ == "__main__":
104
- main()
 
 
93
  parser.add_argument("-n", "--run-name", help="Name for training run", required=True)
94
  return parser
95
 
 
 
 
96
 
97
+ def main(args):
98
  train(args.run_name)
99
 
100
 
101
  if __name__ == "__main__":
102
+ args = args_parser().parse_args()
103
+ main(args)