Rimjhim Mittal commited on
Commit ·
fe7d37c
1
Parent(s): 625857f
Stateless Stateful check
Browse files
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
app.py
CHANGED
|
@@ -17,67 +17,79 @@ def reset_simulation_state():
|
|
| 17 |
if 'selected_columns' in st.session_state:
|
| 18 |
del st.session_state.selected_columns
|
| 19 |
|
| 20 |
-
def run_simulation(param_inputs, mdf_model):
|
| 21 |
mod_graph = mdf_model.graphs[0]
|
| 22 |
nodes = mod_graph.nodes
|
| 23 |
-
|
| 24 |
-
duration = param_inputs["Simulation Duration (s)"]
|
| 25 |
-
dt = param_inputs["Time Step (s)"]
|
| 26 |
-
|
| 27 |
all_node_results = {}
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
for node in nodes:
|
| 30 |
-
eg = EvaluableGraph(mod_graph, verbose=False)
|
| 31 |
-
t = 0
|
| 32 |
-
times = []
|
| 33 |
-
node_outputs = {op.value : [] for op in node.output_ports}
|
| 34 |
-
node_outputs['Time'] = []
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
scalar_value = output_value[0] if len(output_value) > 0 else np.nan
|
| 49 |
-
node_outputs[op.value].append(float(scalar_value))
|
| 50 |
else:
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
return all_node_results
|
| 57 |
-
|
| 58 |
-
def show_simulation_results(all_node_results):
|
| 59 |
if all_node_results is not None:
|
| 60 |
for node_id, chart_data in all_node_results.items():
|
| 61 |
-
st.subheader(f"
|
| 62 |
-
if
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
-
# Filter the data based on selected checkboxes
|
| 77 |
-
filtered_data = chart_data[[col for col, selected in st.session_state.selected_columns[node_id].items() if selected]]
|
| 78 |
|
| 79 |
-
# Display the line chart with filtered data
|
| 80 |
-
st.line_chart(filtered_data, use_container_width=True, height=400)
|
| 81 |
def update_selected_columns(node_id, column):
|
| 82 |
st.session_state.selected_columns[node_id][column] = st.session_state[f"checkbox_{node_id}_{column}"]
|
| 83 |
|
|
@@ -92,16 +104,26 @@ def show_json_output(mdf_model):
|
|
| 92 |
st.json(mdf_model.to_json())
|
| 93 |
|
| 94 |
# st.cache_data()
|
| 95 |
-
def view_tabs(mdf_model, param_inputs): # view
|
| 96 |
tab1, tab2, tab3 = st.tabs(["Simulation Results", "MDF Graph", "Json Model"])
|
| 97 |
with tab1:
|
| 98 |
-
if
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
| 103 |
else:
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
with tab2:
|
| 106 |
show_mdf_graph(mdf_model) # view
|
| 107 |
with tab3:
|
|
@@ -131,17 +153,28 @@ def parameter_form_to_update_model_and_view(mdf_model):
|
|
| 131 |
mod_graph = mdf_model.graphs[0]
|
| 132 |
nodes = mod_graph.nodes
|
| 133 |
parameters = []
|
|
|
|
|
|
|
|
|
|
| 134 |
for node in nodes:
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
param_inputs = {}
|
| 137 |
-
if
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
| 145 |
|
| 146 |
with st.form(key="parameter_form"):
|
| 147 |
valid_inputs = True
|
|
@@ -178,26 +211,26 @@ def parameter_form_to_update_model_and_view(mdf_model):
|
|
| 178 |
valid_inputs = False
|
| 179 |
|
| 180 |
param_inputs[param.id] = value
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
|
| 202 |
run_button = st.form_submit_button("Run Simulation")
|
| 203 |
|
|
@@ -207,12 +240,12 @@ def parameter_form_to_update_model_and_view(mdf_model):
|
|
| 207 |
for param in node.parameters:
|
| 208 |
if param.id in param_inputs:
|
| 209 |
param.value = param_inputs[param.id]
|
| 210 |
-
st.session_state.simulation_results = run_simulation(param_inputs, mdf_model)
|
|
|
|
|
|
|
| 211 |
|
| 212 |
-
view_tabs(mdf_model, param_inputs)
|
| 213 |
def upload_file_and_load_to_model():
|
| 214 |
-
|
| 215 |
-
|
| 216 |
uploaded_file = st.sidebar.file_uploader("Choose a JSON/YAML/BSON file", type=["json", "yaml", "bson"])
|
| 217 |
github_url = st.sidebar.text_input("Enter GitHub raw file URL:", placeholder="Enter GitHub raw file URL")
|
| 218 |
example_models = {
|
|
@@ -281,7 +314,7 @@ def main():
|
|
| 281 |
header1, header2 = st.columns([1, 8], vertical_alignment="top")
|
| 282 |
with header1:
|
| 283 |
with st.container():
|
| 284 |
-
st.image("
|
| 285 |
with header2:
|
| 286 |
with st.container():
|
| 287 |
st.title("MDF: "+ mdf_model.id)
|
|
@@ -291,7 +324,7 @@ def main():
|
|
| 291 |
header1, header2 = st.columns([1, 8], vertical_alignment="top")
|
| 292 |
with header1:
|
| 293 |
with st.container():
|
| 294 |
-
st.image("
|
| 295 |
with header2:
|
| 296 |
with st.container():
|
| 297 |
st.title("Welcome to Model Description Format")
|
|
|
|
| 17 |
if 'selected_columns' in st.session_state:
|
| 18 |
del st.session_state.selected_columns
|
| 19 |
|
| 20 |
+
def run_simulation(param_inputs, mdf_model, stateful):
|
| 21 |
mod_graph = mdf_model.graphs[0]
|
| 22 |
nodes = mod_graph.nodes
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
all_node_results = {}
|
| 24 |
+
if stateful:
|
| 25 |
+
duration = param_inputs["Simulation Duration (s)"]
|
| 26 |
+
dt = param_inputs["Time Step (s)"]
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
|
| 30 |
+
for node in nodes:
|
| 31 |
+
eg = EvaluableGraph(mod_graph, verbose=False)
|
| 32 |
+
t = 0
|
| 33 |
+
times = []
|
| 34 |
+
node_outputs = {op.value : [] for op in node.output_ports}
|
| 35 |
+
node_outputs['Time'] = []
|
| 36 |
+
|
| 37 |
+
while t <= duration:
|
| 38 |
+
times.append(t)
|
| 39 |
+
if t == 0:
|
| 40 |
+
eg.evaluate()
|
|
|
|
|
|
|
| 41 |
else:
|
| 42 |
+
eg.evaluate(time_increment=dt)
|
| 43 |
+
|
| 44 |
+
node_outputs['Time'].append(t)
|
| 45 |
+
for op in node.output_ports:
|
| 46 |
+
eval_param = eg.enodes[node.id].evaluable_outputs[op.id]
|
| 47 |
+
output_value = eval_param.curr_value
|
| 48 |
+
if isinstance(output_value, (list, np.ndarray)):
|
| 49 |
+
scalar_value = output_value[0] if len(output_value) > 0 else np.nan
|
| 50 |
+
node_outputs[op.value].append(float(scalar_value))
|
| 51 |
+
else:
|
| 52 |
+
node_outputs[op.value].append(float(output_value))
|
| 53 |
+
t += dt
|
| 54 |
+
|
| 55 |
+
all_node_results[node.id] = pd.DataFrame(node_outputs).set_index('Time')
|
| 56 |
|
| 57 |
+
return all_node_results
|
| 58 |
+
else:
|
| 59 |
+
for node in nodes:
|
| 60 |
+
eg = EvaluableGraph(mod_graph, verbose=False)
|
| 61 |
+
eg.evaluate()
|
| 62 |
+
all_node_results[node.id] = pd.DataFrame({op.value: [float(eg.enodes[node.id].evaluable_outputs[op.id].curr_value)] for op in node.output_ports})
|
| 63 |
+
|
| 64 |
return all_node_results
|
| 65 |
+
def show_simulation_results(all_node_results, stateful_nodes):
|
|
|
|
| 66 |
if all_node_results is not None:
|
| 67 |
for node_id, chart_data in all_node_results.items():
|
| 68 |
+
st.subheader(f"Results for Node: {node_id}")
|
| 69 |
+
if node_id in stateful_nodes:
|
| 70 |
+
if 'selected_columns' not in st.session_state:
|
| 71 |
+
st.session_state.selected_columns = {node_id: {col: True for col in chart_data.columns}}
|
| 72 |
+
elif node_id not in st.session_state.selected_columns:
|
| 73 |
+
st.session_state.selected_columns[node_id] = {col: True for col in chart_data.columns}
|
| 74 |
+
columns = chart_data.columns
|
| 75 |
+
for column in columns:
|
| 76 |
+
st.checkbox(
|
| 77 |
+
f"{column}",
|
| 78 |
+
value=st.session_state.selected_columns[node_id][column],
|
| 79 |
+
key=f"checkbox_{node_id}_{column}",
|
| 80 |
+
on_change=update_selected_columns,
|
| 81 |
+
args=(node_id, column,)
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Filter the data based on selected checkboxes
|
| 85 |
+
filtered_data = chart_data[[col for col, selected in st.session_state.selected_columns[node_id].items() if selected]]
|
| 86 |
+
|
| 87 |
+
# Display the line chart with filtered data
|
| 88 |
+
st.line_chart(filtered_data, use_container_width=True, height=400)
|
| 89 |
+
else:
|
| 90 |
+
st.write(all_node_results[node_id])
|
| 91 |
|
|
|
|
|
|
|
| 92 |
|
|
|
|
|
|
|
| 93 |
def update_selected_columns(node_id, column):
|
| 94 |
st.session_state.selected_columns[node_id][column] = st.session_state[f"checkbox_{node_id}_{column}"]
|
| 95 |
|
|
|
|
| 104 |
st.json(mdf_model.to_json())
|
| 105 |
|
| 106 |
# st.cache_data()
|
| 107 |
+
def view_tabs(mdf_model, param_inputs, stateful): # view
|
| 108 |
tab1, tab2, tab3 = st.tabs(["Simulation Results", "MDF Graph", "Json Model"])
|
| 109 |
with tab1:
|
| 110 |
+
if stateful:
|
| 111 |
+
if 'simulation_results' not in st.session_state:
|
| 112 |
+
st.session_state.simulation_results = None
|
| 113 |
+
|
| 114 |
+
if st.session_state.simulation_results is not None:
|
| 115 |
+
show_simulation_results(st.session_state.simulation_results, stateful)
|
| 116 |
+
else:
|
| 117 |
+
st.write("Run the simulation to see results.") # model
|
| 118 |
else:
|
| 119 |
+
if 'simulation_results' not in st.session_state:
|
| 120 |
+
st.session_state.simulation_results = None
|
| 121 |
+
|
| 122 |
+
if st.session_state.simulation_results is not None:
|
| 123 |
+
show_simulation_results(st.session_state.simulation_results, stateful)
|
| 124 |
+
else:
|
| 125 |
+
st.write("Stateless.")
|
| 126 |
+
|
| 127 |
with tab2:
|
| 128 |
show_mdf_graph(mdf_model) # view
|
| 129 |
with tab3:
|
|
|
|
| 153 |
mod_graph = mdf_model.graphs[0]
|
| 154 |
nodes = mod_graph.nodes
|
| 155 |
parameters = []
|
| 156 |
+
stateful_nodes = []
|
| 157 |
+
stateful = False
|
| 158 |
+
|
| 159 |
for node in nodes:
|
| 160 |
+
for param in node.parameters:
|
| 161 |
+
if param.is_stateful():
|
| 162 |
+
stateful_nodes.append(node.id)
|
| 163 |
+
stateful = True
|
| 164 |
+
break
|
| 165 |
+
else:
|
| 166 |
+
stateful = False
|
| 167 |
+
|
| 168 |
param_inputs = {}
|
| 169 |
+
if stateful:
|
| 170 |
+
if mdf_model.metadata:
|
| 171 |
+
preferred_duration = float(mdf_model.metadata.get("preferred_duration", 10))
|
| 172 |
+
preferred_dt = float(mdf_model.metadata.get("preferred_dt", 0.1))
|
| 173 |
+
else:
|
| 174 |
+
preferred_duration = 100
|
| 175 |
+
preferred_dt = 0.1
|
| 176 |
+
param_inputs["Simulation Duration (s)"] = preferred_duration
|
| 177 |
+
param_inputs["Time Step (s)"] = preferred_dt
|
| 178 |
|
| 179 |
with st.form(key="parameter_form"):
|
| 180 |
valid_inputs = True
|
|
|
|
| 211 |
valid_inputs = False
|
| 212 |
|
| 213 |
param_inputs[param.id] = value
|
| 214 |
+
if stateful:
|
| 215 |
+
st.write("Simulation Parameters:")
|
| 216 |
+
with st.container(border=True):
|
| 217 |
+
# Add Simulation Duration and Time Step inputs
|
| 218 |
+
col1, col2 = st.columns(2)
|
| 219 |
+
with col1:
|
| 220 |
+
sim_duration = st.text_input("Simulation Duration (s)", value=str(param_inputs["Simulation Duration (s)"]), key="sim_duration")
|
| 221 |
+
with col2:
|
| 222 |
+
time_step = st.text_input("Time Step (s)", value=str(param_inputs["Time Step (s)"]), key="time_step")
|
| 223 |
+
|
| 224 |
+
try:
|
| 225 |
+
param_inputs["Simulation Duration (s)"] = float(sim_duration)
|
| 226 |
+
except ValueError:
|
| 227 |
+
st.error("Invalid input for Simulation Duration. Please enter a valid number.")
|
| 228 |
+
valid_inputs = False
|
| 229 |
+
try:
|
| 230 |
+
param_inputs["Time Step (s)"] = float(time_step)
|
| 231 |
+
except ValueError:
|
| 232 |
+
st.error("Invalid input for Time Step. Please enter a valid number.")
|
| 233 |
+
valid_inputs = False
|
| 234 |
|
| 235 |
run_button = st.form_submit_button("Run Simulation")
|
| 236 |
|
|
|
|
| 240 |
for param in node.parameters:
|
| 241 |
if param.id in param_inputs:
|
| 242 |
param.value = param_inputs[param.id]
|
| 243 |
+
st.session_state.simulation_results = run_simulation(param_inputs, mdf_model, stateful)
|
| 244 |
+
|
| 245 |
+
view_tabs(mdf_model, param_inputs, stateful_nodes)
|
| 246 |
|
|
|
|
| 247 |
def upload_file_and_load_to_model():
|
| 248 |
+
|
|
|
|
| 249 |
uploaded_file = st.sidebar.file_uploader("Choose a JSON/YAML/BSON file", type=["json", "yaml", "bson"])
|
| 250 |
github_url = st.sidebar.text_input("Enter GitHub raw file URL:", placeholder="Enter GitHub raw file URL")
|
| 251 |
example_models = {
|
|
|
|
| 314 |
header1, header2 = st.columns([1, 8], vertical_alignment="top")
|
| 315 |
with header1:
|
| 316 |
with st.container():
|
| 317 |
+
st.image("logo.jpg")
|
| 318 |
with header2:
|
| 319 |
with st.container():
|
| 320 |
st.title("MDF: "+ mdf_model.id)
|
|
|
|
| 324 |
header1, header2 = st.columns([1, 8], vertical_alignment="top")
|
| 325 |
with header1:
|
| 326 |
with st.container():
|
| 327 |
+
st.image("logo.jpg")
|
| 328 |
with header2:
|
| 329 |
with st.container():
|
| 330 |
st.title("Welcome to Model Description Format")
|