Spaces:
Runtime error
Runtime error
"""Interface for implementing a signal.""" | |
import abc | |
from typing import Any, ClassVar, Iterable, Optional, Sequence, Type, TypeVar, Union | |
from pydantic import BaseModel, Extra, validator | |
from pydantic.fields import ModelField | |
from typing_extensions import override | |
from ..embeddings.vector_store import VectorStore | |
from ..schema import Field, Item, RichData, SignalInputType, VectorKey, field | |
EMBEDDING_KEY = 'embedding' | |
class Signal(abc.ABC, BaseModel): | |
"""Interface for signals to implement. A signal can score documents and a dataset column.""" | |
# ClassVars do not get serialized with pydantic. | |
name: ClassVar[str] | |
# The display name is just used for rendering in the UI. | |
display_name: ClassVar[Optional[str]] | |
signal_type: ClassVar[Type['Signal']] | |
# The input type is used to populate the UI for signals that require other signals. For example, | |
# if a signal is an TextEmbeddingModelSignal, it computes over embeddings, but it's input type | |
# will be text. | |
input_type: ClassVar[SignalInputType] | |
# The compute type defines what should be passed to compute(). | |
compute_type: ClassVar[SignalInputType] | |
# The signal_name will get populated in init automatically from the class name so it gets | |
# serialized and the signal author doesn't have to define both the static property and the field. | |
signal_name: Optional[str] | |
class Config: | |
underscore_attrs_are_private = True | |
extra = Extra.forbid | |
def schema_extra(schema: dict[str, Any], signal: Type['Signal']) -> None: | |
"""Add the title to the schema from the display name and name. | |
Pydantic defaults this to the class name. | |
""" | |
if hasattr(signal, 'display_name'): | |
schema['title'] = signal.display_name | |
if hasattr(signal, 'name'): | |
schema['properties']['signal_name']['enum'] = [signal.name] | |
def validate_signal_name(cls, signal_name: str) -> str: | |
"""Return the static name when the signal name hasn't yet been set.""" | |
# When it's already been set from JSON, just return it. | |
if signal_name: | |
return signal_name | |
if 'name' not in cls.__dict__: | |
raise ValueError('Signal attribute "name" must be defined.') | |
return cls.name | |
def fields(self) -> Field: | |
"""Return the fields schema for this signal. | |
Returns | |
A Field object that describes the schema of the signal. | |
""" | |
pass | |
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: | |
"""Compute the signal for an iterable of documents or images. | |
Args: | |
data: An iterable of rich data to compute the signal over. | |
Returns | |
An iterable of items. Sparse signals should return "None" for skipped inputs. | |
""" | |
raise NotImplementedError | |
def vector_compute(self, keys: Iterable[VectorKey], | |
vector_store: VectorStore) -> Iterable[Optional[Item]]: | |
"""Compute the signal for an iterable of keys that point to documents or images. | |
Args: | |
keys: An iterable of value ids (at row-level or lower) to lookup precomputed embeddings. | |
vector_store: The vector store to lookup pre-computed embeddings. | |
Returns | |
An iterable of items. Sparse signals should return "None" for skipped inputs. | |
""" | |
raise NotImplementedError | |
def vector_compute_topk( | |
self, | |
topk: int, | |
vector_store: VectorStore, | |
keys: Optional[Iterable[VectorKey]] = None) -> Sequence[tuple[VectorKey, Optional[Item]]]: | |
"""Return signal results only for the top k documents or images. | |
Signals decide how to rank each document/image in the dataset, usually by a similarity score | |
obtained via the vector store. | |
Args: | |
topk: The number of items to return, ranked by the signal. | |
vector_store: The vector store to lookup pre-computed embeddings. | |
keys: Optional iterable of row ids to restrict the search to. | |
Returns | |
A list of (key, signal_output) tuples containing the `topk` items. Sparse signals should | |
return "None" for skipped inputs. | |
""" | |
raise NotImplementedError | |
def key(self, is_computed_signal: Optional[bool] = False) -> str: | |
"""Get the key for a signal. | |
This is used to make sure signals with multiple arguments do not collide. | |
NOTE: Overriding this method is sensitive. If you override it, make sure that it is globally | |
unique. It will be used as the dictionary key for enriched values. | |
Args: | |
is_computed_signal: True when the signal is computed over the column and written to | |
disk. False when the signal is used as a preview UDF. | |
""" | |
args_dict = self.dict(exclude_unset=True, exclude_defaults=True) | |
# If a user explicitly defines a signal name for whatever reason, remove it as it's redundant. | |
if 'signal_name' in args_dict: | |
del args_dict['signal_name'] | |
return self.name + _args_key_from_dict(args_dict) | |
def setup(self) -> None: | |
"""Setup the signal.""" | |
pass | |
def teardown(self) -> None: | |
"""Tears down the signal.""" | |
pass | |
def _args_key_from_dict(args_dict: dict[str, Any]) -> str: | |
args = None | |
args_list: list[str] = [] | |
for k, v in args_dict.items(): | |
if v: | |
args_list.append(f'{k}={v}') | |
args = ','.join(args_list) | |
return '' if not args_list else f'({args})' | |
class SignalTypeEnum(str): | |
"""A class that represents a string enum for a signal type. | |
This allows us to populate the JSON schema enum field in the UI. | |
""" | |
signal_type: ClassVar[Type[Signal]] | |
def __modify_schema__(cls, field_schema: dict[str, Any], field: Optional[ModelField]) -> None: | |
if field: | |
field_schema['enum'] = [x.name for x in get_signals_by_type(cls.signal_type)] | |
return None | |
class TextSplitterSignal(Signal): | |
"""An interface for signals that compute over text.""" | |
input_type = SignalInputType.TEXT | |
compute_type = SignalInputType.TEXT | |
def fields(self) -> Field: | |
return field(fields=['string_span']) | |
class TextSplitterEnum(SignalTypeEnum): | |
"""A string enum that represents a text splitter signal.""" | |
signal_type = TextSplitterSignal | |
# Signal base classes, used for inferring the dependency chain required for computing a signal. | |
class TextSignal(Signal): | |
"""An interface for signals that compute over text.""" | |
input_type = SignalInputType.TEXT | |
compute_type = SignalInputType.TEXT | |
def key(self, is_computed_signal: Optional[bool] = False) -> str: | |
args_dict = self.dict(exclude_unset=True, exclude_defaults=True) | |
if 'signal_name' in args_dict: | |
del args_dict['signal_name'] | |
return self.name + _args_key_from_dict(args_dict) | |
class TextEmbeddingSignal(TextSignal): | |
"""An interface for signals that compute embeddings for text.""" | |
input_type = SignalInputType.TEXT | |
compute_type = SignalInputType.TEXT | |
_split = True | |
def __init__(self, split: bool = True, **kwargs: Any): | |
super().__init__(**kwargs) | |
self._split = split | |
def fields(self) -> Field: | |
"""NOTE: Override this method at your own risk if you want to add extra metadata. | |
Embeddings should not come with extra metadata. | |
""" | |
return field(fields=[field('string_span', fields={EMBEDDING_KEY: 'embedding'})]) | |
class TextEmbeddingEnum(SignalTypeEnum): | |
"""A string enum that represents a text embedding signal.""" | |
signal_type = TextEmbeddingSignal | |
class TextEmbeddingModelSignal(TextSignal): | |
"""An interface for signals that take embeddings and produce items.""" | |
input_type = SignalInputType.TEXT | |
# compute() takes embeddings, while it operates over text fields by transitively computing splits | |
# and embeddings. | |
compute_type = SignalInputType.TEXT_EMBEDDING | |
embedding: TextEmbeddingEnum | |
_embedding_signal: TextEmbeddingSignal | |
def __init__(self, **kwargs: Any): | |
super().__init__(**kwargs) | |
# Validate the embedding signal is registered and the correct type. | |
# TODO(nsthorat): Allow arguments passed to the embedding signal. | |
self._embedding_signal = get_signal_by_type(self.embedding, TextEmbeddingSignal)() | |
def get_embedding_signal(self) -> TextEmbeddingSignal: | |
"""Return the embedding signal.""" | |
return self._embedding_signal | |
def key(self, is_computed_signal: Optional[bool] = False) -> str: | |
# NOTE: The embedding and split already exists in the path structure. This means we do not | |
# need to provide the signal names as part of the key, which still guarantees uniqueness. | |
args_dict = self.dict(exclude_unset=True) | |
if 'signal_name' in args_dict: | |
del args_dict['signal_name'] | |
del args_dict['embedding'] | |
return self.name + _args_key_from_dict(args_dict) | |
Tsignal = TypeVar('Tsignal', bound=Signal) | |
def get_signal_by_type(signal_name: str, signal_type: Type[Tsignal]) -> Type[Tsignal]: | |
"""Return a signal class by name and signal type.""" | |
if signal_name not in SIGNAL_REGISTRY: | |
raise ValueError(f'Signal "{signal_name}" not found in the registry') | |
signal_cls = SIGNAL_REGISTRY[signal_name] | |
if not issubclass(signal_cls, signal_type): | |
raise ValueError(f'"{signal_name}" is a `{signal_cls.__name__}`, ' | |
f'which is not a subclass of `{signal_type.__name__}`.') | |
return signal_cls | |
def get_signals_by_type(signal_type: Type[Tsignal]) -> list[Type[Tsignal]]: | |
"""Return all signals that match a signal type.""" | |
signal_clses: list[Type[Tsignal]] = [] | |
for signal_cls in SIGNAL_REGISTRY.values(): | |
if issubclass(signal_cls, signal_type): | |
signal_clses.append(signal_cls) | |
return signal_clses | |
SIGNAL_REGISTRY: dict[str, Type[Signal]] = {} | |
def register_signal(signal_cls: Type[Signal]) -> None: | |
"""Register a signal in the global registry.""" | |
if signal_cls.name in SIGNAL_REGISTRY: | |
raise ValueError(f'Signal "{signal_cls.name}" has already been registered!') | |
SIGNAL_REGISTRY[signal_cls.name] = signal_cls | |
def get_signal_cls(signal_name: str) -> Type[Signal]: | |
"""Return a registered signal given the name in the registry.""" | |
if signal_name not in SIGNAL_REGISTRY: | |
raise ValueError(f'Signal "{signal_name}" not found in the registry') | |
return SIGNAL_REGISTRY[signal_name] | |
def resolve_signal(signal: Union[dict, Signal]) -> Signal: | |
"""Resolve a generic signal base class to a specific signal class.""" | |
if isinstance(signal, Signal): | |
# The signal config is already parsed. | |
return signal | |
signal_name = signal.get('signal_name') | |
if not signal_name: | |
raise ValueError('"signal_name" needs to be defined in the json dict.') | |
signal_cls = get_signal_cls(signal_name) | |
return signal_cls(**signal) | |
def clear_signal_registry() -> None: | |
"""Clear the signal registry.""" | |
SIGNAL_REGISTRY.clear() | |