virny-demo / app.py
denysgerasymuk799
Added a demonstration notebook
ad06253
raw
history blame contribute delete
No virus
2.45 kB
import os
import warnings
warnings.filterwarnings('ignore')
os.environ["PYTHONWARNINGS"] = "ignore"
import gradio as gr
import pandas as pd
from virny.datasets import ACSIncomeDataset, ACSPublicCoverageDataset, LawSchoolDataset
from virny.custom_classes.metrics_interactive_visualizer import MetricsInteractiveVisualizer
if __name__ == '__main__':
# Define configs for sample datasets
demo_configs = {
'ACS_Income': {
'data_loader': ACSIncomeDataset(state=['GA'], year=2018, with_nulls=False,
subsample_size=15_000, subsample_seed=42),
'sensitive_attributes_dct': {'SEX': '2', 'RAC1P': ['2', '3', '4', '5', '6', '7', '8', '9'], 'SEX&RAC1P': None},
'model_metrics': pd.read_csv(os.path.join('.', 'data', 'acs_income_metrics.csv'), header=0),
},
'ACS_Public_Coverage': {
'data_loader': ACSPublicCoverageDataset(state=['CA'], year=2018, with_nulls=False,
subsample_size=15_000, subsample_seed=42),
'sensitive_attributes_dct': {'SEX': '2', 'RAC1P': ['2', '3', '4', '5', '6', '7', '8', '9'], 'SEX&RAC1P': None},
'model_metrics': pd.read_csv(os.path.join('.', 'data', 'acs_pub_cov_metrics.csv'), header=0),
},
'Law_School': {
'data_loader': LawSchoolDataset(),
'sensitive_attributes_dct': {'male': '0.0', 'race': 'Non-White', 'male&race': None},
'model_metrics': pd.read_csv(os.path.join('.', 'data', 'law_school_metrics.csv'), header=0),
},
}
# Create gradio demo objects for each sample dataset
dataset_names = list(demo_configs.keys())
sample_demos = []
for dataset_name in dataset_names:
sample_demo = MetricsInteractiveVisualizer(
X_data=demo_configs[dataset_name]['data_loader'].X_data,
y_data=demo_configs[dataset_name]['data_loader'].y_data,
model_metrics=demo_configs[dataset_name]['model_metrics'],
sensitive_attributes_dct=demo_configs[dataset_name]['sensitive_attributes_dct']
).create_web_app(start_app=False)
sample_demos.append(sample_demo)
# Build a web application with tabs for each sample dataset
demo = gr.TabbedInterface(sample_demos, [name.replace('_', ' ') for name in dataset_names], theme=gr.themes.Soft())
demo.launch(inline=False, debug=True, show_error=True)