File size: 1,579 Bytes
55dc3dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""Compute named entity recognition with SpaCy."""
from typing import Iterable, Optional

import spacy
from pydantic import Field as PydanticField
from typing_extensions import override

from ..data.dataset_utils import lilac_span
from ..schema import Field, Item, RichData, SignalInputType, field
from .signal import TextSignal


class SpacyNER(TextSignal):
  """Named entity recognition with SpaCy.

  For details see: [spacy.io/models](https://spacy.io/models).
  """
  name = 'spacy_ner'
  display_name = 'Named Entity Recognition'

  model: Optional[str] = PydanticField(
    title='SpaCy package name or model path.', default='en_core_web_sm')

  input_type = SignalInputType.TEXT
  compute_type = SignalInputType.TEXT

  _nlp: spacy.language.Language

  @override
  def setup(self) -> None:
    if not spacy.util.is_package(self.model):
      spacy.cli.download(self.model)
    self._nlp = spacy.load(
      self.model,
      # Disable everything except the NER component. See: https://spacy.io/models
      disable=['tok2vec', 'tagger', 'parser', 'attribute_ruler', 'lemmatizer'])

  @override
  def fields(self) -> Field:
    return field(fields=[field('string_span', fields={'label': 'string'})])

  @override
  def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
    text_data = (row if isinstance(row, str) else '' for row in data)

    for doc in self._nlp.pipe(text_data):
      result = [lilac_span(ent.start_char, ent.end_char, {'label': ent.label_}) for ent in doc.ents]

      if result:
        yield result
      else:
        yield None