"""Compute named entity recognition with SpaCy.""" from typing import Iterable, Optional import spacy from pydantic import Field as PydanticField from typing_extensions import override from ..data.dataset_utils import lilac_span from ..schema import Field, Item, RichData, SignalInputType, field from .signal import TextSignal class SpacyNER(TextSignal): """Named entity recognition with SpaCy. For details see: [spacy.io/models](https://spacy.io/models). """ name = 'spacy_ner' display_name = 'Named Entity Recognition' model: Optional[str] = PydanticField( title='SpaCy package name or model path.', default='en_core_web_sm') input_type = SignalInputType.TEXT compute_type = SignalInputType.TEXT _nlp: spacy.language.Language @override def setup(self) -> None: if not spacy.util.is_package(self.model): spacy.cli.download(self.model) self._nlp = spacy.load( self.model, # Disable everything except the NER component. See: https://spacy.io/models disable=['tok2vec', 'tagger', 'parser', 'attribute_ruler', 'lemmatizer']) @override def fields(self) -> Field: return field(fields=[field('string_span', fields={'label': 'string'})]) @override def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: text_data = (row if isinstance(row, str) else '' for row in data) for doc in self._nlp.pipe(text_data): result = [lilac_span(ent.start_char, ent.end_char, {'label': ent.label_}) for ent in doc.ents] if result: yield result else: yield None