File size: 2,118 Bytes
b2ffc9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
from hashlib import sha1

import numpy as np
import pandas as pd

from atoms_detection.dl_detection import DLDetection
from atoms_detection.dataset import CoordinatesDataset
from utils.constants import Split, ModelArgs
from utils.paths import PT_DATASET, PREDS_PATH, DETECTION_PATH, PRED_MAP_TABLE_LOGS


threshold = 0.89
extension_name = "replicate"
detections_path = os.path.join(DETECTION_PATH, f"dl_detection_{extension_name}_{threshold}")
inference_cache_path = os.path.join(PREDS_PATH, os.path.basename(detections_path))


def get_pred_map(img_filename: str) -> np.ndarray:
    img_hash = sha1(img_filename.encode()).hexdigest()
    prediciton_cache = os.path.join(inference_cache_path, f"{img_hash}.npy")
    if not os.path.exists(prediciton_cache):
        detection = DLDetection(
            model_name=ModelArgs.BASICCNN,
            ckpt_filename="/home/fpares/PycharmProjects/stem_atoms/models/basic_replicate.ckpt",
            dataset_csv="/home/fpares/PycharmProjects/stem_atoms/dataset/Coordinate_image_pairs.csv",
            threshold=threshold,
            detections_path=detections_path
        )
        img = DLDetection.open_image(image_path)
        pred_map = detection.image_to_pred_map(img)
        np.save(prediciton_cache, pred_map)
    else:
        pred_map = np.load(prediciton_cache)
    return pred_map


if not os.path.exists(PRED_MAP_TABLE_LOGS):
    os.makedirs(PRED_MAP_TABLE_LOGS)

coordinates_dataset = CoordinatesDataset(PT_DATASET)
for image_path, coordinates_path in coordinates_dataset.iterate_data(Split.TEST):
    pred_map = get_pred_map(image_path)

    pred_table = {'X': [], 'Y': [], 'Z': []}
    for index, likelihood in np.ndenumerate(pred_map):
        pred_table['X'].append(index[0])
        pred_table['Y'].append(index[1])
        pred_table['Z'].append(likelihood)

    pred_df = pd.DataFrame(pred_table)

    img_name = os.path.splitext(os.path.basename(image_path))[0]
    pred_table_output_path = os.path.join(PRED_MAP_TABLE_LOGS, f"{img_name}_likelihood_{extension_name}_{threshold}.csv")
    pred_df.to_csv(pred_table_output_path, index=False)