Source code for qrl.algorithms.classical.qvalue_iteration

'''
Implementation of Q-Value Iteration for discrete MDPs as QValueIteration class using PyTorch

Author: Jay Shah (@Jayshah25)

Contact: jay.shah@qrlqai.com

License: Apache-2.0
'''

from typing import Optional, Union

import gymnasium as gym
import torch

from .._base import BaseIteration


[docs] class QValueIteration(BaseIteration): """ Q-Value Iteration for tabular, model-based RL over discrete MDPs. Maintains a state-action value function Q(s,a) and applies the Bellman optimality operator until convergence: Q[s,a] <- Σ_s' P(s'|s,a) · (R(s,a,s') + γ · max_a' Q(s',a')) Action selection reads directly from Q with no recomputation. V(s) = max_a Q(s,a) is available as a derived property. Compared to ValueIteration: - Q(s,a) is stored persistently rather than computed transiently. - Action selection is O(n_actions) per state rather than O(n_actions × n_states). - Q(s,a) is the natural precursor to Q-learning and function approximation (e.g. DQN, quantum RL agents), making this the more forward-compatible choice. Parameters ---------- env : gym.Env Gymnasium or qrl-qai environment with discrete observation and action spaces. gamma : float Discount factor in [0, 1). num_test_episodes : int Informational; used by external training loops for evaluation. device : torch.device, optional Defaults to CUDA if available, else CPU. dtype : torch.dtype, optional Defaults to float32. """ def __init__( self, env: gym.Env, gamma: float = 0.9, num_test_episodes: int = 20, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = torch.float32, ) -> None: super().__init__(env, gamma, num_test_episodes, device, dtype) # Primary stored quantity: Q[s, a], shape (n_states, n_actions) self._Q = torch.zeros( (self.n_states, self.n_actions), dtype=self.dtype, device=self.device ) # ── planning ──────────────────────────────────────────────────────────────
[docs] def qvalue_iteration( self, max_iters: Optional[int] = None, tol: float = 1e-6, ) -> int: """ Run Q-Value Iteration to convergence (or max_iters). Parameters ---------- max_iters : int, optional Hard cap on Bellman updates. Runs until |Q_new - Q|_inf < tol if None. tol : float Convergence threshold on the sup-norm of the Q-value change. Returns ------- int Number of iterations performed. """ P = self._get_P() # (n_s, n_a, n_s) R = self._get_R() # (n_s, n_a, n_s) Q = self._Q for i in range(max_iters or int(1e9)): # V(s') = max_a' Q(s', a') — implicit value, not stored separately V_from_Q = Q.max(dim=1).values # (n_s,) # Q[s,a] = Σ_s' P[s,a,s'] · (R[s,a,s'] + γ · V(s')) V_expand = V_from_Q.unsqueeze(0).unsqueeze(0) # (1, 1, n_s) Q_new = (P * (R + self.gamma * V_expand)).sum(dim=2) # (n_s, n_a) diff = (Q_new - Q).abs().max().item() Q = Q_new if diff < tol: self._Q = Q return i + 1 self._Q = Q return max_iters
# ── action selection ──────────────────────────────────────────────────────
[docs] def select_action(self, state: Union[int, torch.Tensor]) -> int: """ Greedy action: argmax_a Q(s, a). Direct table lookup — no recomputation. Parameters ---------- state : int or 0-d Tensor Current state index. Returns ------- int Greedy action. """ if isinstance(state, torch.Tensor): state = int(state.item()) return int(self._Q[state].argmax().item())
# ── inspection ──────────────────────────────────────────────────────────── @property def Q(self) -> torch.Tensor: """Current state-action value function, shape (n_states, n_actions).""" return self._Q @property def V(self) -> torch.Tensor: """ State-value function derived from Q: V(s) = max_a Q(s,a). Shape (n_states,). Not stored — recomputed on access. """ return self._Q.max(dim=1).values
[docs] def get_policy(self) -> torch.Tensor: """ Greedy policy: pi[s] = argmax_a Q(s,a). Returns ------- torch.Tensor Long tensor of shape (n_states,). """ return self._Q.argmax(dim=1)