Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 7,307 Bytes
cb643a9 e8f13e9 cb643a9 e8f13e9 cb643a9 e8f13e9 cb643a9 e8f13e9 5e2248d e35e532 5e2248d e8f13e9 5e2248d 25678ac e8f13e9 5e2248d a2828aa 5e2248d 4ca0034 5e2248d e8f13e9 5e2248d e8f13e9 5e2248d e8f13e9 5e2248d a2828aa 5e2248d e8f13e9 5e2248d e8f13e9 5e2248d e35e532 e8f13e9 5e2248d e8f13e9 e35e532 e8f13e9 5e2248d e8f13e9 5e2248d 4a5d03c e8f13e9 e35e532 e8f13e9 5e2248d e8f13e9 5e2248d e8f13e9 5e2248d e8f13e9 5e2248d e8f13e9 5e2248d e8f13e9 5e2248d e8f13e9 5e2248d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import logging
import os
import platform
from datetime import datetime
from typing import List, Literal, Optional, Tuple
import chromadb
import polars as pl
import requests
import stamina
from chromadb.utils import embedding_functions
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
from tqdm.contrib.concurrent import thread_map
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
EMBEDDING_MODEL_NAME = "Snowflake/snowflake-arctic-embed-m-long"
EMBEDDING_MODEL_REVISION = "ac9d0cb43661ee1f7d67b3aa63614d65a6c86463"
INFERENCE_MODEL_URL = (
"https://pqzap00ebpl1ydt4.us-east-1.aws.endpoints.huggingface.cloud"
)
DATASET_PARQUET_URL = (
"hf://datasets/librarian-bots/dataset_cards_with_metadata/data/train-*.parquet"
)
COLLECTION_NAME = "dataset_cards"
MAX_EMBEDDING_LENGTH = 8192
def get_save_path() -> Literal["chroma/"] | Literal["/data/chroma/"]:
path = "chroma/" if platform.system() == "Darwin" else "/data/chroma/"
logger.info(f"Using save path: {path}")
return path
SAVE_PATH = get_save_path()
def get_chroma_client():
logger.info("Initializing Chroma client")
return chromadb.PersistentClient(path=SAVE_PATH)
def get_embedding_function():
logger.info(f"Initializing embedding function with model: {EMBEDDING_MODEL_NAME}")
return embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=EMBEDDING_MODEL_NAME,
trust_remote_code=True,
revision=EMBEDDING_MODEL_REVISION,
)
def get_collection(chroma_client, embedding_function):
logger.info(f"Getting or creating collection: {COLLECTION_NAME}")
return chroma_client.create_collection(
name=COLLECTION_NAME, get_or_create=True, embedding_function=embedding_function
)
def get_last_modified_in_collection(collection) -> datetime | None:
logger.info("Fetching last modified date from collection")
try:
all_items = collection.get(include=["metadatas"])
if last_modified := [
datetime.fromisoformat(item["last_modified"])
for item in all_items["metadatas"]
]:
last_mod = max(last_modified)
logger.info(f"Last modified date: {last_mod}")
return last_mod
else:
logger.info("No last modified date found")
return None
except Exception as e:
logger.error(f"Error fetching last modified date: {str(e)}")
return None
def parse_markdown_column(
df: pl.DataFrame, markdown_column: str, dataset_id_column: str
) -> pl.DataFrame:
logger.info("Parsing markdown column")
return df.with_columns(
parsed_markdown=(
pl.col(markdown_column)
.str.extract(r"(?s)^---.*?---\s*(.*)", group_index=1)
.fill_null(pl.col(markdown_column))
.str.strip_chars()
),
prepended_markdown=(
pl.concat_str(
[
pl.lit("Dataset ID "),
pl.col(dataset_id_column).cast(pl.Utf8),
pl.lit("\n\n"),
pl.col(markdown_column)
.str.extract(r"(?s)^---.*?---\s*(.*)", group_index=1)
.fill_null(pl.col(markdown_column))
.str.strip_chars(),
]
)
),
)
def is_unmodified_template(card: str) -> bool:
# Check for a combination of template-specific phrases
template_indicators = [
"# Dataset Card for Dataset Name",
"<!-- Provide a quick summary of the dataset. -->",
"This dataset card aims to be a base template for new datasets",
"[More Information Needed]",
]
# Count how many indicators are present
indicator_count = sum(indicator in card for indicator in template_indicators)
# Check if the card contains a high number of "[More Information Needed]" occurrences
more_info_needed_count = card.count("[More Information Needed]")
# Consider it an unmodified template if it has most of the indicators
# and a high number of "[More Information Needed]" occurrences
return indicator_count >= 3 or more_info_needed_count >= 7
def load_cards(
min_len: int = 50,
min_likes: int | None = None,
last_modified: Optional[datetime] = None,
) -> Optional[Tuple[List[str], List[str], List[datetime]]]:
logger.info(
f"Loading cards with min_len={min_len}, min_likes={min_likes}, last_modified={last_modified}"
)
df = pl.read_parquet(DATASET_PARQUET_URL)
df = df.filter(~pl.col("tags").list.contains("not-for-all-audiences"))
df = parse_markdown_column(df, "card", "datasetId")
df = df.with_columns(pl.col("parsed_markdown").str.len_chars().alias("card_len"))
df = df.filter(pl.col("card_len") > min_len)
if min_likes:
df = df.filter(pl.col("likes") > min_likes)
if last_modified:
df = df.filter(pl.col("last_modified") > last_modified)
# Filter out unmodified template cards
df = df.filter(
~pl.col("prepended_markdown").map_elements(
is_unmodified_template, return_dtype=pl.Boolean
)
)
if len(df) == 0:
logger.info("No cards found matching criteria")
return None
cards = df.get_column("prepended_markdown").to_list()
model_ids = df.get_column("datasetId").to_list()
last_modifieds = df.get_column("last_modified").to_list()
logger.info(f"Loaded {len(cards)} cards")
return cards, model_ids, last_modifieds
@stamina.retry(on=requests.HTTPError, attempts=3, wait_initial=10)
def embed_card(text, client):
text = text[:MAX_EMBEDDING_LENGTH]
return client.feature_extraction(text)
def get_inference_client():
logger.info(f"Initializing inference client with model: {INFERENCE_MODEL_URL}")
return InferenceClient(
model=INFERENCE_MODEL_URL,
token=HF_TOKEN,
)
def refresh_data(min_len: int = 200, min_likes: Optional[int] = None):
logger.info(f"Starting data refresh with min_len={min_len}, min_likes={min_likes}")
chroma_client = get_chroma_client()
embedding_function = get_embedding_function()
collection = get_collection(chroma_client, embedding_function)
most_recent = get_last_modified_in_collection(collection)
if data := load_cards(
min_len=min_len, min_likes=min_likes, last_modified=most_recent
):
_create_and_upsert_embeddings(data, collection)
else:
logger.info("No new data to refresh")
def _create_and_upsert_embeddings(data, collection):
cards, model_ids, last_modifieds = data
logger.info("Embedding cards...")
inference_client = get_inference_client()
results = thread_map(lambda card: embed_card(card, inference_client), cards)
logger.info(f"Upserting {len(model_ids)} items to collection")
collection.upsert(
ids=model_ids,
embeddings=[embedding.tolist()[0] for embedding in results],
metadatas=[{"last_modified": str(lm)} for lm in last_modifieds],
)
logger.info("Data refresh completed successfully")
if __name__ == "__main__":
refresh_data()
|