|
import logging |
|
from typing import Optional, Tuple |
|
|
|
import gradio as gr |
|
import torch |
|
from annotation_utils import labeled_span_to_id |
|
from embedding import EmbeddingModel, HuggingfaceEmbeddingModel |
|
from pytorch_ie import Pipeline |
|
from pytorch_ie.annotations import LabeledSpan |
|
from pytorch_ie.auto import AutoPipeline |
|
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def annotate_document( |
|
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
annotation_pipeline: Pipeline, |
|
embedding_model: Optional[EmbeddingModel] = None, |
|
) -> None: |
|
"""Annotate a document with the provided pipeline. If an embedding model is provided, also |
|
extract embeddings for the labeled spans. |
|
|
|
Args: |
|
document: The document to annotate. |
|
annotation_pipeline: The pipeline to use for annotation. |
|
embedding_model: The embedding model to use for extracting text span embeddings. |
|
""" |
|
|
|
|
|
annotation_pipeline(document) |
|
|
|
if embedding_model is not None: |
|
text_span_embeddings = embedding_model( |
|
document=document, |
|
span_layer_name="labeled_spans", |
|
) |
|
|
|
text_span_embeddings_dict = { |
|
labeled_span_to_id(k): v.tolist() for k, v in text_span_embeddings.items() |
|
} |
|
document.metadata["embeddings"] = text_span_embeddings_dict |
|
else: |
|
gr.Warning( |
|
"No embedding model provided. Skipping embedding extraction. You can load an embedding " |
|
"model in the 'Model Configuration' section." |
|
) |
|
|
|
|
|
def create_document( |
|
text: str, doc_id: str |
|
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: |
|
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided |
|
text. |
|
|
|
Parameters: |
|
text: The text to process. |
|
doc_id: The ID of the document. |
|
|
|
Returns: |
|
The processed document. |
|
""" |
|
|
|
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( |
|
id=doc_id, text=text, metadata={} |
|
) |
|
|
|
document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text")) |
|
return document |
|
|
|
|
|
def load_argumentation_model( |
|
model_name: str, |
|
revision: Optional[str] = None, |
|
device: str = "cpu", |
|
) -> Pipeline: |
|
try: |
|
|
|
if device == "cuda": |
|
pipeline_device = 0 |
|
elif device.startswith("cuda:"): |
|
pipeline_device = int(device.split(":")[1]) |
|
elif device == "cpu": |
|
pipeline_device = -1 |
|
else: |
|
raise gr.Error(f"Invalid device: {device}") |
|
|
|
model = AutoPipeline.from_pretrained( |
|
model_name, |
|
device=pipeline_device, |
|
num_workers=0, |
|
taskmodule_kwargs=dict(revision=revision), |
|
model_kwargs=dict(revision=revision), |
|
) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to load argumentation model: {e}") |
|
gr.Info(f"Loaded argumentation model: model_name={model_name}, revision={revision})") |
|
return model |
|
|
|
|
|
def load_models( |
|
model_name: str, |
|
revision: Optional[str] = None, |
|
embedding_model_name: Optional[str] = None, |
|
|
|
embedding_max_length: int = 512, |
|
embedding_batch_size: int = 16, |
|
device: str = "cpu", |
|
) -> Tuple[Pipeline, Optional[EmbeddingModel]]: |
|
torch.cuda.empty_cache() |
|
argumentation_model = load_argumentation_model(model_name, revision=revision, device=device) |
|
embedding_model = None |
|
if embedding_model_name is not None and embedding_model_name.strip(): |
|
try: |
|
embedding_model = HuggingfaceEmbeddingModel( |
|
embedding_model_name.strip(), |
|
|
|
device=device, |
|
max_length=embedding_max_length, |
|
batch_size=embedding_batch_size, |
|
) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to load embedding model: {e}") |
|
|
|
return argumentation_model, embedding_model |
|
|