'''
Implementation of BlochSphereV0 environment
Author: Jay Shah (@Jayshah25)
License: Apache-2.0
'''
from gymnasium import spaces
from pennylane import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import shutil
from .base__ import QuantumEnv
from .utils import GATES, RX, RY, RZ
[docs]
class BlochSphereV0(QuantumEnv):
"""
Single-qubit Bloch sphere environment for reinforcement learning.
``BlochSphereV0`` is a ``gymnasium.Env``-compatible environment where an agent
controls a single qubit via a discrete set of quantum gates. The qubit state is
represented internally as a statevector and exposed to the agent as a 3D Bloch
vector ``(x, y, z)``.
The objective is to steer the qubit from the fixed initial state ``|0⟩`` to a
target pure state (default ``|+⟩``) within a limited number of steps by applying
unitary gate actions.
Key details
--------------
- **Action space**: Discrete set of single-qubit gates (Clifford + common rotations).
- **Observation space**: Bloch vector ``(x, y, z)``, each component in ``[-1, 1]``.
- **Reward**: Fidelity ``|⟨target | state⟩|²`` in ``[0, 1]``.
- **Termination**: Success when reward exceeds ``reward_tolerance`` or truncation
at ``max_steps``.
Rendering
---------
The ``render()`` method visualizes the Bloch sphere and the agent’s trajectory,
showing the current state and target state as arrows in 3D.
Input Parameters
----------
- **target_state**: Target pure state as a Numpy complex 2-vector, defaults to ``|+⟩``.
- **max_steps**: Maximum number of steps per episode.
- **reward_tolerance**: Fidelity threshold for successful termination.
- **ffmpeg**: If set to True, animations are saved as mp4 videos, else as GIFs. Default is False.
See Also
--------
:doc:`tutorials/bloch_sphere`
"""
def __init__(self, target_state, max_steps=20, reward_tolerance=0.99, ffmpeg=False):
super().__init__()
self.max_steps = max_steps
self.target_state = target_state
self.state = np.array([1, 0], dtype=complex) # Initial State -> |0>
self.writer = "ffmpeg" if ffmpeg else "pillow"
self.render_extension = "mp4" if ffmpeg else "gif"
if ffmpeg==True and shutil.which("ffmpeg") is None:
raise ValueError("ffmpeg not found on system. Please install ffmpeg or set ffmpeg=False")
# Bloch vector (x, y, z)
self.observation_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)
# Discrete action space
self.actions = ["H", "X", "Y", "Z", "S", "SDG", "T", "TDG",
"RX_pi_2", "RX_pi_4", "RX_-pi_4",
"RY_pi_2", "RY_pi_4", "RY_-pi_4",
"RZ_pi_2", "RZ_pi_4", "RZ_-pi_4"]
self.action_space = spaces.Discrete(len(self.actions))
self.reward_tolerance = reward_tolerance
if self.reward_tolerance <= 0 or self.reward_tolerance > 1:
raise ValueError("reward_tolerance must be in (0, 1]")
self.history = []
self.steps = 0
def _state_to_bloch(self, state):
"""
Convert a single-qubit statevector to its Bloch vector representation.
Parameters
----------
state : np.ndarray
Complex 2-element statevector |ψ⟩ representing a pure qubit state.
Returns
-------
np.ndarray
Bloch vector ``(x, y, z)`` as a float32 array of shape ``(3,)``,
with each component in the range ``[-1, 1]``.
"""
rho = np.outer(state, np.conj(state))
x = 2*np.real(rho[0,1])
y = 2*np.imag(rho[1,0])
z = np.real(rho[0,0] - rho[1,1])
return np.array([x, y, z], dtype=np.float32)
[docs]
def reset(self):
"""
Reset the environment to the initial state.
The qubit is initialized to the computational basis state |0⟩.
Episode step count and history are cleared.
Returns
-------
observation : np.ndarray
Initial Bloch vector corresponding to |0⟩, shape ``(3,)``.
info : dict
Empty dictionary provided for compatibility with Gymnasium API.
"""
self.steps = 0
self.state = np.array([1, 0], dtype=complex) # |0>
self.history = [(self._state_to_bloch(self.state),'None','None')]
# Default target state (|+>)
self.target = np.array([1/np.sqrt(2), 1/np.sqrt(2)], dtype=complex)
return self._state_to_bloch(self.state), {}
[docs]
def get_reward(self, action):
"""
Apply a quantum gate action and compute the resulting reward.
This method evolves the internal qubit state by applying the unitary
corresponding to the selected action and evaluates the fidelity with
respect to the target state.
Parameters
----------
action : int
Index of the selected action in ``self.actions``.
Returns
-------
float
Fidelity between the current state and the target state, defined as
``|⟨target | state⟩|²`` and bounded in ``[0, 1]``.
"""
gate = self.actions[action]
if gate in GATES:
U = GATES[gate]
elif "RX" in gate:
U = RX(eval(gate.split("_")[1].replace("pi", "np.pi")))
elif "RY" in gate:
U = RY(eval(gate.split("_")[1].replace("pi", "np.pi")))
elif "RZ" in gate:
U = RZ(eval(gate.split("_")[1].replace("pi", "np.pi")))
self.state = U @ self.state # evolve state
reward = np.abs(np.vdot(self.target_state, self.state))**2
return reward
[docs]
def step(self, action):
"""
Execute one environment step.
Applies the selected quantum gate, updates the internal state and history,
computes the reward, and checks termination conditions.
Parameters
----------
action : int
Index of the selected action in ``self.actions``.
Returns
-------
observation : np.ndarray
Updated Bloch vector of the qubit state, shape ``(3,)``.
reward : float
Fidelity-based reward after applying the action.
done : bool
True if the episode has terminated due to success or truncation.
info : dict
Empty dictionary provided for compatibility with Gymnasium API.
"""
reward = self.get_reward(action)
new_obs = self._state_to_bloch(self.state)
gate = self.actions[action]
self.history.append((new_obs, round(reward, 3), gate))
self.steps += 1
done = reward > self.reward_tolerance or self.steps >= self.max_steps
return self._state_to_bloch(self.state), reward, done, {}
[docs]
def render(self, save_path_without_extension=None, interval=800):
"""
Render the Bloch sphere trajectory as a 3D animation.
The visualization shows:
- A translucent Bloch sphere with labeled basis states,
- The target Bloch vector (green, static),
- The evolving qubit state trajectory (red, dynamic).
Parameters
----------
save_path_without_extension : str or None, optional
Path (without file extension) to save the animation.
If provided, the animation is saved using the configured writer
(MP4 for FFmpeg or GIF for Pillow). If None, the animation is
displayed interactively.
interval : int, optional
Delay between animation frames in milliseconds. Default is 800.
Returns
-------
None
This method produces a visualization but does not return a value.
"""
fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111, projection='3d')
ax.set_box_aspect([1,1,1])
ax.view_init(elev=20, azim=-60)
# Add Qiskit-style Bloch sphere labels
ax.text(0, 0, 1.1, r'$|0\rangle$', fontsize=12, color='black')
ax.text(0, 0, -1.2, r'$|1\rangle$', fontsize=12, color='black')
ax.text(0, 1.2, 0, r'$|+i\rangle$', fontsize=12, color='black')
ax.text(0, -1.4, 0, r'$|-i\rangle$', fontsize=12, color='black')
ax.text(1.2, 0, 0, r'$|+\rangle$', fontsize=12, color='black')
ax.text(-1.4, 0, 0, r'$|-\rangle$', fontsize=12, color='black')
# Sphere (draw once)
u = np.linspace(0, 2*np.pi, 100)
v = np.linspace(0, np.pi, 100)
x = np.outer(np.cos(u), np.sin(v))
y = np.outer(np.sin(u), np.sin(v))
z = np.outer(np.ones_like(u), np.cos(v))
ax.plot_surface(x, y, z, color='lightblue', alpha=0.5, edgecolor='gray', linewidth=0.1)
# Solid lines for X, Y and Z axes
ax.plot([-1, 1], [0, 0], [0, 0], color="black", linewidth=1)
ax.plot([0, 0], [-1, 1], [0, 0], color="black", linewidth=1)
ax.plot([0, 0], [0, 0], [-1, 1], color="black", linewidth=1)
# Solid Planes for XY, XZ, YZ
phi = np.linspace(0, 2*np.pi, 200)
ax.plot(np.cos(phi), np.sin(phi), 0, color="black", linewidth=0.8)
ax.plot(np.cos(phi), 0*np.cos(phi), np.sin(phi), color="black", linewidth=0.8)
ax.plot(0*np.cos(phi), np.cos(phi), np.sin(phi), color="black", linewidth=0.8)
# Axes limits
ax.set_xlim([-1,1]); ax.set_ylim([-1,1]); ax.set_zlim([-1,1])
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
target_state = self._state_to_bloch(self.target)
# Target arrow (static)
target_arrow = ax.quiver(0, 0, 0, target_state[0], target_state[1], target_state[2],
color='green', linewidth=2, label='Target')
# Dynamic prediction arrow (update each frame)
pred_arrow = ax.quiver(0, 0, 0, self.history[0][0][0], self.history[0][0][1], self.history[0][0][2],
color='red', linewidth=2, label='Prediction')
# Legend (only once)
ax.legend()
def update(frame):
nonlocal pred_arrow
# remove old arrow
pred_arrow.remove()
# draw new arrow
pred_arrow = ax.quiver(0, 0, 0, self.history[frame][0][0], self.history[frame][0][1], self.history[frame][0][2],
color='red', linewidth=2)
ax.set_title(f"Step {frame} | Reward={self.history[frame][1]} | Gate={self.history[frame][2]}")
ani = animation.FuncAnimation(fig, update, frames=len(self.history), interval=interval, repeat=False)
if save_path_without_extension:
ani.save(f"{save_path_without_extension}.{self.render_extension}", writer=self.writer)
else:
plt.show()