training / test_partial_function.py
supawichwac's picture
Saving train state of step 5
f544a5d verified
raw
history blame
1.3 kB
from functools import partial
# Mock dataset in a dictionary form, similar to what you might find in a data processing library
dataset = {
"train": [
{"text": "Hello world", "id": 1},
{"text": "Partial functions are cool", "id": 2},
]
}
# Function to preprocess the dataset
def prepare_train_dataset(example):
# Let's say we just transform the text to uppercase for simplicity
return {"text": example["text"].upper()}
# Columns to remove from the dataset after the transformation
columns_to_remove = ['id']
# Creating a mock map function for the dataset
def dataset_map(batch, function, remove_columns, batched, batch_size):
# Process each batch
transformed_data = [function(example) for example in batch]
# Remove specified columns
for item in transformed_data:
for column in remove_columns:
item.pop(column, None)
return transformed_data
# Using partial to pre-configure the map function
map_fn_train = partial(
dataset_map,
batch=dataset["train"],
function=prepare_train_dataset,
remove_columns=columns_to_remove,
batched=True,
batch_size=2 # Assuming we process all data in one batch for simplicity
)
# Using the configured function
transformed_dataset = map_fn_train()
print(transformed_dataset)