File size: 1,852 Bytes
bfc0ec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Compute named entity recognition with SpaCy."""
from typing import TYPE_CHECKING, Iterable, Optional

from pydantic import Field as PydanticField
from typing_extensions import override

from ..schema import Field, Item, RichData, SignalInputType, field, lilac_span
from ..signal import TextSignal

if TYPE_CHECKING:
  import spacy


class SpacyNER(TextSignal):
  """Named entity recognition with SpaCy.

  For details see: [spacy.io/models](https://spacy.io/models).
  """
  name = 'spacy_ner'
  display_name = 'Named Entity Recognition'

  model: str = PydanticField(title='SpaCy package name or model path.', default='en_core_web_sm')

  input_type = SignalInputType.TEXT

  _nlp: Optional['spacy.language.Language'] = None

  @override
  def setup(self) -> None:
    try:
      import spacy
      import spacy.cli
    except ImportError:
      raise ImportError('Could not import the "spacy" python package. '
                        'Please install it with `pip install spacy`.')

    if not spacy.util.is_package(self.model):
      spacy.cli.download(self.model)
    self._nlp = spacy.load(
      self.model,
      # Disable everything except the NER component. See: https://spacy.io/models
      disable=['tok2vec', 'tagger', 'parser', 'attribute_ruler', 'lemmatizer'])

  @override
  def fields(self) -> Field:
    return field(fields=[field('string_span', fields={'label': 'string'})])

  @override
  def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
    if not self._nlp:
      raise RuntimeError('SpaCy model is not initialized.')

    text_data = (row if isinstance(row, str) else '' for row in data)

    for doc in self._nlp.pipe(text_data):
      result = [lilac_span(ent.start_char, ent.end_char, {'label': ent.label_}) for ent in doc.ents]

      if result:
        yield result
      else:
        yield None