# ~/analyseur/cbgtc/visual/peristimulus.py
#
# Documentation by Lungsi 8 Oct 2025
#
# This contains function for Peri-Stimulus Time Histogram (PSTH)
#
import numpy as np
import matplotlib.pyplot as plt
from analyseur.cbgtc.parameters import SignalAnalysisParams
from analyseur.cbgtc.stats.psth import PSTH
[docs]
class VizPSTH(object):
"""
The Peri-Stimulus Time Histogram (PSTH) Class.
+--------------------------------+--------------------------------------------------------------------+
| Methods | Return |
+================================+====================================================================+
| :py:meth:`.plot` | - `matplotlib.pyplot.hist` object |
+--------------------------------+--------------------------------------------------------------------+
| :py:meth:`.plot_in_ax` | - `matplotlib.pyplot.axis` object |
+--------------------------------+--------------------------------------------------------------------+
**Use Case:**
1. Setup
::
from analyseur.cbgtc.loader import LoadSpikeTimes
from analyseur.cbgtc.visual.peristimulus import VizPSTH
loadST = LoadSpikeTimes("/full/path/to/spikes_GPi.csv")
spiketimes_superset = loadST.get_spiketrains()
2. Peri-Stimulus Time Histogram for the whole simulation window
::
VizPSTH.plot_pool(spiketimes_superset)
3. PSTH for desired window and bin size
::
VizPSTH.plot_pool(spiketimes_superset, window=(0,5), binsz=1) # time unit in seconds
VizPSTH.plot_pool(spiketimes_superset, window=(0,5), binsz=0.05)
* PSTH gives an overall temporal pattern of population activity with a picture in both temporal and rate
* The computation is done by :class:`~analyseur.cbgtc.stats.psth.PSTH`
.. raw:: html
<hr style="border: 2px solid red; margin: 20px 0;">
"""
__siganal = SignalAnalysisParams()
[docs]
@classmethod
def plot_pool_in_ax(cls, ax, spiketimes_set, binsz=None, window=None, neurons=None, nucleus=None):
"""
.. code-block:: text
Pooled PSTH (Population Activity)
Total Spike Count
|
| ████ █████ ████ █████ ████
| ████ █████ ████ █████ ████
| ████ █████ ████ █████ ████
|
+--------------------------------------------------> Time (s)
0 2 4 6 8 10
Histogram bins show spike counts aggregated across neurons.
Each bar represents the number of spikes from the pooled neuron
population within a time bin.
Draws the Pooled Peri-Stimulus Time Histogram (PSTH) on the given
`matplotlib.pyplot.axis <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.axis.html>`_
:param ax: object `matplotlib.pyplot.axis``
:param spiketimes_set: Dictionary returned using :meth:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_superset`
or using :meth:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_subset`
:param binsz: integer or float; `0.01` [default]
:param window: Tuple in the form `(start_time, end_time)`; `(0, 10)` [default]
:param neurons: `"all"` [default] or `scalar` or `range(a, b)` or list of neuron ids like `[2, 3, 6, 7]`
- `"all"` means subset = superset
- `N` (a scalar) means subset of first N neurons in the superset
- `range(a, b)` or `[2, 3, 6, 7]` means subset of selected neurons
:param nucleus: string; name of the nucleus
:return: object `ax` with PSTH plotting done into it
.. raw:: html
<hr style="border: 2px solid red; margin: 20px 0;">
"""
#============== DEFAULT Parameters ==============
if window is None:
window = cls.__siganal.window
if binsz is None:
binsz = cls.__siganal.binsz_100perbin
if neurons is None:
neurons = "all"
# Compute PSTH
# [counts, bin_info, _, true_avg_rate, _] = PSTH.compute_poolPSTH(spiketimes_set, neurons=neurons,
# binsz=binsz, window=window)
[_, bin_info, _, true_avg_rate, pooled_spikes] = PSTH.compute_poolPSTH(spiketimes_set, neurons=neurons,
binsz=binsz, window=window)
n_neurons = len(true_avg_rate["firing_rates"])
# Plot
#ax.bar(bin_centers, counts, width=binsz, alpha=0.7, color="blue", edgecolor="black")
ax.hist(pooled_spikes, bins=bin_info["bins"], alpha=0.7, color="blue", edgecolor="black")
ax.grid(True, alpha=0.3)
ax.set_ylabel("Total Spike Count")
ax.set_xlabel("Time (s)")
nucname = "" if nucleus is None else " in " + nucleus
ax.set_title("Pooled PSTH - Population Activity of " + str(n_neurons) + " neurons" + nucname +
"\n (mean firing rate within the window = "
+ str(true_avg_rate["mean_firing_rate"]) + " Hz)")
return ax
[docs]
@classmethod
def plot_pool(cls, spiketimes_set, binsz=0.01, window=(0, 10), neurons="all", nucleus=None, show=True):
"""
Displays the Pooled Peri-Stimulus Time Histogram (PSTH) of the given spike times (seconds)
and returns the plot figure (to save if necessary) using :py:meth:`.plot_pool_in_ax`.
:param spiketimes_set: Dictionary returned using :meth:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_superset`
or using :meth:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_subset`
:param binsz: integer or float; `0.01` [default]
:param window: Tuple in the form `(start_time, end_time)`; `(0, 10)` [default]
:param neurons: `"all"` [default] or `scalar` or `range(a, b)` or list of neuron ids like `[2, 3, 6, 7]`
- `"all"` means subset = superset
- `N` (a scalar) means subset of first N neurons in the superset
- `range(a, b)` or `[2, 3, 6, 7]` means subset of selected neurons
:param nucleus: string; name of the nucleus
:param show: boolean [default: True]
:return: object `matplotlib.axes.Axes <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html#matplotlib.axes.Axes>`_
containing `matplotlib.pyplot.bar <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html#matplotlib.axes.Axes>`_
* `window` controls the binning range as well as the spike counting window
* CBGT simulation was done in seconds so window `(0, 10)` signifies time 0 s to 10 s
.. raw:: html
<hr style="border: 2px solid red; margin: 20px 0;">
"""
# Set binsz and window as the instance attributes
cls.binsz = binsz
cls.window = window
# Plot
fig, ax = plt.subplots(figsize=(10, 6))
ax = cls.plot_pool_in_ax(ax, spiketimes_set, binsz=binsz, window=window,
neurons=neurons, nucleus=nucleus)
if show:
plt.show()
return fig, ax
[docs]
@classmethod
def plot_avg_in_ax(cls, ax, spiketimes_set, binsz=None, window=None, neurons=None, nucleus=None):
"""
.. code-block:: text
Averaged PSTH (Population Activity)
Average Firing Rate (Hz)
^
| ┬ ┬ ┬ ┬ ┬
| ███ ███ ███ ███ ███
| ███ ███ ███ ███ ███
| ███ ███ ███ ███ ███
| ┴ ┴ ┴ ┴ ┴
+--------------------------------------------------> Time (s)
0 2 4 6 8 10
Bars represent the population-average firing rate per time bin.
Error bars indicate variability across neurons (SEM).
Draws the Averaged Peri-Stimulus Time Histogram (PSTH) on the given
`matplotlib.pyplot.axis <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.axis.html>`_
:param ax: object `matplotlib.pyplot.axis``
:param spiketimes_set: Dictionary returned using :meth:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_superset`
or using :meth:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_subset`
:param binsz: integer or float; `0.01` [default]
:param window: Tuple in the form `(start_time, end_time)`; `(0, 10)` [default]
:param neurons: `"all"` [default] or `scalar` or `range(a, b)` or list of neuron ids like `[2, 3, 6, 7]`
- `"all"` means subset = superset
- `N` (a scalar) means subset of first N neurons in the superset
- `range(a, b)` or `[2, 3, 6, 7]` means subset of selected neurons
:param nucleus: string; name of the nucleus
:return: object `ax` with PSTH plotting done into it
.. raw:: html
<hr style="border: 2px solid red; margin: 20px 0;">
"""
#============== DEFAULT Parameters ==============
if window is None:
window = cls.__siganal.window
if binsz is None:
binsz = cls.__siganal.binsz_100perbin
if neurons is None:
neurons = "all"
# Compute PSTH
# average_psth, std_err_psth, bin_info, popfirerates, true_avg_rate, desired_spiketimes_subset
[average_psth, std_err_psth, bin_info, _, true_avg_rate, _] = PSTH.compute_avgPSTH(spiketimes_set, neurons=neurons,
binsz=binsz, window=window)
n_neurons = len(true_avg_rate["firing_rates"])
# Plot
ax.bar(bin_info["bin_centers"], average_psth, width=bin_info["binsz"],
alpha=0.7, color="green", edgecolor="black",
yerr=std_err_psth, capsize=3, error_kw={"elinewidth": 2})
ax.grid(True, alpha=0.3)
ax.set_ylabel("Average Firing Rate (1/s)")
ax.set_xlabel("Time (s)")
nucname = "" if nucleus is None else " in " + nucleus
ax.set_title("Averaged PSTH - Population Activity of " + str(n_neurons) + " neurons" + nucname +
"\n (mean firing rate within the window = "
+ str(true_avg_rate["mean_firing_rate"]) + " Hz)")
return ax
[docs]
@classmethod
def plot_avg(cls, spiketimes_set, binsz=0.01, window=(0, 10), neurons="all", nucleus=None, show=True):
"""
Displays the Averaged Peri-Stimulus Time Histogram (PSTH) of the given spike times (seconds)
and returns the plot figure (to save if necessary) using :py:meth:`.plot_avg_in_ax`.
:param spiketimes_set: Dictionary returned using :meth:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_superset`
or using :meth:`~analyseur.cbgtc.loader.LoadSpikeTimes.get_spiketimes_subset`
:param binsz: integer or float; `0.01` [default]
:param window: Tuple in the form `(start_time, end_time)`; `(0, 10)` [default]
:param neurons: `"all"` [default] or `scalar` or `range(a, b)` or list of neuron ids like `[2, 3, 6, 7]`
- `"all"` means subset = superset
- `N` (a scalar) means subset of first N neurons in the superset
- `range(a, b)` or `[2, 3, 6, 7]` means subset of selected neurons
:param nucleus: string; name of the nucleus
:param show: boolean [default: True]
:return: object `matplotlib.axes.Axes <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html#matplotlib.axes.Axes>`_
containing `matplotlib.pyplot.bar <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html#matplotlib.axes.Axes>`_
* `window` controls the binning range as well as the spike counting window
* CBGT simulation was done in seconds so window `(0, 10)` signifies time 0 s to 10 s
.. raw:: html
<hr style="border: 2px solid red; margin: 20px 0;">
"""
# Set binsz and window as the instance attributes
cls.binsz = binsz
cls.window = window
# Plot
fig, ax = plt.subplots(figsize=(10, 6))
ax = cls.plot_avg_in_ax(ax, spiketimes_set, binsz=binsz, window=window,
neurons=neurons, nucleus=nucleus)
if show:
plt.show()
return fig, ax