import collections import logging from pathlib import Path from typing import Any, Callable, Dict, Iterator, List, Union import torch import transformers as tr from tqdm import tqdm from transformers import AutoConfig from relik.common.log import get_console_logger, get_logger from relik.reader.data.relik_reader_data_utils import batchify, flatten from relik.reader.data.relik_reader_sample import RelikReaderSample from relik.reader.pytorch_modules.hf.modeling_relik import ( RelikReaderConfig, RelikReaderSpanModel, ) from relik.reader.relik_reader_predictor import RelikReaderPredictor from relik.reader.utils.save_load_utilities import load_model_and_conf from relik.reader.utils.special_symbols import NME_SYMBOL, get_special_symbols console_logger = get_console_logger() logger = get_logger(__name__, level=logging.INFO) class RelikReaderForSpanExtraction(torch.nn.Module): def __init__( self, transformer_model: str | tr.PreTrainedModel | None = None, additional_special_symbols: int = 0, num_layers: int | None = None, activation: str = "gelu", linears_hidden_size: int | None = 512, use_last_k_layers: int = 1, training: bool = False, device: str | torch.device | None = None, tokenizer: str | tr.PreTrainedTokenizer | None = None, **kwargs, ) -> None: super().__init__() if isinstance(transformer_model, str): config = AutoConfig.from_pretrained( transformer_model, trust_remote_code=True ) if "relik-reader" in config.model_type: transformer_model = RelikReaderSpanModel.from_pretrained( transformer_model, **kwargs ) else: reader_config = RelikReaderConfig( transformer_model=transformer_model, additional_special_symbols=additional_special_symbols, num_layers=num_layers, activation=activation, linears_hidden_size=linears_hidden_size, use_last_k_layers=use_last_k_layers, training=training, ) transformer_model = RelikReaderSpanModel(reader_config) self.relik_reader_model = transformer_model self._tokenizer = tokenizer # move the model to the device self.to(device or torch.device("cpu")) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, prediction_mask: torch.Tensor | None = None, special_symbols_mask: torch.Tensor | None = None, special_symbols_mask_entities: torch.Tensor | None = None, start_labels: torch.Tensor | None = None, end_labels: torch.Tensor | None = None, disambiguation_labels: torch.Tensor | None = None, relation_labels: torch.Tensor | None = None, is_validation: bool = False, is_prediction: bool = False, *args, **kwargs, ) -> Dict[str, Any]: return self.relik_reader_model( input_ids, attention_mask, token_type_ids, prediction_mask, special_symbols_mask, special_symbols_mask_entities, start_labels, end_labels, disambiguation_labels, relation_labels, is_validation, is_prediction, *args, **kwargs, ) def batch_predict( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor | None = None, prediction_mask: torch.Tensor | None = None, special_symbols_mask: torch.Tensor | None = None, sample: List[RelikReaderSample] | None = None, top_k: int = 5, # the amount of top-k most probable entities to predict *args, **kwargs, ) -> Iterator[RelikReaderSample]: """ Args: input_ids: attention_mask: token_type_ids: prediction_mask: special_symbols_mask: sample: top_k: *args: **kwargs: Returns: """ forward_output = self.forward( input_ids, attention_mask, token_type_ids, prediction_mask, special_symbols_mask, ) ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy() ed_predictions = forward_output["ed_predictions"].cpu().numpy() ed_probabilities = forward_output["ed_probabilities"].cpu().numpy() batch_predictable_candidates = kwargs["predictable_candidates"] patch_offset = kwargs["patch_offset"] for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip( sample, ned_start_predictions, ned_end_predictions, ed_predictions, ed_probabilities, batch_predictable_candidates, patch_offset, ): ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0] ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0] final_class2predicted_spans = collections.defaultdict(list) spans2predicted_probabilities = dict() for start_token_index, end_token_index in zip( ne_start_indices, ne_end_indices ): # predicted candidate token_class = edp[start_token_index + 1] - 1 predicted_candidate_title = pred_cands[token_class] final_class2predicted_spans[predicted_candidate_title].append( [start_token_index, end_token_index] ) # candidates probabilities classes_probabilities = edpr[start_token_index + 1] classes_probabilities_best_indices = classes_probabilities.argsort()[ ::-1 ] titles_2_probs = [] top_k = ( min( top_k, len(classes_probabilities_best_indices), ) if top_k != -1 else len(classes_probabilities_best_indices) ) for i in range(top_k): titles_2_probs.append( ( pred_cands[classes_probabilities_best_indices[i] - 1], classes_probabilities[ classes_probabilities_best_indices[i] ].item(), ) ) spans2predicted_probabilities[ (start_token_index, end_token_index) ] = titles_2_probs if "patches" not in ts._d: ts._d["patches"] = dict() ts._d["patches"][po] = dict() sample_patch = ts._d["patches"][po] sample_patch["predicted_window_labels"] = final_class2predicted_spans sample_patch["span_title_probabilities"] = spans2predicted_probabilities # additional info sample_patch["predictable_candidates"] = pred_cands yield ts def _build_input(self, text: List[str], candidates: List[List[str]]) -> list[str]: candidates_symbols = get_special_symbols(len(candidates)) candidates = [ [cs, ct] if ct != NME_SYMBOL else [NME_SYMBOL] for cs, ct in zip(candidates_symbols, candidates) ] return ( [self.tokenizer.cls_token] + text + [self.tokenizer.sep_token] + flatten(candidates) + [self.tokenizer.sep_token] ) @staticmethod def _compute_offsets(offsets_mapping): offsets_mapping = offsets_mapping.numpy() token2word = [] word2token = {} count = 0 for i, offset in enumerate(offsets_mapping): if offset[0] == 0: token2word.append(i - count) word2token[i - count] = [i] else: token2word.append(token2word[-1]) word2token[token2word[-1]].append(i) count += 1 return token2word, word2token @staticmethod def _convert_tokens_to_word_annotations(sample: RelikReaderSample): triplets = [] entities = [] for entity in sample.predicted_entities: if sample.entity_candidates: entities.append( ( sample.token2word[entity[0] - 1], sample.token2word[entity[1] - 1] + 1, sample.entity_candidates[entity[2]], ) ) else: entities.append( ( sample.token2word[entity[0] - 1], sample.token2word[entity[1] - 1] + 1, -1, ) ) for predicted_triplet, predicted_triplet_probabilities in zip( sample.predicted_relations, sample.predicted_relations_probabilities ): subject, object_, relation = predicted_triplet subject = entities[subject] object_ = entities[object_] relation = sample.candidates[relation] triplets.append( { "subject": { "start": subject[0], "end": subject[1], "type": subject[2], "name": " ".join(sample.tokens[subject[0] : subject[1]]), }, "relation": { "name": relation, "probability": float(predicted_triplet_probabilities.round(2)), }, "object": { "start": object_[0], "end": object_[1], "type": object_[2], "name": " ".join(sample.tokens[object_[0] : object_[1]]), }, } ) sample.predicted_entities = entities sample.predicted_relations = triplets sample.predicted_relations_probabilities = None @torch.no_grad() @torch.inference_mode() def read( self, text: List[str] | List[List[str]] | None = None, samples: List[RelikReaderSample] | None = None, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, token_type_ids: torch.Tensor | None = None, prediction_mask: torch.Tensor | None = None, special_symbols_mask: torch.Tensor | None = None, special_symbols_mask_entities: torch.Tensor | None = None, candidates: List[List[str]] | None = None, max_length: int | None = 1024, max_batch_size: int | None = 64, token_batch_size: int | None = None, progress_bar: bool = False, *args, **kwargs, ) -> List[List[RelikReaderSample]]: """ Reads the given text. Args: text: The text to read in tokens. samples: input_ids: The input ids of the text. attention_mask: The attention mask of the text. token_type_ids: The token type ids of the text. prediction_mask: The prediction mask of the text. special_symbols_mask: The special symbols mask of the text. special_symbols_mask_entities: The special symbols mask entities of the text. candidates: The candidates of the text. max_length: The maximum length of the text. max_batch_size: The maximum batch size. token_batch_size: The maximum number of tokens per batch. progress_bar: Returns: The predicted labels for each sample. """ if text is None and input_ids is None and samples is None: raise ValueError( "Either `text` or `input_ids` or `samples` must be provided." ) if (input_ids is None and samples is None) and ( text is None or candidates is None ): raise ValueError( "`text` and `candidates` must be provided to return the predictions when " "`input_ids` and `samples` is not provided." ) if text is not None and samples is None: if len(text) != len(candidates): raise ValueError("`text` and `candidates` must have the same length.") if isinstance(text[0], str): # change to list of text text = [text] candidates = [candidates] samples = [ RelikReaderSample(tokens=t, candidates=c) for t, c in zip(text, candidates) ] if samples is not None: # function that creates a batch from the 'current_batch' list def output_batch() -> Dict[str, Any]: assert ( len( set( [ len(elem["predictable_candidates"]) for elem in current_batch ] ) ) == 1 ), " ".join( map( str, [len(elem["predictable_candidates"]) for elem in current_batch], ) ) batch_dict = dict() de_values_by_field = { fn: [de[fn] for de in current_batch if fn in de] for fn in self.fields_batcher } # in case you provide fields batchers but in the batch # there are no elements for that field de_values_by_field = { fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 } assert len(set([len(v) for v in de_values_by_field.values()])) # todo: maybe we should report the user about possible # fields filtering due to "None" instances de_values_by_field = { fn: fvs for fn, fvs in de_values_by_field.items() if all([fv is not None for fv in fvs]) } for field_name, field_values in de_values_by_field.items(): field_batch = ( self.fields_batcher[field_name]([fv[0] for fv in field_values]) if self.fields_batcher[field_name] is not None else field_values ) batch_dict[field_name] = field_batch batch_dict = { k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch_dict.items() } return batch_dict current_batch = [] predictions = [] current_cand_len = -1 for sample in tqdm(samples, disable=not progress_bar): sample.candidates = [NME_SYMBOL] + sample.candidates inputs_text = self._build_input(sample.tokens, sample.candidates) model_inputs = self.tokenizer( inputs_text, is_split_into_words=True, add_special_tokens=False, padding=False, truncation=True, max_length=max_length or self.tokenizer.model_max_length, return_offsets_mapping=True, return_tensors="pt", ) model_inputs["special_symbols_mask"] = ( model_inputs["input_ids"] > self.tokenizer.vocab_size ) # prediction mask is 0 until the first special symbol model_inputs["token_type_ids"] = ( torch.cumsum(model_inputs["special_symbols_mask"], dim=1) > 0 ).long() # shift prediction_mask to the left model_inputs["prediction_mask"] = model_inputs["token_type_ids"].roll( shifts=-1, dims=1 ) model_inputs["prediction_mask"][:, -1] = 1 model_inputs["prediction_mask"][:, 0] = 1 assert ( len(model_inputs["special_symbols_mask"]) == len(model_inputs["prediction_mask"]) == len(model_inputs["input_ids"]) ) model_inputs["sample"] = sample # compute cand_len using special_symbols_mask model_inputs["predictable_candidates"] = sample.candidates[ : model_inputs["special_symbols_mask"].sum().item() ] # cand_len = sum([id_ > self.tokenizer.vocab_size for id_ in model_inputs["input_ids"]]) offsets = model_inputs.pop("offset_mapping") offsets = offsets[model_inputs["prediction_mask"] == 0] sample.token2word, sample.word2token = self._compute_offsets(offsets) future_max_len = max( len(model_inputs["input_ids"]), max([len(b["input_ids"]) for b in current_batch], default=0), ) future_tokens_per_batch = future_max_len * (len(current_batch) + 1) if len(current_batch) > 0 and ( ( len(model_inputs["predictable_candidates"]) != current_cand_len and current_cand_len != -1 ) or ( isinstance(token_batch_size, int) and future_tokens_per_batch >= token_batch_size ) or len(current_batch) == max_batch_size ): batch_inputs = output_batch() current_batch = [] predictions.extend(list(self.batch_predict(**batch_inputs))) current_cand_len = len(model_inputs["predictable_candidates"]) current_batch.append(model_inputs) if current_batch: batch_inputs = output_batch() predictions.extend(list(self.batch_predict(**batch_inputs))) else: predictions = list( self.batch_predict( input_ids, attention_mask, token_type_ids, prediction_mask, special_symbols_mask, special_symbols_mask_entities, *args, **kwargs, ) ) return predictions @property def device(self) -> torch.device: """ The device of the model. """ return next(self.parameters()).device @property def tokenizer(self) -> tr.PreTrainedTokenizer: """ The tokenizer. """ if self._tokenizer: return self._tokenizer self._tokenizer = tr.AutoTokenizer.from_pretrained( self.relik_reader_model.config.name_or_path ) return self._tokenizer @property def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: fields_batchers = { "input_ids": lambda x: batchify( x, padding_value=self.tokenizer.pad_token_id ), "attention_mask": lambda x: batchify(x, padding_value=0), "token_type_ids": lambda x: batchify(x, padding_value=0), "prediction_mask": lambda x: batchify(x, padding_value=1), "global_attention": lambda x: batchify(x, padding_value=0), "token2word": None, "sample": None, "special_symbols_mask": lambda x: batchify(x, padding_value=False), "special_symbols_mask_entities": lambda x: batchify(x, padding_value=False), } if "roberta" in self.relik_reader_model.config.model_type: del fields_batchers["token_type_ids"] return fields_batchers def save_pretrained( self, output_dir: str, model_name: str | None = None, push_to_hub: bool = False, **kwargs, ) -> None: """ Saves the model to the given path. Args: output_dir: The path to save the model to. model_name: The name of the model. push_to_hub: Whether to push the model to the hub. """ # create the output directory output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) model_name = model_name or "relik-reader-for-span-extraction" logger.info(f"Saving reader to {output_dir / model_name}") # save the model self.relik_reader_model.register_for_auto_class() self.relik_reader_model.save_pretrained( output_dir / model_name, push_to_hub=push_to_hub, **kwargs ) logger.info("Saving reader to disk done.") if self.tokenizer: self.tokenizer.save_pretrained( output_dir / model_name, push_to_hub=push_to_hub, **kwargs ) logger.info("Saving tokenizer to disk done.") class RelikReader: def __init__(self, model_path: str, predict_nmes: bool = False): model, model_conf = load_model_and_conf(model_path) model.training = False model.eval() val_dataset_conf = model_conf.data.val_dataset val_dataset_conf.special_symbols = get_special_symbols( model_conf.model.entities_per_forward ) val_dataset_conf.transformer_model = model_conf.model.model.transformer_model self.predictor = RelikReaderPredictor( model, dataset_conf=model_conf.data.val_dataset, predict_nmes=predict_nmes, ) self.model_path = model_path def link_entities( self, dataset_path_or_samples: str | Iterator[RelikReaderSample], token_batch_size: int = 2048, progress_bar: bool = False, ) -> List[RelikReaderSample]: data_input = ( (dataset_path_or_samples, None) if isinstance(dataset_path_or_samples, str) else (None, dataset_path_or_samples) ) return self.predictor.predict( *data_input, dataset_conf=None, token_batch_size=token_batch_size, progress_bar=progress_bar, ) # def save_pretrained(self, path: Union[str, Path]): # self.predictor.save(path) def main(): rr = RelikReader("riccorl/relik-reader-aida-deberta-small-old", predict_nmes=True) predictions = rr.link_entities( "/Users/ric/Documents/PhD/Projects/relik/data/reader/aida/testa.jsonl" ) print(predictions) if __name__ == "__main__": main()