Source code for qrl.env.core.bloch_sphere

'''
Implementation of BlochSphereV0 environment

Author: Jay Shah (@Jayshah25)

Contact: jay.shah@qrlqai.com

License: Apache-2.0
'''


from gymnasium import spaces
from pennylane import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.patches as mpatches
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import matplotlib.lines as mlines
import networkx as nx
import shutil
import warnings
from typing import Literal, Optional
from ._base import QuantumEnv
from .utils import GATES, RX, RY, RZ, STATE_LABELS, STATE_BLOCH, STATE_VECTORS, ACTION_NAMES, _TRANSITIONS, _GRAPH_POS


[docs] class BlochSphereV0(QuantumEnv): """ Single-qubit Bloch sphere environment for reinforcement learning. ``BlochSphereV0`` is a ``gymnasium.Env``-compatible environment where an agent controls a single qubit via a discrete set of quantum gates. The qubit state is represented internally as a statevector and exposed to the agent as a 3D Bloch vector ``(x, y, z)``. The objective is to steer the qubit from the fixed initial state ``|0⟩`` to a target pure state (default ``|+⟩``) within a limited number of steps by applying unitary gate actions. Key details -------------- - **Action space**: Discrete set of single-qubit gates (Clifford + common rotations). - **Observation space**: Bloch vector ``(x, y, z)``, each component in ``[-1, 1]``. - **Reward**: Fidelity ``|⟨target | state⟩|²`` in ``[0, 1]``. - **Termination**: Success when reward exceeds ``reward_tolerance`` or truncation at ``max_steps``. Rendering --------- The ``render()`` method visualizes the Bloch sphere and the agent’s trajectory, showing the current state and target state as arrows in 3D. Input Parameters ---------- - **target_state**: Target pure state as a Numpy complex 2-vector, defaults to ``|+⟩``. - **max_steps**: Maximum number of steps per episode. - **reward_tolerance**: Fidelity threshold for successful termination. - **ffmpeg**: If set to True, animations are saved as mp4 videos, else as GIFs. Default is False. See Also -------- :doc:`tutorials/bloch_sphere` """ def __init__(self, target_state, max_steps=20, reward_tolerance=0.99, ffmpeg=False): super().__init__() self.max_steps = max_steps self.target_state = target_state self.state = np.array([1, 0], dtype=complex) # Initial State -> |0> self.writer = "ffmpeg" if ffmpeg else "pillow" self.render_extension = "mp4" if ffmpeg else "gif" if ffmpeg==True and shutil.which("ffmpeg") is None: raise ValueError("ffmpeg not found on system. Please install ffmpeg or set ffmpeg=False") # Bloch vector (x, y, z) self.observation_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32) # Discrete action space self.actions = ["H", "X", "Y", "Z", "S", "SDG", "T", "TDG", "RX_pi_2", "RX_pi_4", "RX_-pi_4", "RY_pi_2", "RY_pi_4", "RY_-pi_4", "RZ_pi_2", "RZ_pi_4", "RZ_-pi_4"] self.action_space = spaces.Discrete(len(self.actions)) self.reward_tolerance = reward_tolerance if self.reward_tolerance <= 0 or self.reward_tolerance > 1: raise ValueError("reward_tolerance must be in (0, 1]") self.history = [] self.steps = 0 def _state_to_bloch(self, state): """ Convert a single-qubit statevector to its Bloch vector representation. Parameters ---------- state : np.ndarray Complex 2-element statevector |ψ⟩ representing a pure qubit state. Returns ------- np.ndarray Bloch vector ``(x, y, z)`` as a float32 array of shape ``(3,)``, with each component in the range ``[-1, 1]``. """ rho = np.outer(state, np.conj(state)) x = 2*np.real(rho[0,1]) y = 2*np.imag(rho[1,0]) z = np.real(rho[0,0] - rho[1,1]) return np.array([x, y, z], dtype=np.float32)
[docs] def reset(self): """ Reset the environment to the initial state. The qubit is initialized to the computational basis state |0⟩. Episode step count and history are cleared. Returns ------- observation : np.ndarray Initial Bloch vector corresponding to |0⟩, shape ``(3,)``. info : dict Empty dictionary provided for compatibility with Gymnasium API. """ self.steps = 0 self.state = np.array([1, 0], dtype=complex) # |0> self.history = [(self._state_to_bloch(self.state),'None','None')] # Default target state (|+>) self.target = np.array([1/np.sqrt(2), 1/np.sqrt(2)], dtype=complex) return self._state_to_bloch(self.state), {}
[docs] def get_reward(self, action): """ Apply a quantum gate action and compute the resulting reward. This method evolves the internal qubit state by applying the unitary corresponding to the selected action and evaluates the fidelity with respect to the target state. Parameters ---------- action : int Index of the selected action in ``self.actions``. Returns ------- float Fidelity between the current state and the target state, defined as ``|⟨target | state⟩|²`` and bounded in ``[0, 1]``. """ gate = self.actions[action] if gate in GATES: U = GATES[gate] elif "RX" in gate: U = RX(eval(gate.split("_")[1].replace("pi", "np.pi"))) elif "RY" in gate: U = RY(eval(gate.split("_")[1].replace("pi", "np.pi"))) elif "RZ" in gate: U = RZ(eval(gate.split("_")[1].replace("pi", "np.pi"))) self.state = U @ self.state # evolve state reward = np.abs(np.vdot(self.target_state, self.state))**2 return reward
[docs] def step(self, action): """ Execute one environment step. Applies the selected quantum gate, updates the internal state and history, computes the reward, and checks termination conditions. Parameters ---------- action : int Index of the selected action in ``self.actions``. Returns ------- observation : np.ndarray Updated Bloch vector of the qubit state, shape ``(3,)``. reward : float Fidelity-based reward after applying the action. done : bool True if the episode has terminated due to success or truncation. info : dict Empty dictionary provided for compatibility with Gymnasium API. """ reward = self.get_reward(action) new_obs = self._state_to_bloch(self.state) gate = self.actions[action] self.history.append((new_obs, round(reward, 3), gate)) self.steps += 1 done = reward > self.reward_tolerance or self.steps >= self.max_steps return self._state_to_bloch(self.state), reward, done, {}
[docs] def render(self, save_path_without_extension=None, interval=800): """ Render the Bloch sphere trajectory as a 3D animation. The visualization shows: - A translucent Bloch sphere with labeled basis states, - The target Bloch vector (green, static), - The evolving qubit state trajectory (red, dynamic). Parameters ---------- save_path_without_extension : str or None, optional Path (without file extension) to save the animation. If provided, the animation is saved using the configured writer (MP4 for FFmpeg or GIF for Pillow). If None, the animation is displayed interactively. interval : int, optional Delay between animation frames in milliseconds. Default is 800. Returns ------- None This method produces a visualization but does not return a value. """ fig = plt.figure(figsize=(6,6)) ax = fig.add_subplot(111, projection='3d') ax.set_box_aspect([1,1,1]) ax.view_init(elev=20, azim=-60) # Add Qiskit-style Bloch sphere labels ax.text(0, 0, 1.1, r'$|0\rangle$', fontsize=12, color='black') ax.text(0, 0, -1.2, r'$|1\rangle$', fontsize=12, color='black') ax.text(0, 1.2, 0, r'$|+i\rangle$', fontsize=12, color='black') ax.text(0, -1.4, 0, r'$|-i\rangle$', fontsize=12, color='black') ax.text(1.2, 0, 0, r'$|+\rangle$', fontsize=12, color='black') ax.text(-1.4, 0, 0, r'$|-\rangle$', fontsize=12, color='black') # Sphere (draw once) u = np.linspace(0, 2*np.pi, 100) v = np.linspace(0, np.pi, 100) x = np.outer(np.cos(u), np.sin(v)) y = np.outer(np.sin(u), np.sin(v)) z = np.outer(np.ones_like(u), np.cos(v)) ax.plot_surface(x, y, z, color='lightblue', alpha=0.5, edgecolor='gray', linewidth=0.1) # Solid lines for X, Y and Z axes ax.plot([-1, 1], [0, 0], [0, 0], color="black", linewidth=1) ax.plot([0, 0], [-1, 1], [0, 0], color="black", linewidth=1) ax.plot([0, 0], [0, 0], [-1, 1], color="black", linewidth=1) # Solid Planes for XY, XZ, YZ phi = np.linspace(0, 2*np.pi, 200) ax.plot(np.cos(phi), np.sin(phi), 0, color="black", linewidth=0.8) ax.plot(np.cos(phi), 0*np.cos(phi), np.sin(phi), color="black", linewidth=0.8) ax.plot(0*np.cos(phi), np.cos(phi), np.sin(phi), color="black", linewidth=0.8) # Axes limits ax.set_xlim([-1,1]); ax.set_ylim([-1,1]); ax.set_zlim([-1,1]) ax.set_xticks([]) ax.set_yticks([]) ax.set_zticks([]) target_state = self._state_to_bloch(self.target) # Target arrow (static) target_arrow = ax.quiver(0, 0, 0, target_state[0], target_state[1], target_state[2], color='green', linewidth=2, label='Target') # Dynamic prediction arrow (update each frame) pred_arrow = ax.quiver(0, 0, 0, self.history[0][0][0], self.history[0][0][1], self.history[0][0][2], color='red', linewidth=2, label='Prediction') # Legend (only once) ax.legend() def update(frame): nonlocal pred_arrow # remove old arrow pred_arrow.remove() # draw new arrow pred_arrow = ax.quiver(0, 0, 0, self.history[frame][0][0], self.history[frame][0][1], self.history[frame][0][2], color='red', linewidth=2) ax.set_title(f"Step {frame} | Reward={self.history[frame][1]} | Gate={self.history[frame][2]}") ani = animation.FuncAnimation(fig, update, frames=len(self.history), interval=interval, repeat=False) if save_path_without_extension: ani.save(f"{save_path_without_extension}.{self.render_extension}", writer=self.writer) else: plt.show()
[docs] class BlochSphereV1(QuantumEnv): """ Single-qubit Bloch sphere environment as a graph problem for reinforcement learning. ``BlochSphereV1`` is a ``gymnasium.Env`` compatible environment where an agent controls a single qubit via a discrete set of quantum gates. The qubit state is exposed to the agent as an integer index corresponding to the discrete states |0⟩, |1⟩, |+⟩, |-⟩, |+i⟩, |-i⟩. The objective is to steer the qubit from the fixed starting initial state ``|0⟩`` to a user defined target pure state (default ``|+⟩``) within a limited number of steps by applying unitary gate actions. The environment is fully compatible with ``ValueIteration`` and ``QValueIteration`` from ``qrl.algorithms``. Key details -------------- - **Action space**: Discrete set of single-qubit gates (H,X,Z,S). - **Observation space**: Integer index corresponding to the Discrete states |0⟩, |1⟩, |+⟩, |-⟩, |+i⟩, |-i⟩. - **Reward**: Fidelity ``|⟨target | state⟩|²`` in ``[0, 1]``. - **Termination**: Success when reward exceeds ``reward_tolerance`` or truncation at ``max_steps``. Parameters ---------- - **target_state** : int, optional Target state index in [0, 5]. Defaults to 2 (|+⟩). The mapping is: 0 → |0⟩, 1 → |1⟩, 2 → |+⟩, 3 → |-⟩, 4 → |+i⟩, 5 → |-i⟩. - **max_steps** : int, optional Maximum number of steps per episode. Default is 10. - **reward_tolerance** : float, optional Fidelity threshold for successful termination. Must be in (0, 1]. Default is 0.99. - **ffmpeg** : bool, optional If True, animations are saved as MP4 via ffmpeg, else as GIFs. Default is False. Raises ------ - **ValueError** : If ``target_state`` is not in [0, 5]. - **ValueError** : If ``reward_tolerance`` is not in (0, 1]. - **ValueError** : If ``ffmpeg=True`` but ffmpeg is not installed on the system. """ def __init__( self, target_state: int = 2, max_steps: int = 10, reward_tolerance: float = 0.99, ffmpeg: bool = False, ) -> None: super().__init__() # define constants # self.STEP_PENALTY = 0.01 # self.SUCCESS_BONUS = 1.0 if not (0 <= target_state <= 5): raise ValueError("target_state must be an integer in [0, 5].") if not (0 < reward_tolerance <= 1): raise ValueError("reward_tolerance must be in (0, 1].") if ffmpeg and shutil.which("ffmpeg") is None: raise ValueError("ffmpeg not found. Install it or set ffmpeg=False.") self.target_state_index = target_state self.max_steps = max_steps self.reward_tolerance = reward_tolerance self.writer = "ffmpeg" if ffmpeg else "pillow" self.render_extension = "mp4" if ffmpeg else "gif" self.fig_array_list = [] self.observation_space = spaces.Discrete(6) self.action_space = spaces.Discrete(4) self._state_index: int = 0 self._statevector = STATE_VECTORS[0].copy() self.steps: int = 0 self.history: list[int] = [] # sequence of state indices self.terminated: bool | None = None self.truncated: bool | None = None
[docs] def reset(self, *, seed=None, options=None): """ Reset the environment to the initial state. The qubit is placed at state index 0 (|0⟩). Episode step count, history, and termination flags are cleared. Parameters ---------- seed : int or None, optional Random seed passed to the base ``gymnasium.Env`` reset. Default is None. options : dict or None, optional Additional options passed to the base reset. Default is None. Returns ------- observation : int Initial state index (always 0, corresponding to |0⟩). info : dict Dictionary containing ``fidelity``, ``gate`` (``"reset"``), and ``bloch_vector`` of the initial state. """ super().reset(seed=seed) self._state_index = 0 self._statevector = STATE_VECTORS[0].copy() self.steps = 0 self.history = [0] self.terminated = False self.truncated = False return 0, self._info()
[docs] def get_reward(self): """ Compute the reward for the current state and update termination flags. Evaluates the fidelity between the current statevector and the target statevector. Sets ``self.terminated`` if fidelity meets or exceeds ``reward_tolerance``, and ``self.truncated`` if the step limit is reached. Returns ------- float 1.0 if the current state matches the target within ``reward_tolerance``, 0.0 otherwise. """ self.terminated = self._fidelity() >= self.reward_tolerance self.truncated = self.steps >= self.max_steps return 1.0 if self.terminated else 0.0
[docs] def step(self, action: int): """ Apply a gate action and advance the episode by one step. Applies the unitary gate corresponding to ``action`` to the current statevector, updates the discrete state index via the transition table, increments the step counter, and appends the new state to history. Parameters ---------- action : int Index into ``ACTION_NAMES`` selecting the gate to apply. 0 → H, 1 → X, 2 → Z, 3 → S. Returns ------- observation : int New discrete state index after applying the gate. reward : float 1.0 if the target is reached within tolerance, 0.0 otherwise. terminated : bool True if fidelity ≥ ``reward_tolerance``. truncated : bool True if ``steps`` ≥ ``max_steps``. info : dict Dictionary containing ``fidelity``, ``gate`` name applied, and ``bloch_vector`` of the resulting state. """ gate_name = ACTION_NAMES[action] U = GATES[gate_name] self._statevector = U @ self._statevector self._state_index = int(_TRANSITIONS[self._state_index, action]) self.steps += 1 self.history.append(self._state_index) reward = self.get_reward() return self._state_index, float(reward), self.terminated, self.truncated, self._info(gate_name)
def _fidelity(self) -> float: """ Compute the fidelity between the current state and the target state. Fidelity is defined as ``|⟨target | state⟩|²``, bounded in ``[0, 1]``. A value of 1.0 indicates the current state is identical to the target. Returns ------- float Fidelity between the current statevector and the target statevector, in the range ``[0, 1]``. """ target_sv = STATE_VECTORS[self.target_state_index] return float(np.abs(np.vdot(target_sv, self._statevector)) ** 2) def _info(self, gate: str = "reset") -> dict: """ Construct the info dictionary for the current environment state. Parameters ---------- gate : str, optional Name of the gate most recently applied. Defaults to ``"reset"`` when called during environment initialization or reset. Returns ------- dict A dictionary with three keys: - ``fidelity`` : float — current fidelity with the target state. - ``gate`` : str — name of the gate applied in this step. - ``bloch_vector``: np.ndarray — Bloch vector ``(x, y, z)`` of the current state, shape ``(3,)``. """ return { "fidelity": self._fidelity(), "gate": gate, "state_index": self._state_index, } # reneder graph def _render_graph(self, agent=None, show_true_dynamics: bool = True) -> None: """ Draw the state-transition graph. When agent is provided, adds a second panel showing the agent's learned model: Left — true environment dynamics Right — agent panel: • Node color : learned state value V(s) or max_a Q(s,a) on a warm colormap (high = warm, low = cool) • Edge opacity: proportional to empirical visit count • Bold edges : greedy policy argmax_a Q(s,a) Ideally, the learnt target state should have minimal value function. Layout (both panels mirror the Bloch sphere): |0⟩ |-⟩ |+⟩ |-i⟩ |+i⟩ |1⟩ """ BG = "#1a1a1a" # figure / axes background FG = "#e0e0e0" # titles, node labels, generic text EDGE_FG = "#444440" # base edge color on dark bg LBL_FG = "#999994" # edge label color (base graph) has_agent = agent is not None if not has_agent: raise ValueError("No agent provided.") n_panels = (1 if show_true_dynamics else 0) + (1 if has_agent else 0)# number of panels: 1 for true dynamics, 1 for agent1 # number of panels: 1 for true dynamics, 1 for agent fig_w = 10 * n_panels fig, axes = plt.subplots(1, n_panels, figsize=(fig_w, 8), facecolor=BG) # figure and axes if n_panels == 1: axes = [axes] ax_true = axes[0] if show_true_dynamics else None ax_agent = axes[-1] if has_agent else None for _ax in axes: _ax.set_facecolor(BG) # set background color for each axis # shared graph structure of true environment dynamics and agent G = nx.MultiDiGraph() # MultiDiGraph is a directed graph with multiple edges between two nodes G.add_nodes_from(range(6)) # add nodes to the graph edge_gates: dict[tuple[int, int], list[str]] = {} for s in range(6): for a, name in enumerate(ACTION_NAMES): dst = int(_TRANSITIONS[s, a]) edge_gates.setdefault((s, dst), []).append(name) for (s, d), gnames in edge_gates.items(): G.add_edge(s, d, label=",".join(gnames)) pos = _GRAPH_POS # arc midpoint helper to place text labels at the true arc midpoint for each edge _graph_center = np.mean(np.array(list(pos.values())), axis=0) def _arc_midpoint(src, dst, rad=0.10): """ Return the (x, y) midpoint of the arc networkx draws. Regular edges: quadratic Bezier B(t=0.5) with control point offset perpendicularly from the edge midpoint by rad × |P2-P0|. Self-loops (src == dst): networkx draws a small tangent loop. We offset the label outward from the graph center so it sits on top of the visible loop, not on the node itself. """ p0 = np.array(pos[src]) p2 = np.array(pos[dst]) # self-loop if src == dst: outward = p0 - _graph_center norm = np.linalg.norm(outward) outward = outward / norm if norm > 1e-9 else np.array([0.0, 1.0]) return p0 + outward * 0.52 # regular arc mid = (p0 + p2) / 2.0 diff = p2 - p0 length = np.linalg.norm(diff) perp = np.array([-diff[1], diff[0]]) / length pc = mid + rad * length * perp return 0.25 * p0 + 0.5 * pc + 0.25 * p2 def _label_edges(ax, edge_label_map, rad=0.10, font_size=7, font_color=LBL_FG): """ Place text labels at the true arc midpoint for each edge. edge_label_map: dict with keys that are either (src, dst) pairs or (src, dst, key) triples (networkx MultiDiGraph convention). """ for edge_key, label in edge_label_map.items(): src, dst = edge_key[0], edge_key[1] mx, my = _arc_midpoint(src, dst, rad) ax.text( mx, my, label, fontsize=font_size, color=font_color, ha="center", va="center", bbox=dict(boxstyle="round,pad=0.15", fc=BG, ec="none", alpha=0.8), ) # helper: draw one graph panel def _draw_base(ax, node_colors, title): ax.set_aspect("equal") ax.axis("off") ax.set_title(title, fontsize=11, pad=10, color=FG) nx.draw_networkx_nodes( G, pos, ax=ax, node_size=1600, node_color=node_colors, edgecolors="#888780", linewidths=1.2, ) nx.draw_networkx_labels( G, pos, ax=ax, labels=STATE_LABELS, font_size=10, font_color="#111111", ) nx.draw_networkx_edges( G, pos, ax=ax, edge_color=EDGE_FG, width=0.8, arrowsize=12, arrowstyle="-|>", connectionstyle="arc3,rad=0.10", min_source_margin=30, min_target_margin=30, ) _label_edges( ax, edge_label_map=nx.get_edge_attributes(G, "label"), font_size=7, font_color="#888780", ) # LEFT PANEL: true dynamics (optional) if show_true_dynamics and ax_true is not None: true_colors = [] for n in range(6): if n == self._state_index and n == self.target_state_index: true_colors.append("#F5C518") elif n == self._state_index: true_colors.append("#E24B4A") elif n == self.target_state_index: true_colors.append("#3B8BD4") else: true_colors.append("#D3D1C7") _draw_base(ax_true, true_colors, "True environment dynamics") # trajectory overlay if len(self.history) > 1: traj_edges = list(zip(self.history[:-1], self.history[1:])) unique_traj = list(dict.fromkeys(traj_edges)) nx.draw_networkx_edges( G, pos, ax=ax_true, edgelist=unique_traj, edge_color="#BA7517", width=2.8, arrowsize=18, arrowstyle="-|>", connectionstyle="arc3,rad=0.10", min_source_margin=30, min_target_margin=30, ) # left legend left_legend = [ mpatches.Patch(facecolor="#E24B4A", edgecolor="#888780", label="Current"), mpatches.Patch(facecolor="#3B8BD4", edgecolor="#888780", label="Target"), mpatches.Patch(facecolor="#F5C518", edgecolor="#888780", label="Current = target"), mpatches.Patch(facecolor="#D3D1C7", edgecolor="#888780", label="Other"), ] if len(self.history) > 1: left_legend.append(mpatches.Patch(facecolor="#BA7517", label="Trajectory")) ax_true.legend(handles=left_legend, loc="lower right", fontsize=8, framealpha=0.85, facecolor="#2a2a2a", edgecolor="#555550", labelcolor=FG) # RIGHT PANEL: agent's learned model if has_agent: # for class ValueIteration from qrl.algorithms.classical if hasattr(agent, "_V"): values = agent._V.cpu().numpy().astype(float) agent_type = "VI" # for class QValueIteration from qrl.algorithms.classical elif hasattr(agent, "_Q"): values = agent._Q.max(dim=1).values.cpu().numpy().astype(float) agent_type = "QI" else: warnings.warn("No value function found in agent, using default values.", UserWarning,stacklevel=2) values = np.zeros(6) agent_type = "VI" counts_sa = agent._counts.sum(dim=2).cpu().numpy() total_steps = int(counts_sa.sum()) try: policy = agent.get_policy().cpu().numpy() except Exception: policy = np.zeros(6, dtype=int) panel_title = f"Agent's learned model ({total_steps} total steps)" # node colors from value function cmap = cm.get_cmap("YlOrRd") v_min, v_max = values.min(), values.max() v_range = v_max - v_min if v_max > v_min else 1.0 agent_node_colors = [] for n in range(6): if n == self.target_state_index: agent_node_colors.append("#3B8BD4") # target always blue else: norm_v = (values[n] - v_min) / v_range agent_node_colors.append(cmap(norm_v)) # color the nodes based on the value function # draw base (gray edges, value-colored nodes) ax_agent.set_aspect("equal") ax_agent.axis("off") ax_agent.set_title(panel_title, fontsize=9, pad=10, color=FG) nx.draw_networkx_nodes( G, pos, ax=ax_agent, node_size=1600, node_color=agent_node_colors, edgecolors="#888780", linewidths=1.2, ) # Value annotations below each node label val_sym = "V" if agent_type == "VI" else "Q" value_labels = { n: f"{STATE_LABELS[n]}\n{val_sym}={values[n]:.2f}" for n in range(6) } nx.draw_networkx_labels( G, pos, ax=ax_agent, labels=value_labels, font_size=8, font_color="#111111", ) # draw edges with opacity from visit count max_count = counts_sa.max() if counts_sa.max() > 0 else 1.0 greedy_edges, explored_edges, unexplored_edges = [], [], [] for s in range(6): for a, name in enumerate(ACTION_NAMES): dst = int(_TRANSITIONS[s, a]) count = counts_sa[s, a] edge = (s, dst) if a == policy[s]: greedy_edges.append((edge, name, count)) elif count > 0: explored_edges.append((edge, count / max_count)) else: unexplored_edges.append(edge) # Unexplored: very faint if unexplored_edges: nx.draw_networkx_edges( G, pos, ax=ax_agent, edgelist=unexplored_edges, edge_color="#666660", alpha=0.15, width=0.5, arrowsize=8, arrowstyle="-|>", connectionstyle="arc3,rad=0.10", min_source_margin=30, min_target_margin=30, ) # Explored non-greedy: alpha proportional to visit count # Draw in buckets by alpha level alpha_buckets: dict[float, list] = {} for edge, norm_count in explored_edges: alpha = round(0.2 + 0.6 * norm_count, 1) alpha_buckets.setdefault(alpha, []).append(edge) for alpha, edgelist in alpha_buckets.items(): nx.draw_networkx_edges( G, pos, ax=ax_agent, edgelist=edgelist, edge_color="#aaaaaa", alpha=alpha, width=1.0, arrowsize=12, arrowstyle="-|>", connectionstyle="arc3,rad=0.10", min_source_margin=30, min_target_margin=30, ) # Greedy policy: bold teal, labeled with gate name greedy_edgelist = [e for e, _, _ in greedy_edges] greedy_labels = {e: name for e, name, _ in greedy_edges} if greedy_edgelist: nx.draw_networkx_edges( G, pos, ax=ax_agent, edgelist=greedy_edgelist, edge_color="#0F6E56", width=3.0, arrowsize=20, arrowstyle="-|>", connectionstyle="arc3,rad=0.10", min_source_margin=30, min_target_margin=30, ) _label_edges( ax_agent, greedy_labels, font_size=8, font_color="#5DCAA5", ) # colorbar for value function # Label: show which quantity is actually displayed value_label = "V(s)" if agent_type == "VI" else "max_a Q(s,a)" # Ticks: min and max only, annotated with the owning state name min_state = int(np.argmin(values)) max_state = int(np.argmax(values)) def _plain(lbl): return lbl.replace("$","").replace("\\rangle",">").replace("{","").replace("}","") min_tick = f"{_plain(STATE_LABELS[min_state])} {v_min:.2f}" max_tick = f"{_plain(STATE_LABELS[max_state])} {v_max:.2f}" sm = cm.ScalarMappable( cmap=cmap, norm=mcolors.Normalize(vmin=v_min, vmax=v_max), ) sm.set_array([]) cbar = fig.colorbar(sm, ax=ax_agent, fraction=0.03, pad=0.04) cbar.set_label(value_label, fontsize=8, color=FG) cbar.set_ticks([v_min, v_max]) cbar.set_ticklabels([min_tick, max_tick]) cbar.ax.tick_params(labelsize=7, colors=FG) cbar.ax.yaxis.set_tick_params(color=FG) plt.setp(cbar.ax.yaxis.get_ticklabels(), color=FG) # right legend right_legend = [ mpatches.Patch(facecolor="#3B8BD4", edgecolor="#888780", label="Target state"), mlines.Line2D([], [], color="#0F6E56", linewidth=2.5, label="Greedy policy"), mlines.Line2D([], [], color="#888780", linewidth=1.0, label="Explored (alpha ∝ visits)"), mlines.Line2D([], [], color="#D3D1C7", linewidth=0.5, label="Unexplored"), ] ax_agent.legend(handles=right_legend, loc="lower right", fontsize=8, framealpha=0.85, facecolor="#2a2a2a", edgecolor="#555550", labelcolor=FG) fig.tight_layout() fig_array = self._fig_to_array(fig) self.fig_array_list.append(fig_array) plt.close(fig) # render: learning animation def _fig_to_array(self, fig): """ Convert a Matplotlib figure to a NumPy RGB array. Rasterizes the figure canvas and reshapes the resulting pixel buffer into a ``(H, W, 3)`` uint8 array suitable for use as an animation frame. Parameters ---------- fig : matplotlib.figure.Figure The figure to convert. Returns ------- np.ndarray RGB pixel array of shape ``(height, width, 3)``, dtype ``uint8``. """ fig.canvas.draw() w, h = fig.canvas.get_width_height() return np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(h, w, 3)
[docs] def render(self, save_path_without_extension, interval=600, ffmpeg=False): """ Render accumulated graph frames as an animation and save to disk. Assembles the list of graph snapshots captured by ``_render_graph()`` into a single animation. Each frame corresponds to one call to ``_render_graph()``, producing a visual record of the agent's learning progression over episodes. Parameters ---------- save_path_without_extension : str File path (without extension) where the animation will be saved. The appropriate extension (``.mp4`` or ``.gif``) is appended automatically based on the ``ffmpeg`` argument. interval : int, optional Delay between frames in milliseconds. Default is 600. ffmpeg : bool, optional If True, saves the animation as an MP4 using ffmpeg. If False, saves as a GIF using Pillow. Default is False. Raises ------ ValueError If ``_render_graph()`` has not been called and no frames are available. Returns ------- None This method produces an animation file but does not return a value. """ if not self.fig_array_list: raise ValueError("No frames to render. Call _render_graph() first.") # Match figure size exactly to the frame pixel dimensions h, w = self.fig_array_list[0].shape[:2] dpi = 100 fig, ax = plt.subplots(figsize=(w / dpi, h / dpi), dpi=dpi) fig.subplots_adjust(left=0, right=1, top=1, bottom=0) # remove all padding ax.axis("off") im = ax.imshow(self.fig_array_list[0]) def _update(i): im.set_data(self.fig_array_list[i]) return [im] ani = animation.FuncAnimation( fig, _update, frames=len(self.fig_array_list), interval=interval, blit=True, ) ext = "mp4" if ffmpeg else "gif" ani.save( f"{save_path_without_extension}.{ext}", writer="ffmpeg" if ffmpeg else "pillow", dpi=dpi, ) plt.close(fig)
@property def state_index(self) -> int: """Current state index (0-5).""" return self._state_index @property def bloch_vector(self) -> np.ndarray: """Current Bloch vector (x, y, z).""" return STATE_BLOCH[self._state_index].copy()
[docs] @staticmethod def transition_table() -> np.ndarray: """ Return the deterministic state-transition table for the environment. Each entry ``T[s, a]`` gives the next state index when action ``a`` is taken from state ``s``. Rows correspond to the 6 Bloch sphere states and columns to the 4 gate actions (H, X, Z, S). Returns ------- np.ndarray Integer array of shape ``(6, 4)`` where ``T[s, a] = s'``. """ return _TRANSITIONS.copy()
def __repr__(self) -> str: """ Return a concise string representation of the environment. Displays the current state label, target state label, and the step count relative to the maximum allowed steps. Returns ------- str String of the form ``BlochSphereV1(state=<label>, target=<label>, steps=<n>/<max>)``. """ return ( f"BlochSphereV1(" f"state={STATE_LABELS[self._state_index]}, " f"target={STATE_LABELS[self.target_state_index]}, " f"steps={self.steps}/{self.max_steps})" )