savtadepth / src /code /eval.py
Dean
Finalized evaluation step, which now works. Ready to merge into master
818ec2e
raw
history blame
No virus
1.77 kB
import sys
import yaml
import torch
from torchvision import transforms
from fastai.vision.all import unet_learner, Path, resnet34, MSELossFlat, get_files, L, tuplify
from custom_data_loading import create_data
from eval_metric_calculation import compute_eval_metrics
from dagshub import dagshub_logger
from tqdm import tqdm
if __name__ == "__main__":
if len(sys.argv) < 2:
print("usage: %s <test_data_path>" % sys.argv[0], file=sys.stderr)
sys.exit(0)
with open(r"./src/code/params.yml") as f:
params = yaml.safe_load(f)
data_path = Path(sys.argv[1])
data = create_data(data_path)
arch = {'resnet34': resnet34}
loss = {'MSELossFlat': MSELossFlat()}
learner = unet_learner(data,
arch.get(params['architecture']),
n_out=int(params['num_outs']),
loss_func=loss.get(params['loss_func']),
path='src/',
model_dir='models')
learner = learner.load('model')
filenames = get_files(Path(data_path), extensions='.jpg')
test_files = L([Path(i) for i in filenames])
for sample in tqdm(test_files.items, desc="Predicting on test images", total=len(test_files.items)):
pred = learner.predict(sample)[0]
pred = transforms.ToPILImage()(pred[:, :, :].type(torch.FloatTensor)).convert('L')
pred.save("src/eval/" + str(sample.stem) + "_pred.png")
print("Calculating metrics...")
metrics = compute_eval_metrics(test_files)
with dagshub_logger(
metrics_path="logs/test_metrics.csv",
should_log_hparams=False
) as logger:
# Metric logging
logger.log_metrics(metrics)
print("Evaluation Done!")