Skip to content

Commit

Permalink
Merge pull request #7 from srivarra/exposure/adjust_sigmoid
Browse files Browse the repository at this point in the history
✨ Added  Adjust Sigmoid
  • Loading branch information
srivarra authored Aug 22, 2024
2 parents e34b2b2 + 8c7326f commit a1f4130
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
exposure.adjust_gamma
exposure.adjust_log
exposure.adjust_sigmoid
```
18 changes: 18 additions & 0 deletions docs/notebooks/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,24 @@
"an.exposure.adjust_gamma(blobs_image, gamma=1.5, gain=2.0).pipe(an.exposure.adjust_log, gain=2.0).compute()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"se.adjust_sigmoid(blobs_image, cutoff=0.5, gain=10, inv=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"an.exposure.adjust_sigmoid(blobs_image, cutoff=0.5, gain=10, inv=True).compute()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
4 changes: 2 additions & 2 deletions src/anatomize/exposure/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .basic import adjust_gamma, adjust_log
from .basic import adjust_gamma, adjust_log, adjust_sigmoid

__all__ = ["adjust_gamma", "adjust_log"]
__all__ = ["adjust_gamma", "adjust_log", "adjust_sigmoid"]
26 changes: 25 additions & 1 deletion src/anatomize/exposure/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@


def _rechunk(image: xr.DataArray, chunks: Mapping[str, int] | None) -> xr.DataArray:
_, n_y, n_x = image.shape
n_c, n_y, n_x = image.shape
if image.chunks == (1, n_y, n_x):
return image
if chunks is None:
chunks = {C: 1, Y: n_y, X: n_x}
return image.chunk(chunks)
Expand Down Expand Up @@ -68,3 +70,25 @@ def _log(Ic, gain, out):
)
def _inv_log(Ic, gain, out):
out[:] = gain[:] * (2 ** Ic[:] - 1)


@guvectorize(
["void(float64[:,:], float64[:], float64[:], float64[:,:])"],
"(m,n),(c),(c)->(m,n)",
nopython=True,
cache=True,
target="cpu",
)
def _sigmoid(Ic, cutoff, gain, out):
out[:] = 1 / (1 + np.exp(gain[:] * (cutoff[:] - Ic[:])))


@guvectorize(
["void(float64[:,:], float64[:], float64[:], float64[:,:])"],
"(m,n),(c),(c)->(m,n)",
nopython=True,
cache=True,
target="cpu",
)
def _inv_sigmoid(Ic, cutoff, gain, out):
out[:] = 1 - (1 / (1 + np.exp(gain[:] * (cutoff[:] - Ic[:]))))
39 changes: 38 additions & 1 deletion src/anatomize/exposure/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from spatialdata.models import C, X, Y

from anatomize.core.decorators import convert_kwargs_to_xr_vec
from anatomize.exposure._utils import _gamma, _inv_log, _log, _normalize, _rechunk
from anatomize.exposure._utils import _gamma, _inv_log, _inv_sigmoid, _log, _normalize, _rechunk, _sigmoid


def iter_channels(image: xr.DataArray) -> xb.BatchGenerator: # noqa: D103
Expand Down Expand Up @@ -122,3 +122,40 @@ def adjust_log(image: xr.DataArray, gain: float = 1, inv=False) -> xr.DataArray:
output_dtypes=[data.dtype],
dask_gufunc_kwargs={"allow_rechunk": True},
)


@convert_kwargs_to_xr_vec("cutoff", "gain")
def adjust_sigmoid(image: xr.DataArray, cutoff: float = 0.5, gain: float = 10, inv=False) -> xr.DataArray:
"""Performs Sigmoid Correction on the input image.
Parameters
----------
image
The image to adjust.
cutoff
Cutoff of the sigmoid function that shifts the characteristic curve in horizontal direction, by default 0.5
gain
The constant multiplier in exponential's power of sigmoid function, by default 10
inv
If True, the negative sigmoid correction is performed, otherwise
the sigmoid correction is performed, by default False.
Returns
-------
The sigmoid adjusted image.
"""
data = _rechunk(image, chunks=None)

f = _sigmoid if not inv else _inv_sigmoid

return xr.apply_ufunc(
f,
data,
cutoff,
gain,
input_core_dims=[[Y, X], [], []],
output_core_dims=[[Y, X]],
dask="parallelized",
output_dtypes=[data.dtype],
dask_gufunc_kwargs={"allow_rechunk": True},
)
16 changes: 16 additions & 0 deletions tests/exposure/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,19 @@ def test_adjust_log(blobs_sdata_store: UPath, gain: float, inv: bool):
dims=blobs_sdata.images["blobs_image"].dims,
)
xr.testing.assert_allclose(a=blobs_image_log_xr, b=blobs_image_log_skimage, atol=1e-10)


@pytest.mark.parametrize("cutoff,gain,inv", ([0.5, 10, True], [2.0, 20, False]))
def test_adjust_sigmoid(blobs_sdata_store: UPath, cutoff: float, gain: float, inv: bool):
blobs_sdata = sd.read_zarr(blobs_sdata_store)

blobs_image_sigmoid_xr = an.exposure.adjust_sigmoid(
image=blobs_sdata.images["blobs_image"], cutoff=cutoff, gain=gain, inv=inv
)

blobs_image_sigmoid_skimage = xr.DataArray(
data=exposure.adjust_sigmoid(blobs_sdata.images["blobs_image"], cutoff=cutoff, gain=gain, inv=inv),
coords=blobs_sdata.images["blobs_image"].coords,
dims=blobs_sdata.images["blobs_image"].dims,
)
xr.testing.assert_allclose(a=blobs_image_sigmoid_xr, b=blobs_image_sigmoid_skimage, atol=1e-10)

0 comments on commit a1f4130

Please sign in to comment.