epaps's picture
test
de93bc1
from cil.optimisation.utilities import AlgorithmDiagnostics
import numpy as np
from bm3d import bm3d, BM3DStages
from cil.optimisation.functions import Function
class StoppingCriterionTime(AlgorithmDiagnostics):
def __init__(self, time):
self.time = time
super(StoppingCriterionTime, self).__init__(verbose=0)
self.should_stop = False
def _should_stop(self):
return self.should_stop
def __call__(self, algo):
if algo.iteration==0:
algo.should_stop = self._should_stop
stop_crit = np.sum(algo.timing)>self.time
if stop_crit:
self.should_stop = True
print("Stop at {} time {}".format(algo.iteration, np.sum(algo.timing)))
class BM3DFunction(Function):
"""
PnP 'regulariser' whose proximal applies BM3D denoising.
In PnP-ISTA/FISTA we typically use a FIXED BM3D sigma (regularization strength),
independent of the gradient step-size tau.
Optionally apply damping: (1-gamma) z + gamma * BM3D(z).
"""
def __init__(self, sigma, gamma=1.0, profile="np",
stage_arg=BM3DStages.ALL_STAGES, positivity=True):
self.sigma = float(sigma) # BM3D noise parameter
self.gamma = float(gamma) # damping in (0,1]
if not (0.0 < self.gamma <= 1.0):
raise ValueError("gamma must be in (0,1].")
self.profile = profile
self.stage_arg = stage_arg
self.positivity = positivity
super().__init__()
def __call__(self, x):
return 0.0
def convex_conjugate(self, x):
return 0.0
def _denoise(self, znp: np.ndarray) -> np.ndarray:
z = np.asarray(znp, dtype=np.float32)
# BM3D expects sigma as noise std (same units as the image)
return bm3d(z, sigma_psd=self.sigma,
profile=self.profile,
stage_arg=self.stage_arg).astype(np.float32)
def proximal(self, x, tau, out=None):
z = x.array.astype(np.float32, copy=False)
d = self._denoise(z)
# damping/relaxation (recommended if you see oscillations)
u = (1.0 - self.gamma) * z + self.gamma * d
if self.positivity:
u = np.maximum(u, 0.0)
if out is None:
out = x * 0.0
out.fill(u)
return out
def create_circular_mask(h, w, center=None, radius=None):
if center is None:
center = (int(w/2), int(h/2))
if radius is None:
radius = min(center[0], center[1], w-center[0], h-center[1])
Y, X = np.ogrid[:h, :w]
dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)
mask = dist_from_center <= radius
return mask
class StoppingCriterion(AlgorithmDiagnostics):
def __init__(self, epsilon, epochs=None):
self.epsilon = epsilon
self.epochs = epochs
super().__init__(verbose=0)
self.should_stop = False
self.rse_reached = False
def _should_stop(self):
return self.should_stop
def __call__(self, algo):
if algo.iteration == 0:
algo.should_stop = self._should_stop
stop_rse = (algo.rse[-1] <= self.epsilon)
stop_epochs = False
if self.epochs is not None:
try:
dp = algo.f.data_passes
dp_last = dp[-1] if hasattr(dp, "__len__") else dp
stop_epochs = (dp_last > self.epochs)
except AttributeError:
stop_epochs = False
stop = stop_rse or stop_epochs
if algo.iteration < algo.max_iteration:
if stop:
self.rse_reached = stop_rse
self.should_stop = True
print(f"Accuracy reached at {algo.iteration}, time = {np.sum(algo.timing):.4f}, NRSE = {algo.rse[-1]:.4e}")
else:
print(f"Failed to reach accuracy. Stop at {algo.iteration}, time = {np.sum(algo.timing):.4f}, NRSE = {algo.rse[-1]:.4e}")