import sys, argparse from tools.module import Net import torch, random, time import numpy as np import pytorch_lightning as pl from tools.utils import get_data, format_predictions, plot_ls, get_model, save_predictions def main(args): time_start = time.time() data, data_name, project_name = get_data(args) model_path, model_arch = get_model(args.model) Net(model_arch=model_arch) DeepStruc = Net.load_from_checkpoint(model_path,model_arch=model_arch) xyz_pred, latent_space, kl, mu, sigma = DeepStruc(data, mode='prior', sigma_scale=args.sigma) samling_pairs = format_predictions(latent_space, data_name, mu, sigma, args.sigma) df, mk_dir, index_highlight = samling_pairs, project_name, args.index_plot these_cords = save_predictions(xyz_pred, samling_pairs, project_name, model_arch, args) return df, index_highlight, these_cords