submission / tasks /data /data_loaders.py
pierre-loic's picture
update content with the text model from Thomas repository https://huggingface.co/spaces/tombou/frugal-ai-challenge
42b7ac6
from abc import ABC, abstractmethod
from datasets import load_dataset, DatasetDict
from tasks.utils.evaluation import TextEvaluationRequest
class DataLoader(ABC):
@abstractmethod
def get_train_dataset(self):
pass
@abstractmethod
def get_test_dataset(self):
pass
class TextDataLoader(DataLoader):
def __init__(self, request: TextEvaluationRequest = TextEvaluationRequest(), light: bool = False):
self.label_mapping = {
"0_not_relevant": 0,
"1_not_happening": 1,
"2_not_human": 2,
"3_not_bad": 3,
"4_solutions_harmful_unnecessary": 4,
"5_science_unreliable": 5,
"6_proponents_biased": 6,
"7_fossil_fuels_needed": 7
}
# Load the dataset, and convert string labels to integers
dataset = load_dataset(request.dataset_name)
dataset = dataset.map(lambda x: {"label": self.label_mapping[x["label"]]})
self.dataset = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
# Create a smaller version of the dataset for quick testing
if light:
self.dataset = DatasetDict({
"train": self.dataset["train"].shuffle(seed=42).select(range(10)),
"test": self.dataset["test"].shuffle(seed=42).select(range(2))
})
def get_train_dataset(self):
return self.dataset["train"]
def get_test_dataset(self):
return self.dataset["test"]
def get_label_to_id_mapping(self):
return self.label_mapping
def get_id_to_label_mapping(self):
return {v: k for k, v in self.label_mapping.items()}