nsthorat's picture
Push
55dc3dd
raw
history blame
No virus
1.58 kB
"""Compute named entity recognition with SpaCy."""
from typing import Iterable, Optional
import spacy
from pydantic import Field as PydanticField
from typing_extensions import override
from ..data.dataset_utils import lilac_span
from ..schema import Field, Item, RichData, SignalInputType, field
from .signal import TextSignal
class SpacyNER(TextSignal):
"""Named entity recognition with SpaCy.
For details see: [spacy.io/models](https://spacy.io/models).
"""
name = 'spacy_ner'
display_name = 'Named Entity Recognition'
model: Optional[str] = PydanticField(
title='SpaCy package name or model path.', default='en_core_web_sm')
input_type = SignalInputType.TEXT
compute_type = SignalInputType.TEXT
_nlp: spacy.language.Language
@override
def setup(self) -> None:
if not spacy.util.is_package(self.model):
spacy.cli.download(self.model)
self._nlp = spacy.load(
self.model,
# Disable everything except the NER component. See: https://spacy.io/models
disable=['tok2vec', 'tagger', 'parser', 'attribute_ruler', 'lemmatizer'])
@override
def fields(self) -> Field:
return field(fields=[field('string_span', fields={'label': 'string'})])
@override
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
text_data = (row if isinstance(row, str) else '' for row in data)
for doc in self._nlp.pipe(text_data):
result = [lilac_span(ent.start_char, ent.end_char, {'label': ent.label_}) for ent in doc.ents]
if result:
yield result
else:
yield None