|
from itertools import islice
|
|
from typing import Tuple, List, Iterable, Optional, Dict, Callable, Any
|
|
|
|
from spacy.scorer import PRFScore
|
|
from thinc.types import Floats2d
|
|
import numpy
|
|
from spacy.training.example import Example
|
|
from thinc.api import Model, Optimizer
|
|
from spacy.tokens.doc import Doc
|
|
from spacy.pipeline.trainable_pipe import TrainablePipe
|
|
from spacy.vocab import Vocab
|
|
from spacy import Language
|
|
from thinc.model import set_dropout_rate
|
|
from wasabi import Printer
|
|
|
|
from typing import List, Tuple, Callable
|
|
|
|
import spacy
|
|
from spacy.tokens import Doc, Span
|
|
from thinc.types import Floats2d, Ints1d, Ragged, cast
|
|
from thinc.api import Model, Linear, chain, Logistic
|
|
|
|
@spacy.registry.architectures("rel_model.v1")
|
|
def create_relation_model(
|
|
create_instance_tensor: Model[List[Doc], Floats2d],
|
|
classification_layer: Model[Floats2d, Floats2d],
|
|
) -> Model[List[Doc], Floats2d]:
|
|
with Model.define_operators({">>": chain}):
|
|
model = create_instance_tensor >> classification_layer
|
|
model.attrs["get_instances"] = create_instance_tensor.attrs["get_instances"]
|
|
return model
|
|
|
|
|
|
@spacy.registry.architectures("rel_classification_layer.v1")
|
|
def create_classification_layer(
|
|
nO: int = None, nI: int = None
|
|
) -> Model[Floats2d, Floats2d]:
|
|
with Model.define_operators({">>": chain}):
|
|
return Linear(nO=nO, nI=nI) >> Logistic()
|
|
|
|
|
|
@spacy.registry.misc("rel_instance_generator.v1")
|
|
def create_instances(max_length: int) -> Callable[[Doc], List[Tuple[Span, Span]]]:
|
|
def get_instances(doc: Doc) -> List[Tuple[Span, Span]]:
|
|
instances = []
|
|
for ent1 in doc.ents:
|
|
for ent2 in doc.ents:
|
|
if ent1 != ent2:
|
|
if max_length and abs(ent2.start - ent1.start) <= max_length:
|
|
instances.append((ent1, ent2))
|
|
return instances
|
|
|
|
return get_instances
|
|
|
|
|
|
@spacy.registry.architectures("rel_instance_tensor.v1")
|
|
def create_tensors(
|
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
|
pooling: Model[Ragged, Floats2d],
|
|
get_instances: Callable[[Doc], List[Tuple[Span, Span]]],
|
|
) -> Model[List[Doc], Floats2d]:
|
|
|
|
return Model(
|
|
"instance_tensors",
|
|
instance_forward,
|
|
layers=[tok2vec, pooling],
|
|
refs={"tok2vec": tok2vec, "pooling": pooling},
|
|
attrs={"get_instances": get_instances},
|
|
init=instance_init,
|
|
)
|
|
|
|
|
|
def instance_forward(model: Model[List[Doc], Floats2d], docs: List[Doc], is_train: bool) -> Tuple[Floats2d, Callable]:
|
|
pooling = model.get_ref("pooling")
|
|
tok2vec = model.get_ref("tok2vec")
|
|
get_instances = model.attrs["get_instances"]
|
|
all_instances = [get_instances(doc) for doc in docs]
|
|
tokvecs, bp_tokvecs = tok2vec(docs, is_train)
|
|
|
|
ents = []
|
|
lengths = []
|
|
|
|
for doc_nr, (instances, tokvec) in enumerate(zip(all_instances, tokvecs)):
|
|
token_indices = []
|
|
for instance in instances:
|
|
for ent in instance:
|
|
token_indices.extend([i for i in range(ent.start, ent.end)])
|
|
lengths.append(ent.end - ent.start)
|
|
ents.append(tokvec[token_indices])
|
|
lengths = cast(Ints1d, model.ops.asarray(lengths, dtype="int32"))
|
|
entities = Ragged(model.ops.flatten(ents), lengths)
|
|
pooled, bp_pooled = pooling(entities, is_train)
|
|
|
|
|
|
relations = model.ops.reshape2f(pooled, -1, pooled.shape[1] * 2)
|
|
|
|
def backprop(d_relations: Floats2d) -> List[Doc]:
|
|
d_pooled = model.ops.reshape2f(d_relations, d_relations.shape[0] * 2, -1)
|
|
d_ents = bp_pooled(d_pooled).data
|
|
d_tokvecs = []
|
|
ent_index = 0
|
|
for doc_nr, instances in enumerate(all_instances):
|
|
shape = tokvecs[doc_nr].shape
|
|
d_tokvec = model.ops.alloc2f(*shape)
|
|
count_occ = model.ops.alloc2f(*shape)
|
|
for instance in instances:
|
|
for ent in instance:
|
|
d_tokvec[ent.start : ent.end] += d_ents[ent_index]
|
|
count_occ[ent.start : ent.end] += 1
|
|
ent_index += ent.end - ent.start
|
|
d_tokvec /= count_occ + 0.00000000001
|
|
d_tokvecs.append(d_tokvec)
|
|
|
|
d_docs = bp_tokvecs(d_tokvecs)
|
|
return d_docs
|
|
|
|
return relations, backprop
|
|
|
|
|
|
def instance_init(model: Model, X: List[Doc] = None, Y: Floats2d = None) -> Model:
|
|
tok2vec = model.get_ref("tok2vec")
|
|
if X is not None:
|
|
tok2vec.initialize(X)
|
|
return model
|
|
|
|
|
|
Doc.set_extension("rel", default={}, force=True)
|
|
msg = Printer()
|
|
|
|
|
|
@Language.factory(
|
|
"relation_extractor",
|
|
requires=["doc.ents", "token.ent_iob", "token.ent_type"],
|
|
assigns=["doc._.rel"],
|
|
default_score_weights={
|
|
"rel_micro_p": None,
|
|
"rel_micro_r": None,
|
|
"rel_micro_f": None,
|
|
},
|
|
)
|
|
def make_relation_extractor(
|
|
nlp: Language, name: str, model: Model, *, threshold: float
|
|
):
|
|
"""Construct a RelationExtractor component."""
|
|
return RelationExtractor(nlp.vocab, model, name, threshold=threshold)
|
|
|
|
|
|
class RelationExtractor(TrainablePipe):
|
|
def __init__(
|
|
self,
|
|
vocab: Vocab,
|
|
model: Model,
|
|
name: str = "rel",
|
|
*,
|
|
threshold: float,
|
|
) -> None:
|
|
"""Initialize a relation extractor."""
|
|
self.vocab = vocab
|
|
self.model = model
|
|
self.name = name
|
|
self.cfg = {"labels": [], "threshold": threshold}
|
|
|
|
@property
|
|
def labels(self) -> Tuple[str]:
|
|
"""Returns the labels currently added to the component."""
|
|
return tuple(self.cfg["labels"])
|
|
|
|
@property
|
|
def threshold(self) -> float:
|
|
"""Returns the threshold above which a prediction is seen as 'True'."""
|
|
return self.cfg["threshold"]
|
|
|
|
def add_label(self, label: str) -> int:
|
|
"""Add a new label to the pipe."""
|
|
if not isinstance(label, str):
|
|
raise ValueError("Only strings can be added as labels to the RelationExtractor")
|
|
if label in self.labels:
|
|
return 0
|
|
self.cfg["labels"] = list(self.labels) + [label]
|
|
return 1
|
|
|
|
def __call__(self, doc: Doc) -> Doc:
|
|
"""Apply the pipe to a Doc."""
|
|
|
|
total_instances = len(self.model.attrs["get_instances"](doc))
|
|
if total_instances == 0:
|
|
msg.info("Could not determine any instances in doc - returning doc as is.")
|
|
return doc
|
|
|
|
predictions = self.predict([doc])
|
|
self.set_annotations([doc], predictions)
|
|
return doc
|
|
|
|
def predict(self, docs: Iterable[Doc]) -> Floats2d:
|
|
"""Apply the pipeline's model to a batch of docs, without modifying them."""
|
|
get_instances = self.model.attrs["get_instances"]
|
|
total_instances = sum([len(get_instances(doc)) for doc in docs])
|
|
if total_instances == 0:
|
|
msg.info("Could not determine any instances in any docs - can not make any predictions.")
|
|
scores = self.model.predict(docs)
|
|
return self.model.ops.asarray(scores)
|
|
|
|
def set_annotations(self, docs: Iterable[Doc], scores: Floats2d) -> None:
|
|
"""Modify a batch of `Doc` objects, using pre-computed scores."""
|
|
c = 0
|
|
get_instances = self.model.attrs["get_instances"]
|
|
for doc in docs:
|
|
for (e1, e2) in get_instances(doc):
|
|
offset = (e1.start, e2.start)
|
|
if offset not in doc._.rel:
|
|
doc._.rel[offset] = {}
|
|
for j, label in enumerate(self.labels):
|
|
doc._.rel[offset][label] = scores[c, j]
|
|
c += 1
|
|
|
|
def update(
|
|
self,
|
|
examples: Iterable[Example],
|
|
*,
|
|
drop: float = 0.0,
|
|
set_annotations: bool = False,
|
|
sgd: Optional[Optimizer] = None,
|
|
losses: Optional[Dict[str, float]] = None,
|
|
) -> Dict[str, float]:
|
|
"""Learn from a batch of documents and gold-standard information,
|
|
updating the pipe's model. Delegates to predict and get_loss."""
|
|
if losses is None:
|
|
losses = {}
|
|
losses.setdefault(self.name, 0.0)
|
|
set_dropout_rate(self.model, drop)
|
|
|
|
|
|
total_instances = 0
|
|
for eg in examples:
|
|
total_instances += len(self.model.attrs["get_instances"](eg.predicted))
|
|
if total_instances == 0:
|
|
msg.info("Could not determine any instances in doc.")
|
|
return losses
|
|
|
|
|
|
docs = [eg.predicted for eg in examples]
|
|
predictions, backprop = self.model.begin_update(docs)
|
|
loss, gradient = self.get_loss(examples, predictions)
|
|
backprop(gradient)
|
|
if sgd is not None:
|
|
self.model.finish_update(sgd)
|
|
losses[self.name] += loss
|
|
if set_annotations:
|
|
self.set_annotations(docs, predictions)
|
|
return losses
|
|
|
|
def get_loss(self, examples: Iterable[Example], scores) -> Tuple[float, float]:
|
|
"""Find the loss and gradient of loss for the batch of documents and
|
|
their predicted scores."""
|
|
truths = self._examples_to_truth(examples)
|
|
gradient = scores - truths
|
|
mean_square_error = (gradient ** 2).sum(axis=1).mean()
|
|
return float(mean_square_error), gradient
|
|
|
|
def initialize(
|
|
self,
|
|
get_examples: Callable[[], Iterable[Example]],
|
|
*,
|
|
nlp: Language = None,
|
|
labels: Optional[List[str]] = None,
|
|
):
|
|
"""Initialize the pipe for training, using a representative set
|
|
of data examples.
|
|
"""
|
|
if labels is not None:
|
|
for label in labels:
|
|
self.add_label(label)
|
|
else:
|
|
for example in get_examples():
|
|
relations = example.reference._.rel
|
|
for indices, label_dict in relations.items():
|
|
for label in label_dict.keys():
|
|
self.add_label(label)
|
|
self._require_labels()
|
|
|
|
subbatch = list(islice(get_examples(), 10))
|
|
doc_sample = [eg.reference for eg in subbatch]
|
|
label_sample = self._examples_to_truth(subbatch)
|
|
if label_sample is None:
|
|
raise ValueError("Call begin_training with relevant entities and relations annotated in "
|
|
"at least a few reference examples!")
|
|
self.model.initialize(X=doc_sample, Y=label_sample)
|
|
|
|
def _examples_to_truth(self, examples: List[Example]) -> Optional[numpy.ndarray]:
|
|
|
|
nr_instances = 0
|
|
for eg in examples:
|
|
nr_instances += len(self.model.attrs["get_instances"](eg.reference))
|
|
if nr_instances == 0:
|
|
return None
|
|
|
|
truths = numpy.zeros((nr_instances, len(self.labels)), dtype="f")
|
|
c = 0
|
|
for i, eg in enumerate(examples):
|
|
for (e1, e2) in self.model.attrs["get_instances"](eg.reference):
|
|
gold_label_dict = eg.reference._.rel.get((e1.start, e2.start), {})
|
|
for j, label in enumerate(self.labels):
|
|
truths[c, j] = gold_label_dict.get(label, 0)
|
|
c += 1
|
|
|
|
truths = self.model.ops.asarray(truths)
|
|
return truths
|
|
|
|
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
|
"""Score a batch of examples."""
|
|
return score_relations(examples, self.threshold)
|
|
|
|
|
|
def score_relations(examples: Iterable[Example], threshold: float) -> Dict[str, Any]:
|
|
"""Score a batch of examples."""
|
|
micro_prf = PRFScore()
|
|
for example in examples:
|
|
gold = example.reference._.rel
|
|
pred = example.predicted._.rel
|
|
for key, pred_dict in pred.items():
|
|
gold_labels = [k for (k, v) in gold.get(key, {}).items() if v == 1.0]
|
|
for k, v in pred_dict.items():
|
|
if v >= threshold:
|
|
if k in gold_labels:
|
|
micro_prf.tp += 1
|
|
else:
|
|
micro_prf.fp += 1
|
|
else:
|
|
if k in gold_labels:
|
|
micro_prf.fn += 1
|
|
return {
|
|
"rel_micro_p": micro_prf.precision,
|
|
"rel_micro_r": micro_prf.recall,
|
|
"rel_micro_f": micro_prf.fscore,
|
|
} |