Spaces:
Runtime error
Runtime error
File size: 10,745 Bytes
e4f9cbe 55dc3dd e4f9cbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 |
"""Interface for implementing a signal."""
import abc
from typing import Any, ClassVar, Iterable, Optional, Sequence, Type, TypeVar, Union
from pydantic import BaseModel, Extra, validator
from pydantic.fields import ModelField
from typing_extensions import override
from ..embeddings.vector_store import VectorStore
from ..schema import Field, Item, RichData, SignalInputType, VectorKey, field
EMBEDDING_KEY = 'embedding'
class Signal(abc.ABC, BaseModel):
"""Interface for signals to implement. A signal can score documents and a dataset column."""
# ClassVars do not get serialized with pydantic.
name: ClassVar[str]
# The display name is just used for rendering in the UI.
display_name: ClassVar[Optional[str]]
signal_type: ClassVar[Type['Signal']]
# The input type is used to populate the UI for signals that require other signals. For example,
# if a signal is an TextEmbeddingModelSignal, it computes over embeddings, but it's input type
# will be text.
input_type: ClassVar[SignalInputType]
# The compute type defines what should be passed to compute().
compute_type: ClassVar[SignalInputType]
# The signal_name will get populated in init automatically from the class name so it gets
# serialized and the signal author doesn't have to define both the static property and the field.
signal_name: Optional[str]
class Config:
underscore_attrs_are_private = True
extra = Extra.forbid
@staticmethod
def schema_extra(schema: dict[str, Any], signal: Type['Signal']) -> None:
"""Add the title to the schema from the display name and name.
Pydantic defaults this to the class name.
"""
if hasattr(signal, 'display_name'):
schema['title'] = signal.display_name
if hasattr(signal, 'name'):
schema['properties']['signal_name']['enum'] = [signal.name]
@validator('signal_name', pre=True, always=True)
def validate_signal_name(cls, signal_name: str) -> str:
"""Return the static name when the signal name hasn't yet been set."""
# When it's already been set from JSON, just return it.
if signal_name:
return signal_name
if 'name' not in cls.__dict__:
raise ValueError('Signal attribute "name" must be defined.')
return cls.name
@abc.abstractmethod
def fields(self) -> Field:
"""Return the fields schema for this signal.
Returns
A Field object that describes the schema of the signal.
"""
pass
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
"""Compute the signal for an iterable of documents or images.
Args:
data: An iterable of rich data to compute the signal over.
Returns
An iterable of items. Sparse signals should return "None" for skipped inputs.
"""
raise NotImplementedError
def vector_compute(self, keys: Iterable[VectorKey],
vector_store: VectorStore) -> Iterable[Optional[Item]]:
"""Compute the signal for an iterable of keys that point to documents or images.
Args:
keys: An iterable of value ids (at row-level or lower) to lookup precomputed embeddings.
vector_store: The vector store to lookup pre-computed embeddings.
Returns
An iterable of items. Sparse signals should return "None" for skipped inputs.
"""
raise NotImplementedError
def vector_compute_topk(
self,
topk: int,
vector_store: VectorStore,
keys: Optional[Iterable[VectorKey]] = None) -> Sequence[tuple[VectorKey, Optional[Item]]]:
"""Return signal results only for the top k documents or images.
Signals decide how to rank each document/image in the dataset, usually by a similarity score
obtained via the vector store.
Args:
topk: The number of items to return, ranked by the signal.
vector_store: The vector store to lookup pre-computed embeddings.
keys: Optional iterable of row ids to restrict the search to.
Returns
A list of (key, signal_output) tuples containing the `topk` items. Sparse signals should
return "None" for skipped inputs.
"""
raise NotImplementedError
def key(self, is_computed_signal: Optional[bool] = False) -> str:
"""Get the key for a signal.
This is used to make sure signals with multiple arguments do not collide.
NOTE: Overriding this method is sensitive. If you override it, make sure that it is globally
unique. It will be used as the dictionary key for enriched values.
Args:
is_computed_signal: True when the signal is computed over the column and written to
disk. False when the signal is used as a preview UDF.
"""
args_dict = self.dict(exclude_unset=True, exclude_defaults=True)
# If a user explicitly defines a signal name for whatever reason, remove it as it's redundant.
if 'signal_name' in args_dict:
del args_dict['signal_name']
return self.name + _args_key_from_dict(args_dict)
def setup(self) -> None:
"""Setup the signal."""
pass
def teardown(self) -> None:
"""Tears down the signal."""
pass
def _args_key_from_dict(args_dict: dict[str, Any]) -> str:
args = None
args_list: list[str] = []
for k, v in args_dict.items():
if v:
args_list.append(f'{k}={v}')
args = ','.join(args_list)
return '' if not args_list else f'({args})'
class SignalTypeEnum(str):
"""A class that represents a string enum for a signal type.
This allows us to populate the JSON schema enum field in the UI.
"""
signal_type: ClassVar[Type[Signal]]
@classmethod
def __modify_schema__(cls, field_schema: dict[str, Any], field: Optional[ModelField]) -> None:
if field:
field_schema['enum'] = [x.name for x in get_signals_by_type(cls.signal_type)]
return None
class TextSplitterSignal(Signal):
"""An interface for signals that compute over text."""
input_type = SignalInputType.TEXT
compute_type = SignalInputType.TEXT
@override
def fields(self) -> Field:
return field(fields=['string_span'])
class TextSplitterEnum(SignalTypeEnum):
"""A string enum that represents a text splitter signal."""
signal_type = TextSplitterSignal
# Signal base classes, used for inferring the dependency chain required for computing a signal.
class TextSignal(Signal):
"""An interface for signals that compute over text."""
input_type = SignalInputType.TEXT
compute_type = SignalInputType.TEXT
@override
def key(self, is_computed_signal: Optional[bool] = False) -> str:
args_dict = self.dict(exclude_unset=True, exclude_defaults=True)
if 'signal_name' in args_dict:
del args_dict['signal_name']
return self.name + _args_key_from_dict(args_dict)
class TextEmbeddingSignal(TextSignal):
"""An interface for signals that compute embeddings for text."""
input_type = SignalInputType.TEXT
compute_type = SignalInputType.TEXT
_split = True
def __init__(self, split: bool = True, **kwargs: Any):
super().__init__(**kwargs)
self._split = split
@override
def fields(self) -> Field:
"""NOTE: Override this method at your own risk if you want to add extra metadata.
Embeddings should not come with extra metadata.
"""
return field(fields=[field('string_span', fields={EMBEDDING_KEY: 'embedding'})])
class TextEmbeddingEnum(SignalTypeEnum):
"""A string enum that represents a text embedding signal."""
signal_type = TextEmbeddingSignal
class TextEmbeddingModelSignal(TextSignal):
"""An interface for signals that take embeddings and produce items."""
input_type = SignalInputType.TEXT
# compute() takes embeddings, while it operates over text fields by transitively computing splits
# and embeddings.
compute_type = SignalInputType.TEXT_EMBEDDING
embedding: TextEmbeddingEnum
_embedding_signal: TextEmbeddingSignal
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
# Validate the embedding signal is registered and the correct type.
# TODO(nsthorat): Allow arguments passed to the embedding signal.
self._embedding_signal = get_signal_by_type(self.embedding, TextEmbeddingSignal)()
def get_embedding_signal(self) -> TextEmbeddingSignal:
"""Return the embedding signal."""
return self._embedding_signal
@override
def key(self, is_computed_signal: Optional[bool] = False) -> str:
# NOTE: The embedding and split already exists in the path structure. This means we do not
# need to provide the signal names as part of the key, which still guarantees uniqueness.
args_dict = self.dict(exclude_unset=True)
if 'signal_name' in args_dict:
del args_dict['signal_name']
del args_dict['embedding']
return self.name + _args_key_from_dict(args_dict)
Tsignal = TypeVar('Tsignal', bound=Signal)
def get_signal_by_type(signal_name: str, signal_type: Type[Tsignal]) -> Type[Tsignal]:
"""Return a signal class by name and signal type."""
if signal_name not in SIGNAL_REGISTRY:
raise ValueError(f'Signal "{signal_name}" not found in the registry')
signal_cls = SIGNAL_REGISTRY[signal_name]
if not issubclass(signal_cls, signal_type):
raise ValueError(f'"{signal_name}" is a `{signal_cls.__name__}`, '
f'which is not a subclass of `{signal_type.__name__}`.')
return signal_cls
def get_signals_by_type(signal_type: Type[Tsignal]) -> list[Type[Tsignal]]:
"""Return all signals that match a signal type."""
signal_clses: list[Type[Tsignal]] = []
for signal_cls in SIGNAL_REGISTRY.values():
if issubclass(signal_cls, signal_type):
signal_clses.append(signal_cls)
return signal_clses
SIGNAL_REGISTRY: dict[str, Type[Signal]] = {}
def register_signal(signal_cls: Type[Signal]) -> None:
"""Register a signal in the global registry."""
if signal_cls.name in SIGNAL_REGISTRY:
raise ValueError(f'Signal "{signal_cls.name}" has already been registered!')
SIGNAL_REGISTRY[signal_cls.name] = signal_cls
def get_signal_cls(signal_name: str) -> Type[Signal]:
"""Return a registered signal given the name in the registry."""
if signal_name not in SIGNAL_REGISTRY:
raise ValueError(f'Signal "{signal_name}" not found in the registry')
return SIGNAL_REGISTRY[signal_name]
def resolve_signal(signal: Union[dict, Signal]) -> Signal:
"""Resolve a generic signal base class to a specific signal class."""
if isinstance(signal, Signal):
# The signal config is already parsed.
return signal
signal_name = signal.get('signal_name')
if not signal_name:
raise ValueError('"signal_name" needs to be defined in the json dict.')
signal_cls = get_signal_cls(signal_name)
return signal_cls(**signal)
def clear_signal_registry() -> None:
"""Clear the signal registry."""
SIGNAL_REGISTRY.clear()
|