POMFinder / predict.py
AndySAnker's picture
Duplicate from AndySAnker/DeepStruc
ce48cf1
raw history blame
No virus
888 Bytes
import sys, argparse
from tools.module import Net
import torch, random, time
import numpy as np
import pytorch_lightning as pl
from tools.utils import get_data, format_predictions, plot_ls, get_model, save_predictions
def main(args):
time_start = time.time()
data, data_name, project_name = get_data(args)
model_path, model_arch = get_model(args.model)
Net(model_arch=model_arch)
DeepStruc = Net.load_from_checkpoint(model_path,model_arch=model_arch)
xyz_pred, latent_space, kl, mu, sigma = DeepStruc(data, mode='prior', sigma_scale=args.sigma)
samling_pairs = format_predictions(latent_space, data_name, mu, sigma, args.sigma)
df, mk_dir, index_highlight = samling_pairs, project_name, args.index_plot
these_cords = save_predictions(xyz_pred, samling_pairs, project_name, model_arch, args)
return df, index_highlight, these_cords