Spaces:
Runtime error
Runtime error
"""A signal to compute a score along a concept.""" | |
from typing import Iterable, Optional | |
import numpy as np | |
from typing_extensions import override | |
from ..auth import UserInfo | |
from ..concepts.concept import DEFAULT_NUM_NEG_EXAMPLES, DRAFT_MAIN, ConceptColumnInfo, ConceptModel | |
from ..concepts.db_concept import DISK_CONCEPT_MODEL_DB, ConceptModelDB | |
from ..data.dataset_utils import lilac_span | |
from ..embeddings.vector_store import VectorDBIndex | |
from ..schema import Field, Item, PathKey, RichData, SignalInputType, field | |
from ..signals.signal import VectorSignal | |
class ConceptScoreSignal(VectorSignal): | |
"""Compute scores along a given concept for documents.""" | |
name = 'concept_score' | |
input_type = SignalInputType.TEXT | |
display_name = 'Concept' | |
namespace: str | |
concept_name: str | |
# The draft version of the concept to use. If not provided, the latest version is used. | |
draft: str = DRAFT_MAIN | |
# Number of randomly chosen negative examples to use when training the concept. This is used to | |
# obtain a better suited model for the concrete dataset. | |
num_negative_examples = DEFAULT_NUM_NEG_EXAMPLES | |
_column_info: Optional[ConceptColumnInfo] = None | |
_concept_model_db: ConceptModelDB = DISK_CONCEPT_MODEL_DB | |
_user: Optional[UserInfo] = None | |
def fields(self) -> Field: | |
return field(fields=[ | |
field( | |
dtype='string_span', | |
fields={ | |
'score': field( | |
'float32', | |
bins=[('Not in concept', None, 0.5), ('In concept', 0.5, None)], | |
) | |
}) | |
]) | |
def set_column_info(self, column_info: ConceptColumnInfo) -> None: | |
"""Set the dataset info for this signal.""" | |
self._column_info = column_info | |
self._column_info.num_negative_examples = self.num_negative_examples | |
def set_user(self, user: Optional[UserInfo]) -> None: | |
"""Set the user for this signal.""" | |
self._user = user | |
def _get_concept_model(self) -> ConceptModel: | |
model = self._concept_model_db.get( | |
self.namespace, self.concept_name, self.embedding, self._column_info, user=self._user) | |
if not model: | |
model = self._concept_model_db.create( | |
self.namespace, self.concept_name, self.embedding, self._column_info, user=self._user) | |
self._concept_model_db.sync(model, self._user) | |
return model | |
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: | |
concept_model = self._get_concept_model() | |
return concept_model.score(self.draft, data) | |
def vector_compute(self, keys: Iterable[PathKey], | |
vector_index: VectorDBIndex) -> Iterable[Optional[Item]]: | |
concept_model = self._get_concept_model() | |
all_vector_spans = vector_index.get(keys) | |
# TODO(smilkov): Do this with batched computation. | |
for vector_spans in all_vector_spans: | |
embeddings = np.array([vector_span['vector'] for vector_span in vector_spans]) | |
scores = concept_model.score_embeddings(self.draft, embeddings) | |
res: Item = [] | |
for vector_span, score in zip(vector_spans, scores): | |
start, end = vector_span['span'] | |
res.append(lilac_span(start, end, {'score': score})) | |
yield res | |
def vector_compute_topk( | |
self, | |
topk: int, | |
vector_index: VectorDBIndex, | |
keys: Optional[Iterable[PathKey]] = None) -> list[tuple[PathKey, Optional[Item]]]: | |
concept_model = self._get_concept_model() | |
query: np.ndarray = concept_model.coef(self.draft) | |
topk_keys = [key for key, _ in vector_index.topk(query, topk, keys)] | |
return list(zip(topk_keys, self.vector_compute(topk_keys, vector_index))) | |
def key(self, is_computed_signal: Optional[bool] = False) -> str: | |
# NOTE: The embedding is a value so already exists in the path structure. This means we do not | |
# need to provide the name as part of the key, which still guarantees uniqueness. | |
version = f'/v{self._get_concept_model().version}' if is_computed_signal else '' | |
return f'{self.namespace}/{self.concept_name}/{self.embedding}{version}' | |