RvanB commited on
Commit
c19ce61
1 Parent(s): c7a5b48

Formatting and commenting, fix console scripts

Browse files
Files changed (5) hide show
  1. demo/app.py +14 -14
  2. marcai/find_matches.py +6 -5
  3. marcai/process.py +1 -3
  4. marcai/train.py +7 -5
  5. setup.cfg +2 -2
demo/app.py CHANGED
@@ -1,45 +1,45 @@
 
 
1
  import gradio as gr
2
- import pymarc
3
- from marcai.process import process
4
- from marcai.utils.parsing import record_dict
5
  import pandas as pd
 
 
6
  from marcai.predict import predict_onnx
 
7
  from marcai.utils import load_config
8
- import os
9
 
10
  demo_dir = os.path.dirname(os.path.realpath(__file__))
11
 
 
12
  def compare(file1, file2):
 
13
  record1 = pymarc.parse_xml_to_array(file1)[0]
14
  record2 = pymarc.parse_xml_to_array(file2)[0]
15
 
 
16
  df1 = pd.DataFrame.from_dict([record_dict(record1)])
17
  df2 = pd.DataFrame.from_dict([record_dict(record2)])
18
 
19
  df = process(df1, df2)
20
 
21
- # Load model config
22
  config = load_config(os.path.join(demo_dir, "config.yaml"))
23
- model_onnx = os.path.join(demo_dir, "model.onnx")
24
 
25
  # Run ONNX model
 
26
  input_df = df[config["model"]["features"]]
27
- prediction = predict_onnx(model_onnx, input_df)
28
-
29
- prediction = prediction.item()
30
 
31
  return {"match": prediction, "not match": 1 - prediction}
32
 
33
 
34
  interface = gr.Interface(
35
  fn=compare,
36
- inputs=[
37
- gr.File(label="MARC XML File 1"),
38
- gr.File(label="MARC XML File 2")
39
- ],
40
  outputs=gr.Label(label="Classification"),
41
  title="MARC Record Matcher",
42
  description="Upload two MARC XML files with one record each.",
43
- allow_flagging="never"
44
  )
45
  interface.launch()
 
1
+ import os
2
+
3
  import gradio as gr
 
 
 
4
  import pandas as pd
5
+ import pymarc
6
+
7
  from marcai.predict import predict_onnx
8
+ from marcai.process import process
9
  from marcai.utils import load_config
10
+ from marcai.utils.parsing import record_dict
11
 
12
  demo_dir = os.path.dirname(os.path.realpath(__file__))
13
 
14
+
15
  def compare(file1, file2):
16
+ # Load records
17
  record1 = pymarc.parse_xml_to_array(file1)[0]
18
  record2 = pymarc.parse_xml_to_array(file2)[0]
19
 
20
+ # Turn into dataframes
21
  df1 = pd.DataFrame.from_dict([record_dict(record1)])
22
  df2 = pd.DataFrame.from_dict([record_dict(record2)])
23
 
24
  df = process(df1, df2)
25
 
26
+ # Load config
27
  config = load_config(os.path.join(demo_dir, "config.yaml"))
 
28
 
29
  # Run ONNX model
30
+ model_onnx = os.path.join(demo_dir, "model.onnx")
31
  input_df = df[config["model"]["features"]]
32
+ prediction = predict_onnx(model_onnx, input_df).item()
 
 
33
 
34
  return {"match": prediction, "not match": 1 - prediction}
35
 
36
 
37
  interface = gr.Interface(
38
  fn=compare,
39
+ inputs=[gr.File(label="MARC XML File 1"), gr.File(label="MARC XML File 2")],
 
 
 
40
  outputs=gr.Label(label="Classification"),
41
  title="MARC Record Matcher",
42
  description="Upload two MARC XML files with one record each.",
43
+ allow_flagging="never",
44
  )
45
  interface.launch()
marcai/find_matches.py CHANGED
@@ -1,13 +1,14 @@
1
  import argparse
2
- from process import multiprocess_pairs
3
- from predict import predict_onnx
4
- from tqdm import tqdm
5
  import pandas as pd
 
6
 
7
- from marcai.utils.parsing import load_records, record_dict
 
8
  from marcai.utils import load_config
 
9
 
10
- import csv
11
 
12
  def main():
13
  parser = argparse.ArgumentParser()
 
1
  import argparse
2
+ import csv
3
+
 
4
  import pandas as pd
5
+ from tqdm import tqdm
6
 
7
+ from marcai.predict import predict_onnx
8
+ from marcai.process import multiprocess_pairs
9
  from marcai.utils import load_config
10
+ from marcai.utils.parsing import load_records, record_dict
11
 
 
12
 
13
  def main():
14
  parser = argparse.ArgumentParser()
marcai/process.py CHANGED
@@ -1,8 +1,8 @@
1
  import argparse
2
  import concurrent.futures
3
  import csv
4
- import itertools
5
  import time
 
6
 
7
  import numpy as np
8
  import pandas as pd
@@ -12,8 +12,6 @@ import marcai.processing.comparisons as comps
12
  import marcai.processing.normalizations as norms
13
  from marcai.utils.parsing import load_records, record_dict
14
 
15
- from multiprocessing import get_context
16
-
17
 
18
  def multiprocess_pairs(
19
  records_df,
 
1
  import argparse
2
  import concurrent.futures
3
  import csv
 
4
  import time
5
+ from multiprocessing import get_context
6
 
7
  import numpy as np
8
  import pandas as pd
 
12
  import marcai.processing.normalizations as norms
13
  from marcai.utils.parsing import load_records, record_dict
14
 
 
 
15
 
16
  def multiprocess_pairs(
17
  records_df,
marcai/train.py CHANGED
@@ -1,13 +1,15 @@
1
- import pytorch_lightning as lightning
2
- from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
3
- import warnings
4
- import yaml
5
  import argparse
6
  import os
 
 
 
 
7
  import torch
 
 
 
8
  from marcai.pl import MARCDataModule, SimilarityVectorModel
9
  from marcai.utils import load_config
10
- import tarfile
11
 
12
 
13
  def train(name=None):
 
 
 
 
 
1
  import argparse
2
  import os
3
+ import tarfile
4
+ import warnings
5
+
6
+ import pytorch_lightning as lightning
7
  import torch
8
+ import yaml
9
+ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
10
+
11
  from marcai.pl import MARCDataModule, SimilarityVectorModel
12
  from marcai.utils import load_config
 
13
 
14
 
15
  def train(name=None):
setup.cfg CHANGED
@@ -10,5 +10,5 @@ console_scripts =
10
  process = marcai:process.main
11
  predict = marcai:predict.main
12
  train = marcai:train.main
13
- compare_records = marcai:compare_records.main
14
-
 
10
  process = marcai:process.main
11
  predict = marcai:predict.main
12
  train = marcai:train.main
13
+ find_matches = marcai:find_matches.main
14
+