baxtrax's picture
Update main.py
fc8f677
raw
history blame
No virus
15.6 kB
import gradio as gr
import helpers.models as h_models
import helpers.listeners as listeners
# Custom css
css = """div[data-testid="block-label"] {z-index: var(--layer-3)}"""
def main():
with gr.Blocks(title="Feature Visualization Generator",
css=css,
theme=gr.themes.Soft(primary_hue="blue",
secondary_hue="blue",
)) as demo:
# Session state init
model, model_layers, selected_layer, ft_map_sizes, \
thresholds, channel_max, nodeX_max, nodeY_max, \
node_max = (gr.State(None) for _ in range(9))
# GUI Elements
with gr.Row(): # Upper banner
gr.Markdown("""# Feature Visualization Generator\n
Feature Visualizations (FV's) answer questions
about what a network—or parts of a network—are
looking for by generating examples.
([Read more about it here](https://distill.pub/2017/feature-visualization/))
FVs are a part of a wider field called Explainable
Artificial Intelligence (XAI) This generator aims
to make it easier to explore different concepts
used in FV generation and allow for experimentation.
Currently Convolutional and Linear layers were tested.\n\n
**Start by selecting a model from the drop down.**""")
with gr.Row(): # Lower inputs and outputs
with gr.Column(): # Inputs
gr.Markdown("""## Model Settings""")
model_dd = gr.Dropdown(label="Model",
info="Select a model. Some models take longer to setup",
choices=[m.name for m in h_models.ModelTypes])
layer_dd = gr.Dropdown(label="Layer",
info="Select a layer. List will change depending on layer type selected",
interactive=True,
visible=False)
layer_text = gr.Markdown("""## Layer Settings (Optional)""",
visible=False)
with gr.Row(): # Inputs specific to layer selection
channel_num = gr.Number(label="Channel",
info="Please choose a layer",
precision=0,
minimum=0,
interactive=True,
visible=False,
value=None)
node_num = gr.Number(label="Node",
info="Please choose a layer",
precision=0,
minimum=0,
interactive=True,
visible=False,
value=None)
nodeX_num = gr.Number(label="Node X",
info="Please choose a layer",
precision=0,
minimum=0,
interactive=True,
visible=False,
value=None)
nodeY_num = gr.Number(label="Node Y",
info="Please choose a layer",
precision=0,
minimum=0,
interactive=True,
visible=False,
value=None)
gr.Markdown("""## Visualization Settings""")
lr_sl = gr.Slider(label="Learning Rate",
info="How aggresive each \"step\" towards the visualization is",
minimum=0.000001,
maximum=3,
step=0.000001,
value=0.125)
epoch_num = gr.Number(label="Epochs",
info="How many steps (epochs) to perform",
precision=0,
minimum=1,
value=200)
with gr.Accordion("Advanced Settings", open=False):
with gr.Column(variant="panel"):
gr.Markdown("""## Image Settings""")
img_num = gr.Number(label="Image Size",
info="Image is square (<value> by <value>)",
precision=0,
minimum=1,
value=227)
chan_decor_ck = gr.Checkbox(label="Channel Decorrelation",
info="Reduces channel-to-channel correlations",
value=True)
spacial_decor_ck = gr.Checkbox(label="Spacial Decorrelation (FFT)",
info="Reduces pixel-to-pixel correlations",
value=True)
sd_num = gr.Number(label="Standard Deviation",
info="The STD of the randomly generated starter image",
value=0.01)
with gr.Column(variant="panel"):
gr.Markdown("""## Transform Settings (WIP)""")
preprocess_ck = gr.Checkbox(label="Preprocess",
info="Enable or disable preprocessing via transformations",
value=True,
interactive=True)
transform_choices = [t.value for t in h_models.TransformTypes]
transforms_dd = gr.Dropdown(label="Applied Transforms",
info="Transforms to apply",
choices=transform_choices,
multiselect=True,
value=transform_choices,
interactive=True)
# Transform specific settings
pad_col = gr.Column()
with pad_col:
gr.Markdown("""### Pad Settings""")
with gr.Row():
pad_num = gr.Number(label="Padding",
info="How many pixels of padding",
minimum=0,
value=12,
precision=0,
interactive=True)
mode_rad = gr.Radio(label="Mode",
info="Constant fills padded pixels with a value. Reflect fills with edge pixels",
choices=["Constant", "Reflect"],
value="Constant",
interactive=True)
constant_num = gr.Number(label="Constant Fill Value",
info="Value to fill padded pixels",
value=0.5,
interactive=True)
jitter_col = gr.Column()
with jitter_col:
gr.Markdown("""### Jitter Settings""")
with gr.Row():
jitter_num = gr.Number(label="Jitter",
info="How much to jitter image by",
minimum=1,
value=8,
precision=0,
interactive=True)
rand_scale_col = gr.Column()
with rand_scale_col:
gr.Markdown("""### Random Scale Settings""")
with gr.Row():
scale_num = gr.Number(label="Max scale",
info="How much to scale (from 1.0) in both directions (+ and -)",
minimum=0,
value=0.1,
interactive=True)
rand_rotate_col = gr.Column()
with rand_rotate_col:
gr.Markdown("""### Random Rotate Settings""")
with gr.Row():
rotate_num = gr.Number(label="Max angle",
info="How much to rotate in both directions (+ and -)",
minimum=0,
value=10,
precision=0,
interactive=True)
ad_jitter_col = gr.Column()
with ad_jitter_col:
gr.Markdown("""### Additional Jitter Settings""")
with gr.Row():
ad_jitter_num = gr.Number(label="Jitter",
info="How much to jitter image by",
minimum=1,
value=4,
precision=0,
interactive=True)
confirm_btn = gr.Button("Generate", visible=False)
with gr.Column(): # Output
gr.Markdown("""## Feature Visualization Output""")
with gr.Row():
images_gal = gr.Gallery(show_label=False,
preview=True,
allow_preview=True)
# Event listener binding
model_dd.select(lambda: gr.Dropdown.update(visible=True),
outputs=layer_dd)
model_dd.select(listeners.on_model,
inputs=[model, model_layers, ft_map_sizes],
outputs=[layer_dd, model, model_layers, ft_map_sizes])
# TODO: Make button invisible always until layer selection
layer_dd.select(lambda: gr.Button.update(visible=True),
outputs=confirm_btn)
layer_dd.select(listeners.on_layer,
inputs=[selected_layer, model_layers, ft_map_sizes],
outputs=[layer_text,
channel_num,
nodeX_num,
nodeY_num,
node_num,
selected_layer,
channel_max,
nodeX_max,
nodeY_max,
node_max])
channel_num.blur(listeners.check_input, inputs=[channel_num, channel_max])
nodeX_num.blur(listeners.check_input, inputs=[nodeX_num, nodeX_max])
nodeY_num.blur(listeners.check_input, inputs=[nodeY_num, nodeY_max])
node_num.blur(listeners.check_input, inputs=[node_num, node_max])
images_gal.select(listeners.update_img_label,
inputs=thresholds,
outputs=images_gal)
preprocess_ck.select(lambda status: (gr.update(visible=status),
gr.update(visible=status),
gr.update(visible=status),
gr.update(visible=status),
gr.update(visible=status),
gr.update(visible=status)),
inputs=preprocess_ck,
outputs=[transforms_dd,
pad_col,
jitter_col,
rand_scale_col,
rand_rotate_col,
ad_jitter_col])
transforms_dd.change(listeners.on_transform,
inputs=transforms_dd,
outputs=[pad_col,
jitter_col,
rand_scale_col,
rand_rotate_col,
ad_jitter_col])
mode_rad.select(listeners.on_pad_mode,
outputs=constant_num)
confirm_btn.click(listeners.generate,
inputs=[lr_sl,
epoch_num,
img_num,
channel_num,
nodeX_num,
nodeY_num,
node_num,
selected_layer,
model,
thresholds,
chan_decor_ck,
spacial_decor_ck,
sd_num,
transforms_dd,
pad_num,
mode_rad,
constant_num,
jitter_num,
scale_num,
rotate_num,
ad_jitter_num],
outputs=[images_gal, thresholds])
demo.queue().launch()
main()