from functools import partial # Mock dataset in a dictionary form, similar to what you might find in a data processing library dataset = { "train": [ {"text": "Hello world", "id": 1}, {"text": "Partial functions are cool", "id": 2}, ] } # Function to preprocess the dataset def prepare_train_dataset(example): # Let's say we just transform the text to uppercase for simplicity return {"text": example["text"].upper()} # Columns to remove from the dataset after the transformation columns_to_remove = ['id'] # Creating a mock map function for the dataset def dataset_map(batch, function, remove_columns, batched, batch_size): # Process each batch transformed_data = [function(example) for example in batch] # Remove specified columns for item in transformed_data: for column in remove_columns: item.pop(column, None) return transformed_data # Using partial to pre-configure the map function map_fn_train = partial( dataset_map, batch=dataset["train"], function=prepare_train_dataset, remove_columns=columns_to_remove, batched=True, batch_size=2 # Assuming we process all data in one batch for simplicity ) # Using the configured function transformed_dataset = map_fn_train() print(transformed_dataset)