import os | |
from datasets import load_dataset | |
CACHE_DIR = 'cache' | |
N_SAMPLES = 15 | |
REMOVED_COMMITS = ['9cc896202dc38d962c01aa2637dbc5bbc3e3dd9b'] | |
def load_data(): | |
df = load_dataset("JetBrains-Research/commit-rewriting-samples", | |
split="train", | |
token=os.environ.get('HF_REWRITING_TOKEN'), | |
cache_dir=CACHE_DIR).to_pandas() | |
removed_idx = df['hash'].isin(REMOVED_COMMITS) | |
df = df[~removed_idx] | |
return df.to_dict('records')[:N_SAMPLES] | |