Spaces:
Running
Running
import sys, argparse | |
from tools.module import Net | |
import torch, random, time | |
import numpy as np | |
import pytorch_lightning as pl | |
from tools.utils import get_data, format_predictions, plot_ls, get_model, save_predictions | |
def main(args): | |
time_start = time.time() | |
data, data_name, project_name = get_data(args) | |
model_path, model_arch = get_model(args.model) | |
Net(model_arch=model_arch) | |
DeepStruc = Net.load_from_checkpoint(model_path,model_arch=model_arch) | |
xyz_pred, latent_space, kl, mu, sigma = DeepStruc(data, mode='prior', sigma_scale=args.sigma) | |
samling_pairs = format_predictions(latent_space, data_name, mu, sigma, args.sigma) | |
df, mk_dir, index_highlight = samling_pairs, project_name, args.index_plot | |
these_cords = save_predictions(xyz_pred, samling_pairs, project_name, model_arch, args) | |
return df, index_highlight, these_cords | |