RvanB commited on
Commit
d7838fe
1 Parent(s): 5334ec8

Update model and code

Browse files
Files changed (4) hide show
  1. app.py +1 -5
  2. config.json +15 -0
  3. config.yaml +0 -27
  4. model.safetensors +3 -0
app.py CHANGED
@@ -6,7 +6,6 @@ import pymarc
6
 
7
  from marcai.predict import predict
8
  from marcai.process import process
9
- from marcai.utils import load_config
10
  from marcai.utils.parsing import record_dict
11
  from marcai.pl import SimilarityVectorModel
12
 
@@ -23,12 +22,9 @@ def compare(file1, file2):
23
 
24
  df = process(df1, df2)
25
 
26
- # Load config
27
- config = load_config(os.path.join(root, "config.yaml"))
28
-
29
  model = SimilarityVectorModel.from_pretrained("cdlib/marc-match-ai")
30
 
31
- input_df = df[config["model"]["features"]]
32
 
33
  # Run model
34
  prediction = predict(model, input_df).item()
 
6
 
7
  from marcai.predict import predict
8
  from marcai.process import process
 
9
  from marcai.utils.parsing import record_dict
10
  from marcai.pl import SimilarityVectorModel
11
 
 
22
 
23
  df = process(df1, df2)
24
 
 
 
 
25
  model = SimilarityVectorModel.from_pretrained("cdlib/marc-match-ai")
26
 
27
+ input_df = df[model.features]
28
 
29
  # Run model
30
  prediction = predict(model, input_df).item()
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "batch_size": 512,
3
+ "features": [
4
+ "title_tokenset",
5
+ "title_agg",
6
+ "author",
7
+ "publisher",
8
+ "pub_date",
9
+ "pub_place",
10
+ "pagination"
11
+ ],
12
+ "lr": 0.006,
13
+ "optimizer": "Adam",
14
+ "weight_decay": 0.0
15
+ }
config.yaml DELETED
@@ -1,27 +0,0 @@
1
- model:
2
- # Inputs features
3
- features:
4
- - title_tokenset
5
- - title_agg
6
- - author
7
- - publisher
8
- - pub_date
9
- - pub_place
10
- - pagination
11
-
12
- # Training
13
- batch_size: 512
14
- weight_decay: 0.0
15
- max_epochs: -1
16
-
17
- # Disable early stopping with -1
18
- patience: 20
19
-
20
- lr: 0.006
21
- optimizer: Adam
22
- saved_models_dir: saved_models
23
-
24
- # Paths to dataset splits
25
- test_processed_path: data/test_processed.csv
26
- train_processed_path: data/train_processed.csv
27
- val_processed_path: data/val_processed.csv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42a2798326955686b9da8b88245e6c9f5f9ab34027a956d9eaac0c125a2751fc
3
+ size 10180