File size: 888 Bytes
ce48cf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import sys, argparse
from tools.module import Net
import torch, random, time
import numpy as np
import pytorch_lightning as pl
from tools.utils import get_data, format_predictions, plot_ls, get_model, save_predictions

def main(args):
    time_start = time.time()
    data, data_name, project_name = get_data(args)
    model_path, model_arch = get_model(args.model)
    
    Net(model_arch=model_arch)
    DeepStruc = Net.load_from_checkpoint(model_path,model_arch=model_arch)
    xyz_pred, latent_space, kl, mu, sigma = DeepStruc(data, mode='prior', sigma_scale=args.sigma)
    samling_pairs = format_predictions(latent_space, data_name, mu, sigma, args.sigma)
    
    df, mk_dir, index_highlight = samling_pairs, project_name, args.index_plot

    these_cords = save_predictions(xyz_pred, samling_pairs, project_name, model_arch, args)
    
    return df, index_highlight, these_cords