sam-pointer-bart-base-v0.3 / model_utils.py
ArneBinder's picture
Upload 9 files
86277c0 verified
raw
history blame
No virus
6.9 kB
import logging
from typing import Dict, List, Optional, Tuple
import gradio as gr
from annotation_utils import labeled_span_to_id
from pie_modules.document.processing import tokenize_document
from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
from pytorch_ie import Pipeline
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.auto import AutoPipeline
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
logger = logging.getLogger(__name__)
def _embed_text_annotations(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
text_layer_name: str,
) -> Dict[LabeledSpan, List[float]]:
# to not modify the original document
document = document.copy()
# tokenize_document does not yet consider predictions, so we need to add them manually
document[text_layer_name].extend(document[text_layer_name].predictions.clear())
added_annotations = []
tokenizer_kwargs = {
"max_length": 512,
"stride": 64,
"truncation": True,
"return_overflowing_tokens": True,
}
tokenized_documents = tokenize_document(
document,
tokenizer=tokenizer,
result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
partition_layer="labeled_partitions",
added_annotations=added_annotations,
strict_span_conversion=False,
**tokenizer_kwargs,
)
# just tokenize again to get tensors in the correct format for the model
model_inputs = tokenizer(document.text, return_tensors="pt", **tokenizer_kwargs)
# this is added when using return_overflowing_tokens=True, but the model does not accept it
model_inputs.pop("overflow_to_sample_mapping", None)
assert len(model_inputs.encodings) == len(tokenized_documents)
model_output = model(**model_inputs)
# get embeddings for all text annotations
embeddings = {}
for batch_idx in range(len(model_output.last_hidden_state)):
text2tok_ann = added_annotations[batch_idx][text_layer_name]
tok2text_ann = {v: k for k, v in text2tok_ann.items()}
for tok_ann in tokenized_documents[batch_idx].labeled_spans:
# skip "empty" annotations
if tok_ann.start == tok_ann.end:
continue
# use the max pooling strategy to get a single embedding for the annotation text
embedding = model_output.last_hidden_state[batch_idx, tok_ann.start : tok_ann.end].max(
dim=0
)[0]
text_ann = tok2text_ann[tok_ann]
if text_ann in embeddings:
logger.warning(
f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
)
embeddings[text_ann] = embedding
return embeddings
def _annotate(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
pipeline: Pipeline,
embedding_model: Optional[PreTrainedModel] = None,
embedding_tokenizer: Optional[PreTrainedTokenizer] = None,
) -> None:
# execute prediction pipeline
pipeline(document)
if embedding_model is not None and embedding_tokenizer is not None:
adu_embeddings = _embed_text_annotations(
document=document,
model=embedding_model,
tokenizer=embedding_tokenizer,
text_layer_name="labeled_spans",
)
# convert keys to str because JSON keys must be strings
adu_embeddings_dict = {
labeled_span_to_id(k): v.detach().tolist() for k, v in adu_embeddings.items()
}
document.metadata["embeddings"] = adu_embeddings_dict
else:
gr.Warning(
"No embedding model provided. Skipping embedding extraction. You can load an embedding "
"model in the 'Model Configuration' section."
)
def create_and_annotate_document(
text: str,
doc_id: str,
models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
text, annotate it, and add it to the index.
Parameters:
text: The text to process.
doc_id: The ID of the document.
models: A tuple containing the prediction pipeline and the embedding model and tokenizer.
Returns:
The processed document.
"""
try:
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
id=doc_id, text=text, metadata={}
)
# add single partition from the whole text (the model only considers text in partitions)
document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
# annotate the document
_annotate(
document=document,
pipeline=models[0],
embedding_model=models[1],
embedding_tokenizer=models[2],
)
return document
except Exception as e:
raise gr.Error(f"Failed to process text: {e}")
def load_argumentation_model(model_name: str, revision: Optional[str] = None) -> Pipeline:
try:
model = AutoPipeline.from_pretrained(
model_name,
device=-1,
num_workers=0,
taskmodule_kwargs=dict(revision=revision),
model_kwargs=dict(revision=revision),
)
except Exception as e:
raise gr.Error(f"Failed to load argumentation model: {e}")
gr.Info(f"Loaded argumentation model: model_name={model_name}, revision={revision})")
return model
def load_embedding_model(model_name: str) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
try:
embedding_model = AutoModel.from_pretrained(model_name)
embedding_tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
raise gr.Error(f"Failed to load embedding model: {e}")
gr.Info(f"Loaded embedding model: model_name={model_name})")
return embedding_model, embedding_tokenizer
def load_models(
model_name: str, revision: Optional[str] = None, embedding_model_name: Optional[str] = None
) -> Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]]:
argumentation_model = load_argumentation_model(model_name, revision)
embedding_model = None
embedding_tokenizer = None
if embedding_model_name is not None and embedding_model_name.strip():
embedding_model, embedding_tokenizer = load_embedding_model(embedding_model_name)
return argumentation_model, embedding_model, embedding_tokenizer