davanstrien HF staff commited on
Commit
5e2248d
1 Parent(s): e8f13e9

chore: Refactor load_data.py for improved readability and maintainability

Browse files
Files changed (1) hide show
  1. load_data.py +87 -50
load_data.py CHANGED
@@ -1,59 +1,82 @@
1
  import chromadb
2
  import platform
3
  import polars as pl
4
- import polars as pl
5
  from chromadb.utils import embedding_functions
6
- from typing import List, Tuple, Optional
7
  from huggingface_hub import InferenceClient
8
  from tqdm.contrib.concurrent import thread_map
9
- from huggingface_hub import login
10
  from dotenv import load_dotenv
11
  import os
12
- from datetime import datetime, timedelta
13
  import stamina
14
  import requests
15
- import polars as pl
16
- from typing import Literal
 
 
 
 
 
17
 
18
  load_dotenv()
 
 
19
  HF_TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
 
 
 
20
 
21
 
22
  def get_save_path() -> Literal["chroma/"] | Literal["/data/chroma/"]:
23
- return "chroma/" if platform.system() == "Darwin" else "/data/chroma/"
 
 
24
 
25
 
26
- save_path = get_save_path()
27
 
28
 
29
- chroma_client = chromadb.PersistentClient(
30
- path=save_path,
31
- )
32
- sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
33
- model_name="Snowflake/snowflake-arctic-embed-m-long", trust_remote_code=True
34
- )
35
- collection = chroma_client.create_collection(
36
- name="dataset_cards", get_or_create=True, embedding_function=sentence_transformer_ef
37
- )
 
38
 
39
 
40
- def get_last_modified_in_collection() -> datetime | None:
41
- all_items = collection.get(
42
- include=[
43
- "metadatas",
44
- ]
45
  )
 
 
 
 
 
46
  if last_modified := [
47
  datetime.fromisoformat(item["last_modified"]) for item in all_items["metadatas"]
48
  ]:
49
- return max(last_modified)
 
 
50
  else:
 
51
  return None
52
 
53
 
54
  def parse_markdown_column(
55
  df: pl.DataFrame, markdown_column: str, dataset_id_column: str
56
  ) -> pl.DataFrame:
 
57
  return df.with_columns(
58
  parsed_markdown=(
59
  pl.col(markdown_column)
@@ -81,58 +104,72 @@ def load_cards(
81
  min_len: int = 50,
82
  min_likes: int | None = None,
83
  last_modified: Optional[datetime] = None,
84
- ) -> (
85
- None
86
- | Tuple[
87
- List[str],
88
- List[str],
89
- List[datetime],
90
- ]
91
- ):
92
- df = pl.read_parquet(
93
- "hf://datasets/librarian-bots/dataset_cards_with_metadata_with_embeddings/data/train-00000-of-00001.parquet"
94
  )
 
95
  df = parse_markdown_column(df, "card", "datasetId")
96
  df = df.with_columns(pl.col("parsed_markdown").str.len_chars().alias("card_len"))
97
- print(df)
98
  df = df.filter(pl.col("card_len") > min_len)
99
- print(df)
100
  if min_likes:
101
  df = df.filter(pl.col("likes") > min_likes)
102
  if last_modified:
103
  df = df.filter(pl.col("last_modified") > last_modified)
104
  if len(df) == 0:
 
105
  return None
106
 
107
  cards = df.get_column("prepended_markdown").to_list()
108
  model_ids = df.get_column("datasetId").to_list()
109
  last_modifieds = df.get_column("last_modified").to_list()
 
110
  return cards, model_ids, last_modifieds
111
 
112
 
113
- client = InferenceClient(
114
- model="https://pqzap00ebpl1ydt4.us-east-1.aws.endpoints.huggingface.cloud",
115
- token=HF_TOKEN,
116
- )
117
-
118
-
119
  @stamina.retry(on=requests.HTTPError, attempts=3, wait_initial=10)
120
- def embed_card(text):
121
- text = text[:8192]
122
  return client.feature_extraction(text)
123
 
124
 
125
- most_recent = get_last_modified_in_collection()
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- if data := load_cards(min_len=200, min_likes=None, last_modified=most_recent):
 
 
 
 
 
 
 
 
 
 
128
  cards, model_ids, last_modifieds = data
129
- print("mapping...")
130
- results = thread_map(embed_card, cards)
 
 
131
  collection.upsert(
132
  ids=model_ids,
133
  embeddings=[embedding.tolist()[0] for embedding in results],
134
  metadatas=[{"last_modified": str(lm)} for lm in last_modifieds],
135
  )
136
- print("done")
137
- else:
138
- print("no new data")
 
 
 
1
  import chromadb
2
  import platform
3
  import polars as pl
 
4
  from chromadb.utils import embedding_functions
5
+ from typing import List, Tuple, Optional, Literal
6
  from huggingface_hub import InferenceClient
7
  from tqdm.contrib.concurrent import thread_map
 
8
  from dotenv import load_dotenv
9
  import os
10
+ from datetime import datetime
11
  import stamina
12
  import requests
13
+ import logging
14
+
15
+ # Set up logging
16
+ logging.basicConfig(
17
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
18
+ )
19
+ logger = logging.getLogger(__name__)
20
 
21
  load_dotenv()
22
+
23
+ # Top-level module variables
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
+ EMBEDDING_MODEL_NAME = "Snowflake/snowflake-arctic-embed-m-long"
26
+ INFERENCE_MODEL_URL = (
27
+ "https://pqzap00ebpl1ydt4.us-east-1.aws.endpoints.huggingface.cloud"
28
+ )
29
+ DATASET_PARQUET_URL = "hf://datasets/librarian-bots/dataset_cards_with_metadata_with_embeddings/data/train-00000-of-00001.parquet"
30
+ COLLECTION_NAME = "dataset_cards"
31
+ MAX_EMBEDDING_LENGTH = 8192
32
 
33
 
34
  def get_save_path() -> Literal["chroma/"] | Literal["/data/chroma/"]:
35
+ path = "chroma/" if platform.system() == "Darwin" else "/data/chroma/"
36
+ logger.info(f"Using save path: {path}")
37
+ return path
38
 
39
 
40
+ SAVE_PATH = get_save_path()
41
 
42
 
43
+ def get_chroma_client():
44
+ logger.info("Initializing Chroma client")
45
+ return chromadb.PersistentClient(path=SAVE_PATH)
46
+
47
+
48
+ def get_embedding_function():
49
+ logger.info(f"Initializing embedding function with model: {EMBEDDING_MODEL_NAME}")
50
+ return embedding_functions.SentenceTransformerEmbeddingFunction(
51
+ model_name=EMBEDDING_MODEL_NAME, trust_remote_code=True
52
+ )
53
 
54
 
55
+ def get_collection(chroma_client, embedding_function):
56
+ logger.info(f"Getting or creating collection: {COLLECTION_NAME}")
57
+ return chroma_client.create_collection(
58
+ name=COLLECTION_NAME, get_or_create=True, embedding_function=embedding_function
 
59
  )
60
+
61
+
62
+ def get_last_modified_in_collection(collection) -> datetime | None:
63
+ logger.info("Fetching last modified date from collection")
64
+ all_items = collection.get(include=["metadatas"])
65
  if last_modified := [
66
  datetime.fromisoformat(item["last_modified"]) for item in all_items["metadatas"]
67
  ]:
68
+ last_mod = max(last_modified)
69
+ logger.info(f"Last modified date: {last_mod}")
70
+ return last_mod
71
  else:
72
+ logger.info("No last modified date found")
73
  return None
74
 
75
 
76
  def parse_markdown_column(
77
  df: pl.DataFrame, markdown_column: str, dataset_id_column: str
78
  ) -> pl.DataFrame:
79
+ logger.info("Parsing markdown column")
80
  return df.with_columns(
81
  parsed_markdown=(
82
  pl.col(markdown_column)
 
104
  min_len: int = 50,
105
  min_likes: int | None = None,
106
  last_modified: Optional[datetime] = None,
107
+ ) -> Optional[Tuple[List[str], List[str], List[datetime]]]:
108
+ logger.info(
109
+ f"Loading cards with min_len={min_len}, min_likes={min_likes}, last_modified={last_modified}"
 
 
 
 
 
 
 
110
  )
111
+ df = pl.read_parquet(DATASET_PARQUET_URL)
112
  df = parse_markdown_column(df, "card", "datasetId")
113
  df = df.with_columns(pl.col("parsed_markdown").str.len_chars().alias("card_len"))
 
114
  df = df.filter(pl.col("card_len") > min_len)
 
115
  if min_likes:
116
  df = df.filter(pl.col("likes") > min_likes)
117
  if last_modified:
118
  df = df.filter(pl.col("last_modified") > last_modified)
119
  if len(df) == 0:
120
+ logger.info("No cards found matching criteria")
121
  return None
122
 
123
  cards = df.get_column("prepended_markdown").to_list()
124
  model_ids = df.get_column("datasetId").to_list()
125
  last_modifieds = df.get_column("last_modified").to_list()
126
+ logger.info(f"Loaded {len(cards)} cards")
127
  return cards, model_ids, last_modifieds
128
 
129
 
 
 
 
 
 
 
130
  @stamina.retry(on=requests.HTTPError, attempts=3, wait_initial=10)
131
+ def embed_card(text, client):
132
+ text = text[:MAX_EMBEDDING_LENGTH]
133
  return client.feature_extraction(text)
134
 
135
 
136
+ def get_inference_client():
137
+ logger.info(f"Initializing inference client with model: {INFERENCE_MODEL_URL}")
138
+ return InferenceClient(
139
+ model=INFERENCE_MODEL_URL,
140
+ token=HF_TOKEN,
141
+ )
142
+
143
+
144
+ def refresh_data(min_len: int = 200, min_likes: Optional[int] = None):
145
+ logger.info(f"Starting data refresh with min_len={min_len}, min_likes={min_likes}")
146
+ chroma_client = get_chroma_client()
147
+ embedding_function = get_embedding_function()
148
+ collection = get_collection(chroma_client, embedding_function)
149
 
150
+ most_recent = get_last_modified_in_collection(collection)
151
+
152
+ if data := load_cards(
153
+ min_len=min_len, min_likes=min_likes, last_modified=most_recent
154
+ ):
155
+ _create_and_upsert_embeddings(data, collection)
156
+ else:
157
+ logger.info("No new data to refresh")
158
+
159
+
160
+ def _create_and_upsert_embeddings(data, collection):
161
  cards, model_ids, last_modifieds = data
162
+ logger.info("Embedding cards...")
163
+ inference_client = get_inference_client()
164
+ results = thread_map(lambda card: embed_card(card, inference_client), cards)
165
+ logger.info(f"Upserting {len(model_ids)} items to collection")
166
  collection.upsert(
167
  ids=model_ids,
168
  embeddings=[embedding.tolist()[0] for embedding in results],
169
  metadatas=[{"last_modified": str(lm)} for lm in last_modifieds],
170
  )
171
+ logger.info("Data refresh completed successfully")
172
+
173
+
174
+ if __name__ == "__main__":
175
+ refresh_data()