Source code for design_research_analysis.sequence.visualization

"""Visualization helpers for Markov-chain and HMM transition dynamics."""

from __future__ import annotations

from typing import Any, cast

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from .models import _state_labels, _transition_like_matrix

_SEQ_IMPORT_ERROR = (
    "State-graph visualization requires optional sequence dependencies. "
    "Install with `pip install design-research-analysis[seq]`."
)


def _resolve_transition_and_labels(
    transition: Any,
    state_labels: list[str] | None,
) -> tuple[np.ndarray, list[str]]:
    try:
        matrix = np.asarray(_transition_like_matrix(transition), dtype=float)
        default_labels = _state_labels(transition)
    except TypeError:
        matrix = np.asarray(transition, dtype=float)
        default_labels = [f"S{idx}" for idx in range(matrix.shape[0])] if matrix.ndim == 2 else []

    if matrix.ndim != 2 or matrix.shape[0] != matrix.shape[1]:
        raise ValueError("Transition matrix must be square.")

    labels = state_labels or default_labels
    if len(labels) != matrix.shape[0]:
        raise ValueError("state_labels length must match transition matrix size.")
    return matrix, labels


[docs] def plot_transition_matrix( transition: Any, *, state_labels: list[str] | None = None, ax: Axes | None = None, cmap: str = "Blues", annotate: bool = True, fmt: str = ".2f", title: str = "Transition Matrix", ) -> tuple[Figure, Axes]: """Plot a transition matrix as a heatmap. Args: transition: Result object or raw square matrix. state_labels: Optional display labels for states. ax: Optional Matplotlib axis. cmap: Heatmap colormap. annotate: Whether to annotate each cell with probabilities. fmt: Format string for annotations. title: Plot title. Returns: ``(figure, axis)`` tuple. """ matrix, labels = _resolve_transition_and_labels(transition, state_labels) if ax is None: fig, ax = plt.subplots(figsize=(6, 5)) else: fig = cast(Figure, ax.figure) image = ax.imshow(matrix, cmap=cmap, vmin=0.0, vmax=max(1.0, float(matrix.max()))) fig.colorbar(image, ax=ax, fraction=0.046, pad=0.04) ticks = np.arange(matrix.shape[0]) ax.set_xticks(ticks) ax.set_yticks(ticks) ax.set_xticklabels(labels, rotation=45, ha="right") ax.set_yticklabels(labels) ax.set_xlabel("Next state") ax.set_ylabel("Current state") ax.set_title(title) if annotate: formatter = "{" + f":{fmt}" + "}" for row in range(matrix.shape[0]): for col in range(matrix.shape[1]): value = formatter.format(matrix[row, col]) color = "white" if matrix[row, col] > (matrix.max() * 0.5) else "black" ax.text(col, row, value, ha="center", va="center", color=color) return fig, ax
[docs] def plot_state_graph( transition: Any, *, state_labels: list[str] | None = None, threshold: float = 0.0, ax: Axes | None = None, seed: int = 0, title: str = "State Transition Graph", ) -> tuple[Figure, Axes]: """Render a directed state-transition graph. Args: transition: Result object or raw square matrix. state_labels: Optional display labels for states. threshold: Draw edges with probability strictly above this value. ax: Optional Matplotlib axis. seed: Random seed passed to layout generation. title: Plot title. Returns: ``(figure, axis)`` tuple. """ if threshold < 0.0: raise ValueError("threshold must be non-negative.") try: import networkx as nx except ImportError as exc: raise ImportError(_SEQ_IMPORT_ERROR) from exc matrix, labels = _resolve_transition_and_labels(transition, state_labels) graph = nx.DiGraph() for idx, label in enumerate(labels): graph.add_node(idx, label=label) for src in range(matrix.shape[0]): for dst in range(matrix.shape[1]): weight = float(matrix[src, dst]) if weight > threshold: graph.add_edge(src, dst, weight=weight) if ax is None: fig, ax = plt.subplots(figsize=(7, 5)) else: fig = cast(Figure, ax.figure) if graph.number_of_nodes() == 0: ax.set_title(title) ax.axis("off") return fig, ax positions = nx.spring_layout(graph, seed=seed) nx.draw_networkx_nodes(graph, positions, node_size=1400, node_color="#d0e3ff", ax=ax) nx.draw_networkx_labels( graph, positions, labels={idx: labels[idx] for idx in range(len(labels))}, font_size=9, ax=ax, ) edge_widths = [max(0.8, 6.0 * graph[u][v]["weight"]) for u, v in graph.edges] nx.draw_networkx_edges( graph, positions, width=edge_widths, edge_color="#4f5d75", arrows=True, arrowsize=18, connectionstyle="arc3,rad=0.12", ax=ax, ) edge_labels = {(u, v): f"{graph[u][v]['weight']:.2f}" for u, v in graph.edges} if edge_labels: nx.draw_networkx_edge_labels(graph, positions, edge_labels=edge_labels, font_size=8, ax=ax) ax.set_title(title) ax.axis("off") return fig, ax
__all__ = ["plot_state_graph", "plot_transition_matrix"]