nikhil_staging / src /concepts /db_concept.py
nsthorat's picture
Push
e4f9cbe
raw
history blame
12.9 kB
"""The concept database."""
import abc
import glob
import os
import pickle
import shutil
# NOTE: We have to import the module for uuid so it can be mocked.
import uuid
from typing import List, Optional, Union, cast
from pydantic import BaseModel
from typing_extensions import override
from ..config import data_path
from ..schema import PATH_WILDCARD, SignalInputType, normalize_path
from ..signals.signal import get_signal_cls
from ..utils import DATASETS_DIR_NAME, delete_file, file_exists, get_dataset_output_dir, open_file
from .concept import (
DRAFT_MAIN,
Concept,
ConceptColumnInfo,
ConceptModel,
DraftId,
Example,
ExampleIn,
)
DATASET_CONCEPTS_DIR = '.concepts'
CONCEPT_JSON_FILENAME = 'concept.json'
class ConceptInfo(BaseModel):
"""Information about a concept."""
namespace: str
name: str
type: SignalInputType
drafts: list[DraftId]
class ConceptUpdate(BaseModel):
"""An update to a concept."""
# List of examples to be inserted.
insert: Optional[list[ExampleIn]] = []
# List of examples to be updated.
update: Optional[list[Example]] = []
# The ids of the examples to be removed.
remove: Optional[list[str]] = []
class ConceptDB(abc.ABC):
"""Interface for the concept database."""
@abc.abstractmethod
def list(self) -> list[ConceptInfo]:
"""List all the concepts."""
pass
@abc.abstractmethod
def get(self, namespace: str, name: str) -> Optional[Concept]:
"""Return a concept or None if there isn't one."""
pass
@abc.abstractmethod
def create(self, namespace: str, name: str, type: SignalInputType) -> Concept:
"""Create a concept."""
pass
@abc.abstractmethod
def edit(self, namespace: str, name: str, change: ConceptUpdate) -> Concept:
"""Edit a concept. If the concept doesn't exist, throw an error."""
pass
@abc.abstractmethod
def remove(self, namespace: str, name: str) -> None:
"""Remove a concept."""
pass
@abc.abstractmethod
def merge_draft(self, namespace: str, name: str, draft: DraftId) -> Concept:
"""Merge a draft concept.."""
pass
class ConceptModelDB(abc.ABC):
"""Interface for the concept model database."""
_concept_db: ConceptDB
def __init__(self, concept_db: ConceptDB) -> None:
self._concept_db = concept_db
@abc.abstractmethod
def create(self,
namespace: str,
concept_name: str,
embedding_name: str,
column_info: Optional[ConceptColumnInfo] = None) -> ConceptModel:
"""Create the concept model."""
pass
@abc.abstractmethod
def get(self,
namespace: str,
concept_name: str,
embedding_name: str,
column_info: Optional[ConceptColumnInfo] = None) -> Optional[ConceptModel]:
"""Get the model associated with the provided concept the embedding.
Returns None if the model does not exist.
"""
pass
@abc.abstractmethod
def _save(self, model: ConceptModel, column_info: Optional[ConceptColumnInfo]) -> None:
"""Save the concept model."""
pass
def in_sync(self, model: ConceptModel) -> bool:
"""Return True if the model is up to date with the concept."""
concept = self._concept_db.get(model.namespace, model.concept_name)
if not concept:
raise ValueError(f'Concept "{model.namespace}/{model.concept_name}" does not exist.')
return concept.version == model.version
def sync(self, model: ConceptModel, column_info: Optional[ConceptColumnInfo]) -> bool:
"""Sync the concept model. Returns true if the model was updated."""
concept = self._concept_db.get(model.namespace, model.concept_name)
if not concept:
raise ValueError(f'Concept "{model.namespace}/{model.concept_name}" does not exist.')
model_updated = model.sync(concept)
self._save(model, column_info)
return model_updated
@abc.abstractmethod
def remove(self,
namespace: str,
concept_name: str,
embedding_name: str,
column_info: Optional[ConceptColumnInfo] = None) -> None:
"""Remove the model of a concept."""
pass
@abc.abstractmethod
def remove_all(self, namespace: str, concept_name: str) -> None:
"""Remove all the models associated with a concept."""
pass
class DiskConceptModelDB(ConceptModelDB):
"""Interface for the concept model database."""
@override
def create(self,
namespace: str,
concept_name: str,
embedding_name: str,
column_info: Optional[ConceptColumnInfo] = None) -> ConceptModel:
if self.get(namespace, concept_name, embedding_name, column_info):
raise ValueError('Concept model already exists.')
return ConceptModel(
namespace=namespace,
concept_name=concept_name,
embedding_name=embedding_name,
version=-1,
column_info=column_info)
@override
def get(self,
namespace: str,
concept_name: str,
embedding_name: str,
column_info: Optional[ConceptColumnInfo] = None) -> Optional[ConceptModel]:
# Make sure the concept exists.
concept = self._concept_db.get(namespace, concept_name)
if not concept:
raise ValueError(f'Concept "{namespace}/{concept_name}" does not exist.')
# Make sure that the embedding signal exists.
if not get_signal_cls(embedding_name):
raise ValueError(f'Embedding signal "{embedding_name}" not found in the registry.')
concept_model_path = _concept_model_path(namespace, concept_name, embedding_name, column_info)
if not file_exists(concept_model_path):
return None
with open_file(concept_model_path, 'rb') as f:
return pickle.load(f)
def _save(self, model: ConceptModel, column_info: Optional[ConceptColumnInfo]) -> None:
"""Save the concept model."""
concept_model_path = _concept_model_path(model.namespace, model.concept_name,
model.embedding_name, column_info)
with open_file(concept_model_path, 'wb') as f:
pickle.dump(model, f)
@override
def remove(self,
namespace: str,
concept_name: str,
embedding_name: str,
column_info: Optional[ConceptColumnInfo] = None) -> None:
concept_model_path = _concept_model_path(namespace, concept_name, embedding_name, column_info)
if not file_exists(concept_model_path):
raise ValueError(f'Concept model {namespace}/{concept_name}/{embedding_name} does not exist.')
delete_file(concept_model_path)
@override
def remove_all(self, namespace: str, concept_name: str) -> None:
datasets_path = os.path.join(data_path(), DATASETS_DIR_NAME)
# Skip if 'datasets' doesn't exist.
if not os.path.isdir(datasets_path):
return
dirs = glob.iglob(
os.path.join(datasets_path, '**', DATASET_CONCEPTS_DIR, namespace, concept_name),
recursive=True)
for dir in dirs:
shutil.rmtree(dir, ignore_errors=True)
def _concept_output_dir(namespace: str, name: str) -> str:
"""Return the output directory for a given concept."""
return os.path.join(data_path(), 'concept', namespace, name)
def _concept_json_path(namespace: str, name: str) -> str:
return os.path.join(_concept_output_dir(namespace, name), CONCEPT_JSON_FILENAME)
def _concept_model_path(namespace: str,
concept_name: str,
embedding_name: str,
column_info: Optional[ConceptColumnInfo] = None) -> str:
if not column_info:
return os.path.join(_concept_output_dir(namespace, concept_name), f'{embedding_name}.pkl')
dataset_dir = get_dataset_output_dir(data_path(), column_info.namespace, column_info.name)
path_tuple = normalize_path(column_info.path)
path_without_wildcards = (p for p in path_tuple if p != PATH_WILDCARD)
path_dir = os.path.join(dataset_dir, *path_without_wildcards)
return os.path.join(path_dir, DATASET_CONCEPTS_DIR, namespace, concept_name,
f'{embedding_name}.pkl')
class DiskConceptDB(ConceptDB):
"""A concept database."""
@override
def list(self) -> list[ConceptInfo]:
# Read the concepts and return a ConceptInfo containing the namespace and name.
concept_infos = []
for root, _, files in os.walk(data_path()):
for file in files:
if file == CONCEPT_JSON_FILENAME:
namespace, name = root.split('/')[-2:]
concept = cast(Concept, self.get(namespace, name))
concept_infos.append(
ConceptInfo(
namespace=namespace,
name=name,
# TODO(nsthorat): Generalize this to images.
type=SignalInputType.TEXT,
drafts=concept.drafts()))
return concept_infos
@override
def get(self, namespace: str, name: str) -> Optional[Concept]:
concept_json_path = _concept_json_path(namespace, name)
if not file_exists(concept_json_path):
return None
with open_file(concept_json_path) as f:
return Concept.parse_raw(f.read())
@override
def create(self, namespace: str, name: str, type: SignalInputType) -> Concept:
"""Create a concept."""
concept_json_path = _concept_json_path(namespace, name)
if file_exists(concept_json_path):
raise ValueError(f'Concept with namespace "{namespace}" and name "{name}" already exists.')
concept = Concept(namespace=namespace, concept_name=name, type=type, data={}, version=0)
self._save(concept)
return concept
def _validate_examples(self, examples: List[Union[ExampleIn, Example]],
type: SignalInputType) -> None:
for example in examples:
inferred_type = 'text' if example.text else 'img'
if inferred_type != type:
raise ValueError(f'Example type "{inferred_type}" does not match concept type "{type}".')
@override
def edit(self, namespace: str, name: str, change: ConceptUpdate) -> Concept:
concept_json_path = _concept_json_path(namespace, name)
if not file_exists(concept_json_path):
raise ValueError(f'Concept with namespace "{namespace}" and name "{name}" does not exist. '
'Please call create() first.')
inserted_points = change.insert or []
updated_points = change.update or []
removed_points = change.remove or []
concept = cast(Concept, self.get(namespace, name))
self._validate_examples([*inserted_points, *updated_points], concept.type)
for remove_example in removed_points:
if remove_example not in concept.data:
raise ValueError(f'Example with id "{remove_example}" does not exist.')
concept.data.pop(remove_example)
for example in inserted_points:
id = uuid.uuid4().hex
concept.data[id] = Example(id=id, **example.dict())
for example in updated_points:
if example.id not in concept.data:
raise ValueError(f'Example with id "{example.id}" does not exist.')
# Remove the old example and make a new one with a new id to keep it functional.
concept.data.pop(example.id)
concept.data[example.id] = example.copy()
concept.version += 1
self._save(concept)
return concept
def _save(self, concept: Concept) -> None:
concept_json_path = _concept_json_path(concept.namespace, concept.concept_name)
with open_file(concept_json_path, 'w') as f:
f.write(concept.json(exclude_none=True, indent=2, exclude_defaults=True))
@override
def remove(self, namespace: str, name: str) -> None:
concept_dir = _concept_output_dir(namespace, name)
if not file_exists(concept_dir):
raise ValueError(f'Concept with namespace "{namespace}" and name "{name}" does not exist.')
shutil.rmtree(concept_dir, ignore_errors=True)
@override
def merge_draft(self, namespace: str, name: str, draft: DraftId) -> Concept:
"""Merge a draft concept.."""
concept = self.get(namespace, name)
if not concept:
raise ValueError(f'Concept with namespace "{namespace}" and name "{name}" does not exist.')
if draft == DRAFT_MAIN:
return concept
# Map the text of examples in main so we can remove them if they are duplicates.
main_text_ids: dict[Optional[str], str] = {
example.text: id for id, example in concept.data.items() if example.draft == DRAFT_MAIN
}
draft_examples: dict[str, Example] = {
id: example for id, example in concept.data.items() if example.draft == draft
}
for example in draft_examples.values():
example.draft = DRAFT_MAIN
# Remove duplicates in main.
main_text_id = main_text_ids.get(example.text)
if main_text_id:
del concept.data[main_text_id]
concept.version += 1
self._save(concept)
return concept
# A singleton concept database.
DISK_CONCEPT_DB = DiskConceptDB()
DISK_CONCEPT_MODEL_DB = DiskConceptModelDB(DISK_CONCEPT_DB)