disaggregators / prep_data.py
dawood's picture
dawood HF staff
Duplicate from society-ethics/disaggregators
924d3bd
from datasets import load_dataset
from disaggregators import Disaggregator, DisaggregationModuleLabels, CustomDisaggregator
from disaggregators.disaggregation_modules.age import Age, AgeLabels, AgeConfig
import matplotlib
matplotlib.use('TKAgg')
import joblib
import os
cache_file = "cached_data.pkl"
cache_dict = {}
if os.path.exists(cache_file):
cache_dict = joblib.load("cached_data.pkl")
class MeSHAgeLabels(AgeLabels):
INFANT = "infant"
CHILD_PRESCHOOL = "child_preschool"
CHILD = "child"
ADOLESCENT = "adolescent"
ADULT = "adult"
MIDDLE_AGED = "middle_aged"
AGED = "aged"
AGED_80_OVER = "aged_80_over"
age = Age(
config=AgeConfig(
labels=MeSHAgeLabels,
ages=[list(MeSHAgeLabels)],
breakpoints=[0, 2, 5, 12, 18, 44, 64, 79]
),
column="question"
)
class TabsSpacesLabels(DisaggregationModuleLabels):
TABS = "tabs"
SPACES = "spaces"
class TabsSpaces(CustomDisaggregator):
module_id = "tabs_spaces"
labels = TabsSpacesLabels
def __call__(self, row, *args, **kwargs):
if "\t" in row[self.column]:
return {self.labels.TABS: True, self.labels.SPACES: False}
else:
return {self.labels.TABS: False, self.labels.SPACES: True}
class ReactComponentLabels(DisaggregationModuleLabels):
CLASS = "class"
FUNCTION = "function"
class ReactComponent(CustomDisaggregator):
module_id = "react_component"
labels = ReactComponentLabels
def __call__(self, row, *args, **kwargs):
if "extends React.Component" in row[self.column] or "extends Component" in row[self.column]:
return {self.labels.CLASS: True, self.labels.FUNCTION: False}
else:
return {self.labels.CLASS: False, self.labels.FUNCTION: True}
configs = {
"laion": {
"disaggregation_modules": ["continent"],
"dataset_name": "society-ethics/laion2B-en_continents",
"column": "TEXT",
"feature_names": {
"continent.africa": "Africa",
"continent.americas": "Americas",
"continent.asia": "Asia",
"continent.europe": "Europe",
"continent.oceania": "Oceania",
# Parent level
"continent": "Continent",
}
},
"medmcqa": {
"disaggregation_modules": [age, "gender"],
"dataset_name": "society-ethics/medmcqa_age_gender_custom",
"column": "question",
"feature_names": {
"age.infant": "Infant",
"age.child_preschool": "Preschool",
"age.child": "Child",
"age.adolescent": "Adolescent",
"age.adult": "Adult",
"age.middle_aged": "Middle Aged",
"age.aged": "Aged",
"age.aged_80_over": "Aged 80+",
"gender.male": "Male",
"gender.female": "Female",
# Parent level
"gender": "Gender",
"age": "Age",
"Both": "Age + Gender",
}
},
"stack": {
"disaggregation_modules": [TabsSpaces, ReactComponent],
"dataset_name": "society-ethics/the-stack-tabs_spaces",
"column": "content",
"feature_names": {
"react_component.class": "Class",
"react_component.function": "Function",
"tabs_spaces.tabs": "Tabs",
"tabs_spaces.spaces": "Spaces",
# Parent level
"react_component": "React Component Syntax",
"tabs_spaces": "Tabs vs. Spaces",
"Both": "React Component Syntax + Tabs vs. Spaces",
}
}
}
def generate_cached_data(disaggregation_modules, dataset_name, column, feature_names):
disaggregator = Disaggregator(disaggregation_modules, column=column)
ds = load_dataset(dataset_name, split="train")
df = ds.to_pandas()
all_fields = {*disaggregator.fields, "None"}
distributions = df[sorted(list(disaggregator.fields))].value_counts()
return {
"fields": all_fields,
"data_fields": disaggregator.fields,
"distributions": distributions,
"disaggregators": [module.name for module in disaggregator.modules],
"column": column,
"feature_names": feature_names,
}
cache_dict.update({
"laion": generate_cached_data(**configs["laion"]),
"medmcqa": generate_cached_data(**configs["medmcqa"]),
"stack": generate_cached_data(**configs["stack"])
})
joblib.dump(cache_dict, cache_file)