Source code for analyseur.cbgtc.stats.wavelet

# ~/analyseur/cbgtc/stat/wavelet.py
#
# Documentation by Lungsi 22 Oct 2025
#
# This contains function for loading the files
#
import re

import numpy as np
from scipy.ndimage import gaussian_filter1d
import pywt

from analyseur.cbgtc.parameters import SignalAnalysisParams
from analyseur.cbgtc.curate import get_binary_spiketrains

siganal = SignalAnalysisParams()

[docs] class ContinuousWaveletTransform(object): """ ============================ Continuous Wavelet Transform ============================ +-----------------------------------+-------------------------------------------------------------------------------------------------------+ | Methods | Argument | +===================================+=======================================================================================================+ | :py:meth:`.smooth_signal` | - `spiketimes_set`: see :class:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_superset` | | | - also :meth:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_subset` | +-----------------------------------+-------------------------------------------------------------------------------------------------------+ | :py:meth:`.scale_to_freq` | - `all_neurons_isi`: Dictionary returned; see :py:meth:`.compute` | +-----------------------------------+-------------------------------------------------------------------------------------------------------+ | :py:meth:`.freq_window_to_scales` | - `all_inst_rates`: Dictionary returned; see :py:meth:`.inst_rates` | | | - `all_times`: 2nd tuple (Dictionary) returned; see :py:meth:`.compute` | | | - `binsz`: [OPTIONAL] 0.01 (default) | +-----------------------------------+-------------------------------------------------------------------------------------------------------+ | :py:meth:`.compute_cwt_avg` | - `all_neurons_isi`: Dictionary returned; see :py:meth:`.compute` | +-----------------------------------+-------------------------------------------------------------------------------------------------------+ | :py:meth:`.compute_cwt_sum` | - `all_neurons_isi`: Dictionary returned; see :py:meth:`.compute` | +-----------------------------------+-------------------------------------------------------------------------------------------------------+ Comments on Activity and Choices in Performing CWT -------------------------------------------------- +------------------+-------------------------------+-------------------------------------------------------------------------------+ | Activity | Description | Purpose | +==================+===============================+===============================================================================+ | Signal Smoothing | converts binary spike trains | - some smoothing can help visualize rhythmic spiking | | | to continuous signals | - over smoothing can obscure precise timing | +------------------+-------------------------------+-------------------------------------------------------------------------------+ | scales | defines the frequencies | - smaller scales for high frequencies and larger for lower frequencies | | | analyzed | - voices per ocatve ≜ number of scales between 2 frequencies (≜ octave) | | | | - higher voices per octave give smoother scalogram but increased computation | +------------------+-------------------------------+-------------------------------------------------------------------------------+ | wavelet choice | determines trade-off between | - Morlet ("cmorB-C") for oscillation | | | time and frequency resolution | - Mexican hat ("mexh") for transient spike detection | +------------------+-------------------------------+-------------------------------------------------------------------------------+ .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ #=============================================================== # Static methods that check for available individual wavelet options # in https://pywavelets.readthedocs.io/en/latest/ref/cwt.html #=============================================================== @staticmethod def _is_cmor_format(w): # "cmorB-C" floating points B, C pattern = r"^cmor(-?\d+\.?\d*)-(-?\d+\.?\d*)$" return bool(re.match(pattern, w)) @staticmethod def _is_shan_format(w): # "shanB-C" floating points B, C pattern = r"^shan(-?\d+\.?\d*)-(-?\d+\.?\d*)$" return bool(re.match(pattern, w)) @staticmethod def _is_fbsp_format(w): # "fbspM-B-C" floating points B, C and integer M pattern = r"^fbsp(-?\d+)(-?\d+\.?\d*)-(-?\d+\.?\d*)$" return bool(re.match(pattern, w)) @staticmethod def _is_gaus_format(w): # "gausP" integer P pattern = r"^gaus(-?\d+)$" return bool(re.match(pattern, w)) @staticmethod def _is_cgau_format(w): # "cgauP" integer P pattern = r"^cgau(-?\d+)$" return bool(re.match(pattern, w)) @classmethod def _check_pywt_wavelet_format(cls, wavelet): """ This method checks for wavelet options available in `pywt.cwt <https://pywavelets.readthedocs.io/en/latest/ref/cwt.html>`_. """ pywt_wavelets1 = ["mexh", "morl", ] pywt_wavelets2 = ["cmor", "gaus", "cgau", "shan", "fbsp"] if wavelet is not pywt_wavelets1: is_any_pywt_wavelet = [getattr(cls, "_is_" + name + "_format")(wavelet) for name in pywt_wavelets2] if not any(is_any_pywt_wavelet): raise ValueError("wavelet must be one of the strings in the list: " + str(pywt_wavelets1 + pywt_wavelets2))
[docs] @staticmethod def smooth_signal(spiketimes_set, sampling_rate=None, window=None, neurons=None, sigma=None): """ This method takes the spike times and converts it into respective binary spike trains which in turn is smoothened. The returned smoothened signal can be used to create a firing rate signal. :param spiketimes_set: Dictionary returned using :meth:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_superset` or using :meth:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_subset` :param window: Tuple in the form `(start_time, end_time)`; `(0, 10)` [default] :param neurons: `"all"` [default] or `scalar` or `range(a, b)` or list of neuron ids like `[2, 3, 6, 7]` - `"all"` means subset = superset - `N` (a scalar) means subset of first N neurons in the superset - `range(a, b)` or `[2, 3, 6, 7]` means subset of selected neurons :param sigma: standard deviation amount from a mean; `2` [default] :param sampling_rate: `1000/dt = 10000` Hz [default]; sampling_rate ∊ (0, 10000) :return: 4-tuple - smooth signal - list of neuron id's - time axis - sampling period Smoothening is done by Gaussian filtering ========================================= - each binary spike (= 1) is placed into a Gaussian-shaped bump - the Gaussian filter replaces each point (spike = 1) with a Gaussian distribution centered at that position - the overall result is many Gaussian distributions (each per point when spike = 1) - where spikes are close the Gaussians overlap - sum together the overlaping Gaussians (i.e convolution) - this represents the "dense estimate" of the spikes, i.e, smoothened curve **Formula** .. table:: Formula =========================================================================================================================== ====================================================== Definitions Interpretation =========================================================================================================================== ====================================================== total neurons, :math:`n_{nuc}` total number of neurons in the Nucleus neuron index, :math:`i` i-th neuron in the pool of :math:`n_{Nuc}` neurons total spikes, :math:`n_{spk}^{(i)}` total number of spikes (spike times) by i-th neuron :math:`\\vec{S}^{(i)}` array of spike times of i-th neuron :math:`S = \\left\\{\\vec{S}^{(i)} \\mid \\forall{i \\in [1, n_{nuc}]} \\right\\}` set of spike times of all neurons :math:`\\vec{B}^{(i)} = \\left[\\sum_{k=1}^{n_{spk}^{(i)}} \\delta[t - t_k]\\right]_{\\forall{t} \\in [t_1, t_{n_{spk}^{(i)}}]}` binary spike train of i-th neuron for spike times :math:`\\vec{S}^{(i)}` at :math:`t_1, t_2, ..., t_{n_{spk}^{(i)}}` :math:`B = \\left\\{\\vec{B}^{(i)} \\mid \\forall{i \\in [1, n_{nuc}]} \\right\\}` set of spike trains of all neurons :math:`\\vec{G} = \\left[\\frac{1}{\\sigma \\sqrt{2\\pi}}e^{-\\frac{k^2}{2\\sigma^2}}\\right]_{\\forall{k}}` Gaussian kernel =========================================================================================================================== ====================================================== Then, the smoothened signal for i-th neuron is .. math:: \\vec{M}^{(i)} &= \\vec{B}^{(i)} \\ast \\vec{G} \n &\\triangleq \\left[ \\sum_{k=1}^{n_{spk}^{(i)}} \\left(B^{(i)}[t_k] \\cdot G[t-t_k]\\right) \\right]_{\\forall{t} \\in [t_1, t_{n_{spk}^{(i)}}]} \n &= \\left[ \\sum_{k=1}^{n_{spk}^{(i)}} G[t-t_k] \\right]_{\\forall{t} \\in [t_1, t_{n_{spk}^{(i)}}]} Note that the last line is due to the fact that only non-zero :math:`B^{(i)}[t_k]` occurs at spike positions. .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ # ============== DEFAULT Parameters ============== if sampling_rate is None: sampling_rate = 1 / siganal.sampling_period sampling_period = 1.0 / sampling_rate if window is None: window = siganal.window if neurons is None: neurons = "all" if sigma is None: sigma = siganal.std_Gaussian_kernel # Convert spike times to spike train [spiketrains, yticks, time_axis] = get_binary_spiketrains(spiketimes_set, sampling_rate=sampling_rate, window=window, neurons=neurons) # Return the smoothened spike train return gaussian_filter1d(spiketrains, sigma=sigma), yticks, time_axis, sampling_period
[docs] @staticmethod def scale_to_freq(scale=None, wavelet=None, sampling_rate=None): """ Converts a scale value to its corresponding frequency value. :param scale: a scalar :param wavelet: name of wavelet type available in `pywt.cwt <https://pywavelets.readthedocs.io/en/latest/ref/cwt.html>`_ :param sampling_rate: `1000/dt = 10000` Hz [default]; sampling_rate ∊ (0, 10000) :return: a frequency value .. list-table:: **Scale is the dilation/compression factor applied to the wavelet.** :widths: 15 15 :header-rows: 1 * - *Scale vs Frequency* - Scale vs Voices/Octave * - Scales defines the frequencies in the wavelet transform analysis. - Voices per octave controls the number of scales between consecutive frequencies (≜ octave) * - :math:`s_a < s_b \\overset{\\frown}{=} f_a > f_b` where scale :math:`s_x` corresponds to frequency :math:`f_x` - In other words, the number of scales between octaves is the voices per octave .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ # ============== DEFAULT Parameters ============== if sampling_rate is None: # sampling_rate = 1 / siganal.sampling_period sampling_period = siganal.sampling_period else: sampling_period = 1.0 / sampling_rate center_freq = pywt.central_frequency(wavelet) # Property of a specific wavelet frequency = center_freq / (scale * sampling_period) return frequency
[docs] @staticmethod def freq_window_to_scales(freq_window=None, voices_per_octave=None, wavelet=None, sampling_rate=None): """ Returns the array of scales between octaves given :param freq_window: Tuple `(min_freq, max_freq)` :param voices_per_octave: scalar :param wavelet: name of wavelet type available in `pywt.cwt <https://pywavelets.readthedocs.io/en/latest/ref/cwt.html>`_ :param sampling_rate: `1000/dt = 10000` Hz [default]; sampling_rate ∊ (0, 10000) +-------------------+---------------------------------------+------------------------------------+ | Parameter | Description | Comment | +===================+=======================================+====================================+ | frequency range | - match signal's expected content | - 1-100 Hz (neuron spikes) | +-------------------+---------------------------------------+------------------------------------+ | Voices Per Octave | - VPO ∝ frequency resolution | - 10-16 (general) | | (VPO) | - VPO ∝ 1/computational cost | - 32 (high precision analysis) | +-------------------+---------------------------------------+------------------------------------+ | wavelet choice | - time/frequency resolution trade-off | `"cmorB-C"` (good for neural data) | +-------------------+---------------------------------------+------------------------------------+ **Scale is the dilation/compression factor applied to the wavelet.** - Scale vs Frequency - Scales defines the frequencies in the wavelet transform analysis. - :math:`s_a < s_b \\overset{\\frown}{=} f_a > f_b` where scale :math:`s_x` corresponds to frequency :math:`f_x` - **Scale vs Voices/Octave** - Voices per octave controls the number of scales between consecutive frequencies (≜ octave) - In other words, the number of scales between octaves is the voices per octave .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ # ============== DEFAULT Parameters ============== if sampling_rate is None: sampling_rate = 1 / siganal.sampling_period min_freq = freq_window[0] max_freq = freq_window[1] num_of_octaves = np.log2(max_freq / min_freq) num_of_scales = int(num_of_octaves * voices_per_octave) # Convert frequencies to approximate scales min_scale = pywt.frequency2scale(wavelet, min_freq, sampling_period=1.0/sampling_rate) max_scale = pywt.frequency2scale(wavelet, max_freq, sampling_period=1.0 / sampling_rate) # Generate scales scales = np.geomspace(min_scale, max_scale, num=num_of_scales) return scales
@classmethod def _compute_cwt_single(cls, spiketimes_set, sampling_rate=None, window=None, sigma=None, scales=None, wavelet=None, neuron_indx=None,): """ Compute the Continuous Wavelet Transform for a single neuron :param spiketimes_set: Dictionary returned using :meth:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_superset` or using :meth:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_subset` [OPTIONAL] :param sampling_rate: `1000/dt = 10000` Hz [default]; sampling_rate ∊ (0, 10000) :param window: Tuple in the form `(start_time, end_time)`; `(0, 10)` [default] :param sigma: standard deviation value, `2` [default] :param wavelet: `"cmor1.5-1.0"` [default], for possible options see `pywt.cwt <https://pywavelets.readthedocs.io/en/latest/ref/cwt.html>`_ :param scales: Tuple; `(1, 128)` [default] :param neuron_indx: randomly picks one [default] :return: 4-tuple - Continuous wavelet transform of the input signal - corresponding frequencies - time axis of the input signal - neuron id .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ # ============== DEFAULT Parameters ============== if window is None: window = siganal.window if sigma is None: sigma = siganal.std_Gaussian_kernel if scales is None: scales = np.arange(1, 128) else: scales = np.arange(scales[0], scales[1]) if wavelet is None: wavelet = "cmor1.5-1.0" # Check wavelet chosen is one of available option in https://pywavelets.readthedocs.io/en/latest/ref/cwt.html cls._check_pywt_wavelet_format(wavelet) # Convert spike times to spike trains [smoothed_signal, yticks, time_axis, sampling_period] = \ cls.smooth_signal(spiketimes_set, sampling_rate=sampling_rate, window=window, neurons="all", sigma=sigma) # ============== DEFAULT Parameters ============== if neuron_indx is None: neuron_indx = np.random.randint(0, high=len(yticks)) # Check wavelet chosen is one of available option in https://pywavelets.readthedocs.io/en/latest/ref/cwt.html cls._check_pywt_wavelet_format(wavelet) single_neuron_train = smoothed_signal[neuron_indx,:].flatten() coefficients, frequencies = pywt.cwt(single_neuron_train, scales, wavelet, sampling_period=sampling_period) # Return the results for a single neuron return coefficients, frequencies, time_axis, yticks[neuron_indx]
[docs] @classmethod def compute_cwt_avg(cls, spiketimes_set, sampling_rate=None, window=None, sigma=None, scales=None, wavelet=None, neurons=None,): """ Compute the Continuous Wavelet Transform as the average of the wavelet transform of the neurons in a population :param spiketimes_set: Dictionary returned using :meth:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_superset` or using :meth:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_subset` [OPTIONAL] :param sampling_rate: `1000/dt = 10000` Hz [default]; sampling_rate ∊ (0, 10000) :param window: Tuple in the form `(start_time, end_time)`; `(0, 10)` [default] :param sigma: standard deviation value, `2` [default] :param wavelet: `"cmor1.5-1.0"` [default], for possible options see `pywt.cwt <https://pywavelets.readthedocs.io/en/latest/ref/cwt.html>`_ :param scales: Tuple; `(1, 128)` [default] :return: 4-tuple - array of mean of coefficients of power spectral density - array of sample frequencies - list of neuron id's - array of time .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ # ============== DEFAULT Parameters ============== if neurons is None: neurons = "all" if window is None: window = siganal.window if sigma is None: sigma = siganal.std_Gaussian_kernel if scales is None: scales = np.arange(1, 128) else: scales = np.arange(scales[0], scales[1]) if wavelet is None: wavelet = "cmor1.5-1.0" # Check wavelet chosen is one of available option in https://pywavelets.readthedocs.io/en/latest/ref/cwt.html cls._check_pywt_wavelet_format(wavelet) # Convert spike times to spike trains [smoothed_signal, yticks, time_axis, sampling_period] = \ cls.smooth_signal(spiketimes_set, sampling_rate=sampling_rate, window=window, neurons=neurons, sigma=sigma) # Compute the Continuous Wavelet Transform for every neuron within the chosen neurons option "all" or selective all_coefficients = [] for i in range(smoothed_signal.shape[0]): single_neuron_train = smoothed_signal[i, :].flatten() coefficients, frequencies = pywt.cwt(single_neuron_train, scales, wavelet, sampling_period=sampling_period) all_coefficients.append(np.abs(coefficients)) # Return the results as average coefficients across chosen neurons return np.mean(all_coefficients, axis=0), frequencies, yticks, time_axis
[docs] @classmethod def compute_cwt_sum(cls, spiketimes_set, sampling_rate=None, window=None, sigma=None, scales=None, wavelet=None, neurons=None,): """ Compute the Continuous Wavelet Transform on the sum of signals across neurons in a population. :param spiketimes_set: Dictionary returned using :meth:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_superset` or using :meth:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_subset` [OPTIONAL] :param sampling_rate: `1000/dt = 10000` Hz [default]; sampling_rate ∊ (0, 10000) :param window: Tuple in the form `(start_time, end_time)`; `(0, 10)` [default] :param neurons: `"all"` [default] or `scalar` or `range(a, b)` or list of neuron ids like `[2, 3, 6, 7]` - `"all"` means subset = superset - `N` (a scalar) means subset of first N neurons in the superset - `range(a, b)` or `[2, 3, 6, 7]` means subset of selected neurons :param sigma: standard deviation value, `2` [default] :param wavelet: `"cmor1.5-1.0"` [default], for possible options see `pywt.cwt <https://pywavelets.readthedocs.io/en/latest/ref/cwt.html>`_ :param scales: Tuple; `(1, 128)` [default] :return: 4-tuple - array of mean of coefficients of power spectral density - array of sample frequencies - list of neuron id's - array of time .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ # ============== DEFAULT Parameters ============== if neurons is None: neurons = "all" if window is None: window = siganal.window if sigma is None: sigma = siganal.std_Gaussian_kernel if scales is None: scales = np.arange(1, 128) else: scales = np.arange(scales[0], scales[1]) if wavelet is None: wavelet = "cmor1.5-1.0" # Check wavelet chosen is one of available option in https://pywavelets.readthedocs.io/en/latest/ref/cwt.html cls._check_pywt_wavelet_format(wavelet) # Convert spike times to spike trains [smoothed_signal, yticks, time_axis, sampling_period] = \ cls.smooth_signal(spiketimes_set, sampling_rate=sampling_rate, window=window, neurons=neurons, sigma=sigma) # Compute Population Firing Rate from the sum of all chosen neurons population_train = np.sum(smoothed_signal, axis=0).flatten() coefficients, frequencies = pywt.cwt(population_train, scales, wavelet, sampling_period=sampling_period) # Return Population CWT across chosen neurons return coefficients, frequencies, yticks, time_axis