personal_math / evaluate_set12.py
psidharth567's picture
Sync full project: code, checkpoints, datasets, logs
dcd2bd2 verified
import glob
import math
import os
import torch
from PIL import Image
from torchvision import transforms
from train_network import UnrolledNetwork
from train_tnrd_baseline import TNRDBaselineNetwork
SIGMA = 25.0 / 255.0
TEST_DIR = "./datasets/Test_Datasets/FFDNet-master/testsets/Set12"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def calculate_psnr(img1, img2):
mse = torch.mean((img1 - img2) ** 2)
if mse == 0:
return float("inf")
return 20 * math.log10(1.0 / math.sqrt(mse))
def _make_model(spec):
if spec["kind"] == "telegraph":
return UnrolledNetwork(num_stages=spec["stages"], use_wave=spec["wave"]).to(DEVICE)
if spec["kind"] == "tnrd":
return TNRDBaselineNetwork(num_stages=spec["stages"]).to(DEVICE)
raise ValueError(f"Unknown model kind: {spec['kind']}")
def _autocast_context():
return torch.amp.autocast("cuda") if DEVICE.type == "cuda" else torch.autocast("cpu", enabled=False)
def evaluate_model(spec):
model = _make_model(spec)
model.load_state_dict(torch.load(spec["file"], map_location=DEVICE))
model.eval()
test_transform = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()])
test_paths = sorted(glob.glob(os.path.join(TEST_DIR, "*.png")))
if not test_paths:
return "Error: No images found in Set12 directory."
torch.manual_seed(42)
total_psnr = 0.0
with torch.no_grad():
for path in test_paths:
clean = test_transform(Image.open(path)).unsqueeze(0).to(DEVICE)
noisy = torch.clamp(clean + torch.randn_like(clean) * SIGMA, 0.0, 1.0)
with _autocast_context():
output = model(noisy)
total_psnr += calculate_psnr(clean, output)
return f"{total_psnr / len(test_paths):.2f} dB"
def main():
print("[*] Evaluating Checkpoints on Set12 (Sigma = 25)...")
print("-" * 60)
models_to_test = [
{
"name": "3-Stage TDE (Proposed)",
"kind": "telegraph",
"stages": 3,
"wave": True,
"file": "model_3stages_waveTrue.pth",
},
{
"name": "5-Stage TDE (Proposed)",
"kind": "telegraph",
"stages": 5,
"wave": True,
"file": "model_5stages_waveTrue.pth",
},
{
"name": "5-Stage Telegraph w/o wave",
"kind": "telegraph",
"stages": 5,
"wave": False,
"file": "model_5stages_waveFalse.pth",
},
{
"name": "5-Stage TNRD baseline",
"kind": "tnrd",
"stages": 5,
"file": "tnrd_baseline_5stages.pth",
},
]
for spec in models_to_test:
if os.path.exists(spec["file"]):
score = evaluate_model(spec)
print(f"{spec['name']:<30} | PSNR: {score}")
else:
print(f"{spec['name']:<30} | PSNR: [File not found]")
print("-" * 60)
print("PDE baseline: run classical_baseline.py for the fixed telegraph/PDE comparison.")
if __name__ == "__main__":
main()