import logging from typing import Optional, Tuple import gradio as gr import torch from annotation_utils import labeled_span_to_id from embedding import EmbeddingModel, HuggingfaceEmbeddingModel from pytorch_ie import Pipeline from pytorch_ie.annotations import LabeledSpan from pytorch_ie.auto import AutoPipeline from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions logger = logging.getLogger(__name__) def annotate_document( document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, annotation_pipeline: Pipeline, embedding_model: Optional[EmbeddingModel] = None, ) -> None: """Annotate a document with the provided pipeline. If an embedding model is provided, also extract embeddings for the labeled spans. Args: document: The document to annotate. annotation_pipeline: The pipeline to use for annotation. embedding_model: The embedding model to use for extracting text span embeddings. """ # execute prediction pipeline annotation_pipeline(document) if embedding_model is not None: text_span_embeddings = embedding_model( document=document, span_layer_name="labeled_spans", ) # convert keys to str because JSON keys must be strings text_span_embeddings_dict = { labeled_span_to_id(k): v.tolist() for k, v in text_span_embeddings.items() } document.metadata["embeddings"] = text_span_embeddings_dict else: gr.Warning( "No embedding model provided. Skipping embedding extraction. You can load an embedding " "model in the 'Model Configuration' section." ) def create_document( text: str, doc_id: str ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided text. Parameters: text: The text to process. doc_id: The ID of the document. Returns: The processed document. """ document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( id=doc_id, text=text, metadata={} ) # add single partition from the whole text (the model only considers text in partitions) document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text")) return document def load_argumentation_model( model_name: str, revision: Optional[str] = None, device: str = "cpu", ) -> Pipeline: try: # the Pipeline class expects an integer for the device if device == "cuda": pipeline_device = 0 elif device.startswith("cuda:"): pipeline_device = int(device.split(":")[1]) elif device == "cpu": pipeline_device = -1 else: raise gr.Error(f"Invalid device: {device}") model = AutoPipeline.from_pretrained( model_name, device=pipeline_device, num_workers=0, taskmodule_kwargs=dict(revision=revision), model_kwargs=dict(revision=revision), ) except Exception as e: raise gr.Error(f"Failed to load argumentation model: {e}") gr.Info(f"Loaded argumentation model: model_name={model_name}, revision={revision})") return model def load_models( model_name: str, revision: Optional[str] = None, embedding_model_name: Optional[str] = None, # embedding_model_revision: Optional[str] = None, embedding_max_length: int = 512, embedding_batch_size: int = 16, device: str = "cpu", ) -> Tuple[Pipeline, Optional[EmbeddingModel]]: torch.cuda.empty_cache() argumentation_model = load_argumentation_model(model_name, revision=revision, device=device) embedding_model = None if embedding_model_name is not None and embedding_model_name.strip(): try: embedding_model = HuggingfaceEmbeddingModel( embedding_model_name.strip(), # revision=embedding_model_revision, device=device, max_length=embedding_max_length, batch_size=embedding_batch_size, ) except Exception as e: raise gr.Error(f"Failed to load embedding model: {e}") return argumentation_model, embedding_model