diff --git a/src/vmc/models/peps.py b/src/vmc/models/peps.py index 1ea46ef..fec795e 100644 --- a/src/vmc/models/peps.py +++ b/src/vmc/models/peps.py @@ -9,7 +9,6 @@ import abc import functools -import logging from typing import TYPE_CHECKING, Any import jax @@ -18,6 +17,7 @@ from plum import dispatch +from vmc.utils.factorizations import _qr_compactwy from vmc.utils.utils import random_tensor, spin_to_occupancy if TYPE_CHECKING: @@ -51,7 +51,6 @@ "_forward_with_cache", ] -logger = logging.getLogger(__name__) # ============================================================================= @@ -136,8 +135,13 @@ def _contract_theta(m: jax.Array, w: jax.Array, carry: jax.Array | None) -> tupl (left_dim, phys_dim, Dr, wr). """ if carry is not None: - tmp = jnp.einsum("kdl,dpr->prkl", carry, m) - theta = jnp.einsum("prkl,lwpq->kqrw", tmp, w) + theta = jnp.einsum( + "kdl,dpr,lwpq->kqrw", + carry, + m, + w, + optimize=[(0, 1), (0, 1)], + ) left_dim, phys_dim, Dr, wr = theta.shape return theta, left_dim, phys_dim, Dr, wr theta = jnp.einsum("dpr,lwpq->dlqrw", m, w) @@ -230,7 +234,6 @@ def _apply_mpo_variational( This avoids SVD entirely, using only QR for canonical form maintenance. """ - n_sites = len(mps) dtype = mps[0].dtype Dc = truncate_bond_dimension @@ -278,7 +281,7 @@ def _init_compressed_mps(mps: tuple, mpo: tuple, Dc: int) -> list[jax.Array]: mat = theta.reshape(left_dim * phys_dim, Dr * wr) # QR decomposition - Q, R = jax.lax.linalg.qr(mat, full_matrices=False) + Q, R = _qr_compactwy(mat) # Truncate to Dc columns k = min(Dc, Q.shape[1]) @@ -370,7 +373,7 @@ def _variational_sweep_lr( # QR decompose: optimal = A @ C (Paeckel Sec 2.5) left_dim, phys_dim, right_dim = optimal.shape mat = optimal.reshape(left_dim * phys_dim, right_dim) - A, C = jax.lax.linalg.qr(mat, full_matrices=False) + A, C = _qr_compactwy(mat) new_tensor = A.reshape(left_dim, phys_dim, A.shape[1]) new_result[i] = new_tensor @@ -438,7 +441,7 @@ def _variational_sweep_rl( # LQ via QR on transpose: A = C @ B, A.T = B.T @ C.T left_dim, phys_dim, right_dim = optimal.shape mat = optimal.reshape(left_dim, phys_dim * right_dim) - Q_t, R_t = jax.lax.linalg.qr(mat.T, full_matrices=False) + Q_t, R_t = _qr_compactwy(mat.T) # B = Q_t.T has orthonormal rows, C = R_t.T B = Q_t.T # (k, phys * right) C = R_t.T # (left, k) @@ -486,18 +489,17 @@ def _forward_with_cache( shape: tuple[int, int], strategy: ContractionStrategy, ) -> tuple[jax.Array, list[tuple]]: - """Forward pass that caches all intermediate boundary MPSs.""" + """Forward pass that caches the top boundary before each row.""" n_rows, n_cols = shape dtype = jnp.asarray(tensors[0][0]).dtype - top_envs = [] + top_envs = [None] * n_rows boundary = tuple(jnp.ones((1, 1, 1), dtype=dtype) for _ in range(n_cols)) - top_envs.append(boundary) for row in range(n_rows): + top_envs[row] = boundary mpo = _build_row_mpo(tensors, spins[row], row, n_cols) boundary = strategy.apply(boundary, mpo) - top_envs.append(boundary) return _contract_bottom(boundary), top_envs @@ -637,15 +639,15 @@ def _compute_all_env_grads_and_energy( amp: jax.Array, shape: tuple[int, int], strategy: ContractionStrategy, - bottom_envs: list[tuple], + top_envs: list[tuple], *, diagonal_terms: list, one_site_terms: list[list[list]], horizontal_terms: list[list[list]], vertical_terms: list[list[list]], collect_grads: bool = True, -) -> tuple[list[list[jax.Array]], jax.Array]: - """Compute gradients and local energy for a PEPS sample.""" +) -> tuple[list[list[jax.Array]], jax.Array, list[tuple]]: + """Backward pass: use cached top_envs, build and cache bottom_envs.""" n_rows, n_cols = shape dtype = jnp.asarray(tensors[0][0]).dtype phys_dim = int(jnp.asarray(tensors[0][0]).shape[0]) @@ -655,17 +657,23 @@ def _compute_all_env_grads_and_energy( if collect_grads else [] ) + bottom_envs_cache = [None] * n_rows energy = jnp.zeros((), dtype=amp.dtype) + + # Diagonal terms for term in diagonal_terms: idx = jnp.asarray(0, dtype=jnp.int32) for row, col in term.sites: idx = idx * phys_dim + spins[row, col] energy = energy + term.diag[idx] - top_env = tuple(jnp.ones((1, 1, 1), dtype=dtype) for _ in range(n_cols)) - mpo = _build_row_mpo(tensors, spins[0], 0, n_cols) - for row in range(n_rows): - bottom_env = bottom_envs[row] + # Backward pass: bottom → top + bottom_env = tuple(jnp.ones((1, 1, 1), dtype=dtype) for _ in range(n_cols)) + next_row_mpo = None + for row in range(n_rows - 1, -1, -1): + bottom_envs_cache[row] = bottom_env + top_env = top_envs[row] + mpo = _build_row_mpo(tensors, spins[row], row, n_cols) right_envs = _compute_right_envs(top_env, mpo, bottom_env, dtype) left_env = jnp.ones((1, 1, 1), dtype=dtype) for c in range(n_cols): @@ -706,19 +714,18 @@ def _compute_all_env_grads_and_energy( amps_flat = amps_edge.reshape(-1) for term in edge_terms: energy = energy + jnp.dot(term.op[:, col_idx], amps_flat) / amp - # Direct einsum for left_env update left_env = jnp.einsum( "ace,aub,cduv,evf->bdf", left_env, top_env[c], mpo[c], bottom_env[c], optimize=[(0, 1), (0, 2), (0, 1)], ) + # Vertical energy between row and row+1 if row < n_rows - 1: - mpo_next = _build_row_mpo(tensors, spins[row + 1], row + 1, n_cols) energy = energy + _compute_row_pair_vertical_energy( top_env, - bottom_envs[row + 1], + bottom_envs_cache[row + 1], mpo, - mpo_next, + next_row_mpo, tensors[row], tensors[row + 1], spins[row], @@ -727,11 +734,10 @@ def _compute_all_env_grads_and_energy( amp, phys_dim, ) - top_env = strategy.apply(top_env, mpo) - if row < n_rows - 1: - mpo = mpo_next + bottom_env = _apply_mpo_from_below(bottom_env, mpo, strategy) + next_row_mpo = mpo - return env_grads, energy + return env_grads, energy, bottom_envs_cache def _compute_2site_horizontal_env( @@ -752,12 +758,16 @@ def _compute_2site_horizontal_env( Returns tensor with shape (up0, down0, mL, up1, down1, mR). """ - # Contract left side: left_env (a,c,e) @ top0 (a,u,b) @ bot0 (e,d,f) -> (c,u,b,d,f) - tmp_left = jnp.einsum("ace,aub,edf->cubdf", left_env, top0, bot0, optimize=[(0, 1), (0, 1)]) - # Contract right side: top1 (b,v,g) @ right_env (g,h,i) @ bot1 (f,w,i) -> (b,v,h,f,w) - tmp_right = jnp.einsum("bvg,ghi,fwi->bvhfw", top1, right_env, bot1, optimize=[(0, 1), (0, 1)]) - # Contract left and right: (c,u,b,d,f) @ (b,v,h,f,w) -> (c,u,d,v,h,w) - env = jnp.einsum("cubdf,bvhfw->cudvhw", tmp_left, tmp_right, optimize=[(0, 1)]) + env = jnp.einsum( + "ace,aub,edf,bvg,ghi,fwi->cudvhw", + left_env, + top0, + bot0, + top1, + right_env, + bot1, + optimize=[(0, 1), (0, 1), (1, 2), (1, 2), (0, 1)], + ) # Transpose to (up0, down0, mL, up1, down1, mR) return jnp.transpose(env, (1, 2, 0, 3, 5, 4)) @@ -885,19 +895,21 @@ def grads_and_energy( sample: jax.Array, amp: jax.Array, operator: Any, - envs: list[tuple], -) -> tuple[list[list[jax.Array]], jax.Array]: + top_envs: list[tuple], +) -> tuple[list[list[jax.Array]], jax.Array, list[tuple]]: """Compute environment gradients and local energy for PEPS. + Uses backward pass: takes cached top_envs, returns cached bottom_envs. + Args: model: PEPS model sample: flat sample array (occupancy indices) amp: amplitude for this configuration operator: LocalHamiltonian with terms - envs: bottom environments (computed via bottom_envs()) + top_envs: top environments (from sweep()) Returns: - (env_grads, energy) + (env_grads, energy, bottom_envs) - includes cached bottom environments for next sweep """ from vmc.operators.local_terms import bucket_terms @@ -912,7 +924,7 @@ def grads_and_energy( amp, model.shape, model.strategy, - envs, + top_envs, diagonal_terms=diagonal_terms, one_site_terms=one_site_terms, horizontal_terms=horizontal_terms, @@ -936,7 +948,7 @@ def sweep( sample: jax.Array, key: jax.Array, envs: list[tuple], -) -> tuple[jax.Array, jax.Array, jax.Array]: +) -> tuple[jax.Array, jax.Array, jax.Array, list[tuple]]: """Single Metropolis sweep for PEPS. Args: @@ -946,7 +958,7 @@ def sweep( envs: bottom environments from previous iteration Returns: - (new_sample, key, amp) - flat sample array after sweep + (new_sample, key, amp, top_envs) - includes cached top environments """ indices = sample.reshape(model.shape) tensors = [[jnp.asarray(t) for t in row] for row in model.tensors] @@ -954,7 +966,9 @@ def sweep( dtype = tensors[0][0].dtype top_env = tuple(jnp.ones((1, 1, 1), dtype=dtype) for _ in range(n_cols)) + top_envs_cache = [None] * n_rows for row in range(n_rows): + top_envs_cache[row] = top_env bottom_env = envs[row] mpo_row = _build_row_mpo(tensors, indices[row], row, n_cols) right_envs = _compute_right_envs(top_env, mpo_row, bottom_env, dtype) @@ -1006,7 +1020,7 @@ def sweep( top_env = model.strategy.apply(top_env, tuple(updated_row)) amp = _contract_bottom(top_env) - return indices.reshape(-1), key, amp + return indices.reshape(-1), key, amp, top_envs_cache class PEPS(nnx.Module): diff --git a/src/vmc/samplers/sequential.py b/src/vmc/samplers/sequential.py index 1270273..c4cc181 100644 --- a/src/vmc/samplers/sequential.py +++ b/src/vmc/samplers/sequential.py @@ -287,7 +287,7 @@ def sequential_sample( envs = jax.vmap(lambda s: bottom_envs(model, s))(samples_flat) def sweep_once(sample, key, envs): - sample, key, _ = sweep(model, sample, key, envs) + sample, key, _, _ = sweep(model, sample, key, envs) envs = bottom_envs(model, sample) return sample, key, envs @@ -550,17 +550,17 @@ def flatten_sliced_gradients(env_grads, sample, amp): def mc_sweep(sample, key, envs): """Single MC sweep: Metropolis sweep + gradient/energy + flatten (single vmap).""" - sample, key, amp = sweep(model, sample, key, envs) - envs = bottom_envs(model, sample) - env_grads, local_energy = grads_and_energy(model, sample, amp, operator, envs) + sample, key, amp, top_envs = sweep(model, sample, key, envs) + env_grads, local_energy, envs = grads_and_energy(model, sample, amp, operator, top_envs) grad_row, p_row = flatten_grads(env_grads, sample, amp) return sample, key, envs, grad_row, p_row, amp, local_energy def burn_step(carry, _): samples, chain_keys, envs = carry - samples, chain_keys, envs, _, _, _, _ = jax.vmap(mc_sweep)( - samples, chain_keys, envs - ) + samples, chain_keys, _, _ = jax.vmap( + lambda sample, key, env: sweep(model, sample, key, env) + )(samples, chain_keys, envs) + envs = jax.vmap(lambda sample: bottom_envs(model, sample))(samples) return (samples, chain_keys, envs), None (samples_flat, chain_keys, envs), _ = jax.lax.scan( @@ -602,4 +602,3 @@ def sample_step(carry, _): return samples_out, grads, p, key, final_samples, amps, local_energies - diff --git a/src/vmc/utils/__init__.py b/src/vmc/utils/__init__.py index 3e3952b..b9986d0 100644 --- a/src/vmc/utils/__init__.py +++ b/src/vmc/utils/__init__.py @@ -16,6 +16,10 @@ occupancy_to_spin, spin_to_occupancy, ) +from vmc.utils.factorizations import ( + _qr_compactwy, + _qr_cholesky, +) # Note: vmc_utils imports are not included here to avoid circular imports. # Import directly from vmc.utils.vmc_utils when needed. @@ -32,4 +36,6 @@ "independent_set_violations", "occupancy_to_spin", "spin_to_occupancy", + "_qr_compactwy", + "_qr_cholesky", ] diff --git a/src/vmc/utils/factorizations.py b/src/vmc/utils/factorizations.py new file mode 100644 index 0000000..f2a968e --- /dev/null +++ b/src/vmc/utils/factorizations.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import jax +import jax.numpy as jnp + +__all__ = ["_qr_compactwy", "_qr_cholesky"] + + +def _qr_compactwy(a: jax.Array) -> tuple[jax.Array, jax.Array]: + """Householder QR via compact WY representation. + + Computes reduced ``(Q, R)`` on trailing matrix axes, batch-polymorphically. + """ + r, tau = jnp.linalg.qr(a, mode="raw") # batchable geqrf by CuSOLVER + q = _householder_wy(r.mT, tau) + return q, jnp.triu(r.mT[..., : tau.shape[-1], :]) + + +def _householder_wy(r: jax.Array, tau: jax.Array) -> jax.Array: + """Build reduced ``Q`` from geqrf reflectors in compact WY form. + + Implements ``Q = I - Y T Y^H`` on trailing matrix axes, batch-polymorphically. + """ + m = r.shape[-2] + k = tau.shape[-1] + dtype = r.dtype + + Y = jnp.tril(r[..., :, :k], k=-1) + jnp.eye(m, k, dtype=dtype) + YHY = jnp.einsum("...ki,...kj->...ij", Y.conj(), Y, optimize=True) + strict_lower = jnp.tril(jnp.ones((k, k), dtype=dtype), k=-1) + basis = jnp.eye(k, dtype=dtype) + + def update_column(j: int, T: jax.Array) -> jax.Array: + mask = strict_lower[j, :] + yhy_col = YHY[..., :, j] * mask + t_yhy = jnp.einsum("...ab,...b->...a", T, yhy_col, optimize=True) + tau_j = tau[..., j][..., None] + new_col = -tau_j * t_yhy * mask + tau_j * basis[j] + return jax.lax.dynamic_update_slice_in_dim(T, new_col[..., None], j, axis=-1) + + T = jax.lax.fori_loop( + 0, + k, + update_column, + jnp.zeros(tau.shape[:-1] + (k, k), dtype=dtype), + ) + + return jnp.eye(m, k, dtype=dtype) - jnp.einsum( + "...ik,...kl,...jl->...ij", + Y, + T, + Y[..., :k, :].conj(), + optimize=True, + ) + + +def _qr_cholesky(a: jax.Array) -> tuple[jax.Array, jax.Array]: + gram = a.mH @ a + L = jnp.linalg.cholesky(gram) + q = jax.scipy.linalg.solve_triangular(L, a.mH, lower=True).mH + return q, L.mH diff --git a/src/vmc/utils/vmc_utils.py b/src/vmc/utils/vmc_utils.py index b166385..429e6e0 100644 --- a/src/vmc/utils/vmc_utils.py +++ b/src/vmc/utils/vmc_utils.py @@ -20,9 +20,8 @@ from vmc.core.eval import _value from vmc.models.peps import ( PEPS, - _build_row_mpo, _compute_all_env_grads_and_energy, - bottom_envs, + _forward_with_cache, ) from vmc.operators.local_terms import LocalHamiltonian, bucket_terms from vmc.utils.utils import occupancy_to_spin, spin_to_occupancy @@ -146,14 +145,14 @@ def diag_only(sample): def per_sample(sample, amp): occupancy = spin_to_occupancy(sample) spins = occupancy.reshape(shape) - envs = bottom_envs(model, occupancy) - _, energy = _compute_all_env_grads_and_energy( + _, top_envs = _forward_with_cache(tensors, spins, shape, model.strategy) + _, energy, _ = _compute_all_env_grads_and_energy( tensors, spins, amp, shape, model.strategy, - envs, + top_envs, diagonal_terms=diagonal_terms, one_site_terms=one_site_terms, horizontal_terms=horizontal_terms,