Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 55 additions & 41 deletions src/vmc/models/peps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import abc
import functools
import logging
from typing import TYPE_CHECKING, Any

import jax
Expand All @@ -18,6 +17,7 @@

from plum import dispatch

from vmc.utils.factorizations import _qr_compactwy
from vmc.utils.utils import random_tensor, spin_to_occupancy

if TYPE_CHECKING:
Expand Down Expand Up @@ -51,7 +51,6 @@
"_forward_with_cache",
]

logger = logging.getLogger(__name__)


# =============================================================================
Expand Down Expand Up @@ -136,8 +135,13 @@ def _contract_theta(m: jax.Array, w: jax.Array, carry: jax.Array | None) -> tupl
(left_dim, phys_dim, Dr, wr).
"""
if carry is not None:
tmp = jnp.einsum("kdl,dpr->prkl", carry, m)
theta = jnp.einsum("prkl,lwpq->kqrw", tmp, w)
theta = jnp.einsum(
"kdl,dpr,lwpq->kqrw",
carry,
m,
w,
optimize=[(0, 1), (0, 1)],
)
left_dim, phys_dim, Dr, wr = theta.shape
return theta, left_dim, phys_dim, Dr, wr
theta = jnp.einsum("dpr,lwpq->dlqrw", m, w)
Expand Down Expand Up @@ -230,7 +234,6 @@ def _apply_mpo_variational(

This avoids SVD entirely, using only QR for canonical form maintenance.
"""
n_sites = len(mps)
dtype = mps[0].dtype
Dc = truncate_bond_dimension

Expand Down Expand Up @@ -278,7 +281,7 @@ def _init_compressed_mps(mps: tuple, mpo: tuple, Dc: int) -> list[jax.Array]:
mat = theta.reshape(left_dim * phys_dim, Dr * wr)

# QR decomposition
Q, R = jax.lax.linalg.qr(mat, full_matrices=False)
Q, R = _qr_compactwy(mat)

# Truncate to Dc columns
k = min(Dc, Q.shape[1])
Expand Down Expand Up @@ -370,7 +373,7 @@ def _variational_sweep_lr(
# QR decompose: optimal = A @ C (Paeckel Sec 2.5)
left_dim, phys_dim, right_dim = optimal.shape
mat = optimal.reshape(left_dim * phys_dim, right_dim)
A, C = jax.lax.linalg.qr(mat, full_matrices=False)
A, C = _qr_compactwy(mat)

new_tensor = A.reshape(left_dim, phys_dim, A.shape[1])
new_result[i] = new_tensor
Expand Down Expand Up @@ -438,7 +441,7 @@ def _variational_sweep_rl(
# LQ via QR on transpose: A = C @ B, A.T = B.T @ C.T
left_dim, phys_dim, right_dim = optimal.shape
mat = optimal.reshape(left_dim, phys_dim * right_dim)
Q_t, R_t = jax.lax.linalg.qr(mat.T, full_matrices=False)
Q_t, R_t = _qr_compactwy(mat.T)
# B = Q_t.T has orthonormal rows, C = R_t.T
B = Q_t.T # (k, phys * right)
C = R_t.T # (left, k)
Expand Down Expand Up @@ -486,18 +489,17 @@ def _forward_with_cache(
shape: tuple[int, int],
strategy: ContractionStrategy,
) -> tuple[jax.Array, list[tuple]]:
"""Forward pass that caches all intermediate boundary MPSs."""
"""Forward pass that caches the top boundary before each row."""
n_rows, n_cols = shape
dtype = jnp.asarray(tensors[0][0]).dtype

top_envs = []
top_envs = [None] * n_rows
boundary = tuple(jnp.ones((1, 1, 1), dtype=dtype) for _ in range(n_cols))
top_envs.append(boundary)

for row in range(n_rows):
top_envs[row] = boundary
mpo = _build_row_mpo(tensors, spins[row], row, n_cols)
boundary = strategy.apply(boundary, mpo)
top_envs.append(boundary)

return _contract_bottom(boundary), top_envs

Expand Down Expand Up @@ -637,15 +639,15 @@ def _compute_all_env_grads_and_energy(
amp: jax.Array,
shape: tuple[int, int],
strategy: ContractionStrategy,
bottom_envs: list[tuple],
top_envs: list[tuple],
*,
diagonal_terms: list,
one_site_terms: list[list[list]],
horizontal_terms: list[list[list]],
vertical_terms: list[list[list]],
collect_grads: bool = True,
) -> tuple[list[list[jax.Array]], jax.Array]:
"""Compute gradients and local energy for a PEPS sample."""
) -> tuple[list[list[jax.Array]], jax.Array, list[tuple]]:
"""Backward pass: use cached top_envs, build and cache bottom_envs."""
n_rows, n_cols = shape
dtype = jnp.asarray(tensors[0][0]).dtype
phys_dim = int(jnp.asarray(tensors[0][0]).shape[0])
Expand All @@ -655,17 +657,23 @@ def _compute_all_env_grads_and_energy(
if collect_grads
else []
)
bottom_envs_cache = [None] * n_rows
energy = jnp.zeros((), dtype=amp.dtype)

# Diagonal terms
for term in diagonal_terms:
idx = jnp.asarray(0, dtype=jnp.int32)
for row, col in term.sites:
idx = idx * phys_dim + spins[row, col]
energy = energy + term.diag[idx]

top_env = tuple(jnp.ones((1, 1, 1), dtype=dtype) for _ in range(n_cols))
mpo = _build_row_mpo(tensors, spins[0], 0, n_cols)
for row in range(n_rows):
bottom_env = bottom_envs[row]
# Backward pass: bottom → top
bottom_env = tuple(jnp.ones((1, 1, 1), dtype=dtype) for _ in range(n_cols))
next_row_mpo = None
for row in range(n_rows - 1, -1, -1):
bottom_envs_cache[row] = bottom_env
top_env = top_envs[row]
mpo = _build_row_mpo(tensors, spins[row], row, n_cols)
right_envs = _compute_right_envs(top_env, mpo, bottom_env, dtype)
left_env = jnp.ones((1, 1, 1), dtype=dtype)
for c in range(n_cols):
Expand Down Expand Up @@ -706,19 +714,18 @@ def _compute_all_env_grads_and_energy(
amps_flat = amps_edge.reshape(-1)
for term in edge_terms:
energy = energy + jnp.dot(term.op[:, col_idx], amps_flat) / amp
# Direct einsum for left_env update
left_env = jnp.einsum(
"ace,aub,cduv,evf->bdf",
left_env, top_env[c], mpo[c], bottom_env[c],
optimize=[(0, 1), (0, 2), (0, 1)],
)
# Vertical energy between row and row+1
if row < n_rows - 1:
mpo_next = _build_row_mpo(tensors, spins[row + 1], row + 1, n_cols)
energy = energy + _compute_row_pair_vertical_energy(
top_env,
bottom_envs[row + 1],
bottom_envs_cache[row + 1],
mpo,
mpo_next,
next_row_mpo,
tensors[row],
tensors[row + 1],
spins[row],
Expand All @@ -727,11 +734,10 @@ def _compute_all_env_grads_and_energy(
amp,
phys_dim,
)
top_env = strategy.apply(top_env, mpo)
if row < n_rows - 1:
mpo = mpo_next
bottom_env = _apply_mpo_from_below(bottom_env, mpo, strategy)
next_row_mpo = mpo

return env_grads, energy
return env_grads, energy, bottom_envs_cache


def _compute_2site_horizontal_env(
Expand All @@ -752,12 +758,16 @@ def _compute_2site_horizontal_env(

Returns tensor with shape (up0, down0, mL, up1, down1, mR).
"""
# Contract left side: left_env (a,c,e) @ top0 (a,u,b) @ bot0 (e,d,f) -> (c,u,b,d,f)
tmp_left = jnp.einsum("ace,aub,edf->cubdf", left_env, top0, bot0, optimize=[(0, 1), (0, 1)])
# Contract right side: top1 (b,v,g) @ right_env (g,h,i) @ bot1 (f,w,i) -> (b,v,h,f,w)
tmp_right = jnp.einsum("bvg,ghi,fwi->bvhfw", top1, right_env, bot1, optimize=[(0, 1), (0, 1)])
# Contract left and right: (c,u,b,d,f) @ (b,v,h,f,w) -> (c,u,d,v,h,w)
env = jnp.einsum("cubdf,bvhfw->cudvhw", tmp_left, tmp_right, optimize=[(0, 1)])
env = jnp.einsum(
"ace,aub,edf,bvg,ghi,fwi->cudvhw",
left_env,
top0,
bot0,
top1,
right_env,
bot1,
optimize=[(0, 1), (0, 1), (1, 2), (1, 2), (0, 1)],
)
# Transpose to (up0, down0, mL, up1, down1, mR)
return jnp.transpose(env, (1, 2, 0, 3, 5, 4))

Expand Down Expand Up @@ -885,19 +895,21 @@ def grads_and_energy(
sample: jax.Array,
amp: jax.Array,
operator: Any,
envs: list[tuple],
) -> tuple[list[list[jax.Array]], jax.Array]:
top_envs: list[tuple],
) -> tuple[list[list[jax.Array]], jax.Array, list[tuple]]:
"""Compute environment gradients and local energy for PEPS.

Uses backward pass: takes cached top_envs, returns cached bottom_envs.

Args:
model: PEPS model
sample: flat sample array (occupancy indices)
amp: amplitude for this configuration
operator: LocalHamiltonian with terms
envs: bottom environments (computed via bottom_envs())
top_envs: top environments (from sweep())

Returns:
(env_grads, energy)
(env_grads, energy, bottom_envs) - includes cached bottom environments for next sweep
"""
from vmc.operators.local_terms import bucket_terms

Expand All @@ -912,7 +924,7 @@ def grads_and_energy(
amp,
model.shape,
model.strategy,
envs,
top_envs,
diagonal_terms=diagonal_terms,
one_site_terms=one_site_terms,
horizontal_terms=horizontal_terms,
Expand All @@ -936,7 +948,7 @@ def sweep(
sample: jax.Array,
key: jax.Array,
envs: list[tuple],
) -> tuple[jax.Array, jax.Array, jax.Array]:
) -> tuple[jax.Array, jax.Array, jax.Array, list[tuple]]:
"""Single Metropolis sweep for PEPS.

Args:
Expand All @@ -946,15 +958,17 @@ def sweep(
envs: bottom environments from previous iteration

Returns:
(new_sample, key, amp) - flat sample array after sweep
(new_sample, key, amp, top_envs) - includes cached top environments
"""
indices = sample.reshape(model.shape)
tensors = [[jnp.asarray(t) for t in row] for row in model.tensors]
n_rows, n_cols = model.shape
dtype = tensors[0][0].dtype

top_env = tuple(jnp.ones((1, 1, 1), dtype=dtype) for _ in range(n_cols))
top_envs_cache = [None] * n_rows
for row in range(n_rows):
top_envs_cache[row] = top_env
bottom_env = envs[row]
mpo_row = _build_row_mpo(tensors, indices[row], row, n_cols)
right_envs = _compute_right_envs(top_env, mpo_row, bottom_env, dtype)
Expand Down Expand Up @@ -1006,7 +1020,7 @@ def sweep(
top_env = model.strategy.apply(top_env, tuple(updated_row))

amp = _contract_bottom(top_env)
return indices.reshape(-1), key, amp
return indices.reshape(-1), key, amp, top_envs_cache


class PEPS(nnx.Module):
Expand Down
15 changes: 7 additions & 8 deletions src/vmc/samplers/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def sequential_sample(
envs = jax.vmap(lambda s: bottom_envs(model, s))(samples_flat)

def sweep_once(sample, key, envs):
sample, key, _ = sweep(model, sample, key, envs)
sample, key, _, _ = sweep(model, sample, key, envs)
envs = bottom_envs(model, sample)
return sample, key, envs

Expand Down Expand Up @@ -550,17 +550,17 @@ def flatten_sliced_gradients(env_grads, sample, amp):

def mc_sweep(sample, key, envs):
"""Single MC sweep: Metropolis sweep + gradient/energy + flatten (single vmap)."""
sample, key, amp = sweep(model, sample, key, envs)
envs = bottom_envs(model, sample)
env_grads, local_energy = grads_and_energy(model, sample, amp, operator, envs)
sample, key, amp, top_envs = sweep(model, sample, key, envs)
env_grads, local_energy, envs = grads_and_energy(model, sample, amp, operator, top_envs)
grad_row, p_row = flatten_grads(env_grads, sample, amp)
return sample, key, envs, grad_row, p_row, amp, local_energy

def burn_step(carry, _):
samples, chain_keys, envs = carry
samples, chain_keys, envs, _, _, _, _ = jax.vmap(mc_sweep)(
samples, chain_keys, envs
)
samples, chain_keys, _, _ = jax.vmap(
lambda sample, key, env: sweep(model, sample, key, env)
)(samples, chain_keys, envs)
envs = jax.vmap(lambda sample: bottom_envs(model, sample))(samples)
return (samples, chain_keys, envs), None

(samples_flat, chain_keys, envs), _ = jax.lax.scan(
Expand Down Expand Up @@ -602,4 +602,3 @@ def sample_step(carry, _):

return samples_out, grads, p, key, final_samples, amps, local_energies


6 changes: 6 additions & 0 deletions src/vmc/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
occupancy_to_spin,
spin_to_occupancy,
)
from vmc.utils.factorizations import (
_qr_compactwy,
_qr_cholesky,
)

# Note: vmc_utils imports are not included here to avoid circular imports.
# Import directly from vmc.utils.vmc_utils when needed.
Expand All @@ -32,4 +36,6 @@
"independent_set_violations",
"occupancy_to_spin",
"spin_to_occupancy",
"_qr_compactwy",
"_qr_cholesky",
]
61 changes: 61 additions & 0 deletions src/vmc/utils/factorizations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

import jax
import jax.numpy as jnp

__all__ = ["_qr_compactwy", "_qr_cholesky"]


def _qr_compactwy(a: jax.Array) -> tuple[jax.Array, jax.Array]:
"""Householder QR via compact WY representation.

Computes reduced ``(Q, R)`` on trailing matrix axes, batch-polymorphically.
"""
r, tau = jnp.linalg.qr(a, mode="raw") # batchable geqrf by CuSOLVER
q = _householder_wy(r.mT, tau)
return q, jnp.triu(r.mT[..., : tau.shape[-1], :])


def _householder_wy(r: jax.Array, tau: jax.Array) -> jax.Array:
"""Build reduced ``Q`` from geqrf reflectors in compact WY form.

Implements ``Q = I - Y T Y^H`` on trailing matrix axes, batch-polymorphically.
"""
m = r.shape[-2]
k = tau.shape[-1]
dtype = r.dtype

Y = jnp.tril(r[..., :, :k], k=-1) + jnp.eye(m, k, dtype=dtype)
YHY = jnp.einsum("...ki,...kj->...ij", Y.conj(), Y, optimize=True)
strict_lower = jnp.tril(jnp.ones((k, k), dtype=dtype), k=-1)
basis = jnp.eye(k, dtype=dtype)

def update_column(j: int, T: jax.Array) -> jax.Array:
mask = strict_lower[j, :]
yhy_col = YHY[..., :, j] * mask
t_yhy = jnp.einsum("...ab,...b->...a", T, yhy_col, optimize=True)
tau_j = tau[..., j][..., None]
new_col = -tau_j * t_yhy * mask + tau_j * basis[j]
return jax.lax.dynamic_update_slice_in_dim(T, new_col[..., None], j, axis=-1)

T = jax.lax.fori_loop(
0,
k,
update_column,
jnp.zeros(tau.shape[:-1] + (k, k), dtype=dtype),
)

return jnp.eye(m, k, dtype=dtype) - jnp.einsum(
"...ik,...kl,...jl->...ij",
Y,
T,
Y[..., :k, :].conj(),
optimize=True,
)


def _qr_cholesky(a: jax.Array) -> tuple[jax.Array, jax.Array]:
gram = a.mH @ a
L = jnp.linalg.cholesky(gram)
q = jax.scipy.linalg.solve_triangular(L, a.mH, lower=True).mH
return q, L.mH
Loading
Loading