davanstrien HF staff commited on
Commit
ed553e8
·
1 Parent(s): 16eb3c6

improve refresh logic

Browse files
Files changed (1) hide show
  1. main.py +44 -19
main.py CHANGED
@@ -13,6 +13,7 @@ import polars as pl
13
  from huggingface_hub import HfApi
14
  from transformers import AutoTokenizer
15
  import torch
 
16
 
17
  # Configuration constants
18
  MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
@@ -89,15 +90,11 @@ def get_embedding_function():
89
  def setup_database():
90
  try:
91
  embedding_function = get_embedding_function()
92
-
93
- # Create dataset collection
94
  dataset_collection = client.get_or_create_collection(
95
  embedding_function=embedding_function,
96
  name="dataset_cards",
97
  metadata={"hnsw:space": "cosine"},
98
  )
99
-
100
- # Create model collection
101
  model_collection = client.get_or_create_collection(
102
  embedding_function=embedding_function,
103
  name="model_cards",
@@ -111,26 +108,52 @@ def setup_database():
111
  df = df.filter(
112
  pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_()
113
  )
114
- row_count = df.select(pl.len()).collect().item()
115
- logger.info(f"Row count of dataset data: {row_count}")
116
 
117
- # Check if we need to update the collection
118
- current_count = dataset_collection.count()
119
- logger.info(f"Current dataset collection count: {current_count}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- if current_count < row_count:
 
 
 
 
122
  logger.info(
123
- f"Updating dataset collection with {row_count - current_count} new records"
124
- )
125
- # Load parquet files and upsert into ChromaDB
126
- df = df.select(
127
- ["datasetId", "summary", "likes", "downloads", "last_modified"]
128
  )
129
- df = df.collect()
130
- total_rows = len(df)
131
 
132
  for i in range(0, total_rows, BATCH_SIZE):
133
  batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i))
 
 
 
 
 
134
 
135
  dataset_collection.upsert(
136
  ids=batch_df.select(["datasetId"]).to_series().to_list(),
@@ -148,9 +171,11 @@ def setup_database():
148
  )
149
  ],
150
  )
151
- logger.info(f"Processed {i + len(batch_df):,} / {total_rows:,} rows")
152
 
153
- logger.info(f"Database initialized with {dataset_collection.count():,} rows")
 
 
154
 
155
  # Load model data
156
  model_df = pl.scan_parquet(
 
13
  from huggingface_hub import HfApi
14
  from transformers import AutoTokenizer
15
  import torch
16
+ import dateutil.parser
17
 
18
  # Configuration constants
19
  MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
 
90
  def setup_database():
91
  try:
92
  embedding_function = get_embedding_function()
 
 
93
  dataset_collection = client.get_or_create_collection(
94
  embedding_function=embedding_function,
95
  name="dataset_cards",
96
  metadata={"hnsw:space": "cosine"},
97
  )
 
 
98
  model_collection = client.get_or_create_collection(
99
  embedding_function=embedding_function,
100
  name="model_cards",
 
108
  df = df.filter(
109
  pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_()
110
  )
 
 
111
 
112
+ # Get the most recent last_modified date from the collection
113
+ latest_update = None
114
+ if dataset_collection.count() > 0:
115
+ metadata = dataset_collection.get(include=["metadatas"]).get("metadatas")
116
+ logger.info(f"Found {len(metadata)} existing records in collection")
117
+
118
+ last_modifieds = [
119
+ dateutil.parser.parse(m.get("last_modified")) for m in metadata
120
+ ]
121
+ latest_update = max(last_modifieds)
122
+ logger.info(f"Most recent record in DB from: {latest_update}")
123
+ logger.info(f"Oldest record in DB from: {min(last_modifieds)}")
124
+
125
+ # Filter and process only newer records
126
+ df = df.select(["datasetId", "summary", "likes", "downloads", "last_modified"])
127
+
128
+ # Log some stats about the incoming data
129
+ sample_dates = df.select("last_modified").limit(5).collect()
130
+ logger.info(f"Sample of incoming dates: {sample_dates}")
131
+
132
+ total_incoming = df.select(pl.len()).collect().item()
133
+ logger.info(f"Total incoming records: {total_incoming}")
134
+
135
+ if latest_update:
136
+ logger.info(f"Filtering records newer than {latest_update}")
137
+ df = df.filter(pl.col("last_modified") > latest_update)
138
+ filtered_count = df.select(pl.len()).collect().item()
139
+ logger.info(f"Found {filtered_count} records to update after filtering")
140
 
141
+ df = df.collect()
142
+ total_rows = len(df)
143
+
144
+ if total_rows > 0:
145
+ logger.info(f"Updating dataset collection with {total_rows} new records")
146
  logger.info(
147
+ f"Date range of updates: {df['last_modified'].min()} to {df['last_modified'].max()}"
 
 
 
 
148
  )
 
 
149
 
150
  for i in range(0, total_rows, BATCH_SIZE):
151
  batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i))
152
+ batch_size = len(batch_df)
153
+ logger.info(
154
+ f"Processing batch {i // BATCH_SIZE + 1}: {batch_size} records "
155
+ f"({batch_df['last_modified'].min()} to {batch_df['last_modified'].max()})"
156
+ )
157
 
158
  dataset_collection.upsert(
159
  ids=batch_df.select(["datasetId"]).to_series().to_list(),
 
171
  )
172
  ],
173
  )
174
+ logger.info(f"Processed {i + batch_size:,} / {total_rows:,} records")
175
 
176
+ logger.info(
177
+ f"Database initialized with {dataset_collection.count():,} total rows"
178
+ )
179
 
180
  # Load model data
181
  model_df = pl.scan_parquet(