File size: 2,541 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""A signal to compute span offsets of already labeled concept text."""
from typing import Iterable, Optional

from typing_extensions import override

from ..concepts.concept import DRAFT_MAIN, draft_examples
from ..concepts.db_concept import DISK_CONCEPT_DB, ConceptDB
from ..data.dataset_utils import lilac_span
from ..schema import Field, Item, RichData, field
from .signal import TextSignal


class ConceptLabelsSignal(TextSignal):
  """Computes spans where text is labeled for the concept, either positive or negative."""
  name = 'concept_labels'
  display_name = 'Concept Labels'

  namespace: str
  concept_name: str

  # The draft version of the concept to use. If not provided, the latest version is used.
  draft: str = DRAFT_MAIN

  _concept_db: ConceptDB = DISK_CONCEPT_DB

  @override
  def fields(self) -> Field:
    return field(fields=[field('string_span', fields={'label': 'boolean', 'draft': 'string'})])

  @override
  def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
    concept = self._concept_db.get(self.namespace, self.concept_name)
    if not concept:
      raise ValueError(f'Concept "{self.namespace}/{self.concept_name}" does not exist.')

    examples = draft_examples(concept, draft=self.draft)
    for text in data:
      if not text:
        yield None
        continue

      if not isinstance(text, str):
        raise ValueError(f'{str(text)} is a {type(text)}, which is not a string.')

      label_spans: list[Item] = []
      for example in examples.values():
        if not example.text:
          continue

        offset = 0
        while offset < len(text):
          offset = text.find(example.text, offset)
          if offset == -1:
            break
          label_spans.append(
            lilac_span(
              offset, offset + len(example.text), {
                'label': example.label,
                **({
                  'draft': example.draft
                } if example.draft != DRAFT_MAIN else {})
              }))
          offset += len(example.text)

      if label_spans:
        yield label_spans
      else:
        yield None

  @override
  def key(self, is_computed_signal: Optional[bool] = False) -> str:
    version = ''
    if is_computed_signal:
      concept = self._concept_db.get(self.namespace, self.concept_name)
      if not concept:
        raise ValueError(f'Concept "{self.namespace}/{self.concept_name}" does not exist.')
      version = f'/v{concept.version}'

    return f'{self.namespace}/{self.concept_name}/labels{version}'