import argparse import torch from reader.data.relik_reader_sample import load_relik_reader_samples from relik.reader.pytorch_modules.hf.modeling_relik import ( RelikReaderConfig, RelikReaderREModel, ) from relik.reader.relik_reader_re import RelikReaderForTripletExtraction from relik.reader.utils.relation_matching_eval import StrongMatching dict_nyt = { "/people/person/nationality": "nationality", "/sports/sports_team/location": "sports team location", "/location/country/administrative_divisions": "administrative divisions", "/business/company/major_shareholders": "shareholders", "/people/ethnicity/people": "ethnicity", "/people/ethnicity/geographic_distribution": "geographic distributi6on", "/business/company_shareholder/major_shareholder_of": "major shareholder", "/location/location/contains": "location", "/business/company/founders": "founders", "/business/person/company": "company", "/business/company/advisors": "advisor", "/people/deceased_person/place_of_death": "place of death", "/business/company/industry": "industry", "/people/person/ethnicity": "ethnic background", "/people/person/place_of_birth": "place of birth", "/location/administrative_division/country": "country of an administration division", "/people/person/place_lived": "place lived", "/sports/sports_team_location/teams": "sports team", "/people/person/children": "child", "/people/person/religion": "religion", "/location/neighborhood/neighborhood_of": "neighborhood", "/location/country/capital": "capital", "/business/company/place_founded": "company founded location", "/people/person/profession": "occupation", } def eval(model_path, data_path, is_eval, output_path=None): if model_path.endswith(".ckpt"): # if it is a lightning checkpoint we load the model state dict and the tokenizer from the config model_dict = torch.load(model_path) additional_special_symbols = model_dict["hyper_parameters"][ "additional_special_symbols" ] from transformers import AutoTokenizer from relik.reader.utils.special_symbols import get_special_symbols_re special_symbols = get_special_symbols_re(additional_special_symbols - 1) tokenizer = AutoTokenizer.from_pretrained( model_dict["hyper_parameters"]["transformer_model"], additional_special_tokens=special_symbols, add_prefix_space=True, ) config_model = RelikReaderConfig( model_dict["hyper_parameters"]["transformer_model"], len(special_symbols), training=False, ) model = RelikReaderREModel(config_model) model_dict["state_dict"] = { k.replace("relik_reader_re_model.", ""): v for k, v in model_dict["state_dict"].items() } model.load_state_dict(model_dict["state_dict"], strict=False) reader = RelikReaderForTripletExtraction( model, training=False, device="cuda", tokenizer=tokenizer ) else: # if it is a huggingface model we load the model directly. Note that it could even be a string from the hub model = RelikReaderREModel.from_pretrained(model_path) reader = RelikReaderForTripletExtraction(model, training=False, device="cuda") samples = list(load_relik_reader_samples(data_path)) for sample in samples: sample.candidates = [dict_nyt[cand] for cand in sample.candidates] sample.triplets = [ { "subject": triplet["subject"], "relation": { "name": dict_nyt[triplet["relation"]["name"]], "type": triplet["relation"]["type"], }, "object": triplet["object"], } for triplet in sample.triplets ] predicted_samples = reader.read(samples=samples, progress_bar=True) if is_eval: strong_matching_metric = StrongMatching() predicted_samples = list(predicted_samples) for k, v in strong_matching_metric(predicted_samples).items(): print(f"test_{k}", v) if output_path is not None: with open(output_path, "w") as f: for sample in predicted_samples: f.write(sample.to_jsons() + "\n") def main(): parser = argparse.ArgumentParser() parser.add_argument( "--model_path", type=str, default="/home/huguetcabot/alby-re/relik/relik/reader/models/relik_re_reader_base", ) parser.add_argument( "--data_path", type=str, default="/home/huguetcabot/alby-re/relik/relik/reader/data/testa.jsonl", ) parser.add_argument("--is-eval", action="store_true") parser.add_argument("--output_path", type=str, default=None) args = parser.parse_args() eval(args.model_path, args.data_path, args.is_eval, args.output_path) if __name__ == "__main__": main()