Lewislou's picture
Upload 40 files
991881f
raw
history blame contribute delete
No virus
7.33 kB
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import warnings
from skimage.measure import regionprops
from skimage.draw import polygon
from csbdeep.utils import _raise
from ..utils import path_absolute, _is_power_of_2, _normalize_grid
from ..matching import _check_label_array
from stardist.lib.stardist2d import c_star_dist
def _ocl_star_dist(lbl, n_rays=32, grid=(1,1)):
from gputools import OCLProgram, OCLArray, OCLImage
(np.isscalar(n_rays) and 0 < int(n_rays)) or _raise(ValueError())
n_rays = int(n_rays)
# slicing with grid is done with tuple(slice(0, None, g) for g in grid)
res_shape = tuple((s-1)//g+1 for s, g in zip(lbl.shape, grid))
src = OCLImage.from_array(lbl.astype(np.uint16,copy=False))
dst = OCLArray.empty(res_shape+(n_rays,), dtype=np.float32)
program = OCLProgram(path_absolute("kernels/stardist2d.cl"), build_options=['-D', 'N_RAYS=%d' % n_rays])
program.run_kernel('star_dist', res_shape[::-1], None, dst.data, src, np.int32(grid[0]),np.int32(grid[1]))
return dst.get()
def _cpp_star_dist(lbl, n_rays=32, grid=(1,1)):
(np.isscalar(n_rays) and 0 < int(n_rays)) or _raise(ValueError())
return c_star_dist(lbl.astype(np.uint16,copy=False), np.int32(n_rays), np.int32(grid[0]),np.int32(grid[1]))
def _py_star_dist(a, n_rays=32, grid=(1,1)):
(np.isscalar(n_rays) and 0 < int(n_rays)) or _raise(ValueError())
if grid != (1,1):
raise NotImplementedError(grid)
n_rays = int(n_rays)
a = a.astype(np.uint16,copy=False)
dst = np.empty(a.shape+(n_rays,),np.float32)
for i in range(a.shape[0]):
for j in range(a.shape[1]):
value = a[i,j]
if value == 0:
dst[i,j] = 0
else:
st_rays = np.float32((2*np.pi) / n_rays)
for k in range(n_rays):
phi = np.float32(k*st_rays)
dy = np.cos(phi)
dx = np.sin(phi)
x, y = np.float32(0), np.float32(0)
while True:
x += dx
y += dy
ii = int(round(i+x))
jj = int(round(j+y))
if (ii < 0 or ii >= a.shape[0] or
jj < 0 or jj >= a.shape[1] or
value != a[ii,jj]):
# small correction as we overshoot the boundary
t_corr = 1-.5/max(np.abs(dx),np.abs(dy))
x -= t_corr*dx
y -= t_corr*dy
dist = np.sqrt(x**2+y**2)
dst[i,j,k] = dist
break
return dst
def star_dist(a, n_rays=32, grid=(1,1), mode='cpp'):
"""'a' assumbed to be a label image with integer values that encode object ids. id 0 denotes background."""
n_rays >= 3 or _raise(ValueError("need 'n_rays' >= 3"))
if mode == 'python':
return _py_star_dist(a, n_rays, grid=grid)
elif mode == 'cpp':
return _cpp_star_dist(a, n_rays, grid=grid)
elif mode == 'opencl':
return _ocl_star_dist(a, n_rays, grid=grid)
else:
_raise(ValueError("Unknown mode %s" % mode))
def _dist_to_coord_old(rhos, grid=(1,1)):
"""convert from polar to cartesian coordinates for a single image (3-D array) or multiple images (4-D array)"""
grid = _normalize_grid(grid,2)
is_single_image = rhos.ndim == 3
if is_single_image:
rhos = np.expand_dims(rhos,0)
assert rhos.ndim == 4
n_images,h,w,n_rays = rhos.shape
coord = np.empty((n_images,h,w,2,n_rays),dtype=rhos.dtype)
start = np.indices((h,w))
for i in range(2):
coord[...,i,:] = grid[i] * np.broadcast_to(start[i].reshape(1,h,w,1), (n_images,h,w,n_rays))
phis = ray_angles(n_rays).reshape(1,1,1,n_rays)
coord[...,0,:] += rhos * np.sin(phis) # row coordinate
coord[...,1,:] += rhos * np.cos(phis) # col coordinate
return coord[0] if is_single_image else coord
def _polygons_to_label_old(coord, prob, points, shape=None, thr=-np.inf):
sh = coord.shape[:2] if shape is None else shape
lbl = np.zeros(sh,np.int32)
# sort points with increasing probability
ind = np.argsort([ prob[p[0],p[1]] for p in points ])
points = points[ind]
i = 1
for p in points:
if prob[p[0],p[1]] < thr:
continue
rr,cc = polygon(coord[p[0],p[1],0], coord[p[0],p[1],1], sh)
lbl[rr,cc] = i
i += 1
return lbl
def dist_to_coord(dist, points, scale_dist=(1,1)):
"""convert from polar to cartesian coordinates for a list of distances and center points
dist.shape = (n_polys, n_rays)
points.shape = (n_polys, 2)
len(scale_dist) = 2
return coord.shape = (n_polys,2,n_rays)
"""
dist = np.asarray(dist)
points = np.asarray(points)
assert dist.ndim==2 and points.ndim==2 and len(dist)==len(points) \
and points.shape[1]==2 and len(scale_dist)==2
n_rays = dist.shape[1]
phis = ray_angles(n_rays)
coord = (dist[:,np.newaxis]*np.array([np.sin(phis),np.cos(phis)])).astype(np.float32)
coord *= np.asarray(scale_dist).reshape(1,2,1)
coord += points[...,np.newaxis]
return coord
def polygons_to_label_coord(coord, shape, labels=None):
"""renders polygons to image of given shape
coord.shape = (n_polys, n_rays)
"""
coord = np.asarray(coord)
if labels is None: labels = np.arange(len(coord))
_check_label_array(labels, "labels")
assert coord.ndim==3 and coord.shape[1]==2 and len(coord)==len(labels)
lbl = np.zeros(shape,np.int32)
for i,c in zip(labels,coord):
rr,cc = polygon(*c, shape)
lbl[rr,cc] = i+1
return lbl
def polygons_to_label(dist, points, shape, prob=None, thr=-np.inf, scale_dist=(1,1)):
"""converts distances and center points to label image
dist.shape = (n_polys, n_rays)
points.shape = (n_polys, 2)
label ids will be consecutive and adhere to the order given
"""
dist = np.asarray(dist)
points = np.asarray(points)
prob = np.inf*np.ones(len(points)) if prob is None else np.asarray(prob)
assert dist.ndim==2 and points.ndim==2 and len(dist)==len(points)
assert len(points)==len(prob) and points.shape[1]==2 and prob.ndim==1
n_rays = dist.shape[1]
ind = prob>thr
points = points[ind]
dist = dist[ind]
prob = prob[ind]
ind = np.argsort(prob, kind='stable')
points = points[ind]
dist = dist[ind]
coord = dist_to_coord(dist, points, scale_dist=scale_dist)
return polygons_to_label_coord(coord, shape=shape, labels=ind)
def relabel_image_stardist(lbl, n_rays, **kwargs):
"""relabel each label region in `lbl` with its star representation"""
_check_label_array(lbl, "lbl")
if not lbl.ndim==2:
raise ValueError("lbl image should be 2 dimensional")
dist = star_dist(lbl, n_rays, **kwargs)
points = np.array(tuple(np.array(r.centroid).astype(int) for r in regionprops(lbl)))
dist = dist[tuple(points.T)]
return polygons_to_label(dist, points, shape=lbl.shape)
def ray_angles(n_rays=32):
return np.linspace(0,2*np.pi,n_rays,endpoint=False)