from vizualize_nn import * """## Launch the app""" device ='cuda' if torch.cuda.is_available() else 'cpu' init_net_and_train_part = partial(init_net_and_train,device=device) with gr.Blocks() as iface: tab_train = gr.Tab("Network Training") tab_viz = gr.Tab("Network Visualization") with tab_train: hidden_units_slider = gr.Slider(minimum=1, maximum=10, step=1, value=4, label="number of neurons in hidden layer") noise_slider = gr.Slider(minimum=0.001, maximum=0.7, step=0.01, value=0.2, label="Noise") epochs_slider = gr.Slider(minimum=1, maximum=50, step=1, value=30, label="Epochs") lr_slider = gr.Slider(minimum=0.001, maximum=0.05, step=0.001, value=0.008, label="Learning Rate") data_points_slider = gr.Slider(minimum=100, maximum=2000, step=4, value=1000, label="Data Points") train_button = gr.Button("Train Network") learning_curve = gr.Plot(label="Learning Curve") with tab_viz: with (gr.Row() if NETWORK_ORIENTAION != 'h' else dummy_context()): with (gr.Column() if NETWORK_ORIENTAION != 'h' else dummy_context()): with (gr.Row() if NETWORK_ORIENTAION != 'v' else dummy_context()): epoch_viz_slider = gr.Slider(minimum=1, maximum=50, step=1, value=1, label="Visualize Epoch") # Dynamic update needed here ner_bounds = gr.Checkbox(label="Invidual neurons decision boundaries") generate_button = gr.Button("Visualize Network") plot_output = gr.Plot(label="Decision Boundary") overall_net_output = gr.Image(type="filepath",label="Network Visualization") with (gr.Column() if NETWORK_ORIENTAION != 'h' else dummy_context()): with gr.Row(): input_x = gr.Number(label="Input X") input_y = gr.Number(label="Input Y") update_button = gr.Button("Check Input") net_activity_sample_output = gr.HTML(label="Network Activity for an Input") # net_activity_sample_output = gr.Image(type="filepath", label="Network Activity for an Input") # Set up button click actions train_button.click(fn=init_net_and_train, inputs=[hidden_units_slider, noise_slider, epochs_slider, data_points_slider, lr_slider], outputs=learning_curve) generate_button.click(fn=generate_images, inputs=[epoch_viz_slider,ner_bounds], outputs=[plot_output, overall_net_output]) update_button.click(fn=get_network_with_inputs, inputs=[epoch_viz_slider, input_x, input_y], outputs=net_activity_sample_output) # # Add Tabs to the interface # iface.add_tabs(tab_train, tab_viz) iface.title = "Neural Network Visualization" iface.description = "Adjust parameters and train the network to see its performance and visualization." # Launch the app iface.launch()