diff --git a/_scripts/poly6_example.py b/_scripts/poly6_example.py index 72b2e21..0164229 100644 --- a/_scripts/poly6_example.py +++ b/_scripts/poly6_example.py @@ -68,7 +68,7 @@ def tune_poly_sdpr( ): # Create SDPR Layer optlayer = SDPRLayer( - n_vars=4, constraints=self.constraints, use_dual=False, diff_qcqp=diff_qcqp + n_vars=4, constraints=self.constraints, use_dual=False, diff_qcqp=diff_qcqp ,redun_list=[2] ) # Set up polynomial parameter tensor diff --git a/src/sdprlayers/layers/sdprlayer.py b/src/sdprlayers/layers/sdprlayer.py index 75735cf..cd1c99c 100644 --- a/src/sdprlayers/layers/sdprlayer.py +++ b/src/sdprlayers/layers/sdprlayer.py @@ -34,6 +34,26 @@ # Minimum Eigenvalue Ratio for Tightness Check. ER_MIN = 1e5 +_QCQP_HISTORY_BUFFER = [] +_QCQP_MAX_HISTORY = 5 + +def _qcqp_save_conditions_history(xs, Hs, Q, params, mults): + global _QCQP_HISTORY_BUFFER, _QCQP_MAX_HISTORY + # Store current iteration data BEFORE processing + current_iter_data = { + 'xs': xs.detach().cpu().numpy().copy(), + 'Hs': [H.copy() if isinstance(H, np.ndarray) else H for H in Hs], + 'objective': Q, + 'params': [p.detach().cpu().numpy().copy() for p in params], + 'multipliers': [m.copy() if isinstance(m, np.ndarray) else m for m in mults], + } + _QCQP_HISTORY_BUFFER.append(current_iter_data) + if len(_QCQP_HISTORY_BUFFER) > _QCQP_MAX_HISTORY: + _QCQP_HISTORY_BUFFER.pop(0) + +def get_last_qcqp_entry(): + """Return the last QCQP history entry, or None if empty.""" + return _QCQP_HISTORY_BUFFER[-1] if _QCQP_HISTORY_BUFFER else None class SDPRLayer(CvxpyLayer): """ @@ -330,6 +350,16 @@ def forward(self, *param_vals, ext_vars_list=None, **kwargs): # Get slack variable and unvectorize hs = soln[3] Hs = [cones.unvec_symm(h, self.n_vars) for h in hs] + #check compilation warnings + for i, H in enumerate(Hs): + H_rank = np.linalg.matrix_rank(Hs, tol=1e-10) + H_corank = H.shape[0] - H_rank + if H_corank != 1: + H_evals = np.linalg.eigvalsh(H) + print(f"\nWARNING: Certificate matrix H {i}has corank {H_corank} (expected corank 1)") + print(f"H rank: {H_rank} (expected {H.shape[0] - 1})") + print(f"H eigenvalues (sorted): {np.sort(H_evals)}") + print(f"This may indicate numerical issues or loose relaxation") # Check that the whole batch is tight. alltight = True for X in Xs: @@ -337,6 +367,10 @@ def forward(self, *param_vals, ext_vars_list=None, **kwargs): if not tight: alltight = False break + + #Store current iteration data in history buffer + _qcqp_save_conditions_history(xs, Hs, param_vals_h[0].detach().cpu().numpy(), param_vals, mults) + # If using nonconvex backprop, overwrite solution IF all problems are tight. if alltight: # Overwrite solution using QCQP autograd function @@ -421,7 +455,7 @@ def check_tightness(X, ER_min=ER_MIN): # Check rank sorted_eigs = np.sort(np.linalg.eigvalsh(X)) sorted_eigs = np.abs(sorted_eigs) - ER = sorted_eigs[-1] / sorted_eigs[-2] + ER = sorted_eigs[-1] / (sorted_eigs[-2]+1e-16) tight = ER > ER_min return tight, ER @@ -495,6 +529,7 @@ class DiffQCQP(torch.autograd.Function): @staticmethod def forward(ctx, *params): """Forward function is basically a dummy to store the required information for implicit backward pass.""" + global _QCQP_HISTORY_BUFFER, _QCQP_MAX_HISTORY # keep track of which parts of the problem are parameterized param_dict = dict(objective=False, constraints=[]) param_ind = 0 @@ -526,7 +561,6 @@ def forward(ctx, *params): "All parameters have not been used in QCQP Forward!" ) ctx.param_dict = param_dict - # Store solution and certificate matrix ctx.xs = xs.detach().cpu().numpy() ctx.Hs = Hs @@ -611,10 +645,111 @@ def backward(ctx, grad_output): # Pad incoming gradient (derivative of loss wrt multipliers is zero) dz_bar = np.vstack([-grad_output[b], np.zeros((G.shape[0], 1))]) + #check compilation warnings + H_rank = np.linalg.matrix_rank(H, tol=1e-10) + H_corank = H.shape[0] - H_rank + if H_corank != 1: + H_evals = np.linalg.eigvalsh(H) + print(f"\nWARNING: Certificate matrix H has corank {H_corank} (expected corank 1)") + print(f"H rank: {H_rank} (expected {H.shape[0] - 1})") + print(f"H eigenvalues (sorted): {np.sort(H_evals)}") + print(f"This may indicate numerical issues or loose relaxation") + # Print history + if len(_QCQP_HISTORY_BUFFER) >= 2: + hist = _QCQP_HISTORY_BUFFER[-2] # Second-to-last = previous iteration + print("\n" + "-"*80) + print("PREVIOUS ITERATION") + print("-"*80) + if b < len(hist['Hs']): + H_prev = hist['Hs'][b] + if isinstance(H_prev, np.ndarray): + H_prev_evals = np.linalg.eigvalsh(H_prev) + print(f"Previous H eigenvalues: {np.sort(H_prev_evals)}") + print(f"Previous H rank: {np.linalg.matrix_rank(H_prev, tol=1e-10)}") + if len(hist['params']) > 0 and ctx.param_dict['objective']: + Q_prev = hist['params'][0] + if len(Q_prev.shape) > 2: + Q_prev = Q_prev[b] + if len(ctx.objective.shape) > 2: + Q = ctx.objective[b] + else: + Q = ctx.objective + Q_prev_evals = np.linalg.eigvalsh(Q_prev) + print(f"Previous Q eigenvalues: {np.sort(Q_prev_evals)}") + print("-"*80 + "\n") + # Check that certificate matrix satisfies the first order KKT conditions - assert np.linalg.norm(H @ x) < ctx.kkt_tol, ValueError( - "First-order KKT conditions cannot be satisfied! Check Certificate matrix." - ) + try: + assert np.linalg.norm(H @ x) < ctx.kkt_tol, ValueError( + "First-order KKT conditions cannot be satisfied! Check Certificate matrix." + ) + except AssertionError as e: + # Print diagnostic information + print("\n" + "="*80) + print("KKT CONDITION VIOLATION DETECTED") + print("="*80) + + # Certificate Matrix H diagnostics + print("\n--- Certificate Matrix H ---") + print(f"H shape: {H.shape}") + H_evals = np.linalg.eigvalsh(H) + print(f"H eigenvalues (sorted): {np.sort(H_evals)}") + # print(f"H condition number: {np.max(np.abs(H_evals)) / np.max(np.abs(H_evals[np.abs(H_evals) > 1e-10]))}") + print(f"H @ x norm: {np.linalg.norm(H @ x)}") + print(f"H @ x: {(H @ x).flatten()}") + + # Objective Matrix Q diagnostics + print("\n--- Objective Matrix Q ---") + if len(ctx.objective.shape) > 2: + Q = ctx.objective[b] + else: + Q = ctx.objective + print(f"Q shape: {Q.shape}") + Q_evals = np.linalg.eigvalsh(Q) + print(f"Q eigenvalues (sorted): {np.sort(Q_evals)}") + # print(f"Q condition number: {np.max(np.abs(Q_evals)) / np.max(np.abs(Q_evals[np.abs(Q_evals) > 1e-10]))}") + + # Solution vector + print("\n--- Solution Vector x ---") + print(f"x shape: {x.shape}") + print(f"x: {x.flatten()}") + print(f"||x||: {np.linalg.norm(x)}") + + # Print previous iteration only + if len(_QCQP_HISTORY_BUFFER) >= 2: + hist = _QCQP_HISTORY_BUFFER[-2] # Previous iteration + print("\n" + "-"*80) + print("PREVIOUS ITERATION") + print("-"*80) + + if b < len(hist['xs']): + x_prev = hist['xs'][b] + print(f"Previous x: {x_prev.flatten()}") + print(f"Previous ||x||: {np.linalg.norm(x_prev):.6e}") + print(f"Change in x: {np.linalg.norm(x - x_prev):.6e}") + + if b < len(hist['Hs']): + H_prev = hist['Hs'][b] + if isinstance(H_prev, np.ndarray): + H_prev_evals = np.linalg.eigvalsh(H_prev) + print(f"Previous H eigenvalues: {np.sort(H_prev_evals)}") + print(f"Previous H rank: {np.linalg.matrix_rank(H_prev, tol=1e-10)}") + print(f"Previous H @ x_prev norm: {np.linalg.norm(H_prev @ x_prev):.6e}") + print(f"Change in H (Frobenius): {np.linalg.norm(H - H_prev, 'fro'):.6e}") + + if len(hist['params']) > 0 and ctx.param_dict['objective']: + Q_prev = hist['params'][0] + if len(Q_prev.shape) > 2: + Q_prev = Q_prev[b] + Q_prev_evals = np.linalg.eigvalsh(Q_prev) + print(f"Previous Q eigenvalues: {np.sort(Q_prev_evals)}") + print(f"Change in Q (Frobenius): {np.linalg.norm(Q - Q_prev, 'fro'):.6e}") + + print("-"*80) + + print("="*80 + "\n") + raise e + # Solve Differential KKT System if M.shape[0] == M.shape[1]: # Symmetric case @@ -635,9 +770,74 @@ def backward(ctx, grad_output): sol = ls_sol[0][:, None] res = ls_sol[3] # Check that we have actually solved the differential KKT system - assert res < ctx.kkt_tol, ValueError( - "Differential KKT system residual is high. Make sure that redundant constraints are actually redundant and that the certificate matrix is correct." - ) + try: + assert res < ctx.kkt_tol, ValueError( + "Differential KKT system residual is high. Make sure that redundant constraints are actually redundant and that the certificate matrix is correct." + ) + except AssertionError as e: + # Print diagnostic information + print("\n" + "="*80) + print("DIFFERENTIAL KKT SYSTEM SOLVE ISSUE DETECTED") + print("="*80) + + # Certificate Matrix H diagnostics + print("\n--- Certificate Matrix H ---") + print(f"H shape: {H.shape}") + H_evals = np.linalg.eigvalsh(H) + print(f"H eigenvalues (sorted): {np.sort(H_evals)}") + + # Objective Matrix Q diagnostics + print("\n--- Objective Matrix Q ---") + if len(ctx.objective.shape) > 2: + Q = ctx.objective[b] + else: + Q = ctx.objective + print(f"Q shape: {Q.shape}") + Q_evals = np.linalg.eigvalsh(Q) + print(f"Q eigenvalues (sorted): {np.sort(Q_evals)}") + + # Solution vector + print("\n--- Solution Vector x ---") + print(f"x shape: {x.shape}") + print(f"x: {x.flatten()}") + print(f"||x||: {np.linalg.norm(x)}") + + print(f"\nKKT Solve Residual: {res} (tolerance: {ctx.kkt_tol})") + + # Print previous iteration only + if len(_QCQP_HISTORY_BUFFER) >= 2: + hist = _QCQP_HISTORY_BUFFER[-2] + print("\n" + "-"*80) + print("PREVIOUS ITERATION") + print("-"*80) + + if b < len(hist['xs']): + x_prev = hist['xs'][b] + print(f"Previous x: {x_prev.flatten()}") + print(f"Previous ||x||: {np.linalg.norm(x_prev):.6e}") + print(f"Change in x: {np.linalg.norm(x - x_prev):.6e}") + + if b < len(hist['Hs']): + H_prev = hist['Hs'][b] + if isinstance(H_prev, np.ndarray): + H_prev_evals = np.linalg.eigvalsh(H_prev) + print(f"Previous H eigenvalues: {np.sort(H_prev_evals)}") + print(f"Previous H rank: {np.linalg.matrix_rank(H_prev, tol=1e-10)}") + print(f"Previous H @ x_prev norm: {np.linalg.norm(H_prev @ x_prev):.6e}") + print(f"Change in H (Frobenius): {np.linalg.norm(H - H_prev, 'fro'):.6e}") + + if len(hist['params']) > 0 and ctx.param_dict['objective']: + Q_prev = hist['params'][0] + if len(Q_prev.shape) > 2: + Q_prev = Q_prev[b] + Q_prev_evals = np.linalg.eigvalsh(Q_prev) + print(f"Previous Q eigenvalues: {np.sort(Q_prev_evals)}") + print(f"Change in Q (Frobenius): {np.linalg.norm(Q - Q_prev, 'fro'):.6e}") + + print("-"*80) + + print("="*80 + "\n") + raise e dy_bar = sol dy_bar_1 = dy_bar[:nvars, :] # Fill with zeros at redundant entries