from functools import partial | |
# Mock dataset in a dictionary form, similar to what you might find in a data processing library | |
dataset = { | |
"train": [ | |
{"text": "Hello world", "id": 1}, | |
{"text": "Partial functions are cool", "id": 2}, | |
] | |
} | |
# Function to preprocess the dataset | |
def prepare_train_dataset(example): | |
# Let's say we just transform the text to uppercase for simplicity | |
return {"text": example["text"].upper()} | |
# Columns to remove from the dataset after the transformation | |
columns_to_remove = ['id'] | |
# Creating a mock map function for the dataset | |
def dataset_map(batch, function, remove_columns, batched, batch_size): | |
# Process each batch | |
transformed_data = [function(example) for example in batch] | |
# Remove specified columns | |
for item in transformed_data: | |
for column in remove_columns: | |
item.pop(column, None) | |
return transformed_data | |
# Using partial to pre-configure the map function | |
map_fn_train = partial( | |
dataset_map, | |
batch=dataset["train"], | |
function=prepare_train_dataset, | |
remove_columns=columns_to_remove, | |
batched=True, | |
batch_size=2 # Assuming we process all data in one batch for simplicity | |
) | |
# Using the configured function | |
transformed_dataset = map_fn_train() | |
print(transformed_dataset) | |