Spaces:
Running
Running
| """Utility functions to baseline-correct data.""" | |
| # Authors: The MNE-Python contributors. | |
| # License: BSD-3-Clause | |
| # Copyright the MNE-Python contributors. | |
| import numpy as np | |
| from .utils import _check_option, _validate_type, logger, verbose | |
| def _log_rescale(baseline, mode="mean"): | |
| """Log the rescaling method.""" | |
| if baseline is not None: | |
| _check_option( | |
| "mode", | |
| mode, | |
| ["logratio", "ratio", "zscore", "mean", "percent", "zlogratio"], | |
| ) | |
| msg = f"Applying baseline correction (mode: {mode})" | |
| else: | |
| msg = "No baseline correction applied" | |
| return msg | |
| def rescale(data, times, baseline, mode="mean", copy=True, picks=None, verbose=None): | |
| """Rescale (baseline correct) data. | |
| Parameters | |
| ---------- | |
| data : array | |
| It can be of any shape. The only constraint is that the last | |
| dimension should be time. | |
| times : 1D array | |
| Time instants is seconds. | |
| %(baseline_rescale)s | |
| mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio' | |
| Perform baseline correction by | |
| - subtracting the mean of baseline values ('mean') | |
| - dividing by the mean of baseline values ('ratio') | |
| - dividing by the mean of baseline values and taking the log | |
| ('logratio') | |
| - subtracting the mean of baseline values followed by dividing by | |
| the mean of baseline values ('percent') | |
| - subtracting the mean of baseline values and dividing by the | |
| standard deviation of baseline values ('zscore') | |
| - dividing by the mean of baseline values, taking the log, and | |
| dividing by the standard deviation of log baseline values | |
| ('zlogratio') | |
| copy : bool | |
| Whether to return a new instance or modify in place. | |
| picks : list of int | None | |
| Data to process along the axis=-2 (None, default, processes all). | |
| %(verbose)s | |
| Returns | |
| ------- | |
| data_scaled: array | |
| Array of same shape as data after rescaling. | |
| """ | |
| if copy: | |
| data = data.copy() | |
| if verbose is not False: | |
| msg = _log_rescale(baseline, mode) | |
| logger.info(msg) | |
| if baseline is None or data.shape[-1] == 0: | |
| return data | |
| bmin, bmax = baseline | |
| if bmin is None: | |
| imin = 0 | |
| else: | |
| imin = np.where(times >= bmin)[0] | |
| if len(imin) == 0: | |
| raise ValueError( | |
| f"bmin is too large ({bmin}), it exceeds the largest time value" | |
| ) | |
| imin = int(imin[0]) | |
| if bmax is None: | |
| imax = len(times) | |
| else: | |
| imax = np.where(times <= bmax)[0] | |
| if len(imax) == 0: | |
| raise ValueError( | |
| f"bmax is too small ({bmax}), it is smaller than the smallest time " | |
| "value" | |
| ) | |
| imax = int(imax[-1]) + 1 | |
| if imin >= imax: | |
| raise ValueError( | |
| f"Bad rescaling slice ({imin}:{imax}) from time values {bmin}, {bmax}" | |
| ) | |
| # technically this is inefficient when `picks` is given, but assuming | |
| # that we generally pick most channels for rescaling, it's not so bad | |
| mean = np.mean(data[..., imin:imax], axis=-1, keepdims=True) | |
| if mode == "mean": | |
| def fun(d, m): | |
| d -= m | |
| elif mode == "ratio": | |
| def fun(d, m): | |
| d /= m | |
| elif mode == "logratio": | |
| def fun(d, m): | |
| d /= m | |
| np.log10(d, out=d) | |
| elif mode == "percent": | |
| def fun(d, m): | |
| d -= m | |
| d /= m | |
| elif mode == "zscore": | |
| def fun(d, m): | |
| d -= m | |
| d /= np.std(d[..., imin:imax], axis=-1, keepdims=True) | |
| elif mode == "zlogratio": | |
| def fun(d, m): | |
| d /= m | |
| np.log10(d, out=d) | |
| d /= np.std(d[..., imin:imax], axis=-1, keepdims=True) | |
| if picks is None: | |
| fun(data, mean) | |
| else: | |
| for pi in picks: | |
| fun(data[..., pi, :], mean[..., pi, :]) | |
| return data | |
| def _check_baseline(baseline, times, sfreq, on_baseline_outside_data="raise"): | |
| """Check if the baseline is valid and adjust it if requested. | |
| ``None`` values inside ``baseline`` will be replaced with ``times[0]`` and | |
| ``times[-1]``. | |
| Parameters | |
| ---------- | |
| baseline : array-like, shape (2,) | None | |
| Beginning and end of the baseline period, in seconds. If ``None``, | |
| assume no baseline and return immediately. | |
| times : array | |
| The time points. | |
| sfreq : float | |
| The sampling rate. | |
| on_baseline_outside_data : 'raise' | 'info' | 'adjust' | |
| What to do if the baseline period exceeds the data. | |
| If ``'raise'``, raise an exception (default). | |
| If ``'info'``, log an info message. | |
| If ``'adjust'``, adjust the baseline such that it is within the data range. | |
| Returns | |
| ------- | |
| (baseline_tmin, baseline_tmax) | None | |
| The baseline with ``None`` values replaced with times, and with adjusted times | |
| if ``on_baseline_outside_data='adjust'``; or ``None``, if ``baseline`` is | |
| ``None``. | |
| """ | |
| if baseline is None: | |
| return None | |
| _validate_type(baseline, "array-like") | |
| baseline = tuple(baseline) | |
| if len(baseline) != 2: | |
| raise ValueError( | |
| f"baseline must have exactly two elements (got {len(baseline)})." | |
| ) | |
| tmin, tmax = times[0], times[-1] | |
| tstep = 1.0 / float(sfreq) | |
| # check default value of baseline and `tmin=0` | |
| if baseline == (None, 0) and tmin == 0: | |
| raise ValueError( | |
| "Baseline interval is only one sample. Use `baseline=(0, 0)` if this is " | |
| "desired." | |
| ) | |
| baseline_tmin, baseline_tmax = baseline | |
| if baseline_tmin is None: | |
| baseline_tmin = tmin | |
| baseline_tmin = float(baseline_tmin) | |
| if baseline_tmax is None: | |
| baseline_tmax = tmax | |
| baseline_tmax = float(baseline_tmax) | |
| if baseline_tmin > baseline_tmax: | |
| raise ValueError( | |
| f"Baseline min ({baseline_tmin}) must be less than baseline max (" | |
| f"{baseline_tmax})" | |
| ) | |
| if (baseline_tmin < tmin - tstep) or (baseline_tmax > tmax + tstep): | |
| msg = ( | |
| f"Baseline interval [{baseline_tmin}, {baseline_tmax}] s is outside of " | |
| f"epochs data [{tmin}, {tmax}] s. Epochs were probably cropped." | |
| ) | |
| if on_baseline_outside_data == "raise": | |
| raise ValueError(msg) | |
| elif on_baseline_outside_data == "info": | |
| logger.info(msg) | |
| elif on_baseline_outside_data == "adjust": | |
| if baseline_tmin < tmin - tstep: | |
| baseline_tmin = tmin | |
| if baseline_tmax > tmax + tstep: | |
| baseline_tmax = tmax | |
| return baseline_tmin, baseline_tmax | |