import os import time import streamlit as st import torch import torch.nn.functional as F import random from sklearn.preprocessing import MinMaxScaler import numpy as np import pandas as pd from model.lstm import LSTMModel from model.tcn import TCNModel from model.tcn import move_custom_layers_to_device from utils.lowlevel import LowLevel from utils.highlevel import HighLevel from utils.midpoint import MidPoint from utils.transform import compute_gradient st.set_page_config(page_title="Inference", page_icon=":chart_with_upwards_trend:", layout="wide", initial_sidebar_state="auto") def uniform_sampling(data, n_sample): k = len(data) // n_sample return data[::k] def low_level(option_time, slider_sample_orbit, progress_bar): time.sleep(0.1) low_level_total_start_time = time.time() low_level_30000_start_time = time.time() lowlevelhelper = LowLevel(j=slider_sample_orbit) j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial() a1 = 1 / (2 - 2 ** (1 / 3)) a2 = 1 - 2 * a1 jn = 0 t = 0.1 # Calculate the total number of iterations for the progress bar update total_iterations = (float(option_time) - t) / h current_iteration = 0 original_low_level_data = [] while t < float(option_time): x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b) x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b) x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b) x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza) t = t + h if jn % 10 == 0: original_low_level_data.append([b, x, y, z, px, py, pz]) # Update progress bar progress_percentage = int((current_iteration / total_iterations) * 100) progress_bar.progress(progress_percentage) if jn == 300000: low_level_30000_end_time = time.time() low_level_30000_execute_time = low_level_30000_end_time - low_level_30000_start_time low_level_2000_start_time = time.time() jn = jn + 1 current_iteration += 1 progress_bar.progress(100) low_level_2000_end_time = time.time() low_level_2000_execute_time = low_level_2000_end_time - low_level_2000_start_time low_level_total_end_time = time.time() low_level_total_execute_time = low_level_total_end_time - low_level_total_start_time result = uniform_sampling(np.array(original_low_level_data), n_sample=int(option_time/100)) return low_level_30000_execute_time, low_level_2000_execute_time, low_level_total_execute_time, result def high_level(option_time, slider_sample_orbit, progress_bar): time.sleep(0.1) high_level_total_start_time = time.time() high_level_30000_start_time = time.time() highlevelhelper = HighLevel(j=slider_sample_orbit) j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = highlevelhelper.initial() a1 = 1 / (2 - 2 ** (1 / 3)) a2 = 1 - 2 * a1 jn = 0 t = 0.1 # Calculate the total number of iterations for the progress bar update total_iterations = (float(option_time) - t) / h current_iteration = 0 original_high_level_data = [] while t < float(option_time): x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b) x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b) x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b) x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza) t = t + h vx, vy, vz, vpx, vpy, vpz, e = highlevelhelper.f(x, y, z, px, py, pz, b) en = np.asarray(e).astype(np.float64) if jn % 10 == 0: original_high_level_data.append([b, x, y, z, px, py, pz]) if jn == 300000: high_level_30000_end_time = time.time() high_level_30000_execute_time = high_level_30000_end_time - high_level_30000_start_time high_level_2000_start_time = time.time() jn = jn + 1 # Update progress bar progress_percentage = int((current_iteration / total_iterations) * 100) progress_bar.progress(progress_percentage) current_iteration += 1 progress_bar.progress(100) high_level_2000_end_time = time.time() high_level_2000_execute_time = high_level_2000_end_time - high_level_2000_start_time high_level_total_end_time = time.time() high_level_total_execute_time = high_level_total_end_time - high_level_total_start_time result = uniform_sampling(np.array(original_high_level_data), n_sample=int(option_time / 100)) return high_level_30000_execute_time, high_level_2000_execute_time, high_level_total_execute_time, result def midpoint(option_time, slider_sample_orbit, progress_bar): time.sleep(0.1) mid_point_total_start_time = time.time() midpointhelper = MidPoint(j=slider_sample_orbit) j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial() #en0 = np.asarray(e0).astype(np.float64) a1 = 1 / (2 - 2 ** (1 / 3)) a2 = 1 - 2 * a1 jn = 0 t = 0.1 # Calculate the total number of iterations for the progress bar update total_iterations = (float(option_time) - t) / h current_iteration = 0 original_mid_point_data = [] mid_point_30000_start_time = time.time() while t < float(option_time): x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b) x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a2 * h, x, y, z, px, py, pz, b) x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b) t = t + h if jn % 10 == 0: original_mid_point_data.append([b, x, y, z, px, py, pz]) # Update progress bar progress_percentage = int((current_iteration / total_iterations) * 100) progress_bar.progress(progress_percentage) if jn == 300000: mid_point_30000_end_time = time.time() mid_point_30000_execute_time = mid_point_30000_end_time - mid_point_30000_start_time mid_point_2000_start_time = time.time() jn = jn + 1 current_iteration += 1 #mid_point_df.to_excel('mid_point_df_output.xlsx', index=False) progress_bar.progress(100) mid_point_2000_end_time = time.time() mid_point_2000_execute_time = mid_point_2000_end_time - mid_point_2000_start_time mid_point_total_end_time = time.time() mid_point_total_execute_time = mid_point_total_end_time - mid_point_total_start_time result = uniform_sampling(np.array(original_mid_point_data), n_sample=int(option_time / 100)) return mid_point_30000_execute_time, mid_point_2000_execute_time, mid_point_total_execute_time, result def low_level_lstm(slider_sample_orbit, lstm_progress_bar): time.sleep(0.1) total_start_time = time.time() lstm_ckpt_file = os.path.join("model", "lstm.ckpt") lstm_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file) lstm_model.to("cpu") lstm_model.eval() # Initialize variables for the classical method lowlevelhelper = LowLevel(j=slider_sample_orbit) j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial() a1 = 1 / (2 - 2 ** (1 / 3)) a2 = 1 - 2 * a1 jn = 0 t = 0.1 # Calculate the total number of iterations for the progress bar update total_iterations = (float(30000) - t) / h current_iteration = 0 original_low_level_data = [] low_level_start_time = time.time() # Perform classical method prediction for the initial segment while t < float(30000): x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b) x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b) x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b) x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza) t = t + h if jn % 10 == 0: original_low_level_data.append([b, x, y, z, px, py, pz]) # Update progress bar progress_percentage = int((current_iteration / total_iterations) * 100) lstm_progress_bar.progress(progress_percentage) jn = jn + 1 current_iteration += 1 original_low_level_data = np.array(original_low_level_data) low_level_end_time = time.time() low_level_data = original_low_level_data.copy() low_level_data = uniform_sampling(low_level_data, n_sample=300) scaler = MinMaxScaler() low_level_data = scaler.fit_transform(low_level_data) low_level_data = torch.tensor(np.stack(low_level_data)).float() low_level_data = torch.stack([compute_gradient(i, degree=2) for i in low_level_data]).unsqueeze(0) lstm_start_time = time.time() with torch.no_grad(): lstm_preds = lstm_model(low_level_data[:, 100:300, :]) lstm_innv_preds = scaler.inverse_transform(lstm_preds.squeeze().cpu().numpy()) original_low_level_data = uniform_sampling(original_low_level_data, n_sample=300) lstm_end_time = time.time() lstm_progress_bar.progress(100) combined_preds = np.concatenate([original_low_level_data, lstm_innv_preds], axis=0) lstm_total_time = lstm_end_time - lstm_start_time low_level_total_time = low_level_end_time - low_level_start_time total_end_time = time.time() total_time = total_end_time - total_start_time return low_level_total_time, lstm_total_time, total_time, combined_preds def mid_point_lstm(slider_sample_orbit, lstm_progress_bar): time.sleep(0.1) total_start_time = time.time() lstm_ckpt_file = os.path.join("model", "lstm.ckpt") lstm_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file) lstm_model.to("cpu") lstm_model.eval() midpointhelper = MidPoint(j=slider_sample_orbit) j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial() a1 = 1 / (2 - 2 ** (1 / 3)) a2 = 1 - 2 * a1 jn = 0 t = 0.1 # Calculate the total number of iterations for the progress bar update total_iterations = (float(30000) - t) / h current_iteration = 0 original_mid_point_data = [] mid_point_start_time = time.time() while t < float(30000): x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b) x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a2 * h, x, y, z, px, py, pz, b) x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b) t = t + h if jn % 10 == 0: original_mid_point_data.append([b, x, y, z, px, py, pz]) # Update progress bar progress_percentage = int((current_iteration / total_iterations) * 100) lstm_progress_bar.progress(progress_percentage) jn = jn + 1 current_iteration += 1 original_mid_point_data = np.array(original_mid_point_data) mid_point_end_time = time.time() mid_point_data = original_mid_point_data.copy() mid_point_data = uniform_sampling(mid_point_data, n_sample=300) scaler = MinMaxScaler() mid_point_data = scaler.fit_transform(mid_point_data) mid_point_data = torch.tensor(np.stack(mid_point_data)).float() mid_point_data = torch.stack([compute_gradient(i, degree=2) for i in mid_point_data]).unsqueeze(0) lstm_start_time = time.time() with torch.no_grad(): lstm_preds = lstm_model(mid_point_data[:, 100:300, :]) lstm_innv_preds = scaler.inverse_transform(lstm_preds.squeeze().cpu().numpy()) original_mid_point_data = uniform_sampling(original_mid_point_data, n_sample=300) lstm_end_time = time.time() lstm_progress_bar.progress(100) combined_preds = np.concatenate([original_mid_point_data, lstm_innv_preds], axis=0) lstm_total_time = lstm_end_time - lstm_start_time mid_point_total_time = mid_point_end_time - mid_point_start_time total_end_time = time.time() total_time = total_end_time - total_start_time return mid_point_total_time, lstm_total_time, total_time, combined_preds def low_level_tcn(slider_sample_orbit, tcn_progress_bar): time.sleep(0.1) total_start_time = time.time() tcn_ckpt_file = os.path.join("model", "tcn.ckpt") tcn_model = TCNModel.load_from_checkpoint(tcn_ckpt_file) move_custom_layers_to_device(tcn_model, "cpu") tcn_model.eval() # Initialize variables for the classical method lowlevelhelper = LowLevel(j=slider_sample_orbit) j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial() a1 = 1 / (2 - 2 ** (1 / 3)) a2 = 1 - 2 * a1 jn = 0 t = 0.1 # Calculate the total number of iterations for the progress bar update total_iterations = (float(30000) - t) / h current_iteration = 0 original_low_level_data = [] low_level_start_time = time.time() # Perform classical method prediction for the initial segment while t < float(30000): x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b) x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b) x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b) x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza) t = t + h if jn % 10 == 0: original_low_level_data.append([b, x, y, z, px, py, pz]) # Update progress bar progress_percentage = int((current_iteration / total_iterations) * 100) tcn_progress_bar.progress(progress_percentage) jn = jn + 1 current_iteration += 1 original_low_level_data = np.array(original_low_level_data) low_level_end_time = time.time() low_level_data = original_low_level_data.copy() low_level_data = uniform_sampling(low_level_data, n_sample=300) scaler = MinMaxScaler() low_level_data = scaler.fit_transform(low_level_data) low_level_data = torch.tensor(np.stack(low_level_data)).float() low_level_data = torch.stack([compute_gradient(i, degree=2) for i in low_level_data]).unsqueeze(0) tcn_start_time = time.time() with torch.no_grad(): tcn_preds = None for i in range(20): if i == 0: tcn_preds = tcn_model(low_level_data[:, :300, :]) else: gd_y_hat = compute_gradient(tcn_preds[:, :i, :], degree=2).to('cpu') output = tcn_model(torch.cat([low_level_data[:, i:300, :], gd_y_hat], dim=1).to('cpu')) tcn_preds = torch.cat([tcn_preds, output], dim=1) tcn_innv_preds = scaler.inverse_transform(tcn_preds.squeeze().cpu().numpy()) original_low_level_data = uniform_sampling(original_low_level_data, n_sample=300) tcn_end_time = time.time() tcn_progress_bar.progress(100) combined_preds = np.concatenate([original_low_level_data, tcn_innv_preds], axis=0) tcn_total_time = tcn_end_time - tcn_start_time low_level_total_time = low_level_end_time - low_level_start_time total_end_time = time.time() total_time = total_end_time - total_start_time return low_level_total_time, tcn_total_time, total_time, combined_preds def mid_point_tcn(slider_sample_orbit, tcn_progress_bar): time.sleep(0.1) total_start_time = time.time() tcn_ckpt_file = os.path.join("model", "tcn.ckpt") tcn_model = TCNModel.load_from_checkpoint(tcn_ckpt_file) move_custom_layers_to_device(tcn_model, "cpu") tcn_model.eval() # Initialize variables for the classical method midpointhelper = MidPoint(j=slider_sample_orbit) j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial() a1 = 1 / (2 - 2 ** (1 / 3)) a2 = 1 - 2 * a1 jn = 0 t = 0.1 # Calculate the total number of iterations for the progress bar update total_iterations = (float(30000) - t) / h current_iteration = 0 original_mid_point_data = [] mid_point_start_time = time.time() # Perform classical method prediction for the initial segment while t < float(30000): x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b) x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a2 * h, x, y, z, px, py, pz, b) x, y, z, px, py, pz = midpointhelper.implecitsymplectic(a1 * h, x, y, z, px, py, pz, b) t = t + h if jn % 10 == 0: original_mid_point_data.append([b, x, y, z, px, py, pz]) # Update progress bar progress_percentage = int((current_iteration / total_iterations) * 100) tcn_progress_bar.progress(progress_percentage) jn = jn + 1 current_iteration += 1 original_mid_point_data = np.array(original_mid_point_data) mid_point_end_time = time.time() mid_point_data = original_mid_point_data.copy() mid_point_data = uniform_sampling(mid_point_data, n_sample=300) scaler = MinMaxScaler() mid_point_data = scaler.fit_transform(mid_point_data) mid_point_data = torch.tensor(np.stack(mid_point_data)).float() mid_point_data = torch.stack([compute_gradient(i, degree=2) for i in mid_point_data]).unsqueeze(0) tcn_start_time = time.time() with torch.no_grad(): tcn_preds = None for i in range(20): if i == 0: tcn_preds = tcn_model(mid_point_data[:, :300, :]) else: gd_y_hat = compute_gradient(tcn_preds[:, :i, :], degree=2).to('cpu') output = tcn_model(torch.cat([mid_point_data[:, i:300, :], gd_y_hat], dim=1).to('cpu')) tcn_preds = torch.cat([tcn_preds, output], dim=1) tcn_innv_preds = scaler.inverse_transform(tcn_preds.squeeze().cpu().numpy()) original_mid_point_data = uniform_sampling(original_mid_point_data, n_sample=300) tcn_end_time = time.time() tcn_progress_bar.progress(100) combined_preds = np.concatenate([original_mid_point_data, tcn_innv_preds], axis=0) tcn_total_time = tcn_end_time - tcn_start_time mid_point_total_time = mid_point_end_time - mid_point_start_time total_end_time = time.time() total_time = total_end_time - total_start_time return mid_point_total_time, tcn_total_time, total_time, combined_preds container = st.container() container1, container2 = st.columns(2) plot_container = st.container() with st.sidebar: slider_sample_orbit = st.slider('Orbit Sample ID', 1, 10, 1) option_time = 32000 st.write(f'Total Time Step: {option_time}') options_method = st.multiselect( 'Compared Methods', ['EPS', 'Midpoint', 'EPS with LSTM', 'EPS with TCN', 'Midpoint with LSTM', 'Midpoint with TCN'], ['EPS']) btn_go = st.button("Go", type="primary", use_container_width=True) if btn_go: if 'EPS' in options_method: with container1: st.write('EPS Progress Bar') low_level_progress_bar = st.progress(0) low_level_30000_time, low_level_2000_time, low_level_total_time, low_level_result = low_level(option_time, slider_sample_orbit, low_level_progress_bar) with container2: st.table(pd.DataFrame({'Model':"EPS", '30000 Time Steps (s)': [low_level_30000_time], '2000 Time Steps (s)': [low_level_2000_time], 'Total Time (s)': [low_level_total_time]})) if 'High-Level' in options_method: with container1: st.write('High Level Progress Bar') high_level_progress_bar = st.progress(0) high_level_30000_time, high_level_2000_time, high_level_total_time, high_level_result = high_level(option_time, slider_sample_orbit, high_level_progress_bar) with container2: st.table(pd.DataFrame({'Model':"High Level", '30000 Time Steps (s)': [high_level_30000_time], '2000 Time Steps (s)': [high_level_2000_time], 'Total Time (s)': [high_level_total_time]})) if 'Midpoint' in options_method: with container1: st.write('Midpoint Progress Bar') mid_point_progress_bar = st.progress(0) mid_point_30000_time, mid_point_2000_time, mid_point_total_time, mid_point_result = midpoint(option_time, slider_sample_orbit, mid_point_progress_bar) with container2: st.table(pd.DataFrame({'Model':"Midpoint", '30000 Time Steps (s)': [mid_point_30000_time], '2000 Time Steps (s)': [mid_point_2000_time], 'Total Time (s)': [mid_point_total_time]})) if 'EPS with LSTM' in options_method: with container1: st.write('EPS LSTM Progress Bar') low_level_lstm_progress_bar = st.progress(0) lstm_30000_time, lstm_2000_time, lstm_total_time, lstm_result = low_level_lstm(slider_sample_orbit, low_level_lstm_progress_bar) with container2: st.table(pd.DataFrame({'Model':"EPS + LSTM", '30000 Time Steps (s)': [lstm_30000_time], '2000 Time Steps (s)': [lstm_2000_time], 'Total Time (s)': [lstm_total_time]})) if 'EPS with TCN' in options_method: with container1: st.write('EPS TCN Progress Bar') low_level_tcn_progress_bar = st.progress(0) tcn_30000_time, tcn_2000_time, tcn_total_time, tcn_result = low_level_tcn(slider_sample_orbit, low_level_tcn_progress_bar) with container2: st.table(pd.DataFrame({'Model':"EPS + TCN", '30000 Time Steps (s)': [tcn_30000_time], '2000 Time Steps (s)': [tcn_2000_time], 'Total Time (s)': [tcn_total_time]})) if 'Midpoint with LSTM' in options_method: with container1: st.write('Midpoint LSTM Progress Bar') mid_point_lstm_progress_bar = st.progress(0) md_lstm_30000_time, md_lstm_2000_time, md_lstm_total_time, md_lstm_result = mid_point_lstm(slider_sample_orbit, mid_point_lstm_progress_bar) with container2: st.table(pd.DataFrame({'Model':"Midpoint + LSTM", '30000 Time Steps (s)': [md_lstm_30000_time], '2000 Time Steps (s)': [md_lstm_2000_time], 'Total Time (s)': [md_lstm_total_time]})) if 'Midpoint with TCN' in options_method: with container1: st.write('Midpoint TCN Progress Bar') mid_point_tcn_progress_bar = st.progress(0) md_tcn_30000_time, md_tcn_2000_time, md_tcn_total_time, md_tcn_result = mid_point_tcn(slider_sample_orbit, mid_point_tcn_progress_bar) with container2: st.table(pd.DataFrame({'Model':"Midpoint + TCN", '30000 Time Steps (s)': [md_tcn_30000_time], '2000 Time Steps (s)': [md_tcn_2000_time], 'Total Time (s)': [md_tcn_total_time]}))