File size: 6,561 Bytes
ced4316
3133b5e
ced4316
3133b5e
 
ced4316
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ced4316
 
3133b5e
 
 
ced4316
 
 
 
 
 
 
 
 
 
 
 
 
 
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ced4316
 
 
3133b5e
 
ced4316
 
3133b5e
ced4316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ced4316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3133b5e
ced4316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import json
import logging
from typing import Iterable, Optional, Sequence, Union

import gradio as gr
from hydra.utils import instantiate
from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger

# this is required to dynamically load the PIE models
from pie_modules.models import *  # noqa: F403
from pie_modules.taskmodules import *  # noqa: F403
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
from pytorch_ie import Pipeline
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import (
    TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
    TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
)

# this is required to dynamically load the PIE models
from pytorch_ie.models import *  # noqa: F403
from pytorch_ie.taskmodules import *  # noqa: F403

from src.utils import parse_config

logger = logging.getLogger(__name__)


def get_merger() -> SpansViaRelationMerger:
    return SpansViaRelationMerger(
        relation_layer="binary_relations",
        link_relation_label="parts_of_same",
        create_multi_spans=True,
        result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
        result_field_mapping={
            "labeled_spans": "labeled_multi_spans",
            "binary_relations": "binary_relations",
            "labeled_partitions": "labeled_partitions",
        },
    )


def annotate_document(
    document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
    argumentation_model: Pipeline,
    handle_parts_of_same: bool = False,
) -> Union[
    TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
    TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
]:
    """Annotate a document with the provided pipeline.

    Args:
        document: The document to annotate.
        argumentation_model: The pipeline to use for annotation.
        handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
    """

    # execute prediction pipeline
    result: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions = argumentation_model(
        document, inplace=True
    )

    if handle_parts_of_same:
        merger = get_merger()
        result = merger(result)

    return result


def annotate_documents(
    documents: Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions],
    argumentation_model: Pipeline,
    handle_parts_of_same: bool = False,
) -> Union[
    Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions],
    Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions],
]:
    """Annotate a sequence of documents with the provided pipeline.

    Args:
        documents: The documents to annotate.
        argumentation_model: The pipeline to use for annotation.
        handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
    """
    # execute prediction pipeline
    result = argumentation_model(documents, inplace=True)

    if handle_parts_of_same:
        merger = get_merger()
        result = [merger(document) for document in result]

    return result


def create_document(
    text: str, doc_id: str, split_regex: Optional[str] = None
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
    """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
    text.

    Parameters:
        text: The text to process.
        doc_id: The ID of the document.
        split_regex: A regular expression pattern to use for splitting the text into partitions.

    Returns:
        The processed document.
    """

    document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
        id=doc_id, text=text, metadata={}
    )
    if split_regex is not None:
        partitioner = RegexPartitioner(
            pattern=split_regex, partition_layer_name="labeled_partitions"
        )
        document = partitioner(document)
    else:
        # add single partition from the whole text (the model only considers text in partitions)
        document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
    return document


def create_documents(
    texts: Iterable[str], doc_ids: Iterable[str], split_regex: Optional[str] = None
) -> Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
    """Create a sequence of TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
    texts.

    Parameters:
        texts: The texts to process.
        doc_ids: The IDs of the documents.
        split_regex: A regular expression pattern to use for splitting the text into partitions.

    Returns:
        The processed documents.
    """
    return [
        create_document(text=text, doc_id=doc_id, split_regex=split_regex)
        for text, doc_id in zip(texts, doc_ids)
    ]


def load_argumentation_model(config_str: str, **kwargs) -> Pipeline:
    try:
        config = parse_config(config_str, format="yaml")

        # for PIE AutoPipeline, we need to handle the revision separately for
        # the taskmodule and the model
        if (
            config.get("_target_") == "pytorch_ie.auto.AutoPipeline.from_pretrained"
            and "revision" in config
        ):
            revision = config.pop("revision")
            if "taskmodule_kwargs" not in config:
                config["taskmodule_kwargs"] = {}
            config["taskmodule_kwargs"]["revision"] = revision
            if "model_kwargs" not in config:
                config["model_kwargs"] = {}
            config["model_kwargs"]["revision"] = revision
        model = instantiate(config, **kwargs)
        gr.Info(f"Loaded argumentation model: {json.dumps({**config, **kwargs})}")
    except Exception as e:
        raise gr.Error(f"Failed to load argumentation model: {e}")

    return model


def set_relation_types(
    argumentation_model: Pipeline,
    default: Optional[Sequence[str]] = None,
) -> gr.Dropdown:
    if isinstance(argumentation_model.taskmodule, PointerNetworkTaskModuleForEnd2EndRE):
        relation_types = argumentation_model.taskmodule.labels_per_layer["binary_relations"]
    else:
        raise gr.Error("Unsupported taskmodule for relation types")

    return gr.Dropdown(
        choices=relation_types,
        label="Argumentative Relation Types",
        value=default,
        multiselect=True,
    )