import sys, argparse import streamlit as st from tools.module import Net import torch, random, time import numpy as np import pytorch_lightning as pl from tools.utils import get_data, format_predictions, plot_ls, get_model, save_predictions def main(args): time_start = time.time() data, data_name, project_name = get_data(args) model_path, model_arch = get_model(args.model) Net(model_arch=model_arch) DeepStruc = Net.load_from_checkpoint(model_path,model_arch=model_arch) #start_time = time.time() xyz_pred, latent_space, kl, mu, sigma = DeepStruc(data, mode='prior', sigma_scale=args.sigma) #st.write("one prediction: " , time.time() - start_time) #start_time = time.time() #for i in range(1000): # xyz_pred, latent_space, kl, mu, sigma = DeepStruc(data, mode='prior', sigma_scale=args.sigma) #st.write("thousand predictions: " , time.time() - start_time) samling_pairs = format_predictions(latent_space, data_name, mu, sigma, args.sigma) df, mk_dir, index_highlight = samling_pairs, project_name, args.index_plot these_cords = save_predictions(xyz_pred, samling_pairs, project_name, model_arch, args) return df, index_highlight, these_cords