Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion _scripts/poly6_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def tune_poly_sdpr(
):
# Create SDPR Layer
optlayer = SDPRLayer(
n_vars=4, constraints=self.constraints, use_dual=False, diff_qcqp=diff_qcqp
n_vars=4, constraints=self.constraints, use_dual=False, diff_qcqp=diff_qcqp ,redun_list=[2]
)

# Set up polynomial parameter tensor
Expand Down
216 changes: 208 additions & 8 deletions src/sdprlayers/layers/sdprlayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,26 @@
# Minimum Eigenvalue Ratio for Tightness Check.
ER_MIN = 1e5

_QCQP_HISTORY_BUFFER = []
_QCQP_MAX_HISTORY = 5

def _qcqp_save_conditions_history(xs, Hs, Q, params, mults):
global _QCQP_HISTORY_BUFFER, _QCQP_MAX_HISTORY
# Store current iteration data BEFORE processing
current_iter_data = {
'xs': xs.detach().cpu().numpy().copy(),
'Hs': [H.copy() if isinstance(H, np.ndarray) else H for H in Hs],
'objective': Q,
'params': [p.detach().cpu().numpy().copy() for p in params],
'multipliers': [m.copy() if isinstance(m, np.ndarray) else m for m in mults],
}
_QCQP_HISTORY_BUFFER.append(current_iter_data)
if len(_QCQP_HISTORY_BUFFER) > _QCQP_MAX_HISTORY:
_QCQP_HISTORY_BUFFER.pop(0)

def get_last_qcqp_entry():
"""Return the last QCQP history entry, or None if empty."""
return _QCQP_HISTORY_BUFFER[-1] if _QCQP_HISTORY_BUFFER else None

class SDPRLayer(CvxpyLayer):
"""
Expand Down Expand Up @@ -330,13 +350,27 @@ def forward(self, *param_vals, ext_vars_list=None, **kwargs):
# Get slack variable and unvectorize
hs = soln[3]
Hs = [cones.unvec_symm(h, self.n_vars) for h in hs]
#check compilation warnings
for i, H in enumerate(Hs):
H_rank = np.linalg.matrix_rank(Hs, tol=1e-10)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would replace matrix_rank and eigvalsh with one single call to eigvalsh and then checking the number of eigenvalues <= 1e-10. Just to be sure -- is it really necessary to do this check every time? Would it be possible to do it only if necessary (as in the cases below, in a caught exception?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply, I’ll address each point.
I agree with your suggested modification of using eigvalsh instead of matrix_rank.

As shown in Lemma 5 of the paper, corank(H) = 1 is a sufficient condition for SOSC, which is one of the key assumptions of the main theorem. My intention was to check whether this condition fails before hitting the assertion errors, mainly for diagnostic purposes. If preferred, this check can be moved behind assertion errors or removed entirely.

Regarding all the edits in sdprlayer.py file: they are not required to solve the singular Jacobian issue itself. However, during some runs I sporadically encountered assertion errors, so I added a try–except block to help identify their causes. I’m happy to remove all these changes if you think they’re unnecessary.

H_corank = H.shape[0] - H_rank
if H_corank != 1:
H_evals = np.linalg.eigvalsh(H)
print(f"\nWARNING: Certificate matrix H {i}has corank {H_corank} (expected corank 1)")
print(f"H rank: {H_rank} (expected {H.shape[0] - 1})")
print(f"H eigenvalues (sorted): {np.sort(H_evals)}")
print(f"This may indicate numerical issues or loose relaxation")
# Check that the whole batch is tight.
alltight = True
for X in Xs:
tight, ER = self.check_tightness(X)
if not tight:
alltight = False
break

#Store current iteration data in history buffer
_qcqp_save_conditions_history(xs, Hs, param_vals_h[0].detach().cpu().numpy(), param_vals, mults)

# If using nonconvex backprop, overwrite solution IF all problems are tight.
if alltight:
# Overwrite solution using QCQP autograd function
Expand Down Expand Up @@ -421,7 +455,7 @@ def check_tightness(X, ER_min=ER_MIN):
# Check rank
sorted_eigs = np.sort(np.linalg.eigvalsh(X))
sorted_eigs = np.abs(sorted_eigs)
ER = sorted_eigs[-1] / sorted_eigs[-2]
ER = sorted_eigs[-1] / (sorted_eigs[-2]+1e-16)
tight = ER > ER_min
return tight, ER

Expand Down Expand Up @@ -495,6 +529,7 @@ class DiffQCQP(torch.autograd.Function):
@staticmethod
def forward(ctx, *params):
"""Forward function is basically a dummy to store the required information for implicit backward pass."""
global _QCQP_HISTORY_BUFFER, _QCQP_MAX_HISTORY
# keep track of which parts of the problem are parameterized
param_dict = dict(objective=False, constraints=[])
param_ind = 0
Expand Down Expand Up @@ -526,7 +561,6 @@ def forward(ctx, *params):
"All parameters have not been used in QCQP Forward!"
)
ctx.param_dict = param_dict

# Store solution and certificate matrix
ctx.xs = xs.detach().cpu().numpy()
ctx.Hs = Hs
Expand Down Expand Up @@ -611,10 +645,111 @@ def backward(ctx, grad_output):
# Pad incoming gradient (derivative of loss wrt multipliers is zero)
dz_bar = np.vstack([-grad_output[b], np.zeros((G.shape[0], 1))])

#check compilation warnings
H_rank = np.linalg.matrix_rank(H, tol=1e-10)
H_corank = H.shape[0] - H_rank
if H_corank != 1:
H_evals = np.linalg.eigvalsh(H)
print(f"\nWARNING: Certificate matrix H has corank {H_corank} (expected corank 1)")
print(f"H rank: {H_rank} (expected {H.shape[0] - 1})")
print(f"H eigenvalues (sorted): {np.sort(H_evals)}")
print(f"This may indicate numerical issues or loose relaxation")
# Print history
if len(_QCQP_HISTORY_BUFFER) >= 2:
hist = _QCQP_HISTORY_BUFFER[-2] # Second-to-last = previous iteration
print("\n" + "-"*80)
print("PREVIOUS ITERATION")
print("-"*80)
if b < len(hist['Hs']):
H_prev = hist['Hs'][b]
if isinstance(H_prev, np.ndarray):
H_prev_evals = np.linalg.eigvalsh(H_prev)
print(f"Previous H eigenvalues: {np.sort(H_prev_evals)}")
print(f"Previous H rank: {np.linalg.matrix_rank(H_prev, tol=1e-10)}")
if len(hist['params']) > 0 and ctx.param_dict['objective']:
Q_prev = hist['params'][0]
if len(Q_prev.shape) > 2:
Q_prev = Q_prev[b]
if len(ctx.objective.shape) > 2:
Q = ctx.objective[b]
else:
Q = ctx.objective
Q_prev_evals = np.linalg.eigvalsh(Q_prev)
print(f"Previous Q eigenvalues: {np.sort(Q_prev_evals)}")
print("-"*80 + "\n")

# Check that certificate matrix satisfies the first order KKT conditions
assert np.linalg.norm(H @ x) < ctx.kkt_tol, ValueError(
"First-order KKT conditions cannot be satisfied! Check Certificate matrix."
)
try:
assert np.linalg.norm(H @ x) < ctx.kkt_tol, ValueError(
"First-order KKT conditions cannot be satisfied! Check Certificate matrix."
)
except AssertionError as e:
# Print diagnostic information
print("\n" + "="*80)
print("KKT CONDITION VIOLATION DETECTED")
print("="*80)

# Certificate Matrix H diagnostics
print("\n--- Certificate Matrix H ---")
print(f"H shape: {H.shape}")
H_evals = np.linalg.eigvalsh(H)
print(f"H eigenvalues (sorted): {np.sort(H_evals)}")
# print(f"H condition number: {np.max(np.abs(H_evals)) / np.max(np.abs(H_evals[np.abs(H_evals) > 1e-10]))}")
print(f"H @ x norm: {np.linalg.norm(H @ x)}")
print(f"H @ x: {(H @ x).flatten()}")

# Objective Matrix Q diagnostics
print("\n--- Objective Matrix Q ---")
if len(ctx.objective.shape) > 2:
Q = ctx.objective[b]
else:
Q = ctx.objective
print(f"Q shape: {Q.shape}")
Q_evals = np.linalg.eigvalsh(Q)
print(f"Q eigenvalues (sorted): {np.sort(Q_evals)}")
# print(f"Q condition number: {np.max(np.abs(Q_evals)) / np.max(np.abs(Q_evals[np.abs(Q_evals) > 1e-10]))}")

# Solution vector
print("\n--- Solution Vector x ---")
print(f"x shape: {x.shape}")
print(f"x: {x.flatten()}")
print(f"||x||: {np.linalg.norm(x)}")

# Print previous iteration only
if len(_QCQP_HISTORY_BUFFER) >= 2:
hist = _QCQP_HISTORY_BUFFER[-2] # Previous iteration
print("\n" + "-"*80)
print("PREVIOUS ITERATION")
print("-"*80)

if b < len(hist['xs']):
x_prev = hist['xs'][b]
print(f"Previous x: {x_prev.flatten()}")
print(f"Previous ||x||: {np.linalg.norm(x_prev):.6e}")
print(f"Change in x: {np.linalg.norm(x - x_prev):.6e}")

if b < len(hist['Hs']):
H_prev = hist['Hs'][b]
if isinstance(H_prev, np.ndarray):
H_prev_evals = np.linalg.eigvalsh(H_prev)
print(f"Previous H eigenvalues: {np.sort(H_prev_evals)}")
print(f"Previous H rank: {np.linalg.matrix_rank(H_prev, tol=1e-10)}")
print(f"Previous H @ x_prev norm: {np.linalg.norm(H_prev @ x_prev):.6e}")
print(f"Change in H (Frobenius): {np.linalg.norm(H - H_prev, 'fro'):.6e}")

if len(hist['params']) > 0 and ctx.param_dict['objective']:
Q_prev = hist['params'][0]
if len(Q_prev.shape) > 2:
Q_prev = Q_prev[b]
Q_prev_evals = np.linalg.eigvalsh(Q_prev)
print(f"Previous Q eigenvalues: {np.sort(Q_prev_evals)}")
print(f"Change in Q (Frobenius): {np.linalg.norm(Q - Q_prev, 'fro'):.6e}")

print("-"*80)

print("="*80 + "\n")
raise e

# Solve Differential KKT System
if M.shape[0] == M.shape[1]:
# Symmetric case
Expand All @@ -635,9 +770,74 @@ def backward(ctx, grad_output):
sol = ls_sol[0][:, None]
res = ls_sol[3]
# Check that we have actually solved the differential KKT system
assert res < ctx.kkt_tol, ValueError(
"Differential KKT system residual is high. Make sure that redundant constraints are actually redundant and that the certificate matrix is correct."
)
try:
assert res < ctx.kkt_tol, ValueError(
"Differential KKT system residual is high. Make sure that redundant constraints are actually redundant and that the certificate matrix is correct."
)
except AssertionError as e:
# Print diagnostic information
print("\n" + "="*80)
print("DIFFERENTIAL KKT SYSTEM SOLVE ISSUE DETECTED")
print("="*80)

# Certificate Matrix H diagnostics
print("\n--- Certificate Matrix H ---")
print(f"H shape: {H.shape}")
H_evals = np.linalg.eigvalsh(H)
print(f"H eigenvalues (sorted): {np.sort(H_evals)}")

# Objective Matrix Q diagnostics
print("\n--- Objective Matrix Q ---")
if len(ctx.objective.shape) > 2:
Q = ctx.objective[b]
else:
Q = ctx.objective
print(f"Q shape: {Q.shape}")
Q_evals = np.linalg.eigvalsh(Q)
print(f"Q eigenvalues (sorted): {np.sort(Q_evals)}")

# Solution vector
print("\n--- Solution Vector x ---")
print(f"x shape: {x.shape}")
print(f"x: {x.flatten()}")
print(f"||x||: {np.linalg.norm(x)}")

print(f"\nKKT Solve Residual: {res} (tolerance: {ctx.kkt_tol})")

# Print previous iteration only
if len(_QCQP_HISTORY_BUFFER) >= 2:
hist = _QCQP_HISTORY_BUFFER[-2]
print("\n" + "-"*80)
print("PREVIOUS ITERATION")
print("-"*80)

if b < len(hist['xs']):
x_prev = hist['xs'][b]
print(f"Previous x: {x_prev.flatten()}")
print(f"Previous ||x||: {np.linalg.norm(x_prev):.6e}")
print(f"Change in x: {np.linalg.norm(x - x_prev):.6e}")

if b < len(hist['Hs']):
H_prev = hist['Hs'][b]
if isinstance(H_prev, np.ndarray):
H_prev_evals = np.linalg.eigvalsh(H_prev)
print(f"Previous H eigenvalues: {np.sort(H_prev_evals)}")
print(f"Previous H rank: {np.linalg.matrix_rank(H_prev, tol=1e-10)}")
print(f"Previous H @ x_prev norm: {np.linalg.norm(H_prev @ x_prev):.6e}")
print(f"Change in H (Frobenius): {np.linalg.norm(H - H_prev, 'fro'):.6e}")

if len(hist['params']) > 0 and ctx.param_dict['objective']:
Q_prev = hist['params'][0]
if len(Q_prev.shape) > 2:
Q_prev = Q_prev[b]
Q_prev_evals = np.linalg.eigvalsh(Q_prev)
print(f"Previous Q eigenvalues: {np.sort(Q_prev_evals)}")
print(f"Change in Q (Frobenius): {np.linalg.norm(Q - Q_prev, 'fro'):.6e}")

print("-"*80)

print("="*80 + "\n")
raise e
dy_bar = sol
dy_bar_1 = dy_bar[:nvars, :]
# Fill with zeros at redundant entries
Expand Down