navigate-data-issues / prepare.py
MarkusStoll's picture
Duplicate from renumics/cifar10-cleanlab
c6a85a5
import pickle
import datasets
import os
import umap
if __name__ == "__main__":
cache_file = "dataset_cache.pkl"
if os.path.exists(cache_file):
# Load dataset from cache
with open(cache_file, "rb") as file:
dataset = pickle.load(file)
print("Dataset loaded from cache.")
else:
# Load dataset using datasets.load_dataset()
ds = datasets.load_dataset("renumics/cifar100-enriched", split="test")
print("Dataset loaded using datasets.load_dataset().")
df = ds.to_pandas()
df = ds.rename_columns({"fine_label": "labels"}).to_pandas()
from tabulate import tabulate
from cleanlab import Datalab
import pandas as pd
import numpy as np
lab = Datalab(data=ds, label_name="fine_label")
features=np.array([x.tolist() for x in df["embedding"]])
pred_probs= np.array([x.tolist() for x in df["probabilities"]])
lab.find_issues(features=features,pred_probs=pred_probs)
print(tabulate(lab.get_issues().iloc[[0,1,2,3,-3,-2,-1]], headers='keys', tablefmt='psql'))
df_with_score = pd.concat([df, lab.get_issues()], axis=1)
df = df_with_score
# Save dataset to cache
with open(cache_file, "wb") as file:
pickle.dump(df, file)
print("Dataset saved to cache.")