File size: 10,745 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55dc3dd
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
"""Interface for implementing a signal."""

import abc
from typing import Any, ClassVar, Iterable, Optional, Sequence, Type, TypeVar, Union

from pydantic import BaseModel, Extra, validator
from pydantic.fields import ModelField
from typing_extensions import override

from ..embeddings.vector_store import VectorStore
from ..schema import Field, Item, RichData, SignalInputType, VectorKey, field

EMBEDDING_KEY = 'embedding'


class Signal(abc.ABC, BaseModel):
  """Interface for signals to implement. A signal can score documents and a dataset column."""
  # ClassVars do not get serialized with pydantic.
  name: ClassVar[str]
  # The display name is just used for rendering in the UI.
  display_name: ClassVar[Optional[str]]

  signal_type: ClassVar[Type['Signal']]
  # The input type is used to populate the UI for signals that require other signals. For example,
  # if a signal is an TextEmbeddingModelSignal, it computes over embeddings, but it's input type
  # will be text.
  input_type: ClassVar[SignalInputType]
  # The compute type defines what should be passed to compute().
  compute_type: ClassVar[SignalInputType]

  # The signal_name will get populated in init automatically from the class name so it gets
  # serialized and the signal author doesn't have to define both the static property and the field.
  signal_name: Optional[str]

  class Config:
    underscore_attrs_are_private = True
    extra = Extra.forbid

    @staticmethod
    def schema_extra(schema: dict[str, Any], signal: Type['Signal']) -> None:
      """Add the title to the schema from the display name and name.

      Pydantic defaults this to the class name.
      """
      if hasattr(signal, 'display_name'):
        schema['title'] = signal.display_name
      if hasattr(signal, 'name'):
        schema['properties']['signal_name']['enum'] = [signal.name]

  @validator('signal_name', pre=True, always=True)
  def validate_signal_name(cls, signal_name: str) -> str:
    """Return the static name when the signal name hasn't yet been set."""
    # When it's already been set from JSON, just return it.
    if signal_name:
      return signal_name

    if 'name' not in cls.__dict__:
      raise ValueError('Signal attribute "name" must be defined.')

    return cls.name

  @abc.abstractmethod
  def fields(self) -> Field:
    """Return the fields schema for this signal.

    Returns
      A Field object that describes the schema of the signal.
    """
    pass

  def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
    """Compute the signal for an iterable of documents or images.

    Args:
      data: An iterable of rich data to compute the signal over.

    Returns
      An iterable of items. Sparse signals should return "None" for skipped inputs.
    """
    raise NotImplementedError

  def vector_compute(self, keys: Iterable[VectorKey],
                     vector_store: VectorStore) -> Iterable[Optional[Item]]:
    """Compute the signal for an iterable of keys that point to documents or images.

    Args:
      keys: An iterable of value ids (at row-level or lower) to lookup precomputed embeddings.
      vector_store: The vector store to lookup pre-computed embeddings.

    Returns
      An iterable of items. Sparse signals should return "None" for skipped inputs.
    """
    raise NotImplementedError

  def vector_compute_topk(
      self,
      topk: int,
      vector_store: VectorStore,
      keys: Optional[Iterable[VectorKey]] = None) -> Sequence[tuple[VectorKey, Optional[Item]]]:
    """Return signal results only for the top k documents or images.

    Signals decide how to rank each document/image in the dataset, usually by a similarity score
    obtained via the vector store.

    Args:
      topk: The number of items to return, ranked by the signal.
      vector_store: The vector store to lookup pre-computed embeddings.
      keys: Optional iterable of row ids to restrict the search to.

    Returns
      A list of (key, signal_output) tuples containing the `topk` items. Sparse signals should
      return "None" for skipped inputs.
    """
    raise NotImplementedError

  def key(self, is_computed_signal: Optional[bool] = False) -> str:
    """Get the key for a signal.

    This is used to make sure signals with multiple arguments do not collide.

    NOTE: Overriding this method is sensitive. If you override it, make sure that it is globally
    unique. It will be used as the dictionary key for enriched values.

    Args:
      is_computed_signal: True when the signal is computed over the column and written to
        disk. False when the signal is used as a preview UDF.
    """
    args_dict = self.dict(exclude_unset=True, exclude_defaults=True)
    # If a user explicitly defines a signal name for whatever reason, remove it as it's redundant.
    if 'signal_name' in args_dict:
      del args_dict['signal_name']

    return self.name + _args_key_from_dict(args_dict)

  def setup(self) -> None:
    """Setup the signal."""
    pass

  def teardown(self) -> None:
    """Tears down the signal."""
    pass


def _args_key_from_dict(args_dict: dict[str, Any]) -> str:
  args = None
  args_list: list[str] = []
  for k, v in args_dict.items():
    if v:
      args_list.append(f'{k}={v}')

  args = ','.join(args_list)
  return '' if not args_list else f'({args})'


class SignalTypeEnum(str):
  """A class that represents a string enum for a signal type.

  This allows us to populate the JSON schema enum field in the UI.
  """
  signal_type: ClassVar[Type[Signal]]

  @classmethod
  def __modify_schema__(cls, field_schema: dict[str, Any], field: Optional[ModelField]) -> None:
    if field:
      field_schema['enum'] = [x.name for x in get_signals_by_type(cls.signal_type)]
    return None


class TextSplitterSignal(Signal):
  """An interface for signals that compute over text."""
  input_type = SignalInputType.TEXT
  compute_type = SignalInputType.TEXT

  @override
  def fields(self) -> Field:
    return field(fields=['string_span'])


class TextSplitterEnum(SignalTypeEnum):
  """A string enum that represents a text splitter signal."""
  signal_type = TextSplitterSignal


# Signal base classes, used for inferring the dependency chain required for computing a signal.
class TextSignal(Signal):
  """An interface for signals that compute over text."""
  input_type = SignalInputType.TEXT
  compute_type = SignalInputType.TEXT

  @override
  def key(self, is_computed_signal: Optional[bool] = False) -> str:
    args_dict = self.dict(exclude_unset=True, exclude_defaults=True)
    if 'signal_name' in args_dict:
      del args_dict['signal_name']
    return self.name + _args_key_from_dict(args_dict)


class TextEmbeddingSignal(TextSignal):
  """An interface for signals that compute embeddings for text."""
  input_type = SignalInputType.TEXT
  compute_type = SignalInputType.TEXT

  _split = True

  def __init__(self, split: bool = True, **kwargs: Any):
    super().__init__(**kwargs)
    self._split = split

  @override
  def fields(self) -> Field:
    """NOTE: Override this method at your own risk if you want to add extra metadata.

    Embeddings should not come with extra metadata.
    """
    return field(fields=[field('string_span', fields={EMBEDDING_KEY: 'embedding'})])


class TextEmbeddingEnum(SignalTypeEnum):
  """A string enum that represents a text embedding signal."""
  signal_type = TextEmbeddingSignal


class TextEmbeddingModelSignal(TextSignal):
  """An interface for signals that take embeddings and produce items."""
  input_type = SignalInputType.TEXT
  # compute() takes embeddings, while it operates over text fields by transitively computing splits
  # and embeddings.
  compute_type = SignalInputType.TEXT_EMBEDDING

  embedding: TextEmbeddingEnum
  _embedding_signal: TextEmbeddingSignal

  def __init__(self, **kwargs: Any):
    super().__init__(**kwargs)

    # Validate the embedding signal is registered and the correct type.
    # TODO(nsthorat): Allow arguments passed to the embedding signal.
    self._embedding_signal = get_signal_by_type(self.embedding, TextEmbeddingSignal)()

  def get_embedding_signal(self) -> TextEmbeddingSignal:
    """Return the embedding signal."""
    return self._embedding_signal

  @override
  def key(self, is_computed_signal: Optional[bool] = False) -> str:
    # NOTE: The embedding and split already exists in the path structure. This means we do not
    # need to provide the signal names as part of the key, which still guarantees uniqueness.

    args_dict = self.dict(exclude_unset=True)
    if 'signal_name' in args_dict:
      del args_dict['signal_name']
    del args_dict['embedding']
    return self.name + _args_key_from_dict(args_dict)


Tsignal = TypeVar('Tsignal', bound=Signal)


def get_signal_by_type(signal_name: str, signal_type: Type[Tsignal]) -> Type[Tsignal]:
  """Return a signal class by name and signal type."""
  if signal_name not in SIGNAL_REGISTRY:
    raise ValueError(f'Signal "{signal_name}" not found in the registry')

  signal_cls = SIGNAL_REGISTRY[signal_name]
  if not issubclass(signal_cls, signal_type):
    raise ValueError(f'"{signal_name}" is a `{signal_cls.__name__}`, '
                     f'which is not a subclass of `{signal_type.__name__}`.')
  return signal_cls


def get_signals_by_type(signal_type: Type[Tsignal]) -> list[Type[Tsignal]]:
  """Return all signals that match a signal type."""
  signal_clses: list[Type[Tsignal]] = []
  for signal_cls in SIGNAL_REGISTRY.values():
    if issubclass(signal_cls, signal_type):
      signal_clses.append(signal_cls)
  return signal_clses


SIGNAL_REGISTRY: dict[str, Type[Signal]] = {}


def register_signal(signal_cls: Type[Signal]) -> None:
  """Register a signal in the global registry."""
  if signal_cls.name in SIGNAL_REGISTRY:
    raise ValueError(f'Signal "{signal_cls.name}" has already been registered!')

  SIGNAL_REGISTRY[signal_cls.name] = signal_cls


def get_signal_cls(signal_name: str) -> Type[Signal]:
  """Return a registered signal given the name in the registry."""
  if signal_name not in SIGNAL_REGISTRY:
    raise ValueError(f'Signal "{signal_name}" not found in the registry')

  return SIGNAL_REGISTRY[signal_name]


def resolve_signal(signal: Union[dict, Signal]) -> Signal:
  """Resolve a generic signal base class to a specific signal class."""
  if isinstance(signal, Signal):
    # The signal config is already parsed.
    return signal

  signal_name = signal.get('signal_name')
  if not signal_name:
    raise ValueError('"signal_name" needs to be defined in the json dict.')

  signal_cls = get_signal_cls(signal_name)
  return signal_cls(**signal)


def clear_signal_registry() -> None:
  """Clear the signal registry."""
  SIGNAL_REGISTRY.clear()