diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs index 0f7d4c164c..0497cd76b4 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -4180,7 +4180,7 @@ private SqlDataReader TryFetchInputParameterEncryptionInfo(int timeout, { // In BatchRPCMode, the actual T-SQL query is in the first parameter and not present as the rpcName, as is the case with non-BatchRPCMode. // So input parameters start at parameters[1]. parameters[0] is the actual T-SQL Statement. rpcName is sp_executesql. - if (_RPCList[i].systemParams.Length > 1) + if (_RPCList[i].systemParams != null && _RPCList[i].systemParams.Length > 1) { _RPCList[i].needsFetchParameterEncryptionMetadata = true; diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs index 96ce941f39..5418fc22f2 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -4304,7 +4304,7 @@ private SqlDataReader TryFetchInputParameterEncryptionInfo(int timeout, { // In _batchRPCMode, the actual T-SQL query is in the first parameter and not present as the rpcName, as is the case with non-_batchRPCMode. // So input parameters start at parameters[1]. parameters[0] is the actual T-SQL Statement. rpcName is sp_executesql. - if (_RPCList[i].systemParams.Length > 1) + if (_RPCList[i].systemParams != null && _RPCList[i].systemParams.Length > 1) { _RPCList[i].needsFetchParameterEncryptionMetadata = true; diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/SqlDataAdapterBatchUpdateTests.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/SqlDataAdapterBatchUpdateTests.cs new file mode 100644 index 0000000000..063bdd3a45 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/SqlDataAdapterBatchUpdateTests.cs @@ -0,0 +1,247 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Data; +using System.Threading.Tasks; +using System.Collections.Generic; +using Microsoft.Data.SqlClient; +using Microsoft.Data.SqlClient.ManualTesting.Tests.AlwaysEncrypted.Setup; +using Xunit; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests.AlwaysEncrypted +{ + public sealed class SqlDataAdapterBatchUpdateTests : IClassFixture, IDisposable + { + private readonly SQLSetupStrategy _fixture; + private readonly string _tableName; + private readonly BuyerSellerTable _buyerSellerTable; + + public SqlDataAdapterBatchUpdateTests(SQLSetupStrategyCertStoreProvider context) + { + _fixture = context; + _buyerSellerTable = _fixture.BuyerSellerTable as BuyerSellerTable; + _tableName = _fixture.BuyerSellerTable.Name; + } + + // ---------- TESTS ---------- + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsTargetReadyForAeWithKeyStore))] + [ClassData(typeof(AEConnectionStringProvider))] + public async Task AdapterUpdate_BatchSizeGreaterThanOne_Succeeds(string connectionString) + { + // Arrange + TruncateTable(connectionString); + int idBase = GetUniqueIdBase(); + PopulateTable(new (int id, string s1, string s2)[] { + (idBase + 10, "123-45-6789", "987-65-4321"), + (idBase + 20, "234-56-7890", "876-54-3210"), + (idBase + 30, "345-67-8901", "765-43-2109"), + (idBase + 40, "456-78-9012", "654-32-1098"), + }, connectionString); + + using var conn = new SqlConnection(GetConnectionString(connectionString, encryptionEnabled: true)); + await conn.OpenAsync(); + + using var adapter = CreateAdapter(conn, updateBatchSize: 10); + var dataTable = BuildBuyerSellerDataTable(); + LoadCurrentRowsIntoDataTable(dataTable, conn); + + MutateForUpdate(dataTable); + + // Act - With batch updates (UpdateBatchSize > 1), this previously threw NullReferenceException due to null systemParams in batch RPC mode + var updated = await Task.Run(() => adapter.Update(dataTable)); + + // Assert + Assert.Equal(dataTable.Rows.Count, updated); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsTargetReadyForAeWithKeyStore))] + [ClassData(typeof(AEConnectionStringProvider))] + public async Task AdapterUpdate_BatchSizeOne_Succeeds(string connectionString) + { + // Arrange + TruncateTable(connectionString); + int idBase = GetUniqueIdBase(); + PopulateTable(new (int id, string s1, string s2)[] { + (idBase + 100, "123-45-6789", "987-65-4321"), + (idBase + 200, "234-56-7890", "876-54-3210"), + (idBase + 300, "345-67-8901", "765-43-2109"), + (idBase + 400, "456-78-9012", "654-32-1098"), + }, connectionString); + + using var conn = new SqlConnection(GetConnectionString(connectionString, encryptionEnabled: true)); + await conn.OpenAsync(); + + using var adapter = CreateAdapter(conn, updateBatchSize: 1); // success path + var dataTable = BuildBuyerSellerDataTable(); + LoadCurrentRowsIntoDataTable(dataTable, conn); + + MutateForUpdate(dataTable); + + // Act + var updatedRows = await Task.Run(() => adapter.Update(dataTable)); + + // Assert + Assert.Equal(dataTable.Rows.Count, updatedRows); + } + + // ---------- HELPERS ---------- + + private int GetUniqueIdBase() => Math.Abs(Guid.NewGuid().GetHashCode()) % 1000000; + + private SqlDataAdapter CreateAdapter(SqlConnection connection, int updateBatchSize) + { + var insertCmd = new SqlCommand(_buyerSellerTable.InsertProcedureName, connection) + { + CommandType = CommandType.StoredProcedure + }; + insertCmd.Parameters.AddRange(new[] + { + new SqlParameter("@BuyerSellerID", SqlDbType.Int) { SourceColumn = "BuyerSellerID" }, + new SqlParameter("@SSN1", SqlDbType.VarChar, 255) { SourceColumn = "SSN1" }, + new SqlParameter("@SSN2", SqlDbType.VarChar, 255) { SourceColumn = "SSN2" }, + }); + insertCmd.UpdatedRowSource = UpdateRowSource.None; + + var updateCmd = new SqlCommand(_buyerSellerTable.UpdateProcedureName, connection) + { + CommandType = CommandType.StoredProcedure + }; + updateCmd.Parameters.AddRange(new[] + { + new SqlParameter("@BuyerSellerID", SqlDbType.Int) { SourceColumn = "BuyerSellerID" }, + new SqlParameter("@SSN1", SqlDbType.VarChar, 255) { SourceColumn = "SSN1" }, + new SqlParameter("@SSN2", SqlDbType.VarChar, 255) { SourceColumn = "SSN2" }, + }); + updateCmd.UpdatedRowSource = UpdateRowSource.None; + + return new SqlDataAdapter + { + InsertCommand = insertCmd, + UpdateCommand = updateCmd, + UpdateBatchSize = updateBatchSize + }; + } + + private DataTable BuildBuyerSellerDataTable() + { + var dt = new DataTable(_tableName); + dt.Columns.AddRange(new[] + { + new DataColumn("BuyerSellerID", typeof(int)), + new DataColumn("SSN1", typeof(string)), + new DataColumn("SSN2", typeof(string)), + }); + dt.PrimaryKey = new[] { dt.Columns["BuyerSellerID"] }; + return dt; + } + + private void LoadCurrentRowsIntoDataTable(DataTable dt, SqlConnection conn) + { + using var cmd = new SqlCommand($"SELECT BuyerSellerID, SSN1, SSN2 FROM [dbo].[{_tableName}] ORDER BY BuyerSellerID", conn); + using var reader = cmd.ExecuteReader(); + while (reader.Read()) + { + dt.Rows.Add(reader.GetInt32(0), reader.GetString(1), reader.GetString(2)); + } + } + + private void MutateForUpdate(DataTable dt) + { + int i = 0; + var fixedTime = new DateTime(2023, 01, 01, 12, 34, 56); + string timeStr = fixedTime.ToString("HHmm"); + foreach (DataRow row in dt.Rows) + { + i++; + row["SSN1"] = $"{i:000}-11-{timeStr}"; + row["SSN2"] = $"{i:000}-22-{timeStr}"; + } + } + + private void TruncateTable(string connectionString) + { + using var connection = new SqlConnection(GetConnectionString(connectionString, encryptionEnabled: true)); + connection.Open(); + ExecuteQuery(connection, $"DELETE FROM [dbo].[{_tableName}]"); + } + + private void ExecuteQuery(SqlConnection connection, string commandText) + { + using var cmd = new SqlCommand( + commandText, + connection: connection, + transaction: null, + columnEncryptionSetting: SqlCommandColumnEncryptionSetting.Enabled); + cmd.ExecuteNonQuery(); + } + + private void PopulateTable((int id, string s1, string s2)[] rows, string connectionString) + { + using var connection = new SqlConnection(GetConnectionString(connectionString, encryptionEnabled: true)); + connection.Open(); + + foreach (var (id, s1, s2) in rows) + { + using var cmd = new SqlCommand( + $"INSERT INTO [dbo].[{_tableName}] (BuyerSellerID, SSN1, SSN2) VALUES (@id, @s1, @s2)", + connection, + null, + SqlCommandColumnEncryptionSetting.Enabled); + + cmd.Parameters.Add(new SqlParameter("@id", SqlDbType.Int) { Value = id }); + cmd.Parameters.Add(new SqlParameter("@s1", SqlDbType.VarChar, 255) { Value = s1 }); + cmd.Parameters.Add(new SqlParameter("@s2", SqlDbType.VarChar, 255) { Value = s2 }); + + cmd.ExecuteNonQuery(); + } + } + + private string GetConnectionString(string baseConnectionString, bool encryptionEnabled) + { + var builder = new SqlConnectionStringBuilder(baseConnectionString) + { + ColumnEncryptionSetting = encryptionEnabled + ? SqlConnectionColumnEncryptionSetting.Enabled + : SqlConnectionColumnEncryptionSetting.Disabled + }; + return builder.ToString(); + } + + private void SilentRunCommand(string commandText, SqlConnection connection) + { + try + { + ExecuteQuery(connection, commandText); + } + catch (SqlException ex) + { + bool onlyObjectNotExist = true; + foreach (SqlError err in ex.Errors) + { + if (err.Number != 208) + { + onlyObjectNotExist = false; + break; + } + } + if (!onlyObjectNotExist) + { + Console.WriteLine($"SilentRunCommand: Unexpected SqlException during cleanup: {ex}"); + } + } + } + + public void Dispose() + { + foreach (string connectionString in DataTestUtility.AEConnStringsSetup) + { + using var connection = new SqlConnection(GetConnectionString(connectionString, encryptionEnabled: true)); + connection.Open(); + ExecuteQuery(connection, $"DELETE FROM [dbo].[{_tableName}]"); + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/SQLSetupStrategy.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/SQLSetupStrategy.cs index d08d2a86be..6beed55e74 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/SQLSetupStrategy.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/SQLSetupStrategy.cs @@ -19,6 +19,8 @@ public class SQLSetupStrategy : ColumnMasterKeyCertificateFixture public string ColumnMasterKeyPath { get; } public Table ApiTestTable { get; private set; } + public Table BuyerSellerTable { get; private set; } + public Table BulkCopyAEErrorMessageTestTable { get; private set; } public Table BulkCopyAETestTable { get; private set; } public Table ColumnDecryptErrorTestTable { get; private set; } @@ -133,6 +135,9 @@ protected List CreateTables(IList columnEncryptionKe ApiTestTable = new ApiTestTable(GenerateUniqueName("ApiTestTable"), columnEncryptionKeys[0], columnEncryptionKeys[1]); tables.Add(ApiTestTable); + BuyerSellerTable = new BuyerSellerTable(GenerateUniqueName("BuyerSellerTable"), columnEncryptionKeys[0], columnEncryptionKeys[1]); + tables.Add(BuyerSellerTable); + BulkCopyAEErrorMessageTestTable = new BulkCopyAEErrorMessageTestTable(GenerateUniqueName("BulkCopyAEErrorMessageTestTable"), columnEncryptionKeys[0], columnEncryptionKeys[1]); tables.Add(BulkCopyAEErrorMessageTestTable); diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/Setup/BuyerSellerTable.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/Setup/BuyerSellerTable.cs new file mode 100644 index 0000000000..a6f1d5ac7f --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/Setup/BuyerSellerTable.cs @@ -0,0 +1,107 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Data; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests.AlwaysEncrypted.Setup +{ + public class BuyerSellerTable : Table + { + private const string ColumnEncryptionAlgorithmName = @"AEAD_AES_256_CBC_HMAC_SHA_256"; + private readonly ColumnEncryptionKey _columnEncryptionKey1; + private readonly ColumnEncryptionKey _columnEncryptionKey2; + + // ✅ ADD: Unique stored procedure names based on table name + public string InsertProcedureName => $"InsertBuyerSeller_{Name}"; + public string UpdateProcedureName => $"UpdateBuyerSeller_{Name}"; + + public BuyerSellerTable(string tableName, ColumnEncryptionKey columnEncryptionKey1, ColumnEncryptionKey columnEncryptionKey2) + : base(tableName) + { + _columnEncryptionKey1 = columnEncryptionKey1; + _columnEncryptionKey2 = columnEncryptionKey2; + } + + public override void Create(SqlConnection sqlConnection) + { + // Create the table with encrypted columns + string createTableSql = $@" + CREATE TABLE [dbo].[{Name}] + ( + [BuyerSellerID] [int] NOT NULL PRIMARY KEY, + [SSN1] [varchar](255) COLLATE Latin1_General_BIN2 ENCRYPTED WITH ( + COLUMN_ENCRYPTION_KEY = [{_columnEncryptionKey1.Name}], + ENCRYPTION_TYPE = DETERMINISTIC, + ALGORITHM = '{ColumnEncryptionAlgorithmName}' + ), + [SSN2] [varchar](255) COLLATE Latin1_General_BIN2 ENCRYPTED WITH ( + COLUMN_ENCRYPTION_KEY = [{_columnEncryptionKey2.Name}], + ENCRYPTION_TYPE = DETERMINISTIC, + ALGORITHM = '{ColumnEncryptionAlgorithmName}' + ) + )"; + + using (SqlCommand command = sqlConnection.CreateCommand()) + { + command.CommandText = createTableSql; + command.ExecuteNonQuery(); + } + + // ✅ CHANGED: Use unique SP names + string createInsertProcSql = $@" + CREATE PROCEDURE [dbo].[{InsertProcedureName}] + @BuyerSellerID int, + @SSN1 varchar(255), + @SSN2 varchar(255) + AS + BEGIN + INSERT INTO [dbo].[{Name}] (BuyerSellerID, SSN1, SSN2) + VALUES (@BuyerSellerID, @SSN1, @SSN2) + END"; + + using (SqlCommand command = sqlConnection.CreateCommand()) + { + command.CommandText = createInsertProcSql; + command.ExecuteNonQuery(); + } + + // ✅ CHANGED: Use unique SP names + string createUpdateProcSql = $@" + CREATE PROCEDURE [dbo].[{UpdateProcedureName}] + @BuyerSellerID int, + @SSN1 varchar(255), + @SSN2 varchar(255) + AS + BEGIN + UPDATE [dbo].[{Name}] + SET SSN1 = @SSN1, SSN2 = @SSN2 + WHERE BuyerSellerID = @BuyerSellerID + END"; + + using (SqlCommand command = sqlConnection.CreateCommand()) + { + command.CommandText = createUpdateProcSql; + command.ExecuteNonQuery(); + } + } + + public override void Drop(SqlConnection sqlConnection) + { + using (SqlCommand command = sqlConnection.CreateCommand()) + { + command.CommandText = $"IF OBJECT_ID('[dbo].[{InsertProcedureName}]', 'P') IS NOT NULL DROP PROCEDURE [dbo].[{InsertProcedureName}]"; + command.ExecuteNonQuery(); + } + + using (SqlCommand command = sqlConnection.CreateCommand()) + { + command.CommandText = $"IF OBJECT_ID('[dbo].[{UpdateProcedureName}]', 'P') IS NOT NULL DROP PROCEDURE [dbo].[{UpdateProcedureName}]"; + command.ExecuteNonQuery(); + } + + // Drop table + base.Drop(sqlConnection); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index 627d123834..33ae4fd3d0 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -34,6 +34,7 @@ + @@ -53,6 +54,7 @@ + @@ -387,4 +389,7 @@ xunit.runner.json + + +