"""Interface for implementing a signal.""" import abc from typing import Any, ClassVar, Iterable, Optional, Sequence, Type, TypeVar, Union from pydantic import BaseModel, Extra, validator from pydantic.fields import ModelField from typing_extensions import override from ..embeddings.vector_store import VectorStore from ..schema import Field, Item, RichData, SignalInputType, VectorKey, field EMBEDDING_KEY = 'embedding' class Signal(abc.ABC, BaseModel): """Interface for signals to implement. A signal can score documents and a dataset column.""" # ClassVars do not get serialized with pydantic. name: ClassVar[str] # The display name is just used for rendering in the UI. display_name: ClassVar[Optional[str]] signal_type: ClassVar[Type['Signal']] # The input type is used to populate the UI for signals that require other signals. For example, # if a signal is an TextEmbeddingModelSignal, it computes over embeddings, but it's input type # will be text. input_type: ClassVar[SignalInputType] # The compute type defines what should be passed to compute(). compute_type: ClassVar[SignalInputType] # The signal_name will get populated in init automatically from the class name so it gets # serialized and the signal author doesn't have to define both the static property and the field. signal_name: Optional[str] class Config: underscore_attrs_are_private = True extra = Extra.forbid @staticmethod def schema_extra(schema: dict[str, Any], signal: Type['Signal']) -> None: """Add the title to the schema from the display name and name. Pydantic defaults this to the class name. """ if hasattr(signal, 'display_name'): schema['title'] = signal.display_name if hasattr(signal, 'name'): schema['properties']['signal_name']['enum'] = [signal.name] @validator('signal_name', pre=True, always=True) def validate_signal_name(cls, signal_name: str) -> str: """Return the static name when the signal name hasn't yet been set.""" # When it's already been set from JSON, just return it. if signal_name: return signal_name if 'name' not in cls.__dict__: raise ValueError('Signal attribute "name" must be defined.') return cls.name @abc.abstractmethod def fields(self) -> Field: """Return the fields schema for this signal. Returns A Field object that describes the schema of the signal. """ pass def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: """Compute the signal for an iterable of documents or images. Args: data: An iterable of rich data to compute the signal over. Returns An iterable of items. Sparse signals should return "None" for skipped inputs. """ raise NotImplementedError def vector_compute(self, keys: Iterable[VectorKey], vector_store: VectorStore) -> Iterable[Optional[Item]]: """Compute the signal for an iterable of keys that point to documents or images. Args: keys: An iterable of value ids (at row-level or lower) to lookup precomputed embeddings. vector_store: The vector store to lookup pre-computed embeddings. Returns An iterable of items. Sparse signals should return "None" for skipped inputs. """ raise NotImplementedError def vector_compute_topk( self, topk: int, vector_store: VectorStore, keys: Optional[Iterable[VectorKey]] = None) -> Sequence[tuple[VectorKey, Optional[Item]]]: """Return signal results only for the top k documents or images. Signals decide how to rank each document/image in the dataset, usually by a similarity score obtained via the vector store. Args: topk: The number of items to return, ranked by the signal. vector_store: The vector store to lookup pre-computed embeddings. keys: Optional iterable of row ids to restrict the search to. Returns A list of (key, signal_output) tuples containing the `topk` items. Sparse signals should return "None" for skipped inputs. """ raise NotImplementedError def key(self, is_computed_signal: Optional[bool] = False) -> str: """Get the key for a signal. This is used to make sure signals with multiple arguments do not collide. NOTE: Overriding this method is sensitive. If you override it, make sure that it is globally unique. It will be used as the dictionary key for enriched values. Args: is_computed_signal: True when the signal is computed over the column and written to disk. False when the signal is used as a preview UDF. """ args_dict = self.dict(exclude_unset=True, exclude_defaults=True) # If a user explicitly defines a signal name for whatever reason, remove it as it's redundant. if 'signal_name' in args_dict: del args_dict['signal_name'] return self.name + _args_key_from_dict(args_dict) def setup(self) -> None: """Setup the signal.""" pass def teardown(self) -> None: """Tears down the signal.""" pass def _args_key_from_dict(args_dict: dict[str, Any]) -> str: args = None args_list: list[str] = [] for k, v in args_dict.items(): if v: args_list.append(f'{k}={v}') args = ','.join(args_list) return '' if not args_list else f'({args})' class SignalTypeEnum(str): """A class that represents a string enum for a signal type. This allows us to populate the JSON schema enum field in the UI. """ signal_type: ClassVar[Type[Signal]] @classmethod def __modify_schema__(cls, field_schema: dict[str, Any], field: Optional[ModelField]) -> None: if field: field_schema['enum'] = [x.name for x in get_signals_by_type(cls.signal_type)] return None class TextSplitterSignal(Signal): """An interface for signals that compute over text.""" input_type = SignalInputType.TEXT compute_type = SignalInputType.TEXT @override def fields(self) -> Field: return field(fields=['string_span']) class TextSplitterEnum(SignalTypeEnum): """A string enum that represents a text splitter signal.""" signal_type = TextSplitterSignal # Signal base classes, used for inferring the dependency chain required for computing a signal. class TextSignal(Signal): """An interface for signals that compute over text.""" input_type = SignalInputType.TEXT compute_type = SignalInputType.TEXT @override def key(self, is_computed_signal: Optional[bool] = False) -> str: args_dict = self.dict(exclude_unset=True, exclude_defaults=True) if 'signal_name' in args_dict: del args_dict['signal_name'] return self.name + _args_key_from_dict(args_dict) class TextEmbeddingSignal(TextSignal): """An interface for signals that compute embeddings for text.""" input_type = SignalInputType.TEXT compute_type = SignalInputType.TEXT _split = True def __init__(self, split: bool = True, **kwargs: Any): super().__init__(**kwargs) self._split = split @override def fields(self) -> Field: """NOTE: Override this method at your own risk if you want to add extra metadata. Embeddings should not come with extra metadata. """ return field(fields=[field('string_span', fields={EMBEDDING_KEY: 'embedding'})]) class TextEmbeddingEnum(SignalTypeEnum): """A string enum that represents a text embedding signal.""" signal_type = TextEmbeddingSignal class TextEmbeddingModelSignal(TextSignal): """An interface for signals that take embeddings and produce items.""" input_type = SignalInputType.TEXT # compute() takes embeddings, while it operates over text fields by transitively computing splits # and embeddings. compute_type = SignalInputType.TEXT_EMBEDDING embedding: TextEmbeddingEnum _embedding_signal: TextEmbeddingSignal def __init__(self, **kwargs: Any): super().__init__(**kwargs) # Validate the embedding signal is registered and the correct type. # TODO(nsthorat): Allow arguments passed to the embedding signal. self._embedding_signal = get_signal_by_type(self.embedding, TextEmbeddingSignal)() def get_embedding_signal(self) -> TextEmbeddingSignal: """Return the embedding signal.""" return self._embedding_signal @override def key(self, is_computed_signal: Optional[bool] = False) -> str: # NOTE: The embedding and split already exists in the path structure. This means we do not # need to provide the signal names as part of the key, which still guarantees uniqueness. args_dict = self.dict(exclude_unset=True) if 'signal_name' in args_dict: del args_dict['signal_name'] del args_dict['embedding'] return self.name + _args_key_from_dict(args_dict) Tsignal = TypeVar('Tsignal', bound=Signal) def get_signal_by_type(signal_name: str, signal_type: Type[Tsignal]) -> Type[Tsignal]: """Return a signal class by name and signal type.""" if signal_name not in SIGNAL_REGISTRY: raise ValueError(f'Signal "{signal_name}" not found in the registry') signal_cls = SIGNAL_REGISTRY[signal_name] if not issubclass(signal_cls, signal_type): raise ValueError(f'"{signal_name}" is a `{signal_cls.__name__}`, ' f'which is not a subclass of `{signal_type.__name__}`.') return signal_cls def get_signals_by_type(signal_type: Type[Tsignal]) -> list[Type[Tsignal]]: """Return all signals that match a signal type.""" signal_clses: list[Type[Tsignal]] = [] for signal_cls in SIGNAL_REGISTRY.values(): if issubclass(signal_cls, signal_type): signal_clses.append(signal_cls) return signal_clses SIGNAL_REGISTRY: dict[str, Type[Signal]] = {} def register_signal(signal_cls: Type[Signal]) -> None: """Register a signal in the global registry.""" if signal_cls.name in SIGNAL_REGISTRY: raise ValueError(f'Signal "{signal_cls.name}" has already been registered!') SIGNAL_REGISTRY[signal_cls.name] = signal_cls def get_signal_cls(signal_name: str) -> Type[Signal]: """Return a registered signal given the name in the registry.""" if signal_name not in SIGNAL_REGISTRY: raise ValueError(f'Signal "{signal_name}" not found in the registry') return SIGNAL_REGISTRY[signal_name] def resolve_signal(signal: Union[dict, Signal]) -> Signal: """Resolve a generic signal base class to a specific signal class.""" if isinstance(signal, Signal): # The signal config is already parsed. return signal signal_name = signal.get('signal_name') if not signal_name: raise ValueError('"signal_name" needs to be defined in the json dict.') signal_cls = get_signal_cls(signal_name) return signal_cls(**signal) def clear_signal_registry() -> None: """Clear the signal registry.""" SIGNAL_REGISTRY.clear()