File size: 5,010 Bytes
626eca0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse

import torch
from reader.data.relik_reader_sample import load_relik_reader_samples

from relik.reader.pytorch_modules.hf.modeling_relik import (
    RelikReaderConfig,
    RelikReaderREModel,
)
from relik.reader.relik_reader_re import RelikReaderForTripletExtraction
from relik.reader.utils.relation_matching_eval import StrongMatching

dict_nyt = {
    "/people/person/nationality": "nationality",
    "/sports/sports_team/location": "sports team location",
    "/location/country/administrative_divisions": "administrative divisions",
    "/business/company/major_shareholders": "shareholders",
    "/people/ethnicity/people": "ethnicity",
    "/people/ethnicity/geographic_distribution": "geographic distributi6on",
    "/business/company_shareholder/major_shareholder_of": "major shareholder",
    "/location/location/contains": "location",
    "/business/company/founders": "founders",
    "/business/person/company": "company",
    "/business/company/advisors": "advisor",
    "/people/deceased_person/place_of_death": "place of death",
    "/business/company/industry": "industry",
    "/people/person/ethnicity": "ethnic background",
    "/people/person/place_of_birth": "place of birth",
    "/location/administrative_division/country": "country of an administration division",
    "/people/person/place_lived": "place lived",
    "/sports/sports_team_location/teams": "sports team",
    "/people/person/children": "child",
    "/people/person/religion": "religion",
    "/location/neighborhood/neighborhood_of": "neighborhood",
    "/location/country/capital": "capital",
    "/business/company/place_founded": "company founded location",
    "/people/person/profession": "occupation",
}


def eval(model_path, data_path, is_eval, output_path=None):
    if model_path.endswith(".ckpt"):
        # if it is a lightning checkpoint we load the model state dict and the tokenizer from the config
        model_dict = torch.load(model_path)

        additional_special_symbols = model_dict["hyper_parameters"][
            "additional_special_symbols"
        ]
        from transformers import AutoTokenizer

        from relik.reader.utils.special_symbols import get_special_symbols_re

        special_symbols = get_special_symbols_re(additional_special_symbols - 1)
        tokenizer = AutoTokenizer.from_pretrained(
            model_dict["hyper_parameters"]["transformer_model"],
            additional_special_tokens=special_symbols,
            add_prefix_space=True,
        )
        config_model = RelikReaderConfig(
            model_dict["hyper_parameters"]["transformer_model"],
            len(special_symbols),
            training=False,
        )
        model = RelikReaderREModel(config_model)
        model_dict["state_dict"] = {
            k.replace("relik_reader_re_model.", ""): v
            for k, v in model_dict["state_dict"].items()
        }
        model.load_state_dict(model_dict["state_dict"], strict=False)
        reader = RelikReaderForTripletExtraction(
            model, training=False, device="cuda", tokenizer=tokenizer
        )
    else:
        # if it is a huggingface model we load the model directly. Note that it could even be a string from the hub
        model = RelikReaderREModel.from_pretrained(model_path)
        reader = RelikReaderForTripletExtraction(model, training=False, device="cuda")

    samples = list(load_relik_reader_samples(data_path))

    for sample in samples:
        sample.candidates = [dict_nyt[cand] for cand in sample.candidates]
        sample.triplets = [
            {
                "subject": triplet["subject"],
                "relation": {
                    "name": dict_nyt[triplet["relation"]["name"]],
                    "type": triplet["relation"]["type"],
                },
                "object": triplet["object"],
            }
            for triplet in sample.triplets
        ]

    predicted_samples = reader.read(samples=samples, progress_bar=True)
    if is_eval:
        strong_matching_metric = StrongMatching()
        predicted_samples = list(predicted_samples)
        for k, v in strong_matching_metric(predicted_samples).items():
            print(f"test_{k}", v)
    if output_path is not None:
        with open(output_path, "w") as f:
            for sample in predicted_samples:
                f.write(sample.to_jsons() + "\n")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_path",
        type=str,
        default="/home/huguetcabot/alby-re/relik/relik/reader/models/relik_re_reader_base",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        default="/home/huguetcabot/alby-re/relik/relik/reader/data/testa.jsonl",
    )
    parser.add_argument("--is-eval", action="store_true")
    parser.add_argument("--output_path", type=str, default=None)
    args = parser.parse_args()
    eval(args.model_path, args.data_path, args.is_eval, args.output_path)


if __name__ == "__main__":
    main()