from datasets import load_dataset from disaggregators import Disaggregator from disaggregators.disaggregation_modules.age import Age, AgeLabels, AgeConfig class MeSHAgeLabels(AgeLabels): INFANT = "infant" CHILD_PRESCHOOL = "child_preschool" CHILD = "child" ADOLESCENT = "adolescent" ADULT = "adult" MIDDLE_AGED = "middle_aged" AGED = "aged" AGED_80_OVER = "aged_80_over" age = Age( config=AgeConfig( labels=MeSHAgeLabels, ages=[ MeSHAgeLabels.INFANT, MeSHAgeLabels.CHILD_PRESCHOOL, MeSHAgeLabels.CHILD, MeSHAgeLabels.ADOLESCENT, MeSHAgeLabels.ADULT, MeSHAgeLabels.MIDDLE_AGED, MeSHAgeLabels.AGED, MeSHAgeLabels.AGED_80_OVER ], breakpoints=[0, 2, 5, 12, 18, 44, 64, 79] ), column="question" ) disaggregator = Disaggregator([age, "gender"], column="question") ds = load_dataset("medmcqa", split="train") ds_mapped = ds.map(disaggregator) ds_mapped.push_to_hub("society-ethics/medmcqa_age_gender_custom")