nikhil_no_persistent / docker_start.py
nsthorat's picture
Push
8176be9
raw
history blame contribute delete
No virus
3.76 kB
"""Startup work before running the web server."""
import os
import shutil
from typing import TypedDict
import yaml
from huggingface_hub import scan_cache_dir, snapshot_download
from lilac.concepts.db_concept import CONCEPTS_DIR, DiskConceptDB, get_concept_output_dir
from lilac.env import data_path, env
from lilac.utils import get_datasets_dir, get_lilac_cache_dir, log
def delete_old_files() -> None:
"""Delete old files from the cache."""
# Scan cache
try:
scan = scan_cache_dir()
except BaseException:
# Cache was not found.
return
# Select revisions to delete
to_delete = []
for repo in scan.repos:
latest_revision = max(repo.revisions, key=lambda x: x.last_modified)
to_delete.extend(
[revision.commit_hash for revision in repo.revisions if revision != latest_revision])
strategy = scan.delete_revisions(*to_delete)
# Delete them
log(f'Will delete {len(to_delete)} old revisions and save {strategy.expected_freed_size_str}')
strategy.execute()
class HfSpaceConfig(TypedDict):
"""The huggingface space config, defined in README.md.
See:
https://huggingface.co/docs/hub/spaces-config-reference
"""
title: str
datasets: list[str]
def main() -> None:
"""Download dataset files from the HF space that was uploaded before building the image."""
# SPACE_ID is the HuggingFace Space ID environment variable that is automatically set by HF.
repo_id = env('SPACE_ID', None)
if not repo_id:
return
delete_old_files()
with open(os.path.abspath('README.md')) as f:
# Strip the '---' for the huggingface readme config.
readme = f.read().strip().strip('---')
hf_config: HfSpaceConfig = yaml.safe_load(readme)
# Download the huggingface space data. This includes code and datasets, so we move the datasets
# alone to the data directory.
for lilac_hf_dataset in hf_config['datasets']:
print('Downloading dataset from HuggingFace: ', lilac_hf_dataset)
snapshot_download(
repo_id=lilac_hf_dataset,
repo_type='dataset',
token=env('HF_ACCESS_TOKEN'),
local_dir=get_datasets_dir(data_path()),
ignore_patterns=['.gitattributes', 'README.md'])
snapshot_dir = snapshot_download(repo_id=repo_id, repo_type='space', token=env('HF_ACCESS_TOKEN'))
# Copy datasets.
spaces_data_dir = os.path.join(snapshot_dir, 'data')
# Delete cache files from persistent storage.
cache_dir = get_lilac_cache_dir(data_path())
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
# NOTE: This is temporary during the move of concepts into the pip package. Once all the demos
# have been updated, this block can be deleted.
old_lilac_concepts_data_dir = os.path.join(data_path(), CONCEPTS_DIR, 'lilac')
if os.path.exists(old_lilac_concepts_data_dir):
shutil.rmtree(old_lilac_concepts_data_dir)
# Copy cache files from the space if they exist.
spaces_cache_dir = get_lilac_cache_dir(spaces_data_dir)
if os.path.exists(spaces_cache_dir):
shutil.copytree(spaces_cache_dir, cache_dir)
# Copy concepts.
concepts = DiskConceptDB(spaces_data_dir).list()
for concept in concepts:
# Ignore lilac concepts, they're already part of the source code.
if concept.namespace == 'lilac':
continue
spaces_concept_output_dir = get_concept_output_dir(spaces_data_dir, concept.namespace,
concept.name)
persistent_output_dir = get_concept_output_dir(data_path(), concept.namespace, concept.name)
shutil.rmtree(persistent_output_dir, ignore_errors=True)
shutil.copytree(spaces_concept_output_dir, persistent_output_dir, dirs_exist_ok=True)
shutil.rmtree(spaces_concept_output_dir, ignore_errors=True)
if __name__ == '__main__':
main()