Source code for analyseur.cbgtc.visual.connections

# ~/analyseur/cbgtc/visual/connections.py
#
# Documentation by Lungsi 12 March 2026

import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import numpy as np

from analyseur.cbgtc.loader import FetchConnectionList
from analyseur.cbgtc.parameters import SimulationParams

simparam = SimulationParams()
simparam.nuclei_thal = simparam.nuclei_thal + ["CMPf"]

[docs] class Conn(object): """ ========= Use Cases ========= ------------------ 1. Pre-requisites ------------------ 1.1. Import Modules ```````````````````` :: from analyseur.cbgtc.visual.connectiions import Conn 1.2. Assign path to data location ````````````````````````````````` :: root_folder = "/path/to/data_folder/" The `root_folder` is the CBGT data directory whose structure is shown below .. code-block:: text . ├── BG/ │ ├── connection_list/ │ │ ├── scale=4_nbchannels=4/ │ │ │ └── model_9/ │ │ └── active_cortex_inputs_scale=4_nbchannels=4/ │ │ └── model_9/ │ └── ... ├── CORTEX/ │ ├── connection_list/ │ │ ├── Thalamus_inputs_nbpops=4/ │ │ └── nbpops=4/ │ └── ... ├── THALAMUS/ │ ├── connection_list/ │ │ ├── nbpops=4/ │ │ ├── BG_inputs_nbpops=4/ │ │ └── active_cortex_inputs_nbpops=4/ │ └── ... ├── ... : where * terminal folders in `connection_list/` contains files `connection_lists_i.dat` and `connection_lists_j.dat` 1.3. Instantiate class object ````````````````````````````` Following the choice of desired connected regions * `"CTX->CTX"` * `"CTX->BG"` * `"CTX->THAL"` * `"BG->THAL"` * `"THAL->CTX"` * `"BG->BG"` * `"THAL->THAL"` Note that tests abbreviations are **not** case-sensitive. Instantiate for `"ctx->bg"` :: conn = Conn(rootfolder=root_folder, region_connections="ctx->bg") # or simply conn = Conn(root_folder, "ctx->bg") --------- 2. Cases --------- For visualizing connection related stuffs invoke `conn.<method_name>` from the available options: +-------------------------------------------+--------------------------------+ | Method name | Obligatory argument | +===========================================+================================+ | :py:meth:`.connections_bar_chart` | no argument is mandatory | +-------------------------------------------+--------------------------------+ | :py:meth:`.overall_connections_bar_chart` | no argument is mandatory | +-------------------------------------------+--------------------------------+ | :py:meth:`.global_stats` | no argument is mandatory | +-------------------------------------------+--------------------------------+ | :py:meth:`.plot_connectivity_matrix` | string: "<nucleus>-><nucleus>" | +-------------------------------------------+--------------------------------+ | :py:meth:`.plot_density` | string: "<nucleus>-><nucleus>" | +-------------------------------------------+--------------------------------+ | :py:meth:`.plot_degree_distribution` | string: "<nucleus>-><nucleus>" | +-------------------------------------------+--------------------------------+ | :py:meth:`.plot_all_density` | no argument is mandatory | +-------------------------------------------+--------------------------------+ | :py:meth:`.plot_global_connectivity` | no argument is mandatory | +-------------------------------------------+--------------------------------+ | :py:meth:`.plot_global_density` | no argument is mandatory | +-------------------------------------------+--------------------------------+ | :py:meth:`.plot_channel_projection` | string: "<nucleus>-><nucleus>" | +-------------------------------------------+--------------------------------+ | :py:meth:`.plot_all_channel_projections` | no argument is mandatory | +-------------------------------------------+--------------------------------+ | :py:meth:`.plot_population_connectome` | no argument is mandatory | +-------------------------------------------+--------------------------------+ .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ # Dispatch table as a class attribute _FETCH_HANDLERS = { ("Cortex", "Cortex"): "within_cortex", ("BasalGanglia", "BasalGanglia"): "within_bg", ("Thalamus", "Thalamus"): "within_thalamus", ("Cortex", "BasalGanglia"): "cortex_to_bg", ("Cortex", "Thalamus"): "cortex_to_thalamus", ("BasalGanglia", "Thalamus"): "bg_to_thalamus", ("Thalamus", "Cortex"): "thalamus_to_cortex", } def __init__(self, rootfolder=None, region_connections="ctx->bg"): (self.source_region, self.target_region), fetch = self.__fetch_function(region_connections) self.conn_i, self.conn_j = fetch(rootfolder=rootfolder, verbose=True, nuclei_filter=True) self.source_target_pairs = list(self.conn_i.keys()) self.n_pairs = len(self.source_target_pairs) self.unique_sources = {key.split("->")[0] for key in self.conn_i.keys() if "->" in key} self.unique_targets = {key.split("->")[1] for key in self.conn_i.keys()} self.n_channels = simparam.size_info["bg"]["TOTAL_NUMBER_OF_CHANNELS"] def __fetch_function(self, region_connections: str): # 1. Split if "->" not in region_connections: raise ValueError("Format must be 'Source->Target'") src_raw, dst_raw = region_connections.split("->") # 2. Normalise using your static method (or a helper) src_norm = self.__normalize_region(src_raw) # returns e.g. "cortex" dst_norm = self.__normalize_region(dst_raw) # returns e.g. "basalganglia" # 3. Validate against the canonical set if (src_norm, dst_norm) not in simparam.connected_regions: raise ValueError(f"Connection '{region_connections}' is not allowed") else: return (src_norm, dst_norm), getattr(FetchConnectionList, self._FETCH_HANDLERS[(src_norm, dst_norm)]) @staticmethod def __normalize_region(name: str) -> str: clean = name.strip().lower().replace(" ", "").replace("_", "") if clean not in simparam.REGION_ALIASES: raise ValueError(f"Unknown region '{name}'") return simparam.REGION_ALIASES[clean] def __validate_n_channels(self, n_channels): """ Check n_channels. .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ if n_channels is None: n_channels = self.n_channels if not (1 <= n_channels <= self.n_channels): raise ValueError(f"n_channels must be between {1} and {self.n_channels}") return n_channels def __has_connections(self, pair): return len(self.conn_i[pair]) > 0 and len(self.conn_j[pair]) > 0
[docs] def connections_bar_chart(self, show=True): """ Show summary of connections at population level .. code-block:: text Source Region → Target Region Connectivity Populations Number of Connections src target -------------------------------- R1a → R2a ██████████████████████████ R1b → R2a ████ R1a → R2b ████ R1b → R2c █ R1b → R2b ▏ .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ fig, ax = plt.subplots(figsize=(12, 6)) connection_counts = [len(self.conn_i[pair]) for pair in self.source_target_pairs] ax.bar(self.source_target_pairs, connection_counts, color="skyblue") ax.set_title(f"Title {self.source_region}{self.target_region} Connections per Population") ax.set_xlabel(f"{self.target_region} Populations") ax.set_ylabel("Number of Connections") ax.tick_params(axis="x", rotation=45) ax.grid(True, alpha=0.3) # Add value labels on bars for i, v in enumerate(connection_counts): ax.text(i, v, str(v), ha="center", va="bottom") fig.tight_layout() if show: plt.show() print("Connection Statistics:") for pair in self.source_target_pairs: i_conn = len(self.conn_i[pair]) j_conn = len(self.conn_j[pair]) print(f" {pair}: {i_conn} connections (should equal {j_conn})") return fig, ax
[docs] def overall_connections_bar_chart(self, show=True): """ Compare connection patterns across all populations .. code-block:: text Source Region → Target Region Connectivity Total Connections R1a→R2a █████████████████████████████████████████ R1a→R2b ██ R1b→R2a ██ R1b→R2c ▏ R1b→R2b ▏ Unique Neurons Cortex: R1a→R2a ████████ R1a→R2b ███████ R1b→R2a ███████ BasalGanglia: R1a→R2a ██████████████████████████ R1b→R2a ████████ Avg Convergence (BG neurons) R1a→R2a █████████████████ R1a→R2b █████████████ R1b→R2c █████ Avg Divergence (Cortex neurons) R1a→R2a █████████████████████████████████████ R1a→R2b ██ R1b→R2a ██ .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ fig, axes = plt.subplots(2, 2, figsize=(15, 12)) # Plot 1: Total connections totals = [len(self.conn_i[pair]) for pair in self.source_target_pairs] axes[0,0].bar(self.source_target_pairs, totals, color="lightblue") axes[0,0].set_title("Total Connections") axes[0,0].tick_params(axis="x", rotation=45) # Plot 2: Unique neurons unique_source = [len(set(self.conn_i[pair])) for pair in self.source_target_pairs] unique_target = [len(set(self.conn_j[pair])) for pair in self.source_target_pairs] x = np.arange(self.n_pairs) width = 0.35 axes[0, 1].bar(x - width / 2, unique_source, width, label=f"{self.source_region}", alpha=0.7) axes[0, 1].bar(x + width / 2, unique_target, width, label=f"{self.target_region}", alpha=0.7) axes[0, 1].set_title("Unique Neurons") axes[0, 1].set_xticks(x) axes[0, 1].set_xticklabels(self.source_target_pairs, rotation=45) axes[0, 1].legend() # Plot 3: Convergence ratio convergence = [totals[i] / unique_target[i] if unique_target[i] > 0 else 0 for i in range(self.n_pairs)] axes[1, 0].bar(self.source_target_pairs, convergence, color="orange") axes[1, 0].set_title(f"Average Convergence (conns/{self.target_region} neurons)") axes[1, 0].tick_params(axis="x", rotation=45) # Plot 4: Divergence ratio divergence = [totals[i] / unique_source[i] if unique_source[i] > 0 else 0 for i in range(self.n_pairs)] axes[1, 1].bar(self.source_target_pairs, divergence, color="green") axes[1, 1].set_title(f"Average Divergence (conns/{self.source_region} neurons)") axes[1, 1].tick_params(axis="x", rotation=45) fig.tight_layout() if show: plt.show() return fig, axes
[docs] def plot_connectivity_matrix(self, pair_name, show=True): """ Plot connection matrix for a `<source nucleus>-><target nucleus>` (e.g `PTN->MSN`) .. code-block:: text Target neuron index │ 42000 ┤ ███████████████ │ │ ███████████████ │ 32000 ┤ ███████████████ │ │ ███████████████ │ 21000 ┤ ███████████████ │ │ ███████████████ │ 10000 ┤██████████████ │ │██████████████ │ 0 └──────────────────────────────────────────→ Source neuron index 0 2000 4000 6000 8000 .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ if pair_name not in self.conn_i: print("Population not found") return i = np.array(self.conn_i[pair_name]) j = np.array(self.conn_j[pair_name]) fig, ax = plt.subplots(figsize=(8, 8)) # ax.scatter(i, j, s=1, alpha=0.5) ax.scatter(i, j, s=1, alpha=0.3, rasterized=True) ax.set_xlim(0, max(i) + 1) ax.set_ylim(0, max(j) + 1) src_to_dst = pair_name.split("->") ax.set_xlabel(f"{src_to_dst[0]} neuron index") ax.set_ylabel(f"{src_to_dst[1]} neuron index") ax.set_title(f"Connectivity matrix: {self.source_region} ({src_to_dst[0]}) → {self.target_region} ({src_to_dst[1]})") # ax.set_aspect('equal') ax.grid(alpha=0.2) if show: plt.show() return fig, ax
[docs] def plot_all_connectivity_matrices(self, show=True): """ Shows all the projection patterns. .. code-block:: text Source → Target connectivity Target │ ████ │ ████ │ ████ │ ████ └──────────────── Source .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ cols = len(self.unique_targets) rows = int(np.ceil(self.n_pairs / cols)) fig, axes = plt.subplots(rows, cols, figsize=(5*cols, 5*rows)) if rows*cols > 1: axes = axes.flatten() else: axes = [axes] for idx, pair in enumerate(self.source_target_pairs): src_to_dst = pair.split("->") i = np.array(self.conn_i[pair]) j = np.array(self.conn_j[pair]) axes[idx].scatter(i, j, s=1, alpha=0.4) # axes[idx].set_title(pair) axes[idx].set_xlabel(f"{self.source_region} ({src_to_dst[0]})") axes[idx].set_ylabel(f"{self.target_region} ({src_to_dst[1]})") axes[idx].grid(alpha=0.2) # Hide empty panels for k in range(idx+1, len(axes)): axes[k].axis("off") fig.suptitle(f"{self.source_region}{self.target_region} Connectivity (All Populations)") fig.tight_layout() if show: plt.show() return fig, axes
[docs] def plot_density(self, pair_name, bins=100, show=True): """ Plot density heatmap for a `<Source nucleus>-><Target nucleus>` (e.g `PTN->MSN`) .. code-block:: text Basal Ganglia neurons 16000 | :##::*:*:#* 14000 | **:#*::*:#: 12000 | :*#*::*:#*: 10000 | :*#::*:#*#: 8000 | *:#*#::*#: 6000 | :##::*:*#* 4000 | *:#*::*:#: 2000 | :*#::*:#* ----------------------------------------------------→ Cortex neurons 0 2000 4000 6000 8000 .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ i = np.array(self.conn_i[pair_name]) j = np.array(self.conn_j[pair_name]) fig, ax = plt.subplots(figsize=(8, 8)) # ax.hist2d(i, j, bins=bins, cmap="inferno") h = ax.hist2d(i, j, bins=bins, cmap="inferno", rasterized=True) fig.colorbar(h[3], ax=ax, label="Number of connections") ax.set_xlim(0, max(i) + 1) ax.set_ylim(0, max(j) + 1) src_to_dst = pair_name.split("->") ax.set_xlabel(f"{src_to_dst[0]} neuron") ax.set_ylabel(f"{src_to_dst[1]} neuron") ax.set_title(f"Projection density: {self.source_region} ({src_to_dst[0]}) → {self.target_region} ({src_to_dst[1]})") if show: plt.show() return fig, ax
[docs] def plot_all_density(self, bins=100, show=True): """ Shows connection density patterns for all source region nucleus to target region nucleus. .. code-block:: text Projection Density: Source Region → Target Target neurons │ [::::***:::#:] │ [:::**:*::] │ [::*:#::*] │ [::**:#::] └────────────────────────────────→ Source neurons .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ cols = len(self.unique_targets) rows = int(np.ceil(self.n_pairs / cols)) fig, axes = plt.subplots(rows, cols, figsize=(5*cols, 5*rows)) if rows*cols > 1: axes = np.atleast_1d(axes).flatten() else: axes = [axes] for idx, pair in enumerate(self.source_target_pairs): src_to_dst = pair.split("->") i = np.array(self.conn_i[pair]) j = np.array(self.conn_j[pair]) axes[idx].hist2d(i, j, bins=bins, cmap="inferno") # axes[idx].set_title(pair) axes[idx].set_xlabel(f"{self.source_region} ({src_to_dst[0]})") axes[idx].set_ylabel(f"{self.target_region} ({src_to_dst[1]})") for k in range(idx+1, len(axes)): axes[k].axis("off") fig.suptitle(f"Connection Density: {self.source_region}{self.target_region}") fig.tight_layout() if show: plt.show() return fig, axes
[docs] def plot_degree_distribution(self, pair_name, show=True): """ Plot convergence and divergence patterns for a `<Source nucleus>-><Target nucleus>` (e.g `PTN->MSN`) .. code-block:: text Source divergence 1 ███████████████████████████████████ 2 ███████████████████ 3 █████████ 4 ███ 5 █ 6 ▏ 7 ▏ Target convergence 1 █████████████████████████████████████████████████ 2 ▏ 3 ▏ 4 ▏ .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ i = np.array(self.conn_i[pair_name]) j = np.array(self.conn_j[pair_name]) source_deg = np.bincount(i) target_deg = np.bincount(j) fig, axes = plt.subplots(1, 2, figsize=(12, 5)) axes[0].hist(source_deg[source_deg>0], bins=50) axes[0].set_title(f"{self.source_region} divergence") axes[0].set_xlabel("Connections per neuron") axes[1].hist(target_deg[target_deg>0], bins=50) axes[1].set_title(f"{self.target_region} convergence") axes[1].set_xlabel("Connections per neuron") fig.tight_layout() if show: plt.show() return fig, axes
[docs] def plot_global_connectivity(self, n_channels=None, band_height=2000, density_contours=False): """ Shows global connectivity scatter plot with channel boundaries. .. code-block:: text Target neurons |----|----|----|----| | ██ | | | | |----|----|----|----| | | ██ | | | |----|----|----|----| | | | ██ | | → Source neurons .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ n_channels = self.__validate_n_channels(n_channels) plt.figure(figsize=(10,8)) yticks = [] ylabels = [] source_max = 0 colors = plt.cm.tab10(np.linspace(0,1,len(self.conn_i))) for idx, pair in enumerate(self.conn_i.keys()): i = np.array(self.conn_i[pair]) j = np.array(self.conn_j[pair]) if len(i) == 0: continue source_max = max(source_max, i.max()) # normalize target neurons inside band j_norm = (j - j.min()) / (j.max() - j.min() + 1e-9) j_scaled = j_norm * band_height + idx * band_height plt.scatter(i, j_scaled, s=1, alpha=0.4, color=colors[idx]) # Density contours highlights where most synapses occur if density_contours: plt.hexbin(i, j_scaled, gridsize=100, cmap="inferno", alpha=0.6) yticks.append(idx * band_height + band_height/2) ylabels.append(pair) # draw source channel boundaries source_per_channel = source_max / n_channels for c in range(1, n_channels): plt.axvline(c * source_per_channel, color="black", linestyle="--", alpha=0.3) # draw target population boundaries for p in range(1, len(self.conn_i)): plt.axhline(p * band_height, color="black", linewidth=1) plt.xlabel(f"{self.source_region} neuron index") plt.ylabel(f"{self.target_region} populations") plt.yticks(yticks, ylabels) plt.title(f"Global {self.source_region}{self.target_region} Connectivity") plt.tight_layout() plt.show()
[docs] def plot_global_density(self, bins=150, show=True): """ Connectome-style plot. .. code-block:: text Source → Target projection density Target │ ░▓█ │ ░▓█ │ ░▓█ │░▓█ └──────── Source Legend: █ very high ▓ high ▒ medium ░ low zero .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ all_i = [] all_j = [] offset = 0 pop_offsets = {} for pair in self.conn_i.keys(): i = np.array(self.conn_i[pair]) j = np.array(self.conn_j[pair]) j_shifted = j + offset all_i.append(i) all_j.append(j_shifted) pop_offsets[pair] = offset if len(j) == 0: continue offset += max(j) + 10 all_i = np.concatenate(all_i) all_j = np.concatenate(all_j) fig, ax = plt.subplots(figsize=(10, 8)) # plt.hist2d(all_i, all_j, bins=bins, cmap="inferno") h = ax.hist2d(all_i, all_j, bins=bins, cmap="inferno", rasterized=True) fig.colorbar(h[3], ax=ax, label="Number of connections") ax.set_xlabel(f"{self.source_region} neurons") ax.set_ylabel(f"{self.target_region} neurons") ax.set_title(f"Global {self.source_region}{self.target_region} Projection Density") fig.tight_layout() if show: plt.show() return fig, ax
[docs] def global_stats(self): """ Returns connection statistics for each population connection pair: * total connections * number of cortex neurons * number of basal ganglia neurons * convergence * divergence .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ print("\nGlobal Connectivity Statistics\n") for pair in self.conn_i.keys(): i = np.array(self.conn_i[pair]) j = np.array(self.conn_j[pair]) total = len(i) unique_source = len(np.unique(i)) unique_target = len(np.unique(j)) convergence = total / unique_target if unique_target else 0 divergence = total / unique_source if unique_source else 0 print(f"{pair}") print(f" total connections : {total}") print(f" {self.source_region} neurons : {unique_source}") print(f" {self.target_region} neurons : {unique_target}") print(f" convergence : {convergence:.2f}") print(f" divergence : {divergence:.2f}") print()
[docs] def plot_channel_projection(self, pair_name, n_channels=None, show=True): """ Show channel-projection map for desired population pair as diagonal channel blocks. .. code-block:: text Target channels 0 1 2 3 Sx 0 ██ Sx 1 ██ Sx 2 ██ Sx 3 ██ **Patterns References:** *Focused connectivity* .. code-block:: text *Diffuse connectivity* .. code-block:: text ████ ████ ████ *Surround inhibition* .. code-block:: text ███ *Channel crosstalk* .. code-block:: text █ █ █ █ .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ n_channels = self.__validate_n_channels(n_channels) i = np.array(self.conn_i[pair_name]) j = np.array(self.conn_j[pair_name]) # estimate neurons per channel source_per_channel = max(i) // n_channels + 1 target_per_channel = max(j) // n_channels + 1 source_channels = i // source_per_channel target_channels = j // target_per_channel matrix = np.zeros((n_channels, n_channels)) for cx, bg in zip(source_channels, target_channels): matrix[cx, bg] += 1 fig, ax = plt.subplots(figsize=(6, 6)) im = ax.imshow(matrix, cmap="inferno", origin="lower") fig.colorbar(im, ax=ax, label="Number of connections") # Set ticks from 1 to n_channels ax.set_xticks(range(n_channels), range(1, n_channels + 1)) ax.set_yticks(range(n_channels), range(1, n_channels + 1)) src_to_dst = pair_name.split("->") ax.set_xlabel(f"{src_to_dst[0]} channel") ax.set_ylabel(f"{src_to_dst[1]} channel") ax.set_title(f"Channel Projection: {self.source_region} ({src_to_dst[0]}) → {self.target_region} ({src_to_dst[1]})") fig.tight_layout() if show: plt.show() return fig, ax
[docs] def plot_all_channel_projections(self, n_channels=None): """ Show channel-projection map for all population pairs using :py:meth:`.plot_channel_projection` .. raw:: html <hr style="border: 2px solid red; margin: 20px 0;"> """ for pair in self.conn_i.keys(): if not self.__has_connections(pair): continue print("Population:", pair) self.plot_channel_projection(pair, n_channels)
[docs] def plot_population_connectome(self, show=True): """ Displays model connectivity as a connectome diagram. .. code-block:: text Cortex // \\ CSN PTN \\ // Striatum // \\ FSI MSN .. raw:: html <hr """ legend_elements = [ Patch(facecolor='lightgreen', label='Source'), Patch(facecolor='salmon', label='Target'), # Patch(facecolor='lightblue', label='Intermediate'), ] G = nx.DiGraph() edges = [] for pair in self.conn_i.keys(): n_conn = len(self.conn_i[pair]) if "->" in pair: src, dst = pair.split("->") else: src = f"{self.source_region}" dst = pair edges.append((src, dst, n_conn)) for src, dst, weight in edges: G.add_edge(src, dst, weight=weight) # Determine node colors based on role in_deg = dict(G.in_degree()) out_deg = dict(G.out_degree()) color_map = [] for node in G.nodes(): if out_deg[node] > 0 and in_deg[node] == 0: color_map.append("lightgreen") # source nodes elif in_deg[node] > 0 and out_deg[node] == 0: color_map.append("salmon") # target nodes else: color_map.append("lightblue") # intermediate (both) # pos = nx.circular_layout(G) pos = nx.spring_layout(G, seed=42) weights = np.array([G[u][v]["weight"] for u, v in G.edges()]) # widths = 1 + 6 * weights / weights.max() if weights.size > 0 else [] if weights.size > 0 and weights.max() > 0: widths = 1 + 6 * weights / weights.max() else: widths = np.ones_like(weights) fig, ax = plt.subplots(figsize=(7, 7)) # nx.draw_networkx_nodes(G, pos, node_size=2500, node_color=color_map) # nx.draw_networkx_labels(G, pos) # if widths.size > 0: # nx.draw_networkx_edges(G, pos, width=widths, arrows=True, arrowsize=20) nx.draw_networkx_nodes(G, pos, node_size=2500, node_color=color_map, ax=ax) nx.draw_networkx_labels(G, pos, ax=ax) nx.draw_networkx_edges(G, pos, width=widths, arrows=True, arrowsize=20, ax=ax) if self.source_region != self.target_region: ax.legend(handles=legend_elements, loc='upper right') ax.set_title(f"{self.source_region}{self.target_region} Population Connectivity") ax.axis("off") if show: plt.show() return fig, ax