davanstrien HF staff commited on
Commit
3e2784f
1 Parent(s): 551f450

load viewer data

Browse files
Files changed (2) hide show
  1. load_viewer_data.py +88 -0
  2. prep_viewer_data.py +158 -0
load_viewer_data.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+
4
+ import chromadb
5
+ import httpx
6
+ import requests
7
+ import stamina
8
+ from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
9
+ from huggingface_hub import InferenceClient
10
+ from tqdm.auto import tqdm
11
+ from tqdm.contrib.concurrent import thread_map
12
+
13
+ from prep_viewer_data import prep_data
14
+
15
+ # Set up logging
16
+ logger = logging.getLogger(__name__)
17
+ logger.setLevel(logging.INFO)
18
+
19
+
20
+ def initialize_clients():
21
+ logger.info("Initializing clients")
22
+ chroma_client = chromadb.PersistentClient()
23
+ inference_client = InferenceClient(
24
+ "https://bm143rfir2on1bkw.us-east-1.aws.endpoints.huggingface.cloud"
25
+ )
26
+ return chroma_client, inference_client
27
+
28
+
29
+ def create_collection(chroma_client):
30
+ logger.info("Creating or getting collection")
31
+ embedding_function = SentenceTransformerEmbeddingFunction(
32
+ model_name="davanstrien/dataset-viewer-descriptions-processed-st",
33
+ trust_remote_code=True,
34
+ )
35
+ return chroma_client.create_collection(
36
+ name="dataset-viewer-descriptions",
37
+ get_or_create=True,
38
+ embedding_function=embedding_function,
39
+ metadata={"hnsw:space": "cosine"},
40
+ )
41
+
42
+
43
+ @stamina.retry(on=requests.HTTPError, attempts=3, wait_initial=10)
44
+ def embed_card(text, client):
45
+ text = text[:8192]
46
+ return client.feature_extraction(text)
47
+
48
+
49
+ def embed_and_upsert_datasets(
50
+ dataset_rows_and_ids, collection, inference_client, batch_size=10
51
+ ):
52
+ logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets")
53
+ for i in tqdm(range(0, len(dataset_rows_and_ids), batch_size)):
54
+ batch = dataset_rows_and_ids[i : i + batch_size]
55
+ ids = []
56
+ documents = []
57
+ for item in batch:
58
+ ids.append(item["dataset_id"])
59
+ documents.append(f"HUB_DATASET_PREVIEW: {item['formatted_prompt']}")
60
+ results = thread_map(
61
+ lambda doc: embed_card(doc, inference_client), documents, leave=False
62
+ )
63
+ collection.upsert(
64
+ ids=ids,
65
+ embeddings=[embedding.tolist()[0] for embedding in results],
66
+ )
67
+ logger.debug(f"Processed batch {i//batch_size + 1}")
68
+
69
+
70
+ async def refresh_viewer_data(sample_size=100_000, min_likes=2):
71
+ logger.info(
72
+ f"Refreshing viewer data with sample_size={sample_size} and min_likes={min_likes}"
73
+ )
74
+ chroma_client, inference_client = initialize_clients()
75
+ collection = create_collection(chroma_client)
76
+
77
+ logger.info("Preparing data")
78
+ df = await prep_data(sample_size=sample_size, min_likes=min_likes)
79
+ dataset_rows_and_ids = df.to_dicts()
80
+
81
+ logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets")
82
+ embed_and_upsert_datasets(dataset_rows_and_ids, collection, inference_client)
83
+ logger.info("Refresh completed successfully")
84
+
85
+
86
+ if __name__ == "__main__":
87
+ logging.basicConfig(level=logging.INFO)
88
+ asyncio.run(refresh_viewer_data())
prep_viewer_data.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import random
4
+
5
+ import httpx
6
+ import polars as pl
7
+ from huggingface_hub import list_datasets
8
+ from tqdm import tqdm
9
+ from tqdm.asyncio import tqdm_asyncio
10
+
11
+ # Initialize the HTTP client
12
+ client = httpx.AsyncClient(timeout=60, http2=True)
13
+
14
+
15
+ async def generate_dataset_prompt(dataset_name, num_rows=2):
16
+ try:
17
+ base_url = "https://datasets-server.huggingface.co"
18
+
19
+ # Get splits and configs
20
+ splits_url = f"{base_url}/splits?dataset={dataset_name}"
21
+ splits_response = await client.get(splits_url)
22
+ splits_data = splits_response.json()
23
+
24
+ if not splits_data.get("splits"):
25
+ return None
26
+
27
+ # Get the first config and split
28
+ first_split = splits_data["splits"][0]
29
+ config_name = first_split["config"]
30
+ split_name = first_split["split"]
31
+
32
+ # Get dataset info for the specific config
33
+ info_url = f"{base_url}/info?dataset={dataset_name}&config={config_name}"
34
+ info_response = await client.get(info_url)
35
+ info_data = info_response.json()
36
+
37
+ # Get first rows for the specific config and split
38
+ first_rows_url = f"{base_url}/first-rows?dataset={dataset_name}&config={config_name}&split={split_name}"
39
+ first_rows_response = await client.get(first_rows_url)
40
+ first_rows_data = first_rows_response.json()
41
+
42
+ # Get size information
43
+ size_url = f"{base_url}/size?dataset={dataset_name}"
44
+ size_response = await client.get(size_url)
45
+ size_data = size_response.json()
46
+
47
+ # Extract relevant information
48
+ dataset_info = info_data.get("dataset_info", {})
49
+ features = dataset_info.get("features", {})
50
+ splits = dataset_info.get("splits", {})
51
+
52
+ # Calculate total examples and size
53
+ total_examples = sum(split.get("num_examples", 0) for split in splits.values())
54
+ total_size = (
55
+ size_data.get("size", {})
56
+ .get("dataset", {})
57
+ .get("num_bytes_original_files", 0)
58
+ )
59
+
60
+ # Format features
61
+ def format_feature(name, details):
62
+ if isinstance(details, dict):
63
+ feature_type = details.get(
64
+ "dtype", details.get("_type", "unknown type")
65
+ )
66
+ elif isinstance(details, list):
67
+ feature_type = "list"
68
+ else:
69
+ feature_type = str(type(details).__name__)
70
+ return f"- {name} ({feature_type})"
71
+
72
+ formatted_features = "\n".join(
73
+ format_feature(name, details) for name, details in features.items()
74
+ )
75
+
76
+ # Format sample data (specified number of rows)
77
+ sample_data = json.dumps(first_rows_data.get("rows", [])[:num_rows], indent=2)
78
+
79
+ # Create the formatted prompt
80
+ prompt = f"""
81
+ Dataset: "{dataset_name}"
82
+
83
+ Features:
84
+ {formatted_features}
85
+
86
+ Splits and Configs:
87
+ {', '.join(f"{split['config']}/{split['split']}" for split in splits_data['splits'])}
88
+
89
+ Size Statistics:
90
+ Total Examples: {total_examples}
91
+ Split Sizes: {', '.join(f"{split}: {info['num_examples']}" for split, info in splits.items())}
92
+
93
+ Data Sample ({num_rows} rows out of {total_examples} total):
94
+ {sample_data}
95
+ """
96
+
97
+ return prompt.strip()
98
+ except Exception as e:
99
+ print(f"Error for {dataset_name}: {e}")
100
+ return None
101
+
102
+
103
+ async def process_batch(batch):
104
+ results = await tqdm_asyncio.gather(
105
+ *[generate_dataset_prompt(dataset) for dataset in batch], leave=False
106
+ )
107
+ return [
108
+ (dataset_id, prompt)
109
+ for dataset_id, prompt in zip(batch, results)
110
+ if prompt is not None
111
+ ]
112
+
113
+
114
+ async def prep_data(sample_size=200_000, min_likes=1):
115
+ # Load the dataset containing dataset IDs
116
+ df = pl.read_parquet(
117
+ "hf://datasets/davanstrien/dataset-viewer-descriptions-processed/data/train-00000-of-00001.parquet"
118
+ )
119
+ in_train_or_test = set(df["dataset_id"].unique().to_list())
120
+
121
+ # Get all datasets
122
+ datasets = [
123
+ dataset for dataset in list_datasets() if dataset.id not in in_train_or_test
124
+ ]
125
+ # filter to datasets with 1 or more likes
126
+ if min_likes:
127
+ datasets = [dataset for dataset in datasets if dataset.likes >= min_likes]
128
+ datasets = [dataset.id for dataset in datasets]
129
+ # Sample datasets (adjust the number as needed)
130
+ datasets = random.sample(datasets, min(sample_size, len(datasets)))
131
+
132
+ # Process datasets in batches of 100
133
+ batch_size = 500
134
+ all_results = []
135
+
136
+ for i in tqdm(range(0, len(datasets), batch_size), desc="Processing batches"):
137
+ batch = datasets[i : i + batch_size]
138
+ batch_results = await process_batch(batch)
139
+ all_results.extend(batch_results)
140
+
141
+ # Optional: Save intermediate results
142
+ if len(all_results) % 1000 == 0:
143
+ intermediate_df = pl.DataFrame(
144
+ {
145
+ "dataset_id": [row[0] for row in all_results],
146
+ "formatted_prompt": [row[1] for row in all_results],
147
+ }
148
+ )
149
+ intermediate_df.write_parquet(
150
+ f"dataset_prompts_intermediate_{len(all_results)}.parquet"
151
+ )
152
+
153
+ return pl.DataFrame(
154
+ {
155
+ "dataset_id": [row[0] for row in all_results],
156
+ "formatted_prompt": [row[1] for row in all_results],
157
+ }
158
+ )