from dataclasses import dataclass from typing import Callable import gradio as gr from ui.functions.run_gym_env import run_gym_env from ui.labels import simulation_labels from ui.tabs.tab_builder import ITabBuilder @dataclass class WithAgentTab(ITabBuilder): f: Callable = run_gym_env tab_name = "With agent(s)" default_output_path: str = "./gradio/agent-example/" output_format = "gif" button_label: str = "Run" def build( self, ) -> None: inputs = [] with gr.Tab(self.tab_name): gr.Markdown( """ Run a simulation with fixed parameters and optional agents. Agents have fixed efficiency, and number of actions they perform each turn can be configured. """ ) with gr.Row(): with gr.Column(): gr.Markdown(simulation_labels.TITLE) inputs.append( gr.Text( value=self.default_output_path, label=simulation_labels.OUTPUT_PATH, ) ) steps_slider_component = gr.Slider( minimum=5, maximum=200, step=1, value=100, label=simulation_labels.STEPS, ) inputs.extend( [ steps_slider_component, gr.Slider( minimum=0.0, maximum=1.0, step=0.01, value=1.0, label=simulation_labels.TEST_RATE, interactive=False, ), ] ) with gr.Accordion("Agents"): agents_selection_component = gr.Checkboxgroup( choices=[ ("Distancing", "distancing"), ("Vaccination", "vaccination"), ("Treatment", "treatment"), ("Masking", "masking"), ], value=["distancing"], label="Enabled agents", ) inputs.extend( [ agents_selection_component, gr.Slider( minimum=0, maximum=20, step=1, value=6, label="Total number of actions allowed per turn", info="Number of actions per turn, split equally between active agents (rounded down where required.)", interactive=True, ), ] ) with gr.Group(): gr.Markdown( """ **Distancing agent params** The distancing agent disconnects nodes wth 95% efficiency.""" ) inputs.extend( [ gr.Slider( label="Start isolating on step", minimum=steps_slider_component.minimum, value=5, maximum=steps_slider_component.maximum, interactive=True, ), gr.Slider( label="Stop isolating on step", minimum=steps_slider_component.minimum, value=50, maximum=steps_slider_component.maximum, interactive=True, ), gr.Slider( label="Start reconnecting on step", minimum=steps_slider_component.minimum, value=40, maximum=steps_slider_component.maximum, interactive=True, ), gr.Slider( label="Stop reconnecting on step", minimum=steps_slider_component.minimum, value=steps_slider_component.maximum, maximum=steps_slider_component.maximum, interactive=True, ), ] ) with gr.Group(): gr.Markdown( """ **Vaccination agent params** The vaccination agent adds 95% immunity to targeted uninfected nodes.""" ) inputs.extend( [ gr.Slider( label="Start vaccinating on step", minimum=steps_slider_component.minimum, value=5, maximum=steps_slider_component.maximum, interactive=True, ), gr.Slider( label="Stop vaccinating on step", minimum=steps_slider_component.minimum, value=50, maximum=steps_slider_component.maximum, interactive=True, ), ] ) with gr.Group(): gr.Markdown( """ **Vaccination agent params** "The treatment agent treats infected nodes with 60% immediate recovery, and a boost to recovery rate otherwise.""" ) inputs.extend( [ gr.Slider( label="Start treating on step", minimum=steps_slider_component.minimum, value=5, maximum=steps_slider_component.maximum, interactive=True, ), gr.Slider( label="Stop treating on step", minimum=steps_slider_component.minimum, value=50, maximum=steps_slider_component.maximum, interactive=True, ), ] ) with gr.Group(): gr.Markdown( """ **Masking agent params** The masking agent gives nodes masks (indicated by circles in plot), which reduces disease spread to/from the node by 25%.""" ) inputs.extend( [ gr.Slider( label="Start masking on step", minimum=steps_slider_component.minimum, value=5, maximum=steps_slider_component.maximum, interactive=True, ), gr.Slider( label="Stop masking on step", minimum=steps_slider_component.minimum, value=50, maximum=steps_slider_component.maximum, interactive=True, ), gr.Slider( label="Start removing masks on step", minimum=steps_slider_component.minimum, value=40, maximum=steps_slider_component.maximum, interactive=True, ), gr.Slider( label="Stop removing masks on step", minimum=steps_slider_component.minimum, value=steps_slider_component.maximum, maximum=steps_slider_component.maximum, interactive=True, ), ] ) with gr.Column(): outputs = [ gr.Image(format=self.output_format), gr.File(label="Log"), ] button = gr.Button(self.button_label) button.click(self.f, inputs=inputs, outputs=outputs)