DeepStruc / predict.py
AndySAnker's picture
Update predict.py
da795f8
raw
history blame
No virus
1.23 kB
import sys, argparse
import streamlit as st
from tools.module import Net
import torch, random, time
import numpy as np
import pytorch_lightning as pl
from tools.utils import get_data, format_predictions, plot_ls, get_model, save_predictions
def main(args):
time_start = time.time()
data, data_name, project_name = get_data(args)
model_path, model_arch = get_model(args.model)
Net(model_arch=model_arch)
DeepStruc = Net.load_from_checkpoint(model_path,model_arch=model_arch)
#start_time = time.time()
xyz_pred, latent_space, kl, mu, sigma = DeepStruc(data, mode='prior', sigma_scale=args.sigma)
#st.write("one prediction: " , time.time() - start_time)
#start_time = time.time()
#for i in range(1000):
# xyz_pred, latent_space, kl, mu, sigma = DeepStruc(data, mode='prior', sigma_scale=args.sigma)
#st.write("thousand predictions: " , time.time() - start_time)
samling_pairs = format_predictions(latent_space, data_name, mu, sigma, args.sigma)
df, mk_dir, index_highlight = samling_pairs, project_name, args.index_plot
these_cords = save_predictions(xyz_pred, samling_pairs, project_name, model_arch, args)
return df, index_highlight, these_cords