import nomic import pandas as pd from tqdm import tqdm from datasets import load_dataset, \ get_dataset_split_names, \ get_dataset_config_names, \ ClassLabel, utils utils.logging.set_verbosity_error() import pyarrow as pa from dateutil.parser import parse import time def get_datum_fields(dataset_dict, n_samples = 100, unique_cutoff=20): # take a sample of points dataset = dataset_dict["first_split_dataset"] sample = pd.DataFrame(dataset.shuffle(seed=42).take(n_samples)) features = dataset.features indexable_field = None numeric_fields = [] string_fields = [] bool_fields = [] list_fields = [] label_fields = [] categorical_fields = [] datetime_fields = [] uncategorized_fields = [] if unique_cutoff < 1: unique_cutoff = unique_cutoff*len(sample) for field, dtype in dataset_dict["schema"].items(): try: num_unique = sample[field].nunique() except: num_unique = len(sample) if dtype == "string": if num_unique < unique_cutoff: categorical_fields.append(field) else: is_datetime = True for row in sample: try: parse(row[field], fuzzy=False) except: is_datetime = False break if is_datetime: datetime_fields.append(field) else: string_fields.append(field) elif dtype in ("float"): numeric_fields.append(field) elif dtype in ("int64", "int32", "int16", "int8"): if features is not None and field in features and isinstance(features[field], ClassLabel): label_fields.append(field) elif num_unique < unique_cutoff: categorical_fields.append(field) else: numeric_fields.append(field) elif dtype == "bool": bool_fields.append(field) elif "list" == dtype[0:4]: list_fields.append(field) else: uncategorized_fields.append(field) longest_length = 0 for field in string_fields: length = 0 for i in range(len(sample)): if sample[field][i]: length += len(str(sample[field][i]).split()) if length > longest_length: longest_length = length indexable_field = field return features, \ numeric_fields, \ string_fields, \ bool_fields, \ list_fields, \ label_fields, \ categorical_fields, \ datetime_fields, \ uncategorized_fields, \ indexable_field def load_dataset_and_metadata(dataset_name, config=None, streaming=True): configs = get_dataset_config_names(dataset_name) if config is None: config = configs[0] splits = get_dataset_split_names(dataset_name, config) dataset = load_dataset(dataset_name, config, split = splits[0], streaming=streaming) head = pa.Table.from_pydict(dataset._head()) schema_dict = {field.name: str(field.type) for field in head.schema} dataset_dict = { "first_split_dataset": dataset, "name": dataset_name, "config": config, "splits": splits, "schema": schema_dict, "head": head } return dataset_dict def upload_dataset_to_atlas(dataset_dict, atlas_api_token: str, project_name = None, unique_id_field_name=None, indexed_field = None, modality=None, organization_name=None, wait_for_map=True, datum_limit=30000): nomic.login(atlas_api_token) if modality is None: modality = "text" if unique_id_field_name is None: unique_id_field_name = "atlas_datum_id" if project_name is None: project_name = dataset_dict["name"].replace("/", "--") + "--hf-atlas-map" desc = f"Config: {dataset_dict['config']}" features, \ numeric_fields, \ string_fields, \ bool_fields, \ list_fields, \ label_fields, \ categorical_fields, \ datetime_fields, \ uncategorized_fields, \ indexable_field = get_datum_fields(dataset_dict) if indexed_field is None: indexed_field = indexable_field topic_label_field = None if modality == "embedding": topic_label_field = indexed_field indexed_field = None easy_fields = string_fields + bool_fields + list_fields + categorical_fields proj = nomic.AtlasProject(name=project_name, modality=modality, unique_id_field=unique_id_field_name, organization_name=organization_name, description=desc, reset_project_if_exists=True) colorable_fields = ["split"] batch_size = 1000 batched_texts = [] allow_upload = True for split in dataset_dict["splits"]: if not allow_upload: break dataset = load_dataset(dataset_dict["name"], dataset_dict["config"], split = split, streaming=True) for i, ex in tqdm(enumerate(dataset)): if i % 10000 == 0: time.sleep(2) if i == datum_limit: print("Datum upload limited to 30,000 points. Stopping upload...") allow_upload = False break data_to_add = {"split": split, unique_id_field_name: f"{split}_{i}"} for field in numeric_fields: data_to_add[field] = ex[field] for field in easy_fields: val = "" if ex[field]: val = str(ex[field]) data_to_add[field] = val for field in datetime_fields: try: data_to_add[field] = parse(ex[field], fuzzy=False) except: data_to_add[field] = None for field in label_fields: label_name = "" if ex[field] is not None: index = ex[field] # NOTE: THIS MAY BREAK if -1 is ACTUALLY NO LABEL if index != -1: label_name = features[field].names[ex[field]] data_to_add[field] = str(ex[field]) data_to_add[field + "_name"] = label_name colorable_fields.add(field + "_name") for field in list_fields: list_str = "" if ex[field]: try: list_str = str(ex[field]) except: continue data_to_add[field] = list_str batched_texts.append(data_to_add) if len(batched_texts) >= batch_size: proj.add_text(batched_texts) batched_texts = [] if len(batched_texts) > 0: proj.add_text(batched_texts) colorable_fields = colorable_fields + \ categorical_fields + label_fields + bool_fields + datetime_fields projection = proj.create_index(name=project_name + " index", indexed_field=indexed_field, colorable_fields=colorable_fields, topic_label_field = topic_label_field, build_topic_model=True) if wait_for_map: with proj.wait_for_project_lock(): time.sleep(1) return projection.map_link # Run test if __name__ == "__main__": dataset_name = "databricks/databricks-dolly-15k" #dataset_name = "fka/awesome-chatgpt-prompts" project_name = "huggingface_auto_upload_test-dolly-15k" dataset_dict = load_dataset_and_metadata(dataset_name) api_token = "ODdPKqJHYci4Gq4jnCC5-VR0L-rnIdfIy-6djgC4CTPCJ" print(upload_dataset_to_atlas(dataset_dict, api_token, project_name=project_name))