|
from typing import List, Tuple, Callable |
|
|
|
import spacy |
|
from spacy.tokens import Doc, Span |
|
from thinc.types import Floats2d, Ints1d, Ragged, cast |
|
from thinc.api import Model, Linear, chain, Logistic, Optimizer |
|
|
|
from itertools import islice |
|
from typing import Tuple, List, Iterable, Optional, Dict, Callable, Any |
|
|
|
from spacy.scorer import PRFScore |
|
from thinc.types import Floats2d |
|
import numpy |
|
from spacy.training.example import Example |
|
from spacy.pipeline.trainable_pipe import TrainablePipe |
|
from spacy.vocab import Vocab |
|
from spacy import Language |
|
from thinc.model import set_dropout_rate |
|
from wasabi import Printer |
|
|
|
|
|
|
|
@spacy.registry.architectures("rel_model.v1") |
|
def create_relation_model( |
|
create_instance_tensor: Model[List[Doc], Floats2d], |
|
classification_layer: Model[Floats2d, Floats2d], |
|
) -> Model[List[Doc], Floats2d]: |
|
with Model.define_operators({">>": chain}): |
|
model = create_instance_tensor >> classification_layer |
|
model.attrs["get_instances"] = create_instance_tensor.attrs["get_instances"] |
|
return model |
|
|
|
|
|
@spacy.registry.architectures("rel_classification_layer.v1") |
|
def create_classification_layer( |
|
nO: int = None, nI: int = None |
|
) -> Model[Floats2d, Floats2d]: |
|
with Model.define_operators({">>": chain}): |
|
return Linear(nO=nO, nI=nI) >> Logistic() |
|
|
|
|
|
@spacy.registry.misc("rel_instance_generator.v1") |
|
def create_instances(max_length: int) -> Callable[[Doc], List[Tuple[Span, Span]]]: |
|
def get_instances(doc: Doc) -> List[Tuple[Span, Span]]: |
|
instances = [] |
|
for ent1 in doc.ents: |
|
for ent2 in doc.ents: |
|
if ent1 != ent2: |
|
if max_length and abs(ent2.start - ent1.start) <= max_length: |
|
instances.append((ent1, ent2)) |
|
return instances |
|
|
|
return get_instances |
|
|
|
|
|
@spacy.registry.architectures("rel_instance_tensor.v1") |
|
def create_tensors( |
|
tok2vec: Model[List[Doc], List[Floats2d]], |
|
pooling: Model[Ragged, Floats2d], |
|
get_instances: Callable[[Doc], List[Tuple[Span, Span]]], |
|
) -> Model[List[Doc], Floats2d]: |
|
|
|
return Model( |
|
"instance_tensors", |
|
instance_forward, |
|
layers=[tok2vec, pooling], |
|
refs={"tok2vec": tok2vec, "pooling": pooling}, |
|
attrs={"get_instances": get_instances}, |
|
init=instance_init, |
|
) |
|
|
|
|
|
def instance_forward(model: Model[List[Doc], Floats2d], docs: List[Doc], is_train: bool) -> Tuple[Floats2d, Callable]: |
|
pooling = model.get_ref("pooling") |
|
tok2vec = model.get_ref("tok2vec") |
|
get_instances = model.attrs["get_instances"] |
|
all_instances = [get_instances(doc) for doc in docs] |
|
tokvecs, bp_tokvecs = tok2vec(docs, is_train) |
|
|
|
ents = [] |
|
lengths = [] |
|
|
|
for doc_nr, (instances, tokvec) in enumerate(zip(all_instances, tokvecs)): |
|
token_indices = [] |
|
for instance in instances: |
|
for ent in instance: |
|
token_indices.extend([i for i in range(ent.start, ent.end)]) |
|
lengths.append(ent.end - ent.start) |
|
ents.append(tokvec[token_indices]) |
|
lengths = cast(Ints1d, model.ops.asarray(lengths, dtype="int32")) |
|
entities = Ragged(model.ops.flatten(ents), lengths) |
|
pooled, bp_pooled = pooling(entities, is_train) |
|
|
|
|
|
relations = model.ops.reshape2f(pooled, -1, pooled.shape[1] * 2) |
|
|
|
def backprop(d_relations: Floats2d) -> List[Doc]: |
|
d_pooled = model.ops.reshape2f(d_relations, d_relations.shape[0] * 2, -1) |
|
d_ents = bp_pooled(d_pooled).data |
|
d_tokvecs = [] |
|
ent_index = 0 |
|
for doc_nr, instances in enumerate(all_instances): |
|
shape = tokvecs[doc_nr].shape |
|
d_tokvec = model.ops.alloc2f(*shape) |
|
count_occ = model.ops.alloc2f(*shape) |
|
for instance in instances: |
|
for ent in instance: |
|
d_tokvec[ent.start : ent.end] += d_ents[ent_index] |
|
count_occ[ent.start : ent.end] += 1 |
|
ent_index += ent.end - ent.start |
|
d_tokvec /= count_occ + 0.00000000001 |
|
d_tokvecs.append(d_tokvec) |
|
|
|
d_docs = bp_tokvecs(d_tokvecs) |
|
return d_docs |
|
|
|
return relations, backprop |
|
|
|
|
|
def instance_init(model: Model, X: List[Doc] = None, Y: Floats2d = None) -> Model: |
|
tok2vec = model.get_ref("tok2vec") |
|
if X is not None: |
|
tok2vec.initialize(X) |
|
return model |
|
|
|
|
|
Doc.set_extension("rel", default={}, force=True) |
|
msg = Printer() |
|
|
|
|
|
@Language.factory( |
|
"relation_extractor", |
|
requires=["doc.ents", "token.ent_iob", "token.ent_type"], |
|
assigns=["doc._.rel"], |
|
default_score_weights={ |
|
"rel_micro_p": None, |
|
"rel_micro_r": None, |
|
"rel_micro_f": None, |
|
}, |
|
) |
|
def make_relation_extractor( |
|
nlp: Language, name: str, model: Model, *, threshold: float |
|
): |
|
"""Construct a RelationExtractor component.""" |
|
return RelationExtractor(nlp.vocab, model, name, threshold=threshold) |
|
|
|
|
|
class RelationExtractor(TrainablePipe): |
|
def __init__( |
|
self, |
|
vocab: Vocab, |
|
model: Model, |
|
name: str = "rel", |
|
*, |
|
threshold: float, |
|
) -> None: |
|
"""Initialize a relation extractor.""" |
|
self.vocab = vocab |
|
self.model = model |
|
self.name = name |
|
self.cfg = {"labels": [], "threshold": threshold} |
|
|
|
@property |
|
def labels(self) -> Tuple[str]: |
|
"""Returns the labels currently added to the component.""" |
|
return tuple(self.cfg["labels"]) |
|
|
|
@property |
|
def threshold(self) -> float: |
|
"""Returns the threshold above which a prediction is seen as 'True'.""" |
|
return self.cfg["threshold"] |
|
|
|
def add_label(self, label: str) -> int: |
|
"""Add a new label to the pipe.""" |
|
if not isinstance(label, str): |
|
raise ValueError("Only strings can be added as labels to the RelationExtractor") |
|
if label in self.labels: |
|
return 0 |
|
self.cfg["labels"] = list(self.labels) + [label] |
|
return 1 |
|
|
|
def __call__(self, doc: Doc) -> Doc: |
|
"""Apply the pipe to a Doc.""" |
|
|
|
total_instances = len(self.model.attrs["get_instances"](doc)) |
|
if total_instances == 0: |
|
msg.info("Could not determine any instances in doc - returning doc as is.") |
|
return doc |
|
|
|
predictions = self.predict([doc]) |
|
self.set_annotations([doc], predictions) |
|
return doc |
|
|
|
def predict(self, docs: Iterable[Doc]) -> Floats2d: |
|
"""Apply the pipeline's model to a batch of docs, without modifying them.""" |
|
get_instances = self.model.attrs["get_instances"] |
|
total_instances = sum([len(get_instances(doc)) for doc in docs]) |
|
if total_instances == 0: |
|
msg.info("Could not determine any instances in any docs - can not make any predictions.") |
|
scores = self.model.predict(docs) |
|
return self.model.ops.asarray(scores) |
|
|
|
def set_annotations(self, docs: Iterable[Doc], scores: Floats2d) -> None: |
|
"""Modify a batch of `Doc` objects, using pre-computed scores.""" |
|
c = 0 |
|
get_instances = self.model.attrs["get_instances"] |
|
for doc in docs: |
|
for (e1, e2) in get_instances(doc): |
|
offset = (e1.start, e2.start) |
|
if offset not in doc._.rel: |
|
doc._.rel[offset] = {} |
|
for j, label in enumerate(self.labels): |
|
doc._.rel[offset][label] = scores[c, j] |
|
c += 1 |
|
|
|
def update( |
|
self, |
|
examples: Iterable[Example], |
|
*, |
|
drop: float = 0.0, |
|
set_annotations: bool = False, |
|
sgd: Optional[Optimizer] = None, |
|
losses: Optional[Dict[str, float]] = None, |
|
) -> Dict[str, float]: |
|
"""Learn from a batch of documents and gold-standard information, |
|
updating the pipe's model. Delegates to predict and get_loss.""" |
|
if losses is None: |
|
losses = {} |
|
losses.setdefault(self.name, 0.0) |
|
set_dropout_rate(self.model, drop) |
|
|
|
|
|
total_instances = 0 |
|
for eg in examples: |
|
total_instances += len(self.model.attrs["get_instances"](eg.predicted)) |
|
if total_instances == 0: |
|
msg.info("Could not determine any instances in doc.") |
|
return losses |
|
|
|
|
|
docs = [eg.predicted for eg in examples] |
|
predictions, backprop = self.model.begin_update(docs) |
|
loss, gradient = self.get_loss(examples, predictions) |
|
backprop(gradient) |
|
if sgd is not None: |
|
self.model.finish_update(sgd) |
|
losses[self.name] += loss |
|
if set_annotations: |
|
self.set_annotations(docs, predictions) |
|
return losses |
|
|
|
def get_loss(self, examples: Iterable[Example], scores) -> Tuple[float, float]: |
|
"""Find the loss and gradient of loss for the batch of documents and |
|
their predicted scores.""" |
|
truths = self._examples_to_truth(examples) |
|
gradient = scores - truths |
|
mean_square_error = (gradient ** 2).sum(axis=1).mean() |
|
return float(mean_square_error), gradient |
|
|
|
def initialize( |
|
self, |
|
get_examples: Callable[[], Iterable[Example]], |
|
*, |
|
nlp: Language = None, |
|
labels: Optional[List[str]] = None, |
|
): |
|
"""Initialize the pipe for training, using a representative set |
|
of data examples. |
|
""" |
|
if labels is not None: |
|
for label in labels: |
|
self.add_label(label) |
|
else: |
|
for example in get_examples(): |
|
relations = example.reference._.rel |
|
for indices, label_dict in relations.items(): |
|
for label in label_dict.keys(): |
|
self.add_label(label) |
|
self._require_labels() |
|
|
|
subbatch = list(islice(get_examples(), 10)) |
|
doc_sample = [eg.reference for eg in subbatch] |
|
label_sample = self._examples_to_truth(subbatch) |
|
if label_sample is None: |
|
raise ValueError("Call begin_training with relevant entities and relations annotated in " |
|
"at least a few reference examples!") |
|
self.model.initialize(X=doc_sample, Y=label_sample) |
|
|
|
def _examples_to_truth(self, examples: List[Example]) -> Optional[numpy.ndarray]: |
|
|
|
nr_instances = 0 |
|
for eg in examples: |
|
nr_instances += len(self.model.attrs["get_instances"](eg.reference)) |
|
if nr_instances == 0: |
|
return None |
|
|
|
truths = numpy.zeros((nr_instances, len(self.labels)), dtype="f") |
|
c = 0 |
|
for i, eg in enumerate(examples): |
|
for (e1, e2) in self.model.attrs["get_instances"](eg.reference): |
|
gold_label_dict = eg.reference._.rel.get((e1.start, e2.start), {}) |
|
for j, label in enumerate(self.labels): |
|
truths[c, j] = gold_label_dict.get(label, 0) |
|
c += 1 |
|
|
|
truths = self.model.ops.asarray(truths) |
|
return truths |
|
|
|
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]: |
|
"""Score a batch of examples.""" |
|
return score_relations(examples, self.threshold) |
|
|
|
|
|
def score_relations(examples: Iterable[Example], threshold: float) -> Dict[str, Any]: |
|
"""Score a batch of examples.""" |
|
micro_prf = PRFScore() |
|
for example in examples: |
|
gold = example.reference._.rel |
|
pred = example.predicted._.rel |
|
for key, pred_dict in pred.items(): |
|
gold_labels = [k for (k, v) in gold.get(key, {}).items() if v == 1.0] |
|
for k, v in pred_dict.items(): |
|
if v >= threshold: |
|
if k in gold_labels: |
|
micro_prf.tp += 1 |
|
else: |
|
micro_prf.fp += 1 |
|
else: |
|
if k in gold_labels: |
|
micro_prf.fn += 1 |
|
return { |
|
"rel_micro_p": micro_prf.precision, |
|
"rel_micro_r": micro_prf.recall, |
|
"rel_micro_f": micro_prf.fscore, |
|
} |
|
|