Spaces:
Sleeping
Sleeping
import argparse | |
import os | |
import random | |
import numpy as np | |
from matplotlib import pyplot as plt | |
from PIL import Image | |
from networkx.tests.test_convert_pandas import pd | |
from atoms_detection.dl_detection import DLDetection | |
from atoms_detection.dl_detection_scaled import DLScaled | |
from atoms_detection.evaluation import Evaluation | |
from utils.paths import MODELS_PATH, LOGS_PATH, DETECTION_PATH, PREDS_PATH, FE_DATASET, PRED_GT_VIS_PATH | |
from utils.constants import ModelArgs, Split | |
from visualizations.prediction_gt_images import plot_gt_pred_on_img, get_gt_coords | |
def detection_pipeline(args): | |
extension_name = args.extension_name | |
print(f"Storing at {extension_name}") | |
architecture = ModelArgs.BASICCNN | |
ckpt_filename = os.path.join(MODELS_PATH, "model_sac_cnn.ckpt") | |
inference_cache_path = os.path.join(PREDS_PATH, f"dl_detection_{extension_name}") | |
testing_thresholds = [0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99] | |
testing_thresholds = [0.8, 0.85, 0.9, 0.95] | |
for threshold in testing_thresholds: | |
detections_path = os.path.join(DETECTION_PATH, f"dl_detection_{extension_name}", | |
f"dl_detection_{extension_name}_{threshold}") | |
print(f"Detecting atoms on test data with threshold={threshold}...") | |
if args.experimental_rescale: | |
print("Using experimental ruler rescaling") | |
detection = DLScaled( | |
model_name=architecture, | |
ckpt_filename=ckpt_filename, | |
dataset_csv=args.dataset, | |
threshold=threshold, | |
detections_path=detections_path, | |
inference_cache_path=inference_cache_path | |
) | |
else: | |
detection = DLDetection( | |
model_name=architecture, | |
ckpt_filename=ckpt_filename, | |
dataset_csv=args.dataset, | |
threshold=threshold, | |
detections_path=detections_path, | |
inference_cache_path=inference_cache_path | |
) | |
detection.run() | |
if args.eval: | |
logging_filename = os.path.join(LOGS_PATH, f"dl_detection_{extension_name}", | |
f"dl_detection_{extension_name}_{threshold}.csv") | |
evaluation = Evaluation( | |
coords_csv=args.dataset, | |
predictions_path=detections_path, | |
logging_filename=logging_filename | |
) | |
evaluation.run() | |
if args.visualise: | |
vis_folder = os.path.join(PRED_GT_VIS_PATH, f"dl_detection_{extension_name}") | |
if not os.path.exists(vis_folder): | |
os.makedirs(vis_folder) | |
vis_folder = os.path.join(vis_folder, f"dl_detection_{extension_name}_{threshold}") | |
if not os.path.exists(vis_folder): | |
os.makedirs(vis_folder) | |
if args.eval: | |
gt_coords_dict = get_gt_coords(evaluation.coordinates_dataset) | |
for image_path in detection.image_dataset.iterate_data(Split.TEST): | |
print(image_path) | |
img_name = os.path.split(image_path)[-1] | |
gt_coords = gt_coords_dict[img_name] if args.eval else None | |
pred_df_path = os.path.join(detections_path, os.path.splitext(img_name)[0]+'.csv') | |
df_predicted = pd.read_csv(pred_df_path) | |
pred_coords = [(row['x'], row['y']) for _, row in df_predicted.iterrows()] | |
img = Image.open(image_path) | |
img_arr = np.array(img).astype(np.float32) | |
img_normed = (img_arr - img_arr.min()) / (img_arr.max() - img_arr.min()) | |
plot_gt_pred_on_img(img_normed, gt_coords, pred_coords) | |
clean_image_name = os.path.splitext(img_name)[0] | |
vis_path = os.path.join(vis_folder, f'{clean_image_name}.png') | |
plt.savefig(vis_path, bbox_inches='tight', pad_inches=0.0, transparent=True) | |
plt.close() | |
print(f"Experiment {extension_name} completed") | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"extension_name", | |
type=str, | |
help="Experiment extension name" | |
) | |
parser.add_argument( | |
"dataset", | |
type=str, | |
help="Dataset file upon which to do inference" | |
) | |
parser.add_argument( | |
"--eval", | |
action='store_true', | |
help="Whether to perform evaluation after inference", | |
default=False | |
) | |
parser.add_argument( | |
"--visualise", | |
action='store_true', | |
help="Whether to store inference results as visual png images", | |
default=False | |
) | |
parser.add_argument( | |
"--experimental_rescale", | |
action='store_true', | |
help="Whether to rescale inputs based on the ruler of the image as preprocess", | |
default=False | |
) | |
parser.add_argument('--feature', ) | |
return parser.parse_args() | |
if __name__=='__main__': | |
args = get_args() | |
detection_pipeline(args) | |